[FRONTEND] Return allocated registers and spilled registers for users (#541)

This commit is contained in:
Keren Zhou
2022-06-07 18:37:12 -07:00
committed by GitHub
parent 2cdc6d35c4
commit 38573d1261
2 changed files with 17 additions and 12 deletions

View File

@@ -418,7 +418,7 @@ typedef std::map<std::string, py::object> asm_map_t;
// --------------------------------------- // ---------------------------------------
// CUDA // CUDA
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
// load assembly // load assembly
std::string assembly; std::string assembly;
if(asm_map.find("cubin") != asm_map.end()) if(asm_map.find("cubin") != asm_map.end())
@@ -430,24 +430,27 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t
CUmodule mod; CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str()); drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
n_spills /= 4;
// set dynamic shared memory if necessary // set dynamic shared memory if necessary
int shared_optin; int shared_optin;
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev); drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
if(n_shared_bytes > 49152 && shared_optin > 49152){ if(n_shared_bytes > 49152 && shared_optin > 49152){
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static; int shared_total, shared_static;
int n_spills, n_reg;
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev); drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
} }
return std::make_tuple((uint64_t)mod, (uint64_t)fun); return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
} }
// ROCM // ROCM
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
py::bytes _assembly = asm_map["hsaco"]; py::bytes _assembly = asm_map["hsaco"];
std::string assembly = py::cast<std::string>(_assembly); std::string assembly = py::cast<std::string>(_assembly);
// HSA-CO -> hipModule // HSA-CO -> hipModule
@@ -456,7 +459,7 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
hipFunction_t fun; hipFunction_t fun;
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
// record asm // record asm
return std::make_tuple((uint64_t)mod, (uint64_t)fun); return std::make_tuple((uint64_t)mod, (uint64_t)fun, 0, 0);
} }
// --------------------------------------- // ---------------------------------------

View File

@@ -769,16 +769,18 @@ class Binary:
class LoadedBinary: class LoadedBinary:
def __init__(self, device: int, bin: Binary): def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend, module, kernel, n_regs, n_spills = _triton.code_gen.load_binary(bin.backend,
bin.name, bin.name,
bin.asm, bin.asm,
bin.shared_mem, bin.shared_mem,
device) device)
self.bin = bin self.bin = bin
self.asm = bin.asm self.asm = bin.asm
self.sass = '' self.sass = ''
self.module = module self.module = module
self.kernel = kernel self.kernel = kernel
self.n_regs = n_regs
self.n_spills = n_spills
self.device = device self.device = device
self.shared_mem = bin.shared_mem self.shared_mem = bin.shared_mem