[FRONTEND] Backport new runtime from master
(#706)
This PR merges the new runtime back into the `triton-mlir` branch. This adds caching and just-in-time compilation functionality to the triton-mlir project, and paves the way for re-using tests from the master branch.
This commit is contained in:
@@ -544,50 +544,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
/*****************************************************************************/
|
||||
typedef std::map<std::string, py::object> asm_map_t;
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
|
||||
// CUDA
|
||||
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
|
||||
asm_map_t &asm_map,
|
||||
size_t n_shared_bytes,
|
||||
uint64_t dev) {
|
||||
// load assembly
|
||||
std::string assembly;
|
||||
if (asm_map.find("cubin") != asm_map.end())
|
||||
assembly = py::cast<std::string>(asm_map["cubin"]);
|
||||
else
|
||||
assembly = py::cast<std::string>(asm_map["ptx"]);
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
dev);
|
||||
if (n_shared_bytes > 49152 && shared_optin > 49152) {
|
||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
int shared_total, shared_static;
|
||||
int n_spills, n_reg;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
dev);
|
||||
drv::dispatch::cuFuncGetAttribute(&shared_static,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_spills,
|
||||
CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||
drv::dispatch::cuFuncSetAttribute(
|
||||
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
/*****************************************************************************/
|
||||
@@ -1728,11 +1684,41 @@ void init_triton_translation(py::module &m) {
|
||||
|
||||
m.def(
|
||||
"load_binary",
|
||||
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
|
||||
size_t n_shared_bytes, uint64_t dev) {
|
||||
[](const std::string &name, const std::string &data,
|
||||
size_t n_shared_bytes, uint64_t device) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
assert(backend == CUDA); // Only CUDA is supported now.
|
||||
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
|
||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
|
||||
fun);
|
||||
drv::dispatch::cuFuncGetAttribute(
|
||||
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_optin,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
|
||||
if (n_shared_bytes > 49152 && shared_optin > 49152) {
|
||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
int shared_total, shared_static;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_total,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
|
||||
drv::dispatch::cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncSetAttribute(
|
||||
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
|
||||
(uint64_t)n_spills);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
Reference in New Issue
Block a user