[FRONTEND] Return allocated registers and spilled registers for users (#541)
This commit is contained in:
@@ -418,7 +418,7 @@ typedef std::map<std::string, py::object> asm_map_t;
|
||||
// ---------------------------------------
|
||||
|
||||
// CUDA
|
||||
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
// load assembly
|
||||
std::string assembly;
|
||||
if(asm_map.find("cubin") != asm_map.end())
|
||||
@@ -430,24 +430,27 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t
|
||||
CUmodule mod;
|
||||
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
int shared_total, shared_static;
|
||||
int n_spills, n_reg;
|
||||
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
|
||||
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
||||
}
|
||||
|
||||
// ROCM
|
||||
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
py::bytes _assembly = asm_map["hsaco"];
|
||||
std::string assembly = py::cast<std::string>(_assembly);
|
||||
// HSA-CO -> hipModule
|
||||
@@ -456,7 +459,7 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
|
||||
hipFunction_t fun;
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// record asm
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, 0, 0);
|
||||
}
|
||||
|
||||
// ---------------------------------------
|
||||
|
@@ -769,16 +769,18 @@ class Binary:
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
module, kernel, n_regs, n_spills = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
self.sass = ''
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.n_regs = n_regs
|
||||
self.n_spills = n_spills
|
||||
self.device = device
|
||||
self.shared_mem = bin.shared_mem
|
||||
|
||||
|
Reference in New Issue
Block a user