[FRONTEND] Simplified detection of corrupted cache (#295)
This commit is contained in:
@@ -602,12 +602,34 @@ class Kernel:
|
|||||||
key = repr(key)
|
key = repr(key)
|
||||||
# get cached binary
|
# get cached binary
|
||||||
drv_cache = self.fn.drv_cache
|
drv_cache = self.fn.drv_cache
|
||||||
|
bin_mut_path = self.fn.bin_mut_path
|
||||||
bin_cache_path = self.fn.bin_cache_path
|
bin_cache_path = self.fn.bin_cache_path
|
||||||
bin_lock_path = self.fn.bin_lock_path
|
bin_lock_path = self.fn.bin_lock_path
|
||||||
if key not in drv_cache:
|
if key not in drv_cache:
|
||||||
binary = None
|
binary = None
|
||||||
if bin_lock_path:
|
if bin_lock_path:
|
||||||
with FileLock(bin_lock_path):
|
with FileLock(bin_lock_path):
|
||||||
|
dbtype = dbm.whichdb(bin_cache_path)
|
||||||
|
# handle stale/corrupted cache if it exists
|
||||||
|
if dbtype is not None:
|
||||||
|
# some db types can create multiple files
|
||||||
|
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
|
||||||
|
'dbm.dumb': ['dir', 'dat']}[dbtype]
|
||||||
|
db_paths = [bin_cache_path + ext for ext in exts]
|
||||||
|
# check if the cache is stale
|
||||||
|
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
|
||||||
|
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
|
||||||
|
cache_mtime = max([os.path.getmtime(db) for db in db_paths])
|
||||||
|
is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime
|
||||||
|
# check if the cache is corrupted
|
||||||
|
is_corrupted = os.path.exists(bin_mut_path)
|
||||||
|
# delete the cache if stale or corrupted
|
||||||
|
if is_stale or is_corrupted:
|
||||||
|
for db in db_paths:
|
||||||
|
os.remove(db)
|
||||||
|
if is_corrupted:
|
||||||
|
os.remove(bin_mut_path)
|
||||||
|
# read the cache, creating if needed
|
||||||
with shelve.open(bin_cache_path) as db:
|
with shelve.open(bin_cache_path) as db:
|
||||||
binary = db.get(key, None)
|
binary = db.get(key, None)
|
||||||
if binary is None:
|
if binary is None:
|
||||||
@@ -618,30 +640,10 @@ class Kernel:
|
|||||||
)
|
)
|
||||||
if bin_lock_path:
|
if bin_lock_path:
|
||||||
with FileLock(bin_lock_path):
|
with FileLock(bin_lock_path):
|
||||||
dbtype = dbm.whichdb(bin_cache_path)
|
open(bin_mut_path, 'a').close()
|
||||||
# extension of file(s) created by db
|
with shelve.open(bin_cache_path) as db:
|
||||||
ext = {'dbm.gnu': '', 'dbm.ndbm': '.db'}[dbtype]
|
db[key] = binary
|
||||||
dbpath = bin_cache_path + ext
|
os.remove(bin_mut_path)
|
||||||
# create temporary file(s)
|
|
||||||
try:
|
|
||||||
dbdir = os.path.dirname(os.path.abspath(dbpath))
|
|
||||||
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext, dir=dbdir)
|
|
||||||
tmp.close()
|
|
||||||
# move data-base to temporary file(s)
|
|
||||||
# do not copy as it can be expensive
|
|
||||||
# so it's probably preferrable to have the whole
|
|
||||||
# cache wiped out in the rare event that
|
|
||||||
# the process is killed while updating it
|
|
||||||
os.rename(dbpath, tmp.name)
|
|
||||||
# write data to temporary file
|
|
||||||
with shelve.open(os.path.splitext(tmp.name)[0]) as db:
|
|
||||||
db[key] = binary
|
|
||||||
# move temporary file(s) to db
|
|
||||||
os.rename(tmp.name, dbpath)
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp.name):
|
|
||||||
os.remove(tmp.name)
|
|
||||||
|
|
||||||
|
|
||||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
@@ -716,21 +718,11 @@ class Autotuner:
|
|||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
|
||||||
# clear cache if the db is older than either the frontend or the backend
|
|
||||||
def _clear_cache(self):
|
|
||||||
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
|
|
||||||
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
|
|
||||||
with FileLock(self.bin_lock_path):
|
|
||||||
cache_mtime = os.path.getmtime(self.db_path)
|
|
||||||
if frontend_mtime > cache_mtime or backend_mtime > cache_mtime:
|
|
||||||
os.remove(self.db_path)
|
|
||||||
|
|
||||||
def _init_cache_paths(self):
|
def _init_cache_paths(self):
|
||||||
# fetch cache directory path
|
# fetch cache directory path
|
||||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||||
if not cache_dir:
|
if not cache_dir:
|
||||||
self.bin_cache_path = None
|
self.bin_cache_path = None
|
||||||
self.db_path = None
|
|
||||||
self.bin_lock_path = None
|
self.bin_lock_path = None
|
||||||
return
|
return
|
||||||
# create cache directory
|
# create cache directory
|
||||||
@@ -742,11 +734,8 @@ class JITFunction:
|
|||||||
md5_hash = md5.hexdigest()
|
md5_hash = md5.hexdigest()
|
||||||
# load dbm file in cache_dir for md5_hash
|
# load dbm file in cache_dir for md5_hash
|
||||||
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
|
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
|
||||||
self.db_path = self.bin_cache_path + '.db'
|
|
||||||
self.bin_lock_path = self.bin_cache_path + '.lock'
|
self.bin_lock_path = self.bin_cache_path + '.lock'
|
||||||
# if bin_cache_path exists
|
self.bin_mut_path = self.bin_cache_path + '.mutating'
|
||||||
if os.path.exists(self.db_path):
|
|
||||||
self._clear_cache()
|
|
||||||
|
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
|
Reference in New Issue
Block a user