[FRONTEND] Added on-disk cache for compiled kernels (#287)
This commit is contained in:
@@ -10,6 +10,7 @@ namespace driver{
|
|||||||
|
|
||||||
void init_llvm();
|
void init_llvm();
|
||||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||||
|
std::string ptx_to_cubin(const std::string& ptx, int cc);
|
||||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||||
|
@@ -157,6 +157,39 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
||||||
|
std::string ptxas = "ptxas";
|
||||||
|
std::string version;
|
||||||
|
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
|
||||||
|
if(!use_system_ptxas)
|
||||||
|
return "";
|
||||||
|
|
||||||
|
// compile ptx with ptxas
|
||||||
|
char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||||
|
char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||||
|
mkstemp(_fsrc);
|
||||||
|
mkstemp(_flog);
|
||||||
|
std::string fsrc = _fsrc;
|
||||||
|
std::string flog = _flog;
|
||||||
|
std::string fbin = fsrc + ".o";
|
||||||
|
const char* _fbin = fbin.c_str();
|
||||||
|
std::ofstream ofs(fsrc);
|
||||||
|
ofs << ptx;
|
||||||
|
ofs.close();
|
||||||
|
std::string cmd;
|
||||||
|
int err;
|
||||||
|
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||||
|
err = system(cmd.c_str());
|
||||||
|
CUmodule ret;
|
||||||
|
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||||
|
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||||
|
_cubin.close();
|
||||||
|
dispatch::cuModuleLoadData(&ret, cubin.c_str());
|
||||||
|
unlink(_fsrc);
|
||||||
|
unlink(_flog);
|
||||||
|
unlink(_fbin);
|
||||||
|
return cubin;
|
||||||
|
}
|
||||||
|
|
||||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
|
CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
@@ -175,6 +208,8 @@ CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
|
|||||||
mkstemp(_flog);
|
mkstemp(_flog);
|
||||||
std::string fsrc = _fsrc;
|
std::string fsrc = _fsrc;
|
||||||
std::string flog = _flog;
|
std::string flog = _flog;
|
||||||
|
std::string fbin = fsrc + ".o";
|
||||||
|
const char* _fbin = fbin.c_str();
|
||||||
std::ofstream ofs(fsrc);
|
std::ofstream ofs(fsrc);
|
||||||
ofs << ptx;
|
ofs << ptx;
|
||||||
ofs.close();
|
ofs.close();
|
||||||
@@ -183,9 +218,13 @@ CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
|
|||||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||||
err = system(cmd.c_str());
|
err = system(cmd.c_str());
|
||||||
CUmodule ret;
|
CUmodule ret;
|
||||||
dispatch::cuModuleLoad(&ret, (fsrc + ".o").c_str());
|
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||||
|
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||||
|
_cubin.close();
|
||||||
|
dispatch::cuModuleLoadData(&ret, cubin.c_str());
|
||||||
unlink(_fsrc);
|
unlink(_fsrc);
|
||||||
unlink(_flog);
|
unlink(_flog);
|
||||||
|
unlink(_fbin);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -127,7 +127,7 @@ setup(
|
|||||||
description="A language and compiler for custom Deep Learning operations",
|
description="A language and compiler for custom Deep Learning operations",
|
||||||
long_description="",
|
long_description="",
|
||||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||||
install_requires=["torch"],
|
install_requires=["torch", "filelock"],
|
||||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||||
|
@@ -148,19 +148,26 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
/* Python bindings for triton::codegen */
|
/* Python bindings for triton::codegen */
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
typedef std::map<std::string, std::string> asm_map_t;
|
typedef std::map<std::string, py::object> asm_map_t;
|
||||||
|
|
||||||
|
// ---------------------------------------
|
||||||
|
// Load provided assembly code into driver
|
||||||
|
// ---------------------------------------
|
||||||
|
|
||||||
std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n_shared_bytes, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map, int cc, int version){
|
// CUDA
|
||||||
// LLVM-IR -> PTX
|
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||||
std::string ptx = drv::llir_to_ptx(llvm, cc, version);
|
// load assembly
|
||||||
asm_map["ptx"] = ptx;
|
std::string assembly;
|
||||||
// PTX -> Binary
|
if(asm_map.find("cubin") != asm_map.end())
|
||||||
CUmodule mod = drv::ptx_to_cumodule(ptx, cc);
|
assembly = py::cast<std::string>(asm_map["cubin"]);
|
||||||
// Handle to the kernel
|
else
|
||||||
|
assembly = py::cast<std::string>(asm_map["ptx"]);
|
||||||
|
// create driver handles
|
||||||
CUfunction fun;
|
CUfunction fun;
|
||||||
|
CUmodule mod;
|
||||||
|
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
||||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||||
// Dynamic shared memory
|
// set dynamic shared memory if necessary
|
||||||
int shared_optin;
|
int shared_optin;
|
||||||
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
||||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||||
@@ -173,16 +180,15 @@ std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n
|
|||||||
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||||
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||||
}
|
}
|
||||||
|
|
||||||
// record asm
|
|
||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map){
|
// ROCM
|
||||||
// LLVM-IR -> HSA-CO
|
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||||
std::string path = drv::llir_to_amdgpu(llvm, "gfx908");
|
py::bytes _assembly = asm_map["hsaco"];
|
||||||
|
std::string assembly = py::cast<std::string>(_assembly);
|
||||||
// HSA-CO -> hipModule
|
// HSA-CO -> hipModule
|
||||||
hipModule_t mod = drv::amdgpu_to_hipmodule(path);
|
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
|
||||||
// Handle to the kernel
|
// Handle to the kernel
|
||||||
hipFunction_t fun;
|
hipFunction_t fun;
|
||||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||||
@@ -190,17 +196,15 @@ std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::M
|
|||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_triton_codegen(py::module &&m) {
|
// ---------------------------------------
|
||||||
m.def(
|
// Compile Triton-IR to assembly
|
||||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
|
// ---------------------------------------
|
||||||
std::string name = ir.get_function_list()[0]->get_name();
|
|
||||||
// record asm as we generate
|
// CUDA
|
||||||
asm_map_t asm_map;
|
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
||||||
std::ostringstream ttir;
|
uint64_t device, int num_warps, int num_stages,
|
||||||
ir::print(ir, ttir);
|
bool force_nc_cache, asm_map_t &asm_map){
|
||||||
asm_map["ttir"] = ttir.str();
|
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
if(backend == CUDA){
|
|
||||||
// device properties
|
// device properties
|
||||||
CUdevice dev = (CUdevice)device;
|
CUdevice dev = (CUdevice)device;
|
||||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||||
@@ -212,29 +216,64 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
triton::codegen::nvidia_cu_target target(cc);
|
triton::codegen::nvidia_cu_target target(cc);
|
||||||
int n_shared_bytes;
|
int n_shared_bytes;
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||||
llvm::raw_string_ostream llir(asm_map["llir"]);
|
std::string tmp;
|
||||||
|
llvm::raw_string_ostream llir(tmp);
|
||||||
llir << *llvm;
|
llir << *llvm;
|
||||||
llir.flush();
|
llir.flush();
|
||||||
// LLVM-IR -> Bin
|
asm_map["llir"] = py::cast(tmp);
|
||||||
uint64_t mod, fun;
|
// LLVM-IR -> PTX
|
||||||
std::tie(mod, fun) = cu_compile_llir(name, n_shared_bytes, &*llvm, device, asm_map, cc, version);
|
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
||||||
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
|
asm_map["ptx"] = py::cast(ptx);
|
||||||
|
// PTX -> Binary
|
||||||
|
std::string cubin = drv::ptx_to_cubin(ptx, cc);
|
||||||
|
if(!cubin.empty()){
|
||||||
|
py::bytes bytes(cubin);
|
||||||
|
asm_map["cubin"] = bytes;
|
||||||
}
|
}
|
||||||
if(backend == ROCM){
|
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// HIP
|
||||||
|
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
|
||||||
|
uint64_t device, int num_warps, int num_stages,
|
||||||
|
bool force_nc_cache, asm_map_t &asm_map){
|
||||||
|
llvm::LLVMContext ctx;
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
// Triton-IR -> NVPTX LLVM-IR
|
||||||
triton::codegen::amd_cl_target target;
|
triton::codegen::amd_cl_target target;
|
||||||
int n_shared_bytes;
|
int n_shared_bytes;
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||||
llvm::raw_string_ostream llir(asm_map["llir"]);
|
std::string tmp;
|
||||||
|
llvm::raw_string_ostream llir(tmp);
|
||||||
llir << *llvm;
|
llir << *llvm;
|
||||||
llir.flush();
|
llir.flush();
|
||||||
// LLVM-IR -> Bin
|
asm_map["llir"] = py::cast(tmp);
|
||||||
uint64_t mod, fun;
|
// LLVM-IR -> HSA-CO
|
||||||
std::tie(mod, fun) = hip_compile_llir(name, &*llvm, device, asm_map);
|
std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
|
||||||
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
|
asm_map["hsaco"] = py::cast(path);
|
||||||
|
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||||
}
|
}
|
||||||
},
|
|
||||||
py::return_value_policy::take_ownership);
|
void init_triton_codegen(py::module &&m) {
|
||||||
|
m.def(
|
||||||
|
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
|
||||||
|
std::string name = ir.get_function_list()[0]->get_name();
|
||||||
|
// record asm as we generate
|
||||||
|
asm_map_t asm_map;
|
||||||
|
std::ostringstream ttir;
|
||||||
|
ir::print(ir, ttir);
|
||||||
|
asm_map["ttir"] = py::cast(ttir.str());
|
||||||
|
llvm::LLVMContext ctx;
|
||||||
|
if(backend == CUDA)
|
||||||
|
return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||||
|
if(backend == ROCM)
|
||||||
|
return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||||
|
}, py::return_value_policy::take_ownership);
|
||||||
|
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||||
|
if(backend == CUDA)
|
||||||
|
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||||
|
if(backend == ROCM)
|
||||||
|
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||||
|
}, py::return_value_policy::take_ownership);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
|
@@ -5,6 +5,11 @@ import struct
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import hashlib
|
||||||
|
import atexit
|
||||||
|
import os
|
||||||
|
import shelve
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -411,23 +416,31 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
|
|
||||||
class Binary:
|
class Binary:
|
||||||
def __init__(self, backend, module, kernel, asm, num_warps, num_stages, force_nc_cache, shared_mem):
|
def __init__(self, backend, name, asm, shared_mem, num_warps):
|
||||||
# cache ir asm
|
self.backend = backend
|
||||||
|
self.name = name
|
||||||
self.asm = asm
|
self.asm = asm
|
||||||
self.module = module
|
|
||||||
self.kernel = kernel
|
|
||||||
self.shared_mem = shared_mem
|
self.shared_mem = shared_mem
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
self.num_stages = num_stages
|
|
||||||
self.force_nc_cache = force_nc_cache
|
class LoadedBinary:
|
||||||
self.sass = None
|
def __init__(self, device: int, bin: Binary):
|
||||||
self.backend = backend
|
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||||
|
bin.name,
|
||||||
|
bin.asm,
|
||||||
|
bin.shared_mem,
|
||||||
|
device)
|
||||||
|
self.bin = bin
|
||||||
|
self.asm = bin.asm
|
||||||
|
self.module = module
|
||||||
|
self.kernel = kernel
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||||
_triton.runtime.enqueue(self.backend, stream, self.kernel,
|
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||||
grid_0, grid_1, grid_2,
|
grid_0, grid_1, grid_2,
|
||||||
self.num_warps * 32, 1, 1,
|
self.bin.num_warps * 32, 1, 1,
|
||||||
args, self.shared_mem)
|
args, self.bin.shared_mem)
|
||||||
|
|
||||||
|
|
||||||
class CompilationError(Exception):
|
class CompilationError(Exception):
|
||||||
@@ -536,11 +549,11 @@ class Kernel:
|
|||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
else:
|
else:
|
||||||
backend = _triton.runtime.backend.ROCM
|
backend = _triton.runtime.backend.ROCM
|
||||||
mod, ker, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||||
if shared_mem > max_shared_memory:
|
if shared_mem > max_shared_memory:
|
||||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||||
return Binary(backend, mod, ker, asm, num_warps, num_stages, force_nc_cache, shared_mem)
|
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||||
# device inference
|
# device inference
|
||||||
@@ -579,29 +592,43 @@ class Kernel:
|
|||||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||||
# determine if we need to re-compile
|
# compute hash for caching this kernel
|
||||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||||
attr_key = frozenset(attributes.items())
|
attr_key = frozenset(attributes.items())
|
||||||
meta_key = frozenset(meta.items())
|
meta_key = frozenset(meta.items())
|
||||||
const_key = frozenset(constants.items())
|
const_key = frozenset(constants.items())
|
||||||
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||||
cache = self.fn.cache
|
key = repr(key)
|
||||||
if key not in cache:
|
# get cached binary
|
||||||
# compile and cache configuration if necessary
|
drv_cache = self.fn.drv_cache
|
||||||
cache[key] = self._compile(
|
bin_cache_path = self.fn.bin_cache_path
|
||||||
|
bin_lock_path = self.fn.bin_lock_path
|
||||||
|
if key not in drv_cache:
|
||||||
|
binary = None
|
||||||
|
if bin_lock_path:
|
||||||
|
with FileLock(bin_lock_path):
|
||||||
|
with shelve.open(bin_cache_path) as db:
|
||||||
|
binary = db.get(key, None)
|
||||||
|
if binary is None:
|
||||||
|
binary = self._compile(
|
||||||
*wargs, device=device_idx, attributes=attributes,
|
*wargs, device=device_idx, attributes=attributes,
|
||||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||||
constants=constants, **meta
|
constants=constants, **meta
|
||||||
)
|
)
|
||||||
|
if bin_lock_path:
|
||||||
|
with FileLock(bin_lock_path):
|
||||||
|
with shelve.open(bin_cache_path) as db:
|
||||||
|
db[key] = binary
|
||||||
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||||
params = struct.pack(fmt, *args)
|
params = struct.pack(fmt, *args)
|
||||||
# enqueue cached function into stream
|
# enqueue cached function into stream
|
||||||
binary = cache[key]
|
callable = drv_cache[key]
|
||||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||||
binary(stream, params, *grid)
|
callable(stream, params, *grid)
|
||||||
return binary
|
return callable
|
||||||
|
|
||||||
|
|
||||||
class Launcher:
|
class Launcher:
|
||||||
@@ -662,17 +689,59 @@ class Autotuner:
|
|||||||
|
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
|
||||||
|
# clear cache if the db is older than either the frontend or the backend
|
||||||
|
def _clear_cache(self):
|
||||||
|
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
|
||||||
|
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
|
||||||
|
with FileLock(self.bin_lock_path):
|
||||||
|
cache_mtime = os.path.getmtime(self.db_path)
|
||||||
|
if frontend_mtime > cache_mtime or backend_mtime > cache_mtime:
|
||||||
|
os.remove(self.db_path)
|
||||||
|
|
||||||
|
def _init_cache_paths(self):
|
||||||
|
# fetch cache directory path
|
||||||
|
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||||
|
if not cache_dir:
|
||||||
|
self.bin_cache_path = None
|
||||||
|
self.db_path = None
|
||||||
|
self.bin_lock_path = None
|
||||||
|
return
|
||||||
|
# create cache directory
|
||||||
|
if not os.path.exists(cache_dir):
|
||||||
|
os.makedirs(cache_dir)
|
||||||
|
# create md5 hash of src
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
md5.update(self.src.encode('utf-8'))
|
||||||
|
md5_hash = md5.hexdigest()
|
||||||
|
# load dbm file in cache_dir for md5_hash
|
||||||
|
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
|
||||||
|
self.db_path = self.bin_cache_path + '.db'
|
||||||
|
self.bin_lock_path = self.bin_cache_path + '.lock'
|
||||||
|
# if bin_cache_path exists
|
||||||
|
if os.path.exists(self.db_path):
|
||||||
|
self._clear_cache()
|
||||||
|
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
|
# information of wrapped function
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.module = fn.__module__
|
self.module = fn.__module__
|
||||||
self.arg_names = inspect.getfullargspec(fn).args
|
self.arg_names = inspect.getfullargspec(fn).args
|
||||||
self.cache = dict()
|
|
||||||
self.kernel_decorators = []
|
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
|
# cache for callable driver objects (e.g. CUkernel)
|
||||||
|
self.drv_cache = dict()
|
||||||
|
# on-disk paths for the binary cache and corresponding
|
||||||
|
# file-lock
|
||||||
|
self._init_cache_paths()
|
||||||
|
# JITFunction can be instantiated as kernel
|
||||||
|
# when called with a grid using __getitem__
|
||||||
|
self.kernel_decorators = []
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
|
# forward docs
|
||||||
self.__doc__ = fn.__doc__
|
self.__doc__ = fn.__doc__
|
||||||
|
|
||||||
# we do not parse in the constructor because
|
|
||||||
|
# we do not parse `src` in the constructor because
|
||||||
# the user might want to monkey-patch self.src dynamically.
|
# the user might want to monkey-patch self.src dynamically.
|
||||||
# Some unit tests do this, for example.
|
# Some unit tests do this, for example.
|
||||||
def parse(self):
|
def parse(self):
|
||||||
@@ -699,10 +768,16 @@ class JITFunction:
|
|||||||
raise e
|
raise e
|
||||||
raise CompilationError(self.src, node, e)
|
raise CompilationError(self.src, node, e)
|
||||||
|
|
||||||
|
# - when `.src` attribute is set, cache path needs
|
||||||
|
# to be reinitialized
|
||||||
|
# - when kernel decorators change, cached kernel
|
||||||
|
# needs to be cleared
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
if name == 'kernel_decorators':
|
if name == 'kernel_decorators':
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
super(JITFunction, self).__setattr__(name, value)
|
super(JITFunction, self).__setattr__(name, value)
|
||||||
|
if name == 'src':
|
||||||
|
self._init_cache_paths()
|
||||||
|
|
||||||
def _init_kernel(self):
|
def _init_kernel(self):
|
||||||
if self.kernel is None:
|
if self.kernel is None:
|
||||||
|
Reference in New Issue
Block a user