[FRONTEND] Compute cache now supports atomic writes (#294)
Note that killing a Triton process while it updates the cache will result in the cache being wiped out. This is because copying a whole `db` to a temporary file can be quite expensive on some systems.
This commit is contained in:
@@ -6,16 +6,17 @@ import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import hashlib
|
||||
import atexit
|
||||
import os
|
||||
import shelve
|
||||
from filelock import FileLock
|
||||
|
||||
import shutil
|
||||
import os
|
||||
from .tools.disasm import extract
|
||||
import tempfile
|
||||
import torch
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
from .tools.disasm import extract
|
||||
from filelock import FileLock
|
||||
import dbm
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -617,8 +618,31 @@ class Kernel:
|
||||
)
|
||||
if bin_lock_path:
|
||||
with FileLock(bin_lock_path):
|
||||
with shelve.open(bin_cache_path) as db:
|
||||
dbtype = dbm.whichdb(bin_cache_path)
|
||||
# extension of file(s) created by db
|
||||
ext = {'dbm.gnu': '', 'dbm.ndbm': '.db'}[dbtype]
|
||||
dbpath = bin_cache_path + ext
|
||||
# create temporary file(s)
|
||||
try:
|
||||
dbdir = os.path.dirname(os.path.abspath(dbpath))
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext, dir=dbdir)
|
||||
tmp.close()
|
||||
# move data-base to temporary file(s)
|
||||
# do not copy as it can be expensive
|
||||
# so it's probably preferrable to have the whole
|
||||
# cache wiped out in the rare event that
|
||||
# the process is killed while updating it
|
||||
os.rename(dbpath, tmp.name)
|
||||
# write data to temporary file
|
||||
with shelve.open(os.path.splitext(tmp.name)[0]) as db:
|
||||
db[key] = binary
|
||||
# move temporary file(s) to db
|
||||
os.rename(tmp.name, dbpath)
|
||||
finally:
|
||||
if os.path.exists(tmp.name):
|
||||
os.remove(tmp.name)
|
||||
|
||||
|
||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
@@ -688,6 +712,8 @@ class Autotuner:
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
||||
|
||||
|
||||
|
||||
|
||||
class JITFunction:
|
||||
|
||||
# clear cache if the db is older than either the frontend or the backend
|
||||
|
Reference in New Issue
Block a user