[FRONTEND] Added on-disk cache for compiled kernels (#287)

This commit is contained in:
Philippe Tillet
2021-09-18 22:48:26 -07:00
committed by GitHub
parent bd855ac13d
commit 6e5b0b4301
5 changed files with 235 additions and 81 deletions

View File

@@ -10,6 +10,7 @@ namespace driver{
void init_llvm(); void init_llvm();
std::string llir_to_ptx(llvm::Module* module, int cc, int version); std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc); CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc); std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path); hipModule_t amdgpu_to_hipmodule(const std::string& path);

View File

@@ -157,6 +157,39 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
return result; return result;
} }
std::string ptx_to_cubin(const std::string& ptx, int cc) {
std::string ptxas = "ptxas";
std::string version;
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
if(!use_system_ptxas)
return "";
// compile ptx with ptxas
char _fsrc[] = "/tmp/triton_k_XXXXXX";
char _flog[] = "/tmp/triton_l_XXXXXX";
mkstemp(_fsrc);
mkstemp(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
dispatch::cuModuleLoadData(&ret, cubin.c_str());
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
return cubin;
}
CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
// JIT compile source-code // JIT compile source-code
@@ -175,6 +208,8 @@ CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
mkstemp(_flog); mkstemp(_flog);
std::string fsrc = _fsrc; std::string fsrc = _fsrc;
std::string flog = _flog; std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str();
std::ofstream ofs(fsrc); std::ofstream ofs(fsrc);
ofs << ptx; ofs << ptx;
ofs.close(); ofs.close();
@@ -183,9 +218,13 @@ CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str()); err = system(cmd.c_str());
CUmodule ret; CUmodule ret;
dispatch::cuModuleLoad(&ret, (fsrc + ".o").c_str()); std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
dispatch::cuModuleLoadData(&ret, cubin.c_str());
unlink(_fsrc); unlink(_fsrc);
unlink(_flog); unlink(_flog);
unlink(_fbin);
return ret; return ret;
} }

View File

