[FRONTEND] Backport new runtime from master (#706)

This PR merges the new runtime back into the `triton-mlir` branch. This
adds caching and just-in-time compilation functionality to the
triton-mlir project, and paves the way for re-using tests from the
master branch.
This commit is contained in:
Philippe Tillet
2022-09-23 16:09:43 -07:00
committed by GitHub
parent ecd1bc33df
commit 22ec22c257
13 changed files with 790 additions and 419 deletions

View File

@@ -544,50 +544,6 @@ void init_triton_runtime(py::module &&m) {
/*****************************************************************************/
typedef std::map<std::string, py::object> asm_map_t;
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
// CUDA
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
asm_map_t &asm_map,
size_t n_shared_bytes,
uint64_t dev) {
// load assembly
std::string assembly;
if (asm_map.find("cubin") != asm_map.end())
assembly = py::cast<std::string>(asm_map["cubin"]);
else
assembly = py::cast<std::string>(asm_map["ptx"]);
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
dev);
if (n_shared_bytes > 49152 && shared_optin > 49152) {
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
int n_spills, n_reg;
drv::dispatch::cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
dev);
drv::dispatch::cuFuncGetAttribute(&shared_static,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_spills,
CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncSetAttribute(
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
/*****************************************************************************/
/* Python bindings for triton::ir */
/*****************************************************************************/
@@ -1728,11 +1684,41 @@ void init_triton_translation(py::module &m) {
m.def(
"load_binary",
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
size_t n_shared_bytes, uint64_t dev) {
[](const std::string &name, const std::string &data,
size_t n_shared_bytes, uint64_t device) {
py::gil_scoped_release allow_threads;
assert(backend == CUDA); // Only CUDA is supported now.
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
fun);
drv::dispatch::cuFuncGetAttribute(
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(
&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
if (n_shared_bytes > 49152 && shared_optin > 49152) {
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
drv::dispatch::cuDeviceGetAttribute(
&shared_total,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
drv::dispatch::cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncSetAttribute(
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
(uint64_t)n_spills);
},
py::return_value_policy::take_ownership);
}