try to load binary
This commit is contained in:
@@ -554,7 +554,7 @@ void init_triton_codegen(py::module &&m) {
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
m.def("load_binary_cubin", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
@@ -581,6 +581,45 @@ void init_triton_codegen(py::module &&m) {
|
||||
},
|
||||
py::return_value_policy::take_ownership
|
||||
);
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
m.def("load_binary_hipmodule", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||
std::cout << "triton.cc: load_binary_hipmodule:" << std::endl;
|
||||
std::cout << "\tname:" << name << std::endl;
|
||||
std::cout << "\tdata:" << data << std::endl;
|
||||
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
|
||||
std::cout << "\tdevice:" << device << std::endl;
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create driver handles
|
||||
hipFunction_t fun;
|
||||
hipModule_t mod;
|
||||
drv::dispatch::hipModuleLoadData(&mod, data.c_str());
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
hipFuncAttributes attr;
|
||||
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
n_regs = attr.numRegs;
|
||||
n_spills = attr.localSizeBytes / 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||
drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
|
||||
int shared_total, shared_static;
|
||||
drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
||||
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||
shared_total = attr.sharedSizeBytes;
|
||||
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
||||
},
|
||||
py::return_value_policy::take_ownership
|
||||
);
|
||||
|
||||
|
||||
struct InstanceDescriptor
|
||||
|
@@ -1353,16 +1353,25 @@ class CompiledKernel:
|
||||
self.num_stages = metadata["num_stages"]
|
||||
# initialize asm dict
|
||||
self.asm = dict()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
if torch.version.hip is not None:
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.hipmodule"), "rb") as f:
|
||||
self.asm["hipmodule"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
|
||||
self.asm["amdgpu"] = f.read()
|
||||
else:
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
|
||||
self.asm["llir"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
|
||||
self.asm["ttir"] = f.read()
|
||||
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
if torch.version.hip is not None:
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hipmodule(metadata["name"], self.asm["hipmodule"], self.shared, device)
|
||||
else:
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.fn_name = fn_name
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
Reference in New Issue
Block a user