@@ -127,7 +127,7 @@ setup(
description="A language and compiler for custom Deep Learning operations", description="A language and compiler for custom Deep Learning operations",
long_description="", long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
install_requires=["torch"], install_requires=["torch", "filelock"],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True, include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")], ext_modules=[CMakeExtension("triton", "triton/_C/")],

View File

@@ -148,19 +148,26 @@ void init_triton_runtime(py::module &&m) {
/*****************************************************************************/ /*****************************************************************************/
/* Python bindings for triton::codegen */ /* Python bindings for triton::codegen */
/*****************************************************************************/ /*****************************************************************************/
typedef std::map<std::string, std::string> asm_map_t; typedef std::map<std::string, py::object> asm_map_t;
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n_shared_bytes, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map, int cc, int version){ // CUDA
// LLVM-IR -> PTX std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
std::string ptx = drv::llir_to_ptx(llvm, cc, version); // load assembly
asm_map["ptx"] = ptx; std::string assembly;
// PTX -> Binary if(asm_map.find("cubin") != asm_map.end())
CUmodule mod = drv::ptx_to_cumodule(ptx, cc); assembly = py::cast<std::string>(asm_map["cubin"]);
// Handle to the kernel else
assembly = py::cast<std::string>(asm_map["ptx"]);
// create driver handles
CUfunction fun; CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// Dynamic shared memory // set dynamic shared memory if necessary
int shared_optin; int shared_optin;
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev); drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
if(n_shared_bytes > 49152 && shared_optin > 49152){ if(n_shared_bytes > 49152 && shared_optin > 49152){
@@ -173,16 +180,15 @@ std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
} }
// record asm
return std::make_tuple((uint64_t)mod, (uint64_t)fun); return std::make_tuple((uint64_t)mod, (uint64_t)fun);
} }
std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map){ // ROCM
// LLVM-IR -> HSA-CO std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
std::string path = drv::llir_to_amdgpu(llvm, "gfx908"); py::bytes _assembly = asm_map["hsaco"];
std::string assembly = py::cast<std::string>(_assembly);
// HSA-CO -> hipModule // HSA-CO -> hipModule
hipModule_t mod = drv::amdgpu_to_hipmodule(path); hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
// Handle to the kernel // Handle to the kernel
hipFunction_t fun; hipFunction_t fun;
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
@@ -190,17 +196,15 @@ std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::M
return std::make_tuple((uint64_t)mod, (uint64_t)fun); return std::make_tuple((uint64_t)mod, (uint64_t)fun);
} }
void init_triton_codegen(py::module &&m) { // ---------------------------------------
m.def( // Compile Triton-IR to assembly
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) { // ---------------------------------------
std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate // CUDA
asm_map_t asm_map; std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
std::ostringstream ttir; uint64_t device, int num_warps, int num_stages,
ir::print(ir, ttir); bool force_nc_cache, asm_map_t &asm_map){
asm_map["ttir"] = ttir.str();
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
if(backend == CUDA){
// device properties // device properties
CUdevice dev = (CUdevice)device; CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
@@ -212,29 +216,64 @@ void init_triton_codegen(py::module &&m) {
triton::codegen::nvidia_cu_target target(cc); triton::codegen::nvidia_cu_target target(cc);
int n_shared_bytes; int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes); auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
llvm::raw_string_ostream llir(asm_map["llir"]); std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm; llir << *llvm;
llir.flush(); llir.flush();
// LLVM-IR -> Bin asm_map["llir"] = py::cast(tmp);
uint64_t mod, fun; // LLVM-IR -> PTX
std::tie(mod, fun) = cu_compile_llir(name, n_shared_bytes, &*llvm, device, asm_map, cc, version); std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
return std::make_tuple(mod, fun, asm_map, n_shared_bytes); asm_map["ptx"] = py::cast(ptx);
// PTX -> Binary
std::string cubin = drv::ptx_to_cubin(ptx, cc);
if(!cubin.empty()){
py::bytes bytes(cubin);
asm_map["cubin"] = bytes;
} }
if(backend == ROCM){ return std::make_tuple(name, asm_map, n_shared_bytes);
}
// HIP
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
bool force_nc_cache, asm_map_t &asm_map){
llvm::LLVMContext ctx;
// Triton-IR -> NVPTX LLVM-IR // Triton-IR -> NVPTX LLVM-IR
triton::codegen::amd_cl_target target; triton::codegen::amd_cl_target target;
int n_shared_bytes; int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes); auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
llvm::raw_string_ostream llir(asm_map["llir"]); std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm; llir << *llvm;
llir.flush(); llir.flush();
// LLVM-IR -> Bin asm_map["llir"] = py::cast(tmp);
uint64_t mod, fun; // LLVM-IR -> HSA-CO
std::tie(mod, fun) = hip_compile_llir(name, &*llvm, device, asm_map); std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
return std::make_tuple(mod, fun, asm_map, n_shared_bytes); asm_map["hsaco"] = py::cast(path);
return std::make_tuple(name, asm_map, n_shared_bytes);
} }
},
py::return_value_policy::take_ownership); void init_triton_codegen(py::module &&m) {
m.def(
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate
asm_map_t asm_map;
std::ostringstream ttir;
ir::print(ir, ttir);
asm_map["ttir"] = py::cast(ttir.str());
llvm::LLVMContext ctx;
if(backend == CUDA)
return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
if(backend == ROCM)
return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
}, py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
if(backend == CUDA)
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
if(backend == ROCM)
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
}, py::return_value_policy::take_ownership);
} }
/*****************************************************************************/ /*****************************************************************************/

View File

