[FRONTEND] Add cache_version to triton.jit (#301)

This commit is contained in:
Shantanu
2021-09-23 16:45:54 -07:00
committed by GitHub
parent 5211f23a63
commit d253eb8719
2 changed files with 55 additions and 30 deletions

View File

@@ -1,3 +1,6 @@
# version
__version__ = '1.0.0'
# TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
@@ -8,5 +11,3 @@ from . import language
from . import code_gen
from . import testing
from . import ops
# version
__version__ = '1.0.0'

View File

@@ -1,17 +1,16 @@
import ast
import builtins
import functools
import inspect
import struct
import sys
import tempfile
import textwrap
import hashlib
import os
import shelve
import shutil
import subprocess
import os
from .tools.disasm import extract
import tempfile
import torch
import triton
import triton._C.libtriton.triton as _triton
@@ -574,7 +573,6 @@ class Kernel:
" Only CUDA is supported at the moment")
device = torch.device('cuda', torch.cuda.current_device())
device_ty = device.type
device_idx = device.index
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
# try to enable P2P communication
@@ -595,10 +593,12 @@ class Kernel:
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
# compute hash for caching this kernel
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = frozenset(attributes.items())
meta_key = frozenset(meta.items())
const_key = frozenset(constants.items())
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
attr_key = tuple(attributes.items())
meta_key = tuple(sorted(meta.items()))
const_key = tuple(constants.items())
compute_capability = torch.cuda.get_device_capability(device)
key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
key = repr(key)
# get cached binary
drv_cache = self.fn.drv_cache
@@ -616,18 +616,12 @@ class Kernel:
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
'dbm.dumb': ['.dir', '.dat']}[dbtype]
db_paths = [bin_cache_path + ext for ext in exts]
# check if the cache is stale
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
cache_mtime = max([os.path.getmtime(db) for db in db_paths])
is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime
# check if the cache is corrupted
is_corrupted = os.path.exists(bin_mut_path)
# delete the cache if stale or corrupted
if is_stale or is_corrupted:
# delete the cache if corrupted
if is_corrupted:
for db in db_paths:
os.remove(db)
if is_corrupted:
os.remove(bin_mut_path)
# read the cache, creating if needed
with shelve.open(bin_cache_path) as db:
@@ -714,7 +708,26 @@ class Autotuner:
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
@functools.lru_cache()
def version_key():
with open(triton.code_gen.__file__, "rb") as f:
frontend_contents = hashlib.md5(f.read()).hexdigest()
with open(triton._C.libtriton.__file__, "rb") as f:
backend_contents = hashlib.md5(f.read()).hexdigest()
try:
nvcc_version = hashlib.md5(subprocess.check_output(["nvcc", "--version"])).hexdigest()
except Exception:
nvcc_version = None
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = None
return (
triton.__version__, frontend_contents, backend_contents,
nvcc_version, ptxas_version
)
class JITFunction:
@@ -728,26 +741,27 @@ class JITFunction:
# create cache directory
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
# create md5 hash of src
md5 = hashlib.md5()
md5.update(self.src.encode('utf-8'))
md5_hash = md5.hexdigest()
# load dbm file in cache_dir for md5_hash
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
self.bin_lock_path = self.bin_cache_path + '.lock'
# paths for dbm file in cache_dir
cache_key = (self.version, self.src) + version_key()
cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest()
self.bin_cache_path = os.path.join(cache_dir, cache_key_str)
self.bin_lock_path = self.bin_cache_path + '.lock'
self.bin_mut_path = self.bin_cache_path + '.mutating'
def __init__(self, fn):
def __init__(self, fn, version=None):
# information of wrapped function
self.fn = fn
self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args
self.src = textwrap.dedent(inspect.getsource(fn))
self.version = version
self.src = textwrap.dedent(inspect.getsource(fn))
# cache for callable driver objects (e.g. CUkernel)
self.drv_cache = dict()
# on-disk paths for the binary cache and corresponding
# file-lock
self._init_cache_paths()
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
@@ -804,6 +818,9 @@ class JITFunction:
def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid)
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
class Config:
"""
@@ -899,7 +916,7 @@ def heuristics(values):
return decorator
def jit(fn):
def jit(*args, **kwargs):
"""
Decorator for JIT-compiling a function using the Triton compiler.
@@ -911,11 +928,18 @@ def jit(fn):
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
return JITFunction(fn)
if args:
assert len(args) == 1
assert callable(args[0])
return JITFunction(args[0], **kwargs)
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator
######