[FRONTEND] single file caches (#306)

Co-authored-by: hauntsaninja <>
This commit is contained in:
Shantanu
2021-09-23 20:21:19 -07:00
committed by GitHub
parent e22d92c63c
commit 2066ccd87e

View File

@@ -7,7 +7,7 @@ import sys
import textwrap import textwrap
import hashlib import hashlib
import os import os
import shelve import pickle
import subprocess import subprocess
import os import os
from .tools.disasm import extract from .tools.disasm import extract
@@ -598,53 +598,55 @@ class Kernel:
const_key = tuple(constants.items()) const_key = tuple(constants.items())
compute_capability = torch.cuda.get_device_capability(device) compute_capability = torch.cuda.get_device_capability(device)
key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key) key = (
key = repr(key) self.fn.cache_key, version_key(), compute_capability,
types_key, attr_key, num_warps, num_stages, meta_key, const_key
)
key_str = repr(key)
# get cached binary # get cached binary
drv_cache = self.fn.drv_cache drv_cache = self.fn.drv_cache
bin_mut_path = self.fn.bin_mut_path
bin_cache_path = self.fn.bin_cache_path if key_str not in drv_cache:
bin_lock_path = self.fn.bin_lock_path hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest()
if key not in drv_cache:
# create cache directory
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
if cache_dir:
bin_cache_path = os.path.join(cache_dir, hashed_key)
bin_lock_path = bin_cache_path + ".lock"
else:
bin_cache_path = None
bin_lock_path = None
binary = None binary = None
if bin_lock_path: if bin_cache_path and os.path.exists(bin_cache_path):
assert bin_lock_path is not None
with FileLock(bin_lock_path): with FileLock(bin_lock_path):
dbtype = dbm.whichdb(bin_cache_path) with open(bin_cache_path, 'rb') as f:
# handle stale/corrupted cache if it exists binary = pickle.load(f)["binary"]
if dbtype is not None:
# some db types can create multiple files
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
'dbm.dumb': ['.dir', '.dat']}[dbtype]
db_paths = [bin_cache_path + ext for ext in exts]
# check if the cache is corrupted
is_corrupted = os.path.exists(bin_mut_path)
# delete the cache if corrupted
if is_corrupted:
for db in db_paths:
os.remove(db)
os.remove(bin_mut_path)
# read the cache, creating if needed
with shelve.open(bin_cache_path) as db:
binary = db.get(key, None)
if binary is None: if binary is None:
binary = self._compile( binary = self._compile(
*wargs, device=device_idx, attributes=attributes, *wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
constants=constants, **meta constants=constants, **meta
) )
if bin_lock_path: if bin_cache_path:
assert bin_lock_path is not None
with FileLock(bin_lock_path): with FileLock(bin_lock_path):
open(bin_mut_path, 'a').close() with open(bin_cache_path + ".tmp", "wb") as f:
with shelve.open(bin_cache_path) as db: pickle.dump({"binary": binary, "key": key}, f)
db[key] = binary os.rename(bin_cache_path + ".tmp", bin_cache_path)
os.remove(bin_mut_path)
drv_cache[key] = LoadedBinary(device_idx, binary) drv_cache[key_str] = LoadedBinary(device_idx, binary)
# pack arguments # pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args) params = struct.pack(fmt, *args)
# enqueue cached function into stream # enqueue cached function into stream
callable = drv_cache[key] callable = drv_cache[key_str]
stream = torch.cuda.current_stream(device_idx).cuda_stream stream = torch.cuda.current_stream(device_idx).cuda_stream
grid = grid(meta) if hasattr(grid, '__call__') else grid grid = grid(meta) if hasattr(grid, '__call__') else grid
callable(stream, params, *grid) callable(stream, params, *grid)
@@ -730,23 +732,8 @@ def version_key():
) )
class JITFunction: class JITFunction:
def _set_cache_key(self):
def _init_cache_paths(self): self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
# fetch cache directory path
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
if not cache_dir:
self.bin_cache_path = None
self.bin_lock_path = None
return
# create cache directory
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
# paths for dbm file in cache_dir
cache_key = (self.version, self.src) + version_key()
cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest()
self.bin_cache_path = os.path.join(cache_dir, cache_key_str)
self.bin_lock_path = self.bin_cache_path + '.lock'
self.bin_mut_path = self.bin_cache_path + '.mutating'
def __init__(self, fn, version=None): def __init__(self, fn, version=None):
# information of wrapped function # information of wrapped function
@@ -757,10 +744,7 @@ class JITFunction:
self.src = textwrap.dedent(inspect.getsource(fn)) self.src = textwrap.dedent(inspect.getsource(fn))
# cache for callable driver objects (e.g. CUkernel) # cache for callable driver objects (e.g. CUkernel)
self.drv_cache = dict() self.drv_cache = dict()
self._set_cache_key()
# on-disk paths for the binary cache and corresponding
# file-lock
self._init_cache_paths()
# JITFunction can be instantiated as kernel # JITFunction can be instantiated as kernel
# when called with a grid using __getitem__ # when called with a grid using __getitem__
@@ -806,7 +790,7 @@ class JITFunction:
self.kernel = None self.kernel = None
super(JITFunction, self).__setattr__(name, value) super(JITFunction, self).__setattr__(name, value)
if name == 'src': if name == 'src':
self._init_cache_paths() self._set_cache_key()
def _init_kernel(self): def _init_kernel(self):
if self.kernel is None: if self.kernel is None: