[FRONTEND] Compute cache now supports atomic writes (#294)
Note that killing a Triton process while it updates the cache will result in the cache being wiped out. This is because copying a whole `db` to a temporary file can be quite expensive on some systems.
This commit is contained in:
@@ -6,16 +6,17 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
import hashlib
|
import hashlib
|
||||||
import atexit
|
|
||||||
import os
|
import os
|
||||||
import shelve
|
import shelve
|
||||||
from filelock import FileLock
|
import shutil
|
||||||
|
import os
|
||||||
|
from .tools.disasm import extract
|
||||||
|
import tempfile
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
|
from filelock import FileLock
|
||||||
from .tools.disasm import extract
|
import dbm
|
||||||
|
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
@@ -617,8 +618,31 @@ class Kernel:
|
|||||||
)
|
)
|
||||||
if bin_lock_path:
|
if bin_lock_path:
|
||||||
with FileLock(bin_lock_path):
|
with FileLock(bin_lock_path):
|
||||||
with shelve.open(bin_cache_path) as db:
|
dbtype = dbm.whichdb(bin_cache_path)
|
||||||
db[key] = binary
|
# extension of file(s) created by db
|
||||||
|
ext = {'dbm.gnu': '', 'dbm.ndbm': '.db'}[dbtype]
|
||||||
|
dbpath = bin_cache_path + ext
|
||||||
|
# create temporary file(s)
|
||||||
|
try:
|
||||||
|
dbdir = os.path.dirname(os.path.abspath(dbpath))
|
||||||
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext, dir=dbdir)
|
||||||
|
tmp.close()
|
||||||
|
# move data-base to temporary file(s)
|
||||||
|
# do not copy as it can be expensive
|
||||||
|
# so it's probably preferrable to have the whole
|
||||||
|
# cache wiped out in the rare event that
|
||||||
|
# the process is killed while updating it
|
||||||
|
os.rename(dbpath, tmp.name)
|
||||||
|
# write data to temporary file
|
||||||
|
with shelve.open(os.path.splitext(tmp.name)[0]) as db:
|
||||||
|
db[key] = binary
|
||||||
|
# move temporary file(s) to db
|
||||||
|
os.rename(tmp.name, dbpath)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp.name):
|
||||||
|
os.remove(tmp.name)
|
||||||
|
|
||||||
|
|
||||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||||
@@ -688,6 +712,8 @@ class Autotuner:
|
|||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
|
||||||
# clear cache if the db is older than either the frontend or the backend
|
# clear cache if the db is older than either the frontend or the backend
|
||||||
|
Reference in New Issue
Block a user