@@ -5,6 +5,11 @@ import struct
import sys import sys
import tempfile import tempfile
import textwrap import textwrap
import hashlib
import atexit
import os
import shelve
from filelock import FileLock
import torch import torch
import triton import triton
@@ -411,23 +416,31 @@ class CodeGenerator(ast.NodeVisitor):
class Binary: class Binary:
def __init__(self, backend, module, kernel, asm, num_warps, num_stages, force_nc_cache, shared_mem): def __init__(self, backend, name, asm, shared_mem, num_warps):
# cache ir asm self.backend = backend
self.name = name
self.asm = asm self.asm = asm
self.module = module
self.kernel = kernel
self.shared_mem = shared_mem self.shared_mem = shared_mem
self.num_warps = num_warps self.num_warps = num_warps
self.num_stages = num_stages
self.force_nc_cache = force_nc_cache class LoadedBinary:
self.sass = None def __init__(self, device: int, bin: Binary):
self.backend = backend module, kernel = _triton.code_gen.load_binary(bin.backend,
bin.name,
bin.asm,
bin.shared_mem,
device)
self.bin = bin
self.asm = bin.asm
self.module = module
self.kernel = kernel
self.device = device
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
_triton.runtime.enqueue(self.backend, stream, self.kernel, _triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
grid_0, grid_1, grid_2, grid_0, grid_1, grid_2,
self.num_warps * 32, 1, 1, self.bin.num_warps * 32, 1, 1,
args, self.shared_mem) args, self.bin.shared_mem)
class CompilationError(Exception): class CompilationError(Exception):
@@ -536,11 +549,11 @@ class Kernel:
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
else: else:
backend = _triton.runtime.backend.ROCM backend = _triton.runtime.backend.ROCM
mod, ker, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache) name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
max_shared_memory = _triton.runtime.max_shared_memory(backend, device) max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
if shared_mem > max_shared_memory: if shared_mem > max_shared_memory:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory") raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, mod, ker, asm, num_warps, num_stages, force_nc_cache, shared_mem) return Binary(backend, name, asm, shared_mem, num_warps)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
# device inference # device inference
@@ -579,29 +592,43 @@ class Kernel:
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)} attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
# determine if we need to re-compile # compute hash for caching this kernel
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = frozenset(attributes.items()) attr_key = frozenset(attributes.items())
meta_key = frozenset(meta.items()) meta_key = frozenset(meta.items())
const_key = frozenset(constants.items()) const_key = frozenset(constants.items())
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key) key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
cache = self.fn.cache key = repr(key)
if key not in cache: # get cached binary
# compile and cache configuration if necessary drv_cache = self.fn.drv_cache
cache[key] = self._compile( bin_cache_path = self.fn.bin_cache_path
bin_lock_path = self.fn.bin_lock_path
if key not in drv_cache:
binary = None
if bin_lock_path:
with FileLock(bin_lock_path):
with shelve.open(bin_cache_path) as db:
binary = db.get(key, None)
if binary is None:
binary = self._compile(
*wargs, device=device_idx, attributes=attributes, *wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
constants=constants, **meta constants=constants, **meta
) )
if bin_lock_path:
with FileLock(bin_lock_path):
with shelve.open(bin_cache_path) as db:
db[key] = binary
drv_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments # pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args) params = struct.pack(fmt, *args)
# enqueue cached function into stream # enqueue cached function into stream
binary = cache[key] callable = drv_cache[key]
stream = torch.cuda.current_stream(device_idx).cuda_stream stream = torch.cuda.current_stream(device_idx).cuda_stream
grid = grid(meta) if hasattr(grid, '__call__') else grid grid = grid(meta) if hasattr(grid, '__call__') else grid
binary(stream, params, *grid) callable(stream, params, *grid)
return binary return callable
class Launcher: class Launcher:
@@ -662,17 +689,59 @@ class Autotuner:
class JITFunction: class JITFunction:
# clear cache if the db is older than either the frontend or the backend
def _clear_cache(self):
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
with FileLock(self.bin_lock_path):
cache_mtime = os.path.getmtime(self.db_path)
if frontend_mtime > cache_mtime or backend_mtime > cache_mtime:
os.remove(self.db_path)
def _init_cache_paths(self):
# fetch cache directory path
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
if not cache_dir:
self.bin_cache_path = None
self.db_path = None
self.bin_lock_path = None
return
# create cache directory
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# create md5 hash of src
md5 = hashlib.md5()
md5.update(self.src.encode('utf-8'))
md5_hash = md5.hexdigest()
# load dbm file in cache_dir for md5_hash
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
self.db_path = self.bin_cache_path + '.db'
self.bin_lock_path = self.bin_cache_path + '.lock'
# if bin_cache_path exists
if os.path.exists(self.db_path):
self._clear_cache()
def __init__(self, fn): def __init__(self, fn):
# information of wrapped function
self.fn = fn self.fn = fn
self.module = fn.__module__ self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args self.arg_names = inspect.getfullargspec(fn).args
self.cache = dict()
self.kernel_decorators = []
self.src = textwrap.dedent(inspect.getsource(fn)) self.src = textwrap.dedent(inspect.getsource(fn))
# cache for callable driver objects (e.g. CUkernel)
self.drv_cache = dict()
# on-disk paths for the binary cache and corresponding
# file-lock
self._init_cache_paths()
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None self.kernel = None
# forward docs
self.__doc__ = fn.__doc__ self.__doc__ = fn.__doc__
# we do not parse in the constructor because
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically. # the user might want to monkey-patch self.src dynamically.
# Some unit tests do this, for example. # Some unit tests do this, for example.
def parse(self): def parse(self):
@@ -699,10 +768,16 @@ class JITFunction:
raise e raise e
raise CompilationError(self.src, node, e) raise CompilationError(self.src, node, e)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
# - when kernel decorators change, cached kernel
# needs to be cleared
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name == 'kernel_decorators': if name == 'kernel_decorators':
self.kernel = None self.kernel = None
super(JITFunction, self).__setattr__(name, value) super(JITFunction, self).__setattr__(name, value)
if name == 'src':
self._init_cache_paths()
def _init_kernel(self): def _init_kernel(self):
if self.kernel is None: if self.kernel is None: