[FRONTEND] Added on-disk cache for compiled kernels (#287)

This commit is contained in:
Philippe Tillet
2021-09-18 22:48:26 -07:00
committed by GitHub
parent bd855ac13d
commit 6e5b0b4301
5 changed files with 235 additions and 81 deletions

View File

@@ -148,19 +148,26 @@ void init_triton_runtime(py::module &&m) {
/*****************************************************************************/
/* Python bindings for triton::codegen */
/*****************************************************************************/
typedef std::map<std::string, std::string> asm_map_t;
typedef std::map<std::string, py::object> asm_map_t;
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n_shared_bytes, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map, int cc, int version){
// LLVM-IR -> PTX
std::string ptx = drv::llir_to_ptx(llvm, cc, version);
asm_map["ptx"] = ptx;
// PTX -> Binary
CUmodule mod = drv::ptx_to_cumodule(ptx, cc);
// Handle to the kernel
// CUDA
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
// load assembly
std::string assembly;
if(asm_map.find("cubin") != asm_map.end())
assembly = py::cast<std::string>(asm_map["cubin"]);
else
assembly = py::cast<std::string>(asm_map["ptx"]);
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// Dynamic shared memory
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
if(n_shared_bytes > 49152 && shared_optin > 49152){
@@ -173,16 +180,15 @@ std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
// record asm
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map){
// LLVM-IR -> HSA-CO
std::string path = drv::llir_to_amdgpu(llvm, "gfx908");
// ROCM
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
py::bytes _assembly = asm_map["hsaco"];
std::string assembly = py::cast<std::string>(_assembly);
// HSA-CO -> hipModule
hipModule_t mod = drv::amdgpu_to_hipmodule(path);
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
// Handle to the kernel
hipFunction_t fun;
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
@@ -190,6 +196,63 @@ std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::M
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
// ---------------------------------------
// Compile Triton-IR to assembly
// ---------------------------------------
// CUDA
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
bool force_nc_cache, asm_map_t &asm_map){
llvm::LLVMContext ctx;
// device properties
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
size_t cc = major*10 + minor;
int version;
drv::dispatch::cuDriverGetVersion(&version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
llir.flush();
asm_map["llir"] = py::cast(tmp);
// LLVM-IR -> PTX
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
asm_map["ptx"] = py::cast(ptx);
// PTX -> Binary
std::string cubin = drv::ptx_to_cubin(ptx, cc);
if(!cubin.empty()){
py::bytes bytes(cubin);
asm_map["cubin"] = bytes;
}
return std::make_tuple(name, asm_map, n_shared_bytes);
}
// HIP
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
bool force_nc_cache, asm_map_t &asm_map){
llvm::LLVMContext ctx;
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::amd_cl_target target;
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
llir.flush();
asm_map["llir"] = py::cast(tmp);
// LLVM-IR -> HSA-CO
std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
asm_map["hsaco"] = py::cast(path);
return std::make_tuple(name, asm_map, n_shared_bytes);
}
void init_triton_codegen(py::module &&m) {
m.def(
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
@@ -198,43 +261,19 @@ void init_triton_codegen(py::module &&m) {
asm_map_t asm_map;
std::ostringstream ttir;
ir::print(ir, ttir);
asm_map["ttir"] = ttir.str();
asm_map["ttir"] = py::cast(ttir.str());
llvm::LLVMContext ctx;
if(backend == CUDA){
// device properties
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
size_t cc = major*10 + minor;
int version;
drv::dispatch::cuDriverGetVersion(&version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
llvm::raw_string_ostream llir(asm_map["llir"]);
llir << *llvm;
llir.flush();
// LLVM-IR -> Bin
uint64_t mod, fun;
std::tie(mod, fun) = cu_compile_llir(name, n_shared_bytes, &*llvm, device, asm_map, cc, version);
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
}
if(backend == ROCM){
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::amd_cl_target target;
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
llvm::raw_string_ostream llir(asm_map["llir"]);
llir << *llvm;
llir.flush();
// LLVM-IR -> Bin
uint64_t mod, fun;
std::tie(mod, fun) = hip_compile_llir(name, &*llvm, device, asm_map);
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
}
},
py::return_value_policy::take_ownership);
if(backend == CUDA)
return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
if(backend == ROCM)
return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
}, py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
if(backend == CUDA)
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
if(backend == ROCM)
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
}, py::return_value_policy::take_ownership);
}
/*****************************************************************************/