[FRONTEND] Added on-disk cache for compiled kernels (#287)
This commit is contained in:
@@ -5,6 +5,11 @@ import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import hashlib
|
||||
import atexit
|
||||
import os
|
||||
import shelve
|
||||
from filelock import FileLock
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -411,23 +416,31 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, backend, module, kernel, asm, num_warps, num_stages, force_nc_cache, shared_mem):
|
||||
# cache ir asm
|
||||
def __init__(self, backend, name, asm, shared_mem, num_warps):
|
||||
self.backend = backend
|
||||
self.name = name
|
||||
self.asm = asm
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.force_nc_cache = force_nc_cache
|
||||
self.sass = None
|
||||
self.backend = backend
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.device = device
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
_triton.runtime.enqueue(self.backend, stream, self.kernel,
|
||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.num_warps * 32, 1, 1,
|
||||
args, self.shared_mem)
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
args, self.bin.shared_mem)
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
@@ -536,11 +549,11 @@ class Kernel:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
else:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
mod, ker, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
if shared_mem > max_shared_memory:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, mod, ker, asm, num_warps, num_stages, force_nc_cache, shared_mem)
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
# device inference
|
||||
@@ -579,29 +592,43 @@ class Kernel:
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
# determine if we need to re-compile
|
||||
# compute hash for caching this kernel
|
||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||
attr_key = frozenset(attributes.items())
|
||||
meta_key = frozenset(meta.items())
|
||||
const_key = frozenset(constants.items())
|
||||
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
cache = self.fn.cache
|
||||
if key not in cache:
|
||||
# compile and cache configuration if necessary
|
||||
cache[key] = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
constants=constants, **meta
|
||||
)
|
||||
key = repr(key)
|
||||
# get cached binary
|
||||
drv_cache = self.fn.drv_cache
|
||||
bin_cache_path = self.fn.bin_cache_path
|
||||
bin_lock_path = self.fn.bin_lock_path
|
||||
if key not in drv_cache:
|
||||
binary = None
|
||||
if bin_lock_path:
|
||||
with FileLock(bin_lock_path):
|
||||
with shelve.open(bin_cache_path) as db:
|
||||
binary = db.get(key, None)
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
constants=constants, **meta
|
||||
)
|
||||
if bin_lock_path:
|
||||
with FileLock(bin_lock_path):
|
||||
with shelve.open(bin_cache_path) as db:
|
||||
db[key] = binary
|
||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
binary = cache[key]
|
||||
callable = drv_cache[key]
|
||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||
binary(stream, params, *grid)
|
||||
return binary
|
||||
callable(stream, params, *grid)
|
||||
return callable
|
||||
|
||||
|
||||
class Launcher:
|
||||
@@ -662,17 +689,59 @@ class Autotuner:
|
||||
|
||||
|
||||
class JITFunction:
|
||||
|
||||
# clear cache if the db is older than either the frontend or the backend
|
||||
def _clear_cache(self):
|
||||
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
|
||||
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
|
||||
with FileLock(self.bin_lock_path):
|
||||
cache_mtime = os.path.getmtime(self.db_path)
|
||||
if frontend_mtime > cache_mtime or backend_mtime > cache_mtime:
|
||||
os.remove(self.db_path)
|
||||
|
||||
def _init_cache_paths(self):
|
||||
# fetch cache directory path
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||
if not cache_dir:
|
||||
self.bin_cache_path = None
|
||||
self.db_path = None
|
||||
self.bin_lock_path = None
|
||||
return
|
||||
# create cache directory
|
||||
if not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir)
|
||||
# create md5 hash of src
|
||||
md5 = hashlib.md5()
|
||||
md5.update(self.src.encode('utf-8'))
|
||||
md5_hash = md5.hexdigest()
|
||||
# load dbm file in cache_dir for md5_hash
|
||||
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
|
||||
self.db_path = self.bin_cache_path + '.db'
|
||||
self.bin_lock_path = self.bin_cache_path + '.lock'
|
||||
# if bin_cache_path exists
|
||||
if os.path.exists(self.db_path):
|
||||
self._clear_cache()
|
||||
|
||||
def __init__(self, fn):
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.arg_names = inspect.getfullargspec(fn).args
|
||||
self.cache = dict()
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.drv_cache = dict()
|
||||
# on-disk paths for the binary cache and corresponding
|
||||
# file-lock
|
||||
self._init_cache_paths()
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.kernel = None
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
|
||||
# we do not parse in the constructor because
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Some unit tests do this, for example.
|
||||
def parse(self):
|
||||
@@ -699,10 +768,16 @@ class JITFunction:
|
||||
raise e
|
||||
raise CompilationError(self.src, node, e)
|
||||
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
def __setattr__(self, name, value):
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
if name == 'src':
|
||||
self._init_cache_paths()
|
||||
|
||||
def _init_kernel(self):
|
||||
if self.kernel is None:
|
||||
|
Reference in New Issue
Block a user