diff --git a/python/src/triton.cc b/python/src/triton.cc index ac2bedebf..7f9e7e752 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -418,7 +418,7 @@ typedef std::map asm_map_t; // --------------------------------------- // CUDA -std::tuple cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ +std::tuple cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ // load assembly std::string assembly; if(asm_map.find("cubin") != asm_map.end()) @@ -430,24 +430,27 @@ std::tuple cu_load_binary(const std::string& name, asm_map_t CUmodule mod; drv::dispatch::cuModuleLoadData(&mod, assembly.c_str()); drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); + drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); + n_spills /= 4; // set dynamic shared memory if necessary int shared_optin; drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev); if(n_shared_bytes > 49152 && shared_optin > 49152){ drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); int shared_total, shared_static; - int n_spills, n_reg; drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev); drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); - drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); - drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); } - return std::make_tuple((uint64_t)mod, (uint64_t)fun); + return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills); } // ROCM -std::tuple hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ +std::tuple hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ py::bytes _assembly = asm_map["hsaco"]; std::string assembly = py::cast(_assembly); // HSA-CO -> hipModule @@ -456,7 +459,7 @@ std::tuple hip_load_binary(const std::string& name, asm_map_ hipFunction_t fun; drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); // record asm - return std::make_tuple((uint64_t)mod, (uint64_t)fun); + return std::make_tuple((uint64_t)mod, (uint64_t)fun, 0, 0); } // --------------------------------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 01cd1b5ed..60feb1740 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -769,16 +769,18 @@ class Binary: class LoadedBinary: def __init__(self, device: int, bin: Binary): - module, kernel = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, - device) + module, kernel, n_regs, n_spills = _triton.code_gen.load_binary(bin.backend, + bin.name, + bin.asm, + bin.shared_mem, + device) self.bin = bin self.asm = bin.asm self.sass = '' self.module = module self.kernel = kernel + self.n_regs = n_regs + self.n_spills = n_spills self.device = device self.shared_mem = bin.shared_mem