[FRONTEND] Compute cache now supports atomic writes (#294)

Note that killing a Triton process while it updates the cache will result in the cache being wiped out. This is because copying a whole `db` to a temporary file can be quite expensive on some systems.
This commit is contained in:
Philippe Tillet
2021-09-21 14:10:02 -07:00
committed by GitHub
parent b53f5f3803
commit e96edc16ff

View File

@@ -6,16 +6,17 @@ import sys
import tempfile import tempfile
import textwrap import textwrap
import hashlib import hashlib
import atexit
import os import os
import shelve import shelve
from filelock import FileLock import shutil
import os
from .tools.disasm import extract
import tempfile
import torch import torch
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from filelock import FileLock
from .tools.disasm import extract import dbm
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
@@ -617,8 +618,31 @@ class Kernel:
) )
if bin_lock_path: if bin_lock_path:
with FileLock(bin_lock_path): with FileLock(bin_lock_path):
with shelve.open(bin_cache_path) as db: dbtype = dbm.whichdb(bin_cache_path)
db[key] = binary # extension of file(s) created by db
ext = {'dbm.gnu': '', 'dbm.ndbm': '.db'}[dbtype]
dbpath = bin_cache_path + ext
# create temporary file(s)
try:
dbdir = os.path.dirname(os.path.abspath(dbpath))
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext, dir=dbdir)
tmp.close()
# move data-base to temporary file(s)
# do not copy as it can be expensive
# so it's probably preferrable to have the whole
# cache wiped out in the rare event that
# the process is killed while updating it
os.rename(dbpath, tmp.name)
# write data to temporary file
with shelve.open(os.path.splitext(tmp.name)[0]) as db:
db[key] = binary
# move temporary file(s) to db
os.rename(tmp.name, dbpath)
finally:
if os.path.exists(tmp.name):
os.remove(tmp.name)
drv_cache[key] = LoadedBinary(device_idx, binary) drv_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments # pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
@@ -688,6 +712,8 @@ class Autotuner:
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
class JITFunction: class JITFunction:
# clear cache if the db is older than either the frontend or the backend # clear cache if the db is older than either the frontend or the backend