try to load binary
This commit is contained in:
@@ -554,7 +554,7 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
// Load provided assembly code into driver
|
// Load provided assembly code into driver
|
||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
m.def("load_binary_cubin", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
// create driver handles
|
// create driver handles
|
||||||
CUfunction fun;
|
CUfunction fun;
|
||||||
@@ -581,6 +581,45 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
},
|
},
|
||||||
py::return_value_policy::take_ownership
|
py::return_value_policy::take_ownership
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// ---------------------------------------
|
||||||
|
// Load provided assembly code into driver
|
||||||
|
// ---------------------------------------
|
||||||
|
m.def("load_binary_hipmodule", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||||
|
std::cout << "triton.cc: load_binary_hipmodule:" << std::endl;
|
||||||
|
std::cout << "\tname:" << name << std::endl;
|
||||||
|
std::cout << "\tdata:" << data << std::endl;
|
||||||
|
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
|
||||||
|
std::cout << "\tdevice:" << device << std::endl;
|
||||||
|
py::gil_scoped_release allow_threads;
|
||||||
|
// create driver handles
|
||||||
|
hipFunction_t fun;
|
||||||
|
hipModule_t mod;
|
||||||
|
drv::dispatch::hipModuleLoadData(&mod, data.c_str());
|
||||||
|
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||||
|
// get allocated registers and spilled registers from the function
|
||||||
|
int n_regs = 0;
|
||||||
|
int n_spills = 0;
|
||||||
|
hipFuncAttributes attr;
|
||||||
|
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||||
|
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||||
|
n_regs = attr.numRegs;
|
||||||
|
n_spills = attr.localSizeBytes / 4;
|
||||||
|
// set dynamic shared memory if necessary
|
||||||
|
int shared_optin;
|
||||||
|
drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
||||||
|
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||||
|
drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
|
||||||
|
int shared_total, shared_static;
|
||||||
|
drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
||||||
|
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||||
|
shared_total = attr.sharedSizeBytes;
|
||||||
|
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
|
||||||
|
}
|
||||||
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
||||||
|
},
|
||||||
|
py::return_value_policy::take_ownership
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
struct InstanceDescriptor
|
struct InstanceDescriptor
|
||||||
|
@@ -1353,16 +1353,25 @@ class CompiledKernel:
|
|||||||
self.num_stages = metadata["num_stages"]
|
self.num_stages = metadata["num_stages"]
|
||||||
# initialize asm dict
|
# initialize asm dict
|
||||||
self.asm = dict()
|
self.asm = dict()
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
if torch.version.hip is not None:
|
||||||
self.asm["cubin"] = f.read()
|
with open(os.path.join(cache_dir, f"{fn_name}.hipmodule"), "rb") as f:
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
self.asm["hipmodule"] = f.read()
|
||||||
self.asm["ptx"] = f.read()
|
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
|
||||||
|
self.asm["amdgpu"] = f.read()
|
||||||
|
else:
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||||
|
self.asm["cubin"] = f.read()
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||||
|
self.asm["ptx"] = f.read()
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
|
||||||
self.asm["llir"] = f.read()
|
self.asm["llir"] = f.read()
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
|
||||||
self.asm["ttir"] = f.read()
|
self.asm["ttir"] = f.read()
|
||||||
|
|
||||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
if torch.version.hip is not None:
|
||||||
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hipmodule(metadata["name"], self.asm["hipmodule"], self.shared, device)
|
||||||
|
else:
|
||||||
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
self.fn_name = fn_name
|
self.fn_name = fn_name
|
||||||
self.cu_module = mod
|
self.cu_module = mod
|
||||||
self.cu_function = func
|
self.cu_function = func
|
||||||
|
Reference in New Issue
Block a user