try to load binary

This commit is contained in:
Michael Melesse
2022-10-25 20:29:43 +00:00
parent da5c24ffcb
commit 61c85c18b2
2 changed files with 54 additions and 6 deletions

View File

@@ -554,7 +554,7 @@ void init_triton_codegen(py::module &&m) {
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
m.def("load_binary_cubin", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
py::gil_scoped_release allow_threads;
// create driver handles
CUfunction fun;
@@ -581,6 +581,45 @@ void init_triton_codegen(py::module &&m) {
},
py::return_value_policy::take_ownership
);
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
m.def("load_binary_hipmodule", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
std::cout << "triton.cc: load_binary_hipmodule:" << std::endl;
std::cout << "\tname:" << name << std::endl;
std::cout << "\tdata:" << data << std::endl;
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
std::cout << "\tdevice:" << device << std::endl;
py::gil_scoped_release allow_threads;
// create driver handles
hipFunction_t fun;
hipModule_t mod;
drv::dispatch::hipModuleLoadData(&mod, data.c_str());
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
hipFuncAttributes attr;
drv::dispatch::hipFuncGetAttributes(&attr, fun);
drv::dispatch::hipFuncGetAttributes(&attr, fun);
n_regs = attr.numRegs;
n_spills = attr.localSizeBytes / 4;
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
if(n_shared_bytes > 49152 && shared_optin > 49152){
drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
int shared_total, shared_static;
drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
drv::dispatch::hipFuncGetAttributes(&attr, fun);
shared_total = attr.sharedSizeBytes;
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
},
py::return_value_policy::take_ownership
);
struct InstanceDescriptor

View File

@@ -1353,16 +1353,25 @@ class CompiledKernel:
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = dict()
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
if torch.version.hip is not None:
with open(os.path.join(cache_dir, f"{fn_name}.hipmodule"), "rb") as f:
self.asm["hipmodule"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
self.asm["amdgpu"] = f.read()
else:
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
self.asm["llir"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
self.asm["ttir"] = f.read()
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
if torch.version.hip is not None:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hipmodule(metadata["name"], self.asm["hipmodule"], self.shared, device)
else:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
self.fn_name = fn_name
self.cu_module = mod
self.cu_function = func