From 8c3d4d57495108e64ece8e1d26469b1e61237457 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 22 Sep 2022 16:44:22 -0700 Subject: [PATCH] [RUNTIME] now decoupling entry point from cubin (#696) --- python/src/triton.cc | 214 +++++++++++------------------- python/triton/compiler.py | 243 ++++++++++++++++------------------- python/triton/runtime/jit.py | 4 +- 3 files changed, 188 insertions(+), 273 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 8bfb076c3..31bc0445f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -430,150 +430,90 @@ void init_triton_runtime(py::module &&m) { /*****************************************************************************/ typedef std::map asm_map_t; -// --------------------------------------- -// Load provided assembly code into driver -// --------------------------------------- - -// CUDA -std::tuple cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ - // load assembly - std::string assembly; - if(asm_map.find("cubin") != asm_map.end()) - assembly = py::cast(asm_map["cubin"]); - else - assembly = py::cast(asm_map["ptx"]); - // create driver handles - CUfunction fun; - CUmodule mod; - drv::dispatch::cuModuleLoadData(&mod, assembly.c_str()); - drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); - // get allocated registers and spilled registers from the function - int n_regs = 0; - int n_spills = 0; - drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); - drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); - n_spills /= 4; - // set dynamic shared memory if necessary - int shared_optin; - drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev); - if(n_shared_bytes > 49152 && shared_optin > 49152){ - drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); - int shared_total, shared_static; - drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev); - drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); - drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); - } - return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills); -} - -// ROCM -std::tuple hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ - py::bytes _assembly = asm_map["hsaco"]; - std::string assembly = py::cast(_assembly); - // HSA-CO -> hipModule - hipModule_t mod = drv::amdgpu_to_hipmodule(assembly); - // Handle to the kernel - hipFunction_t fun; - drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); - // record asm - return std::make_tuple((uint64_t)mod, (uint64_t)fun, 0, 0); -} - // --------------------------------------- // Compile Triton-IR to assembly // --------------------------------------- -// CUDA -std::tuple cu_compile_ttir( - const std::string &name, ir::module &ir, uint64_t device, int num_warps, - int num_stages, asm_map_t &asm_map, - const triton::codegen::ExternLibMap &extern_lib_map) { - py::gil_scoped_release allow_threads; - llvm::LLVMContext ctx; - // device properties - CUdevice dev = (CUdevice)device; - size_t major = cuGetInfo(dev); - size_t minor = cuGetInfo(dev); - size_t cc = major*10 + minor; - int version; - std::string ptxas_path = drv::path_to_ptxas(version); - // Triton-IR -> NVPTX LLVM-IR - triton::codegen::nvidia_cu_target target(cc); - int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin( - ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); - std::string tmp; - llvm::raw_string_ostream llir(tmp); - llir << *llvm; - llir.flush(); - asm_map["llir"] = py::cast(tmp); - // LLVM-IR -> PTX - std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); - asm_map["ptx"] = py::cast(ptx); - // PTX -> Binary - std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); - if(!cubin.empty()){ - py::bytes bytes(cubin); - asm_map["cubin"] = bytes; - } - return std::make_tuple(name, asm_map, n_shared_bytes); -} - -// HIP -std::tuple hip_compile_ttir( - const std::string &name, ir::module &ir, uint64_t device, int num_warps, - int num_stages, asm_map_t &asm_map, - const triton::codegen::ExternLibMap &extern_lib_map) { - llvm::LLVMContext ctx; - // Triton-IR -> NVPTX LLVM-IR - triton::codegen::amd_cl_target target; - int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin( - ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); - std::string tmp; - llvm::raw_string_ostream llir(tmp); - llir << *llvm; - llir.flush(); - asm_map["llir"] = py::cast(tmp); - // LLVM-IR -> HSA-CO - std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908"); - asm_map["hsaco"] = py::cast(path); - return std::make_tuple(name, asm_map, n_shared_bytes); -} - void init_triton_codegen(py::module &&m) { - m.def( - "compile_ttir", - [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, - int num_stages, py::dict& extern_libs) { - std::string name = ir.get_function_list()[0]->get_name(); - // record asm as we generate - asm_map_t asm_map; - std::ostringstream ttir; - ir.print(ttir); - asm_map["ttir"] = py::cast(ttir.str()); - llvm::LLVMContext ctx; - // construct extern lib map - triton::codegen::ExternLibMap extern_lib_map; - for (auto item : extern_libs) { - auto name = item.first.cast(); - auto path = item.second.cast(); - extern_lib_map.emplace( - name, triton::codegen::create_extern_lib(name, path)); - } - if(backend == CUDA) - return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); - assert(backend == ROCM); - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); + m.def("compile_ttir", + [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs) { + py::gil_scoped_release allow_threads; + std::string name = ir.get_function_list()[0]->get_name(); + // record asm as we generate + asm_map_t asm_map; + std::ostringstream ttir; + ir.print(ttir); + asm_map["ttir"] = py::cast(ttir.str()); + llvm::LLVMContext ctx; + // construct extern lib map + triton::codegen::ExternLibMap extern_lib_map; + for (auto item : extern_libs) { + auto name = item.first.cast(); + auto path = item.second.cast(); + extern_lib_map.emplace( + name, triton::codegen::create_extern_lib(name, path)); + } + // device properties + CUdevice dev = (CUdevice)device; + size_t major = cuGetInfo(dev); + size_t minor = cuGetInfo(dev); + size_t cc = major*10 + minor; + int version; + std::string ptxas_path = drv::path_to_ptxas(version); + // Triton-IR -> NVPTX LLVM-IR + triton::codegen::nvidia_cu_target target(cc); + int n_shared_bytes; + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); + std::string tmp; + llvm::raw_string_ostream llir(tmp); + llir << *llvm; + llir.flush(); + asm_map["llir"] = py::cast(tmp); + // LLVM-IR -> PTX + std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); + asm_map["ptx"] = py::cast(ptx); + // PTX -> Binary + std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); + if(!cubin.empty()){ + py::bytes bytes(cubin); + asm_map["cubin"] = bytes; + } + return std::make_tuple(name, asm_map, n_shared_bytes); }, py::return_value_policy::take_ownership); - m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ - py::gil_scoped_release allow_threads; - if(backend == CUDA) - return cu_load_binary(name, asm_map, n_shared_bytes, dev); - assert(backend == ROCM); - return hip_load_binary(name, asm_map, n_shared_bytes, dev); - }, py::return_value_policy::take_ownership); + + + // --------------------------------------- + // Load provided assembly code into driver + // --------------------------------------- + m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){ + py::gil_scoped_release allow_threads; + // create driver handles + CUfunction fun; + CUmodule mod; + drv::dispatch::cuModuleLoadData(&mod, data.c_str()); + drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); + drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device); + if(n_shared_bytes > 49152 && shared_optin > 49152){ + drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); + int shared_total, shared_static; + drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device); + drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); + drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); + } + return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills); + }, + py::return_value_policy::take_ownership + ); struct InstanceDescriptor diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 535f323da..c6523f016 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -5,6 +5,7 @@ import contextlib import functools import hashlib import io +import json import os import shutil import subprocess @@ -926,23 +927,7 @@ def binary_name_to_header_name(name): return f"{name}.h" -def generate_torch_glue(kernel_name, constants, signature, num_warps, binaries, tmpdir): - headers = dict() - - # write all cubins to header files - assert len(binaries) == 1, "AoT compilation not yet supported" - - for bin, shmem_size, name in binaries: - assert len(name) < 1024 - initializer = f""" -const char* {name}_ptx = R"({bin["ptx"]})"; -unsigned char {name}_bin[] = {{ {','.join(map(hex, bin["cubin"]))} }}; -unsigned int {name}_shmem = {shmem_size};""" - headers[name] = os.path.join(tmpdir, binary_name_to_header_name(name)) - with open(headers[name], "w") as f: - f.write(initializer) - - func_init = '\n '.join(f"init_function(\"{name}\", {name}_bin, {name}_shmem, device);" for _, _, name in binaries) +def generate_launcher(identifier, constants, signature): arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): @@ -970,13 +955,10 @@ unsigned int {name}_shmem = {shmem_size};""" "int64_t": "L", }[ty] - format = "iiiK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) # generate glue code - src = "" - for bin, shmem_size, name in binaries: - src += f"#include \"{headers[name]}\"\n" - src += f""" + src = f""" #include \"cuda.h\" #include @@ -995,50 +977,16 @@ static inline void gpuAssert(CUresult code, const char *file, int line) }} #define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} -static CUmodule module = 0; -static CUfunction function = 0; -static inline void init_function(const char* name, const unsigned char* src, size_t n_shared_bytes, int64_t device){{ - CUmodule mod; - CUfunction fun; - CUDA_CHECK(cuModuleLoadData(&mod, src)); - CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); - // set dynamic shared memory if necessary - int shared_optin; - CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); - if (n_shared_bytes > 49152 && shared_optin > 49152) {{ - CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); - int shared_total, shared_static; - int n_spills, n_reg; - CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device)); - CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); - CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); - CUDA_CHECK(cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); - CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); - }} - module = mod; - function = fun; -}} - -static inline void init_module(CUdevice device) {{ - {func_init} -}} - - -void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{ - // TODO: machine may have heterogeneous devices - if(function == 0){{ - CUdevice device; - CUDA_CHECK(cuCtxGetDevice(&device)); - init_module(device); - }} +void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; if(gridX*gridY*gridZ > 0){{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*{num_warps}, 1, 1, {name}_shmem, stream, params, 0)); + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); }} }} -CUdeviceptr getPointer(PyObject *obj, int idx) {{ + +static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{ if (PyLong_Check(obj)) {{ return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj); }} @@ -1061,15 +1009,18 @@ CUdeviceptr getPointer(PyObject *obj, int idx) {{ }} -static PyObject* {kernel_name}(PyObject* self, PyObject* args) {{ +static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; - uint64_t stream; + uint64_t _stream; + uint64_t _function; + int num_warps; + int shared_memory; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &stream, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ return NULL; }} - _{kernel_name}(gridX, gridY, gridZ, (CUstream)stream, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); if(PyErr_Occurred()) {{ @@ -1081,38 +1032,26 @@ static PyObject* {kernel_name}(PyObject* self, PyObject* args) {{ }} static PyMethodDef ModuleMethods[] = {{ - {{"{kernel_name}", {kernel_name}, METH_VARARGS, "Call {kernel_name} kernel"}}, + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, {{NULL, NULL, 0, NULL}} // sentinel }}; static struct PyModuleDef ModuleDef = {{ PyModuleDef_HEAD_INIT, - \"{kernel_name}\", + \"launcher\", NULL, //documentation -1, //size ModuleMethods }}; -PyMODINIT_FUNC PyInit_{kernel_name}(void) {{ +PyMODINIT_FUNC PyInit_launcher(void) {{ PyObject *m = PyModule_Create(&ModuleDef); if(m == NULL) {{ return NULL; }} PyModule_AddFunctions(m, ModuleMethods); - PyObject *ptx = PyDict_New(); -""" - - for _, _, name in binaries: - src += f""" - PyObject *py_{name}_ptx = PyUnicode_FromString({name}_ptx); - PyDict_SetItemString(ptx, "{name}", py_{name}_ptx); - Py_DECREF(py_{name}_ptx); -""" - - src += """ - PyModule_AddObject(m, "ptx", ptx); return m; -} +}} """ return src @@ -1126,35 +1065,34 @@ class CacheManager: def __init__(self, key): self.key = key - self.bin_path = None self.lock_path = None - # if caching is enabled, get the lock and bin path + # create cache directory if it doesn't exist self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) - if self.cache_dir: - self.bin_path = os.path.join(self.cache_dir, self.key + ".so") - self.lock_path = self.bin_path + ".lock" - def has_file(self): - return self.bin_path and os.path.exists(self.bin_path) + def _make_path(self, filename): + return os.path.join(self.cache_dir, filename) - def put(self, binary): - if self.bin_path: - assert self.lock_path is not None - with FileLock(self.lock_path): - with open(self.bin_path + ".tmp", "wb") as f: - f.write(binary) - os.rename(self.bin_path + ".tmp", self.bin_path) + def has_file(self, filename): + if not self.cache_dir: + return False + return os.path.exists(self._make_path(filename)) + def put(self, data, filename, binary=True): + if not self.cache_dir: + return + assert self.lock_path is not None + filepath = self._make_path(filename) + with FileLock(self.lock_path): + # use tempfile to be robust against program interruptions + mode = "wb" if binary else "w" + with open(filepath + ".tmp", mode) as f: + f.write(data) + os.rename(filepath + ".tmp", filepath) -def make_cache_key(fn, signature, configs, constants, num_warps, num_stages): - # Get unique key for the compiled code - get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) - configs_key = [get_conf_key(conf) for conf in configs] - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}" - key = hashlib.md5(key.encode("utf-8")).hexdigest() - return key # utilties for generating and compiling C wrappers @@ -1224,54 +1162,91 @@ def _build(name, src, srcdir): return so +def make_so_cache_key(signature, constants): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{''.join(signature.values())}{constants}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key + + +def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages): + # Get unique key for the compiled code + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) + configs_key = [get_conf_key(conf) for conf in configs] + key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key + + def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): # we get the kernel, i.e. the first function generated in the module - if configs is None: - assert False, "automatic specialization is not supported yet" - ref, _ = make_triton_ir(fn, signature, _triton.code_gen.instance_descriptor(), constants) - fns = ref.get_functions() - configs = _triton.infer_specialization_configs(fns[0]) assert len(configs) == 1 # cache manager - cache_key = make_cache_key(fn, signature, configs, constants, num_warps, num_stages) - cache_manager = CacheManager(cache_key) - # retrieve cached shared object if it exists - if cache_manager.has_file(): - return CompiledKernel(fn.__name__, cache_manager.bin_path) - # compile all the configs - binaries = [] - for config in configs: - binaries.append(_compile(fn, signature, device, constants, config, num_warps, num_stages, extern_libs, "cubin")) - # generate and compile glue code into shared object - with tempfile.TemporaryDirectory() as tmpdir: - all_constants = set(constants.keys()) - all_constants.update(configs[0].equal_to_1) - src = generate_torch_glue(fn.__name__, constants, signature, num_warps, binaries, tmpdir) - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build(fn.__name__, src_path, tmpdir) - with open(so, "rb") as f: - cache_manager.put(f.read()) + name = fn.__name__ + # name of files that are cached + so_cache_key = make_so_cache_key(signature, constants) + so_cache_manager = CacheManager(so_cache_key) + so_name = f"{name}.so" + # retrieve stub from cache if it exists + if not so_cache_manager.has_file(so_name): + with tempfile.TemporaryDirectory() as tmpdir: + src = generate_launcher(name, constants, signature) + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(fn.__name__, src_path, tmpdir) + with open(so, "rb") as f: + so_cache_manager.put(f.read(), so_name, binary=True) - return CompiledKernel(fn.__name__, cache_manager.bin_path) + # retrieve cached shared object if it exists + fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages) + fn_cache_manager = CacheManager(fn_cache_key) + ptx_name = f"{name}.ptx" + cubin_name = f"{name}.cubin" + data_name = f"{name}.json" + if not fn_cache_manager.has_file(cubin_name) or \ + not fn_cache_manager.has_file(data_name) or \ + not fn_cache_manager.has_file(ptx_name): + asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin") + metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} + fn_cache_manager.put(asm["cubin"], cubin_name) + fn_cache_manager.put(asm["ptx"], ptx_name, binary=False) + fn_cache_manager.put(json.dumps(metadata), data_name, binary=False) + + return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir) class CompiledKernel: - def __init__(self, fn_name, data_path): + def __init__(self, fn_name, so_path, cache_dir): + # initialize launcher import importlib.util - spec = importlib.util.spec_from_file_location(fn_name, data_path) + spec = importlib.util.spec_from_file_location("launcher", so_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) - self.c_wrapper = getattr(mod, fn_name) - ptx = getattr(mod, "ptx") - if len(ptx) == 1: - self.asm = {"ptx": list(ptx.values())[0]} + self.c_wrapper = getattr(mod, "launch") + # initialize metadata + with open(os.path.join(cache_dir, f"{fn_name}.json")) as f: + metadata = json.load(f) + self.shared = metadata["shared"] + self.num_warps = metadata["num_warps"] + self.num_stages = metadata["num_stages"] + # initialize asm dict + self.asm = dict() + with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f: + self.asm["cubin"] = f.read() + with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: + self.asm["ptx"] = f.read() + + device = torch.cuda.current_device() + mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) + self.cu_module = mod + self.cu_function = func def __getitem__(self, grid): def runner(*args, stream=None): if stream is None: stream = torch.cuda.current_stream().cuda_stream - self.c_wrapper(grid[0], grid[1], grid[2], stream, *args) + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) return runner diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 0187a7faa..5e1bc544b 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -253,7 +253,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage try: bin = cache[key] if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args}) return bin # kernel not cached -- compile except KeyError: @@ -274,7 +274,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args) self.cache[key] = bin return bin return None