diff --git a/include/triton/driver/llvm.h b/include/triton/driver/llvm.h index 70e9069bd..89dc98169 100644 --- a/include/triton/driver/llvm.h +++ b/include/triton/driver/llvm.h @@ -10,6 +10,7 @@ namespace driver{ void init_llvm(); std::string llir_to_ptx(llvm::Module* module, int cc, int version); +std::string ptx_to_cubin(const std::string& ptx, int cc); CUmodule ptx_to_cumodule(const std::string& ptx, int cc); std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc); hipModule_t amdgpu_to_hipmodule(const std::string& path); diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index d99e9383c..f3c76ce77 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -157,6 +157,39 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ return result; } +std::string ptx_to_cubin(const std::string& ptx, int cc) { + std::string ptxas = "ptxas"; + std::string version; + int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; + if(!use_system_ptxas) + return ""; + + // compile ptx with ptxas + char _fsrc[] = "/tmp/triton_k_XXXXXX"; + char _flog[] = "/tmp/triton_l_XXXXXX"; + mkstemp(_fsrc); + mkstemp(_flog); + std::string fsrc = _fsrc; + std::string flog = _flog; + std::string fbin = fsrc + ".o"; + const char* _fbin = fbin.c_str(); + std::ofstream ofs(fsrc); + ofs << ptx; + ofs.close(); + std::string cmd; + int err; + cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; + err = system(cmd.c_str()); + CUmodule ret; + std::ifstream _cubin(_fbin, std::ios::binary ); + std::string cubin(std::istreambuf_iterator(_cubin), {}); + _cubin.close(); + dispatch::cuModuleLoadData(&ret, cubin.c_str()); + unlink(_fsrc); + unlink(_flog); + unlink(_fbin); + return cubin; +} CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { // JIT compile source-code @@ -175,6 +208,8 @@ CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { mkstemp(_flog); std::string fsrc = _fsrc; std::string flog = _flog; + std::string fbin = fsrc + ".o"; + const char* _fbin = fbin.c_str(); std::ofstream ofs(fsrc); ofs << ptx; ofs.close(); @@ -183,9 +218,13 @@ CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; err = system(cmd.c_str()); CUmodule ret; - dispatch::cuModuleLoad(&ret, (fsrc + ".o").c_str()); + std::ifstream _cubin(_fbin, std::ios::binary ); + std::string cubin(std::istreambuf_iterator(_cubin), {}); + _cubin.close(); + dispatch::cuModuleLoadData(&ret, cubin.c_str()); unlink(_fsrc); unlink(_flog); + unlink(_fbin); return ret; } diff --git a/python/setup.py b/python/setup.py index 2965f167b..52f324412 100644 --- a/python/setup.py +++ b/python/setup.py @@ -127,7 +127,7 @@ setup( description="A language and compiler for custom Deep Learning operations", long_description="", packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], - install_requires=["torch"], + install_requires=["torch", "filelock"], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], diff --git a/python/src/triton.cc b/python/src/triton.cc index 111170af0..0140a1362 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -148,19 +148,26 @@ void init_triton_runtime(py::module &&m) { /*****************************************************************************/ /* Python bindings for triton::codegen */ /*****************************************************************************/ -typedef std::map asm_map_t; +typedef std::map asm_map_t; +// --------------------------------------- +// Load provided assembly code into driver +// --------------------------------------- -std::tuple cu_compile_llir(const std::string& name, size_t n_shared_bytes, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map, int cc, int version){ - // LLVM-IR -> PTX - std::string ptx = drv::llir_to_ptx(llvm, cc, version); - asm_map["ptx"] = ptx; - // PTX -> Binary - CUmodule mod = drv::ptx_to_cumodule(ptx, cc); - // Handle to the kernel +// CUDA +std::tuple cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ + // load assembly + std::string assembly; + if(asm_map.find("cubin") != asm_map.end()) + assembly = py::cast(asm_map["cubin"]); + else + assembly = py::cast(asm_map["ptx"]); + // create driver handles CUfunction fun; + CUmodule mod; + drv::dispatch::cuModuleLoadData(&mod, assembly.c_str()); drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); - // Dynamic shared memory + // set dynamic shared memory if necessary int shared_optin; drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev); if(n_shared_bytes > 49152 && shared_optin > 49152){ @@ -173,16 +180,15 @@ std::tuple cu_compile_llir(const std::string& name, size_t n drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); } - - // record asm return std::make_tuple((uint64_t)mod, (uint64_t)fun); } -std::tuple hip_compile_llir(const std::string& name, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map){ - // LLVM-IR -> HSA-CO - std::string path = drv::llir_to_amdgpu(llvm, "gfx908"); +// ROCM +std::tuple hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ + py::bytes _assembly = asm_map["hsaco"]; + std::string assembly = py::cast(_assembly); // HSA-CO -> hipModule - hipModule_t mod = drv::amdgpu_to_hipmodule(path); + hipModule_t mod = drv::amdgpu_to_hipmodule(assembly); // Handle to the kernel hipFunction_t fun; drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); @@ -190,6 +196,63 @@ std::tuple hip_compile_llir(const std::string& name, llvm::M return std::make_tuple((uint64_t)mod, (uint64_t)fun); } +// --------------------------------------- +// Compile Triton-IR to assembly +// --------------------------------------- + +// CUDA +std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, + uint64_t device, int num_warps, int num_stages, + bool force_nc_cache, asm_map_t &asm_map){ + llvm::LLVMContext ctx; + // device properties + CUdevice dev = (CUdevice)device; + size_t major = cuGetInfo(dev); + size_t minor = cuGetInfo(dev); + size_t cc = major*10 + minor; + int version; + drv::dispatch::cuDriverGetVersion(&version); + // Triton-IR -> NVPTX LLVM-IR + triton::codegen::nvidia_cu_target target(cc); + int n_shared_bytes; + auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes); + std::string tmp; + llvm::raw_string_ostream llir(tmp); + llir << *llvm; + llir.flush(); + asm_map["llir"] = py::cast(tmp); + // LLVM-IR -> PTX + std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); + asm_map["ptx"] = py::cast(ptx); + // PTX -> Binary + std::string cubin = drv::ptx_to_cubin(ptx, cc); + if(!cubin.empty()){ + py::bytes bytes(cubin); + asm_map["cubin"] = bytes; + } + return std::make_tuple(name, asm_map, n_shared_bytes); +} + +// HIP +std::tuple hip_compile_ttir(const std::string& name, ir::module &ir, + uint64_t device, int num_warps, int num_stages, + bool force_nc_cache, asm_map_t &asm_map){ + llvm::LLVMContext ctx; + // Triton-IR -> NVPTX LLVM-IR + triton::codegen::amd_cl_target target; + int n_shared_bytes; + auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes); + std::string tmp; + llvm::raw_string_ostream llir(tmp); + llir << *llvm; + llir.flush(); + asm_map["llir"] = py::cast(tmp); + // LLVM-IR -> HSA-CO + std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908"); + asm_map["hsaco"] = py::cast(path); + return std::make_tuple(name, asm_map, n_shared_bytes); +} + void init_triton_codegen(py::module &&m) { m.def( "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) { @@ -198,43 +261,19 @@ void init_triton_codegen(py::module &&m) { asm_map_t asm_map; std::ostringstream ttir; ir::print(ir, ttir); - asm_map["ttir"] = ttir.str(); + asm_map["ttir"] = py::cast(ttir.str()); llvm::LLVMContext ctx; - if(backend == CUDA){ - // device properties - CUdevice dev = (CUdevice)device; - size_t major = cuGetInfo(dev); - size_t minor = cuGetInfo(dev); - size_t cc = major*10 + minor; - int version; - drv::dispatch::cuDriverGetVersion(&version); - // Triton-IR -> NVPTX LLVM-IR - triton::codegen::nvidia_cu_target target(cc); - int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes); - llvm::raw_string_ostream llir(asm_map["llir"]); - llir << *llvm; - llir.flush(); - // LLVM-IR -> Bin - uint64_t mod, fun; - std::tie(mod, fun) = cu_compile_llir(name, n_shared_bytes, &*llvm, device, asm_map, cc, version); - return std::make_tuple(mod, fun, asm_map, n_shared_bytes); - } - if(backend == ROCM){ - // Triton-IR -> NVPTX LLVM-IR - triton::codegen::amd_cl_target target; - int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes); - llvm::raw_string_ostream llir(asm_map["llir"]); - llir << *llvm; - llir.flush(); - // LLVM-IR -> Bin - uint64_t mod, fun; - std::tie(mod, fun) = hip_compile_llir(name, &*llvm, device, asm_map); - return std::make_tuple(mod, fun, asm_map, n_shared_bytes); - } - }, - py::return_value_policy::take_ownership); + if(backend == CUDA) + return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map); + if(backend == ROCM) + return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map); + }, py::return_value_policy::take_ownership); + m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ + if(backend == CUDA) + return cu_load_binary(name, asm_map, n_shared_bytes, dev); + if(backend == ROCM) + return hip_load_binary(name, asm_map, n_shared_bytes, dev); + }, py::return_value_policy::take_ownership); } /*****************************************************************************/ diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 33d1a53af..10561898f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -5,6 +5,11 @@ import struct import sys import tempfile import textwrap +import hashlib +import atexit +import os +import shelve +from filelock import FileLock import torch import triton @@ -411,23 +416,31 @@ class CodeGenerator(ast.NodeVisitor): class Binary: - def __init__(self, backend, module, kernel, asm, num_warps, num_stages, force_nc_cache, shared_mem): - # cache ir asm + def __init__(self, backend, name, asm, shared_mem, num_warps): + self.backend = backend + self.name = name self.asm = asm - self.module = module - self.kernel = kernel self.shared_mem = shared_mem self.num_warps = num_warps - self.num_stages = num_stages - self.force_nc_cache = force_nc_cache - self.sass = None - self.backend = backend + +class LoadedBinary: + def __init__(self, device: int, bin: Binary): + module, kernel = _triton.code_gen.load_binary(bin.backend, + bin.name, + bin.asm, + bin.shared_mem, + device) + self.bin = bin + self.asm = bin.asm + self.module = module + self.kernel = kernel + self.device = device def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): - _triton.runtime.enqueue(self.backend, stream, self.kernel, + _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, grid_0, grid_1, grid_2, - self.num_warps * 32, 1, 1, - args, self.shared_mem) + self.bin.num_warps * 32, 1, 1, + args, self.bin.shared_mem) class CompilationError(Exception): @@ -536,11 +549,11 @@ class Kernel: backend = _triton.runtime.backend.CUDA else: backend = _triton.runtime.backend.ROCM - mod, ker, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache) + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache) max_shared_memory = _triton.runtime.max_shared_memory(backend, device) if shared_mem > max_shared_memory: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") - return Binary(backend, mod, ker, asm, num_warps, num_stages, force_nc_cache, shared_mem) + return Binary(backend, name, asm, shared_mem, num_warps) def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): # device inference @@ -579,29 +592,43 @@ class Kernel: attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)} # transforms ints whose value is one into constants for just-in-time compilation constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} - # determine if we need to re-compile + # compute hash for caching this kernel types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) attr_key = frozenset(attributes.items()) meta_key = frozenset(meta.items()) const_key = frozenset(constants.items()) key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key) - cache = self.fn.cache - if key not in cache: - # compile and cache configuration if necessary - cache[key] = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, - constants=constants, **meta - ) + key = repr(key) + # get cached binary + drv_cache = self.fn.drv_cache + bin_cache_path = self.fn.bin_cache_path + bin_lock_path = self.fn.bin_lock_path + if key not in drv_cache: + binary = None + if bin_lock_path: + with FileLock(bin_lock_path): + with shelve.open(bin_cache_path) as db: + binary = db.get(key, None) + if binary is None: + binary = self._compile( + *wargs, device=device_idx, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, + constants=constants, **meta + ) + if bin_lock_path: + with FileLock(bin_lock_path): + with shelve.open(bin_cache_path) as db: + db[key] = binary + drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) params = struct.pack(fmt, *args) # enqueue cached function into stream - binary = cache[key] + callable = drv_cache[key] stream = torch.cuda.current_stream(device_idx).cuda_stream grid = grid(meta) if hasattr(grid, '__call__') else grid - binary(stream, params, *grid) - return binary + callable(stream, params, *grid) + return callable class Launcher: @@ -662,17 +689,59 @@ class Autotuner: class JITFunction: + + # clear cache if the db is older than either the frontend or the backend + def _clear_cache(self): + frontend_mtime = os.path.getmtime(triton.code_gen.__file__) + backend_mtime = os.path.getmtime(triton._C.libtriton.__file__) + with FileLock(self.bin_lock_path): + cache_mtime = os.path.getmtime(self.db_path) + if frontend_mtime > cache_mtime or backend_mtime > cache_mtime: + os.remove(self.db_path) + + def _init_cache_paths(self): + # fetch cache directory path + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if not cache_dir: + self.bin_cache_path = None + self.db_path = None + self.bin_lock_path = None + return + # create cache directory + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + # create md5 hash of src + md5 = hashlib.md5() + md5.update(self.src.encode('utf-8')) + md5_hash = md5.hexdigest() + # load dbm file in cache_dir for md5_hash + self.bin_cache_path = os.path.join(cache_dir, md5_hash) + self.db_path = self.bin_cache_path + '.db' + self.bin_lock_path = self.bin_cache_path + '.lock' + # if bin_cache_path exists + if os.path.exists(self.db_path): + self._clear_cache() + def __init__(self, fn): + # information of wrapped function self.fn = fn self.module = fn.__module__ self.arg_names = inspect.getfullargspec(fn).args - self.cache = dict() + self.src = textwrap.dedent(inspect.getsource(fn)) + # cache for callable driver objects (e.g. CUkernel) + self.drv_cache = dict() + # on-disk paths for the binary cache and corresponding + # file-lock + self._init_cache_paths() + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ self.kernel_decorators = [] - self.src = textwrap.dedent(inspect.getsource(fn)) self.kernel = None + # forward docs self.__doc__ = fn.__doc__ - # we do not parse in the constructor because + + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Some unit tests do this, for example. def parse(self): @@ -699,10 +768,16 @@ class JITFunction: raise e raise CompilationError(self.src, node, e) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + # - when kernel decorators change, cached kernel + # needs to be cleared def __setattr__(self, name, value): if name == 'kernel_decorators': self.kernel = None super(JITFunction, self).__setattr__(name, value) + if name == 'src': + self._init_cache_paths() def _init_kernel(self): if self.kernel is None: