diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 2eec6248d..598dd04a9 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -6,16 +6,17 @@ import sys import tempfile import textwrap import hashlib -import atexit import os import shelve -from filelock import FileLock - +import shutil +import os +from .tools.disasm import extract +import tempfile import torch import triton import triton._C.libtriton.triton as _triton - -from .tools.disasm import extract +from filelock import FileLock +import dbm class CodeGenerator(ast.NodeVisitor): @@ -617,8 +618,31 @@ class Kernel: ) if bin_lock_path: with FileLock(bin_lock_path): - with shelve.open(bin_cache_path) as db: - db[key] = binary + dbtype = dbm.whichdb(bin_cache_path) + # extension of file(s) created by db + ext = {'dbm.gnu': '', 'dbm.ndbm': '.db'}[dbtype] + dbpath = bin_cache_path + ext + # create temporary file(s) + try: + dbdir = os.path.dirname(os.path.abspath(dbpath)) + tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext, dir=dbdir) + tmp.close() + # move data-base to temporary file(s) + # do not copy as it can be expensive + # so it's probably preferrable to have the whole + # cache wiped out in the rare event that + # the process is killed while updating it + os.rename(dbpath, tmp.name) + # write data to temporary file + with shelve.open(os.path.splitext(tmp.name)[0]) as db: + db[key] = binary + # move temporary file(s) to db + os.rename(tmp.name, dbpath) + finally: + if os.path.exists(tmp.name): + os.remove(tmp.name) + + drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) @@ -688,6 +712,8 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) + + class JITFunction: # clear cache if the db is older than either the frontend or the backend