[FRONTEND] Return allocated registers and spilled registers for users (#541)
This commit is contained in:
@@ -418,7 +418,7 @@ typedef std::map<std::string, py::object> asm_map_t;
|
|||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
|
|
||||||
// CUDA
|
// CUDA
|
||||||
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||||
// load assembly
|
// load assembly
|
||||||
std::string assembly;
|
std::string assembly;
|
||||||
if(asm_map.find("cubin") != asm_map.end())
|
if(asm_map.find("cubin") != asm_map.end())
|
||||||
@@ -430,24 +430,27 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t
|
|||||||
CUmodule mod;
|
CUmodule mod;
|
||||||
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
||||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||||
|
// get allocated registers and spilled registers from the function
|
||||||
|
int n_regs = 0;
|
||||||
|
int n_spills = 0;
|
||||||
|
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||||
|
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||||
|
n_spills /= 4;
|
||||||
// set dynamic shared memory if necessary
|
// set dynamic shared memory if necessary
|
||||||
int shared_optin;
|
int shared_optin;
|
||||||
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
||||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||||
int shared_total, shared_static;
|
int shared_total, shared_static;
|
||||||
int n_spills, n_reg;
|
|
||||||
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
|
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
|
||||||
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||||
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
|
||||||
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
|
||||||
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||||
}
|
}
|
||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ROCM
|
// ROCM
|
||||||
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||||
py::bytes _assembly = asm_map["hsaco"];
|
py::bytes _assembly = asm_map["hsaco"];
|
||||||
std::string assembly = py::cast<std::string>(_assembly);
|
std::string assembly = py::cast<std::string>(_assembly);
|
||||||
// HSA-CO -> hipModule
|
// HSA-CO -> hipModule
|
||||||
@@ -456,7 +459,7 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
|
|||||||
hipFunction_t fun;
|
hipFunction_t fun;
|
||||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||||
// record asm
|
// record asm
|
||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
|
@@ -769,16 +769,18 @@ class Binary:
|
|||||||
|
|
||||||
class LoadedBinary:
|
class LoadedBinary:
|
||||||
def __init__(self, device: int, bin: Binary):
|
def __init__(self, device: int, bin: Binary):
|
||||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
module, kernel, n_regs, n_spills = _triton.code_gen.load_binary(bin.backend,
|
||||||
bin.name,
|
bin.name,
|
||||||
bin.asm,
|
bin.asm,
|
||||||
bin.shared_mem,
|
bin.shared_mem,
|
||||||
device)
|
device)
|
||||||
self.bin = bin
|
self.bin = bin
|
||||||
self.asm = bin.asm
|
self.asm = bin.asm
|
||||||
self.sass = ''
|
self.sass = ''
|
||||||
self.module = module
|
self.module = module
|
||||||
self.kernel = kernel
|
self.kernel = kernel
|
||||||
|
self.n_regs = n_regs
|
||||||
|
self.n_spills = n_spills
|
||||||
self.device = device
|
self.device = device
|
||||||
self.shared_mem = bin.shared_mem
|
self.shared_mem = bin.shared_mem
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user