diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index 3be6c0f7a..a6a5d3309 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -80,9 +80,16 @@ public: static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); static CUresult cuModuleLoad(CUmodule *module, const char *fname); + static CUresult cuModuleLoadData(CUmodule* module, const void* image); static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); static CUresult cuModuleUnload(CUmodule hmod); static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); + + static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); + static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); + static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); + static CUresult cuLinkDestroy(CUlinkState state); + static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); static CUresult cuDeviceGetCount(int *count); static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); @@ -146,6 +153,11 @@ private: static void* cuLaunchKernel_; static void* cuModuleUnload_; static void* cuModuleLoadDataEx_; + static void* cuLinkAddData_v2_; + static void* cuLinkCreate_v2_; + static void* cuLinkDestroy_; + static void* cuModuleLoadData_; + static void* cuLinkComplete_; static void* cuDeviceGetAttribute_; static void* cuDeviceGetCount_; static void* cuMemcpyHtoD_v2_; diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index b31cf6f8a..ddfdef380 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -68,9 +68,11 @@ public: std::unique_ptr symbol(const char * name) const; std::string llir() const { return llir_; } const std::string& ptx() const { return ptx_; } + const std::string& cubin() const { return cubin_; } private: std::string ptx_; + std::string cubin_; std::string llir_; }; diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index df6f14ddb..a6e93db1a 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -140,11 +140,16 @@ CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *) CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice) CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice) CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*) +CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**); +CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*); +CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState); +CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*); CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream) CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *) CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **) CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule) +CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *) CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **) CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice) CUDA_DEFINE1(CUresult, cuDeviceGetCount, int *) @@ -211,6 +216,12 @@ void* dispatch::cuDeviceGetName_; void* dispatch::cuDeviceGetPCIBusId_; void* dispatch::cuModuleGetGlobal_v2_; +void* dispatch::cuLinkAddData_v2_; +void* dispatch::cuLinkCreate_v2_; +void* dispatch::cuLinkDestroy_; +void* dispatch::cuModuleLoadData_; +void* dispatch::cuLinkComplete_; + void* dispatch::cuMemcpyHtoDAsync_v2_; void* dispatch::cuModuleLoad_; void* dispatch::cuLaunchKernel_; diff --git a/lib/driver/module.cc b/lib/driver/module.cc index ed984ad43..b1a054e85 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -319,24 +319,15 @@ void cu_module::init_from_ptx(const std::string& ptx) { // std::cout << log << std::endl; // std::cout << ptx_ << std::endl; - CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, - CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, - CU_JIT_LOG_VERBOSE}; - unsigned int errbufsize = 8192; - unsigned int logbufsize = 8192; - char _err[errbufsize]; - char _log[logbufsize]; - void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; - dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval); - std::string err(_err); - std::string log(_log); -// std::smatch match; -// std::regex expr ("\\b([0-9]+) bytes spill"); -// spilled_ = 0; -// while (std::regex_search(log,match,expr)){ -// spilled_ += std::stoi(match[1]); -// log = match.suffix(); -// } + CUlinkState link_state; + dispatch::cuLinkCreate_v2(0, 0, 0, &link_state); + dispatch::cuLinkAddData_v2(link_state, CU_JIT_INPUT_PTX, (void*)ptx_.data(), ptx_.size(), 0, 0, 0, 0); + size_t cubin_size; + void *cubin; + dispatch::cuLinkComplete(link_state, &cubin, &cubin_size); + dispatch::cuModuleLoadData(&*cu_, cubin); + cubin_ = std::string((const char*)cubin, cubin_size); + dispatch::cuLinkDestroy(link_state); } catch(exception::cuda::invalid_ptx const &){ //#ifdef TRITON_LOG_PTX_ERROR diff --git a/python/setup.py b/python/setup.py index 587e2dce9..a11b81ebf 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,7 +96,7 @@ setup( author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", long_description="", - packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"], + packages=["triton", "triton/_C", "triton/tools", "triton/ops", "triton/ops/blocksparse"], install_requires=["numpy", "torch"], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, diff --git a/python/src/triton.cc b/python/src/triton.cc index aa551a94b..45d0e2704 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -63,6 +63,7 @@ void init_triton_driver(py::module &&m) { py::class_(m, "cu_module") .def("ptx", &drv::cu_module::ptx) + .def("cubin", [](drv::cu_module *self) { return py::bytes(self->cubin()); }) .def("llir", &drv::cu_module::llir); py::class_(m, "kernel"); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 04df5d842..9bdec6129 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -5,6 +5,8 @@ import types import torch import ast import builtins +import tempfile +from .tools.disasm import extract import triton._C.libtriton.triton as _triton import triton import sys @@ -413,12 +415,24 @@ class Binary: self.kernel = kernel self.shared_mem = shared_mem self.num_warps = num_warps + self.sass = None def asm(self, mode): if mode == 'ttir': return self.ir_asm if mode == 'ptx': return self.module.ptx() + if mode == 'sass': + if self.sass is None: + cubin = self.module.cubin() + # get a temporary file name + fd, path = tempfile.mkstemp(suffix='.cubin') + f = open(path, 'wb') + f.write(cubin) + f.close() + # extract SASS from cubin + self.sass = extract(path, None) + return self.sass if mode == 'llir': return self.module.llir() raise ValueError('Unsupported mode ' + mode) @@ -711,6 +725,7 @@ def cdiv(x, y): ###### + class TensorWrapper: def __init__(self, data_ptr, dtype, device): self._data_ptr = data_ptr diff --git a/python/triton/tools/__init__.py b/python/triton/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py new file mode 100644 index 000000000..fbbfa6d0b --- /dev/null +++ b/python/triton/tools/disasm.py @@ -0,0 +1,123 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import argparse +import subprocess +import re + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) != None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +def extract(file_path, fun): + if fun == None: + sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) + else: + sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + fname_match = FNAME_RE.match(line) + # Looking for new function header (function: ) + while FNAME_RE.match(line) == None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) != None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convension: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret