From 61c85c18b2d2b5f6761799de9ffee0a42b94c425 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 25 Oct 2022 20:29:43 +0000 Subject: [PATCH] try to load binary --- python/src/triton.cc | 41 ++++++++++++++++++++++++++++++++++++++- python/triton/compiler.py | 19 +++++++++++++----- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 52867f9fe..8b596cb30 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -554,7 +554,7 @@ void init_triton_codegen(py::module &&m) { // --------------------------------------- // Load provided assembly code into driver // --------------------------------------- - m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){ + m.def("load_binary_cubin", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){ py::gil_scoped_release allow_threads; // create driver handles CUfunction fun; @@ -581,6 +581,45 @@ void init_triton_codegen(py::module &&m) { }, py::return_value_policy::take_ownership ); + + // --------------------------------------- + // Load provided assembly code into driver + // --------------------------------------- + m.def("load_binary_hipmodule", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){ + std::cout << "triton.cc: load_binary_hipmodule:" << std::endl; + std::cout << "\tname:" << name << std::endl; + std::cout << "\tdata:" << data << std::endl; + std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl; + std::cout << "\tdevice:" << device << std::endl; + py::gil_scoped_release allow_threads; + // create driver handles + hipFunction_t fun; + hipModule_t mod; + drv::dispatch::hipModuleLoadData(&mod, data.c_str()); + drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + hipFuncAttributes attr; + drv::dispatch::hipFuncGetAttributes(&attr, fun); + drv::dispatch::hipFuncGetAttributes(&attr, fun); + n_regs = attr.numRegs; + n_spills = attr.localSizeBytes / 4; + // set dynamic shared memory if necessary + int shared_optin; + drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device); + if(n_shared_bytes > 49152 && shared_optin > 49152){ + drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared); + int shared_total, shared_static; + drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device); + drv::dispatch::hipFuncGetAttributes(&attr, fun); + shared_total = attr.sharedSizeBytes; + // drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static); + } + return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills); + }, + py::return_value_policy::take_ownership + ); struct InstanceDescriptor diff --git a/python/triton/compiler.py b/python/triton/compiler.py index de250ad91..e03cc7f16 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1353,16 +1353,25 @@ class CompiledKernel: self.num_stages = metadata["num_stages"] # initialize asm dict self.asm = dict() - with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f: - self.asm["cubin"] = f.read() - with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: - self.asm["ptx"] = f.read() + if torch.version.hip is not None: + with open(os.path.join(cache_dir, f"{fn_name}.hipmodule"), "rb") as f: + self.asm["hipmodule"] = f.read() + with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f: + self.asm["amdgpu"] = f.read() + else: + with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f: + self.asm["cubin"] = f.read() + with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: + self.asm["ptx"] = f.read() with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f: self.asm["llir"] = f.read() with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f: self.asm["ttir"] = f.read() - mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) + if torch.version.hip is not None: + mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hipmodule(metadata["name"], self.asm["hipmodule"], self.shared, device) + else: + mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device) self.fn_name = fn_name self.cu_module = mod self.cu_function = func