diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 598dd04a9..7d9c66542 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -602,12 +602,34 @@ class Kernel: key = repr(key) # get cached binary drv_cache = self.fn.drv_cache + bin_mut_path = self.fn.bin_mut_path bin_cache_path = self.fn.bin_cache_path bin_lock_path = self.fn.bin_lock_path if key not in drv_cache: binary = None if bin_lock_path: with FileLock(bin_lock_path): + dbtype = dbm.whichdb(bin_cache_path) + # handle stale/corrupted cache if it exists + if dbtype is not None: + # some db types can create multiple files + exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'], + 'dbm.dumb': ['dir', 'dat']}[dbtype] + db_paths = [bin_cache_path + ext for ext in exts] + # check if the cache is stale + frontend_mtime = os.path.getmtime(triton.code_gen.__file__) + backend_mtime = os.path.getmtime(triton._C.libtriton.__file__) + cache_mtime = max([os.path.getmtime(db) for db in db_paths]) + is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime + # check if the cache is corrupted + is_corrupted = os.path.exists(bin_mut_path) + # delete the cache if stale or corrupted + if is_stale or is_corrupted: + for db in db_paths: + os.remove(db) + if is_corrupted: + os.remove(bin_mut_path) + # read the cache, creating if needed with shelve.open(bin_cache_path) as db: binary = db.get(key, None) if binary is None: @@ -618,30 +640,10 @@ class Kernel: ) if bin_lock_path: with FileLock(bin_lock_path): - dbtype = dbm.whichdb(bin_cache_path) - # extension of file(s) created by db - ext = {'dbm.gnu': '', 'dbm.ndbm': '.db'}[dbtype] - dbpath = bin_cache_path + ext - # create temporary file(s) - try: - dbdir = os.path.dirname(os.path.abspath(dbpath)) - tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext, dir=dbdir) - tmp.close() - # move data-base to temporary file(s) - # do not copy as it can be expensive - # so it's probably preferrable to have the whole - # cache wiped out in the rare event that - # the process is killed while updating it - os.rename(dbpath, tmp.name) - # write data to temporary file - with shelve.open(os.path.splitext(tmp.name)[0]) as db: - db[key] = binary - # move temporary file(s) to db - os.rename(tmp.name, dbpath) - finally: - if os.path.exists(tmp.name): - os.remove(tmp.name) - + open(bin_mut_path, 'a').close() + with shelve.open(bin_cache_path) as db: + db[key] = binary + os.remove(bin_mut_path) drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments @@ -716,21 +718,11 @@ class Autotuner: class JITFunction: - # clear cache if the db is older than either the frontend or the backend - def _clear_cache(self): - frontend_mtime = os.path.getmtime(triton.code_gen.__file__) - backend_mtime = os.path.getmtime(triton._C.libtriton.__file__) - with FileLock(self.bin_lock_path): - cache_mtime = os.path.getmtime(self.db_path) - if frontend_mtime > cache_mtime or backend_mtime > cache_mtime: - os.remove(self.db_path) - def _init_cache_paths(self): # fetch cache directory path cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') if not cache_dir: self.bin_cache_path = None - self.db_path = None self.bin_lock_path = None return # create cache directory @@ -742,11 +734,8 @@ class JITFunction: md5_hash = md5.hexdigest() # load dbm file in cache_dir for md5_hash self.bin_cache_path = os.path.join(cache_dir, md5_hash) - self.db_path = self.bin_cache_path + '.db' self.bin_lock_path = self.bin_cache_path + '.lock' - # if bin_cache_path exists - if os.path.exists(self.db_path): - self._clear_cache() + self.bin_mut_path = self.bin_cache_path + '.mutating' def __init__(self, fn): # information of wrapped function