[FRONTEND] Add cache_version to triton.jit (#301)
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
# version
|
||||
__version__ = '1.0.0'
|
||||
|
||||
# TODO: torch needs to be imported first
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
@@ -8,5 +11,3 @@ from . import language
|
||||
from . import code_gen
|
||||
from . import testing
|
||||
from . import ops
|
||||
# version
|
||||
__version__ = '1.0.0'
|
||||
|
@@ -1,17 +1,16 @@
|
||||
import ast
|
||||
import builtins
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import hashlib
|
||||
import os
|
||||
import shelve
|
||||
import shutil
|
||||
import subprocess
|
||||
import os
|
||||
from .tools.disasm import extract
|
||||
import tempfile
|
||||
import torch
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
@@ -574,7 +573,6 @@ class Kernel:
|
||||
" Only CUDA is supported at the moment")
|
||||
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
device_ty = device.type
|
||||
device_idx = device.index
|
||||
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||
# try to enable P2P communication
|
||||
@@ -595,10 +593,12 @@ class Kernel:
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
# compute hash for caching this kernel
|
||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||
attr_key = frozenset(attributes.items())
|
||||
meta_key = frozenset(meta.items())
|
||||
const_key = frozenset(constants.items())
|
||||
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
attr_key = tuple(attributes.items())
|
||||
meta_key = tuple(sorted(meta.items()))
|
||||
const_key = tuple(constants.items())
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
|
||||
key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
key = repr(key)
|
||||
# get cached binary
|
||||
drv_cache = self.fn.drv_cache
|
||||
@@ -616,18 +616,12 @@ class Kernel:
|
||||
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
|
||||
'dbm.dumb': ['.dir', '.dat']}[dbtype]
|
||||
db_paths = [bin_cache_path + ext for ext in exts]
|
||||
# check if the cache is stale
|
||||
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
|
||||
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
|
||||
cache_mtime = max([os.path.getmtime(db) for db in db_paths])
|
||||
is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime
|
||||
# check if the cache is corrupted
|
||||
is_corrupted = os.path.exists(bin_mut_path)
|
||||
# delete the cache if stale or corrupted
|
||||
if is_stale or is_corrupted:
|
||||
# delete the cache if corrupted
|
||||
if is_corrupted:
|
||||
for db in db_paths:
|
||||
os.remove(db)
|
||||
if is_corrupted:
|
||||
os.remove(bin_mut_path)
|
||||
# read the cache, creating if needed
|
||||
with shelve.open(bin_cache_path) as db:
|
||||
@@ -714,7 +708,26 @@ class Autotuner:
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
with open(triton.code_gen.__file__, "rb") as f:
|
||||
frontend_contents = hashlib.md5(f.read()).hexdigest()
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
backend_contents = hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
try:
|
||||
nvcc_version = hashlib.md5(subprocess.check_output(["nvcc", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
nvcc_version = None
|
||||
try:
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = None
|
||||
|
||||
return (
|
||||
triton.__version__, frontend_contents, backend_contents,
|
||||
nvcc_version, ptxas_version
|
||||
)
|
||||
|
||||
class JITFunction:
|
||||
|
||||
@@ -728,26 +741,27 @@ class JITFunction:
|
||||
# create cache directory
|
||||
if not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
# create md5 hash of src
|
||||
md5 = hashlib.md5()
|
||||
md5.update(self.src.encode('utf-8'))
|
||||
md5_hash = md5.hexdigest()
|
||||
# load dbm file in cache_dir for md5_hash
|
||||
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
|
||||
self.bin_lock_path = self.bin_cache_path + '.lock'
|
||||
# paths for dbm file in cache_dir
|
||||
cache_key = (self.version, self.src) + version_key()
|
||||
cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest()
|
||||
self.bin_cache_path = os.path.join(cache_dir, cache_key_str)
|
||||
self.bin_lock_path = self.bin_cache_path + '.lock'
|
||||
self.bin_mut_path = self.bin_cache_path + '.mutating'
|
||||
|
||||
def __init__(self, fn):
|
||||
def __init__(self, fn, version=None):
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.arg_names = inspect.getfullargspec(fn).args
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.version = version
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.drv_cache = dict()
|
||||
|
||||
# on-disk paths for the binary cache and corresponding
|
||||
# file-lock
|
||||
self._init_cache_paths()
|
||||
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
@@ -804,6 +818,9 @@ class JITFunction:
|
||||
def __getitem__(self, grid):
|
||||
return Launcher(self._init_kernel(), grid)
|
||||
|
||||
def __repr__(self):
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
@@ -899,7 +916,7 @@ def heuristics(values):
|
||||
return decorator
|
||||
|
||||
|
||||
def jit(fn):
|
||||
def jit(*args, **kwargs):
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
@@ -911,11 +928,18 @@ def jit(fn):
|
||||
* objects within the triton.language package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
return JITFunction(fn)
|
||||
if args:
|
||||
assert len(args) == 1
|
||||
assert callable(args[0])
|
||||
return JITFunction(args[0], **kwargs)
|
||||
else:
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
|
||||
|
||||
######
|
||||
|
Reference in New Issue
Block a user