@@ -7,7 +7,7 @@ import sys
|
|||||||
import textwrap
|
import textwrap
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import shelve
|
import pickle
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
@@ -598,53 +598,55 @@ class Kernel:
|
|||||||
const_key = tuple(constants.items())
|
const_key = tuple(constants.items())
|
||||||
compute_capability = torch.cuda.get_device_capability(device)
|
compute_capability = torch.cuda.get_device_capability(device)
|
||||||
|
|
||||||
key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
key = (
|
||||||
key = repr(key)
|
self.fn.cache_key, version_key(), compute_capability,
|
||||||
|
types_key, attr_key, num_warps, num_stages, meta_key, const_key
|
||||||
|
)
|
||||||
|
key_str = repr(key)
|
||||||
|
|
||||||
# get cached binary
|
# get cached binary
|
||||||
drv_cache = self.fn.drv_cache
|
drv_cache = self.fn.drv_cache
|
||||||
bin_mut_path = self.fn.bin_mut_path
|
|
||||||
bin_cache_path = self.fn.bin_cache_path
|
if key_str not in drv_cache:
|
||||||
bin_lock_path = self.fn.bin_lock_path
|
hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest()
|
||||||
if key not in drv_cache:
|
|
||||||
|
# create cache directory
|
||||||
|
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||||
|
if cache_dir and not os.path.exists(cache_dir):
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if cache_dir:
|
||||||
|
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||||
|
bin_lock_path = bin_cache_path + ".lock"
|
||||||
|
else:
|
||||||
|
bin_cache_path = None
|
||||||
|
bin_lock_path = None
|
||||||
|
|
||||||
binary = None
|
binary = None
|
||||||
if bin_lock_path:
|
if bin_cache_path and os.path.exists(bin_cache_path):
|
||||||
|
assert bin_lock_path is not None
|
||||||
with FileLock(bin_lock_path):
|
with FileLock(bin_lock_path):
|
||||||
dbtype = dbm.whichdb(bin_cache_path)
|
with open(bin_cache_path, 'rb') as f:
|
||||||
# handle stale/corrupted cache if it exists
|
binary = pickle.load(f)["binary"]
|
||||||
if dbtype is not None:
|
|
||||||
# some db types can create multiple files
|
|
||||||
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
|
|
||||||
'dbm.dumb': ['.dir', '.dat']}[dbtype]
|
|
||||||
db_paths = [bin_cache_path + ext for ext in exts]
|
|
||||||
# check if the cache is corrupted
|
|
||||||
is_corrupted = os.path.exists(bin_mut_path)
|
|
||||||
# delete the cache if corrupted
|
|
||||||
if is_corrupted:
|
|
||||||
for db in db_paths:
|
|
||||||
os.remove(db)
|
|
||||||
os.remove(bin_mut_path)
|
|
||||||
# read the cache, creating if needed
|
|
||||||
with shelve.open(bin_cache_path) as db:
|
|
||||||
binary = db.get(key, None)
|
|
||||||
if binary is None:
|
if binary is None:
|
||||||
binary = self._compile(
|
binary = self._compile(
|
||||||
*wargs, device=device_idx, attributes=attributes,
|
*wargs, device=device_idx, attributes=attributes,
|
||||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||||
constants=constants, **meta
|
constants=constants, **meta
|
||||||
)
|
)
|
||||||
if bin_lock_path:
|
if bin_cache_path:
|
||||||
|
assert bin_lock_path is not None
|
||||||
with FileLock(bin_lock_path):
|
with FileLock(bin_lock_path):
|
||||||
open(bin_mut_path, 'a').close()
|
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||||
with shelve.open(bin_cache_path) as db:
|
pickle.dump({"binary": binary, "key": key}, f)
|
||||||
db[key] = binary
|
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||||
os.remove(bin_mut_path)
|
|
||||||
|
|
||||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
drv_cache[key_str] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||||
params = struct.pack(fmt, *args)
|
params = struct.pack(fmt, *args)
|
||||||
# enqueue cached function into stream
|
# enqueue cached function into stream
|
||||||
callable = drv_cache[key]
|
callable = drv_cache[key_str]
|
||||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||||
callable(stream, params, *grid)
|
callable(stream, params, *grid)
|
||||||
@@ -730,23 +732,8 @@ def version_key():
|
|||||||
)
|
)
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
def _set_cache_key(self):
|
||||||
def _init_cache_paths(self):
|
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
||||||
# fetch cache directory path
|
|
||||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
|
||||||
if not cache_dir:
|
|
||||||
self.bin_cache_path = None
|
|
||||||
self.bin_lock_path = None
|
|
||||||
return
|
|
||||||
# create cache directory
|
|
||||||
if not os.path.exists(cache_dir):
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
# paths for dbm file in cache_dir
|
|
||||||
cache_key = (self.version, self.src) + version_key()
|
|
||||||
cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest()
|
|
||||||
self.bin_cache_path = os.path.join(cache_dir, cache_key_str)
|
|
||||||
self.bin_lock_path = self.bin_cache_path + '.lock'
|
|
||||||
self.bin_mut_path = self.bin_cache_path + '.mutating'
|
|
||||||
|
|
||||||
def __init__(self, fn, version=None):
|
def __init__(self, fn, version=None):
|
||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
@@ -757,10 +744,7 @@ class JITFunction:
|
|||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
# cache for callable driver objects (e.g. CUkernel)
|
# cache for callable driver objects (e.g. CUkernel)
|
||||||
self.drv_cache = dict()
|
self.drv_cache = dict()
|
||||||
|
self._set_cache_key()
|
||||||
# on-disk paths for the binary cache and corresponding
|
|
||||||
# file-lock
|
|
||||||
self._init_cache_paths()
|
|
||||||
|
|
||||||
# JITFunction can be instantiated as kernel
|
# JITFunction can be instantiated as kernel
|
||||||
# when called with a grid using __getitem__
|
# when called with a grid using __getitem__
|
||||||
@@ -806,7 +790,7 @@ class JITFunction:
|
|||||||
self.kernel = None
|
self.kernel = None
|
||||||
super(JITFunction, self).__setattr__(name, value)
|
super(JITFunction, self).__setattr__(name, value)
|
||||||
if name == 'src':
|
if name == 'src':
|
||||||
self._init_cache_paths()
|
self._set_cache_key()
|
||||||
|
|
||||||
def _init_kernel(self):
|
def _init_kernel(self):
|
||||||
if self.kernel is None:
|
if self.kernel is None:
|
||||||
|
Reference in New Issue
Block a user