From 2066ccd87e85d7192f26fb8c840ebfb7b4eccc22 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Thu, 23 Sep 2021 20:21:19 -0700 Subject: [PATCH] [FRONTEND] single file caches (#306) Co-authored-by: hauntsaninja <> --- python/triton/code_gen.py | 92 ++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 54 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 857bfd4cf..93b6d4681 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -7,7 +7,7 @@ import sys import textwrap import hashlib import os -import shelve +import pickle import subprocess import os from .tools.disasm import extract @@ -598,53 +598,55 @@ class Kernel: const_key = tuple(constants.items()) compute_capability = torch.cuda.get_device_capability(device) - key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key) - key = repr(key) + key = ( + self.fn.cache_key, version_key(), compute_capability, + types_key, attr_key, num_warps, num_stages, meta_key, const_key + ) + key_str = repr(key) + # get cached binary drv_cache = self.fn.drv_cache - bin_mut_path = self.fn.bin_mut_path - bin_cache_path = self.fn.bin_cache_path - bin_lock_path = self.fn.bin_lock_path - if key not in drv_cache: + + if key_str not in drv_cache: + hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + binary = None - if bin_lock_path: + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None with FileLock(bin_lock_path): - dbtype = dbm.whichdb(bin_cache_path) - # handle stale/corrupted cache if it exists - if dbtype is not None: - # some db types can create multiple files - exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'], - 'dbm.dumb': ['.dir', '.dat']}[dbtype] - db_paths = [bin_cache_path + ext for ext in exts] - # check if the cache is corrupted - is_corrupted = os.path.exists(bin_mut_path) - # delete the cache if corrupted - if is_corrupted: - for db in db_paths: - os.remove(db) - os.remove(bin_mut_path) - # read the cache, creating if needed - with shelve.open(bin_cache_path) as db: - binary = db.get(key, None) + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] if binary is None: binary = self._compile( *wargs, device=device_idx, attributes=attributes, num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, constants=constants, **meta ) - if bin_lock_path: + if bin_cache_path: + assert bin_lock_path is not None with FileLock(bin_lock_path): - open(bin_mut_path, 'a').close() - with shelve.open(bin_cache_path) as db: - db[key] = binary - os.remove(bin_mut_path) + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) - drv_cache[key] = LoadedBinary(device_idx, binary) + drv_cache[key_str] = LoadedBinary(device_idx, binary) # pack arguments fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) params = struct.pack(fmt, *args) # enqueue cached function into stream - callable = drv_cache[key] + callable = drv_cache[key_str] stream = torch.cuda.current_stream(device_idx).cuda_stream grid = grid(meta) if hasattr(grid, '__call__') else grid callable(stream, params, *grid) @@ -730,23 +732,8 @@ def version_key(): ) class JITFunction: - - def _init_cache_paths(self): - # fetch cache directory path - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if not cache_dir: - self.bin_cache_path = None - self.bin_lock_path = None - return - # create cache directory - if not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - # paths for dbm file in cache_dir - cache_key = (self.version, self.src) + version_key() - cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest() - self.bin_cache_path = os.path.join(cache_dir, cache_key_str) - self.bin_lock_path = self.bin_cache_path + '.lock' - self.bin_mut_path = self.bin_cache_path + '.mutating' + def _set_cache_key(self): + self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version) def __init__(self, fn, version=None): # information of wrapped function @@ -757,10 +744,7 @@ class JITFunction: self.src = textwrap.dedent(inspect.getsource(fn)) # cache for callable driver objects (e.g. CUkernel) self.drv_cache = dict() - - # on-disk paths for the binary cache and corresponding - # file-lock - self._init_cache_paths() + self._set_cache_key() # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ @@ -806,7 +790,7 @@ class JITFunction: self.kernel = None super(JITFunction, self).__setattr__(name, value) if name == 'src': - self._init_cache_paths() + self._set_cache_key() def _init_kernel(self): if self.kernel is None: