[PYTHON] Added frontend to print sass using turingas disasm.py (#109)
This commit is contained in:
committed by
Philippe Tillet
parent
c91dd56a92
commit
288b4f7f58
@@ -80,9 +80,16 @@ public:
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
||||
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
||||
static CUresult cuModuleUnload(CUmodule hmod);
|
||||
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||
|
||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
||||
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
||||
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
||||
static CUresult cuLinkDestroy(CUlinkState state);
|
||||
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
static CUresult cuDeviceGetCount(int *count);
|
||||
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
|
||||
@@ -146,6 +153,11 @@ private:
|
||||
static void* cuLaunchKernel_;
|
||||
static void* cuModuleUnload_;
|
||||
static void* cuModuleLoadDataEx_;
|
||||
static void* cuLinkAddData_v2_;
|
||||
static void* cuLinkCreate_v2_;
|
||||
static void* cuLinkDestroy_;
|
||||
static void* cuModuleLoadData_;
|
||||
static void* cuLinkComplete_;
|
||||
static void* cuDeviceGetAttribute_;
|
||||
static void* cuDeviceGetCount_;
|
||||
static void* cuMemcpyHtoD_v2_;
|
||||
|
@@ -68,9 +68,11 @@ public:
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
std::string llir() const { return llir_; }
|
||||
const std::string& ptx() const { return ptx_; }
|
||||
const std::string& cubin() const { return cubin_; }
|
||||
|
||||
private:
|
||||
std::string ptx_;
|
||||
std::string cubin_;
|
||||
std::string llir_;
|
||||
};
|
||||
|
||||
|
@@ -140,11 +140,16 @@ CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice)
|
||||
CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*)
|
||||
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
||||
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
|
||||
|
||||
CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*);
|
||||
CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream)
|
||||
CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *)
|
||||
CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **)
|
||||
CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule)
|
||||
CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *)
|
||||
CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **)
|
||||
CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice)
|
||||
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *)
|
||||
@@ -211,6 +216,12 @@ void* dispatch::cuDeviceGetName_;
|
||||
void* dispatch::cuDeviceGetPCIBusId_;
|
||||
void* dispatch::cuModuleGetGlobal_v2_;
|
||||
|
||||
void* dispatch::cuLinkAddData_v2_;
|
||||
void* dispatch::cuLinkCreate_v2_;
|
||||
void* dispatch::cuLinkDestroy_;
|
||||
void* dispatch::cuModuleLoadData_;
|
||||
void* dispatch::cuLinkComplete_;
|
||||
|
||||
void* dispatch::cuMemcpyHtoDAsync_v2_;
|
||||
void* dispatch::cuModuleLoad_;
|
||||
void* dispatch::cuLaunchKernel_;
|
||||
|
@@ -319,24 +319,15 @@ void cu_module::init_from_ptx(const std::string& ptx) {
|
||||
// std::cout << log << std::endl;
|
||||
// std::cout << ptx_ << std::endl;
|
||||
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
|
||||
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
|
||||
CU_JIT_LOG_VERBOSE};
|
||||
unsigned int errbufsize = 8192;
|
||||
unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
|
||||
std::string err(_err);
|
||||
std::string log(_log);
|
||||
// std::smatch match;
|
||||
// std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
// spilled_ = 0;
|
||||
// while (std::regex_search(log,match,expr)){
|
||||
// spilled_ += std::stoi(match[1]);
|
||||
// log = match.suffix();
|
||||
// }
|
||||
CUlinkState link_state;
|
||||
dispatch::cuLinkCreate_v2(0, 0, 0, &link_state);
|
||||
dispatch::cuLinkAddData_v2(link_state, CU_JIT_INPUT_PTX, (void*)ptx_.data(), ptx_.size(), 0, 0, 0, 0);
|
||||
size_t cubin_size;
|
||||
void *cubin;
|
||||
dispatch::cuLinkComplete(link_state, &cubin, &cubin_size);
|
||||
dispatch::cuModuleLoadData(&*cu_, cubin);
|
||||
cubin_ = std::string((const char*)cubin, cubin_size);
|
||||
dispatch::cuLinkDestroy(link_state);
|
||||
}
|
||||
catch(exception::cuda::invalid_ptx const &){
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
|
@@ -96,7 +96,7 @@ setup(
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
||||
packages=["triton", "triton/_C", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||
install_requires=["numpy", "torch"],
|
||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||
include_package_data=True,
|
||||
|
@@ -63,6 +63,7 @@ void init_triton_driver(py::module &&m) {
|
||||
|
||||
py::class_<drv::cu_module, drv::module>(m, "cu_module")
|
||||
.def("ptx", &drv::cu_module::ptx)
|
||||
.def("cubin", [](drv::cu_module *self) { return py::bytes(self->cubin()); })
|
||||
.def("llir", &drv::cu_module::llir);
|
||||
|
||||
py::class_<drv::kernel>(m, "kernel");
|
||||
|
@@ -5,6 +5,8 @@ import types
|
||||
import torch
|
||||
import ast
|
||||
import builtins
|
||||
import tempfile
|
||||
from .tools.disasm import extract
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton
|
||||
import sys
|
||||
@@ -413,12 +415,24 @@ class Binary:
|
||||
self.kernel = kernel
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
self.sass = None
|
||||
|
||||
def asm(self, mode):
|
||||
if mode == 'ttir':
|
||||
return self.ir_asm
|
||||
if mode == 'ptx':
|
||||
return self.module.ptx()
|
||||
if mode == 'sass':
|
||||
if self.sass is None:
|
||||
cubin = self.module.cubin()
|
||||
# get a temporary file name
|
||||
fd, path = tempfile.mkstemp(suffix='.cubin')
|
||||
f = open(path, 'wb')
|
||||
f.write(cubin)
|
||||
f.close()
|
||||
# extract SASS from cubin
|
||||
self.sass = extract(path, None)
|
||||
return self.sass
|
||||
if mode == 'llir':
|
||||
return self.module.llir()
|
||||
raise ValueError('Unsupported mode ' + mode)
|
||||
@@ -711,6 +725,7 @@ def cdiv(x, y):
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, data_ptr, dtype, device):
|
||||
self._data_ptr = data_ptr
|
||||
|
0
python/triton/tools/__init__.py
Normal file
0
python/triton/tools/__init__.py
Normal file
123
python/triton/tools/disasm.py
Normal file
123
python/triton/tools/disasm.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2020 Da Yan @ HKUST
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
||||
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
||||
FNAME_RE = re.compile(r'\s*Function : (\w+)\s*')
|
||||
BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);')
|
||||
|
||||
|
||||
def parseCtrl(sline):
|
||||
enc = int(SLINE_RE.match(sline).group(1), 16)
|
||||
stall = (enc >> 41) & 0xf
|
||||
yld = (enc >> 45) & 0x1
|
||||
wrtdb = (enc >> 46) & 0x7
|
||||
readb = (enc >> 49) & 0x7
|
||||
watdb = (enc >> 52) & 0x3f
|
||||
|
||||
yld_str = 'Y' if yld == 0 else '-'
|
||||
wrtdb_str = '-' if wrtdb == 7 else str(wrtdb)
|
||||
readb_str = '-' if readb == 7 else str(readb)
|
||||
watdb_str = '--' if watdb == 0 else f'{watdb:02d}'
|
||||
return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}'
|
||||
|
||||
|
||||
def processSassLines(fline, sline, labels):
|
||||
asm = FLINE_RE.match(fline).group(1)
|
||||
# Remove tailing space
|
||||
if asm.endswith(" ;"):
|
||||
asm = asm[:-2] + ";"
|
||||
ctrl = parseCtrl(sline)
|
||||
# BRA target address
|
||||
if BRA_RE.match(asm) != None:
|
||||
target = int(BRA_RE.match(asm).group(2), 16)
|
||||
if target in labels:
|
||||
pass
|
||||
else:
|
||||
labels[target] = len(labels)
|
||||
return (f'{ctrl}', f'{asm}')
|
||||
|
||||
|
||||
def extract(file_path, fun):
|
||||
if fun == None:
|
||||
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
|
||||
else:
|
||||
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
|
||||
sass_lines = sass_str.splitlines()
|
||||
line_idx = 0
|
||||
while line_idx < len(sass_lines):
|
||||
line = sass_lines[line_idx].decode()
|
||||
# format:
|
||||
# function : <function_name>
|
||||
# .headerflags: ...
|
||||
# /*0000*/ asmstr /*0x...*/
|
||||
# /*0x...*/
|
||||
fname_match = FNAME_RE.match(line)
|
||||
# Looking for new function header (function: <name>)
|
||||
while FNAME_RE.match(line) == None:
|
||||
line_idx += 1
|
||||
if line_idx < len(sass_lines):
|
||||
line = sass_lines[line_idx].decode()
|
||||
else:
|
||||
return
|
||||
|
||||
fname = FNAME_RE.match(line).group(1)
|
||||
ret = ''
|
||||
ret += f'Function:{fname}\n'
|
||||
line_idx += 2 # bypass .headerflags
|
||||
line = sass_lines[line_idx].decode()
|
||||
# Remapping address to label
|
||||
labels = {} # address -> label_idx
|
||||
# store sass asm in buffer and them print them (for labels)
|
||||
# (ctrl, asm)
|
||||
asm_buffer = []
|
||||
while FLINE_RE.match(line) != None:
|
||||
# First line (Offset ASM Encoding)
|
||||
fline = sass_lines[line_idx].decode()
|
||||
line_idx += 1
|
||||
# Second line (Encoding)
|
||||
sline = sass_lines[line_idx].decode()
|
||||
line_idx += 1
|
||||
asm_buffer.append(processSassLines(fline, sline, labels))
|
||||
# peek the next line
|
||||
line = sass_lines[line_idx].decode()
|
||||
# Print sass
|
||||
# label naming convension: LBB#i
|
||||
for idx, (ctrl, asm) in enumerate(asm_buffer):
|
||||
# Print label if this is BRA target
|
||||
offset = idx * 16
|
||||
if offset in labels:
|
||||
label_name = f'LBB{labels[offset]}'
|
||||
ret += f'{label_name}:\n'
|
||||
ret += ctrl + '\t'
|
||||
# if this is BRA, remap offset to label
|
||||
if BRA_RE.match(asm):
|
||||
target = int(BRA_RE.match(asm).group(2), 16)
|
||||
target_name = f'LBB{labels[target]}'
|
||||
asm = BRA_RE.sub(rf'\1{target_name};', asm)
|
||||
ret += asm + '\n'
|
||||
ret += '\n'
|
||||
return ret
|
Reference in New Issue
Block a user