[FRONTEND] Add cache_version to triton.jit (#301)

This commit is contained in:
Shantanu
2021-09-23 16:45:54 -07:00
committed by GitHub
parent 5211f23a63
commit d253eb8719
2 changed files with 55 additions and 30 deletions

View File

@@ -1,3 +1,6 @@
# version
__version__ = '1.0.0'
# TODO: torch needs to be imported first # TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer` # or pybind11 shows `munmap_chunk(): invalid pointer`
import torch import torch
@@ -8,5 +11,3 @@ from . import language
from . import code_gen from . import code_gen
from . import testing from . import testing
from . import ops from . import ops
# version
__version__ = '1.0.0'

View File

@@ -1,17 +1,16 @@
import ast import ast
import builtins import builtins
import functools
import inspect import inspect
import struct import struct
import sys import sys
import tempfile
import textwrap import textwrap
import hashlib import hashlib
import os import os
import shelve import shelve
import shutil import subprocess
import os import os
from .tools.disasm import extract from .tools.disasm import extract
import tempfile
import torch import torch
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
@@ -574,7 +573,6 @@ class Kernel:
" Only CUDA is supported at the moment") " Only CUDA is supported at the moment")
device = torch.device('cuda', torch.cuda.current_device()) device = torch.device('cuda', torch.cuda.current_device())
device_ty = device.type
device_idx = device.index device_idx = device.index
if len(set(device_ids)) != 1 or device_ids[0] != device_idx: if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
# try to enable P2P communication # try to enable P2P communication
@@ -595,10 +593,12 @@ class Kernel:
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
# compute hash for caching this kernel # compute hash for caching this kernel
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = frozenset(attributes.items()) attr_key = tuple(attributes.items())
meta_key = frozenset(meta.items()) meta_key = tuple(sorted(meta.items()))
const_key = frozenset(constants.items()) const_key = tuple(constants.items())
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key) compute_capability = torch.cuda.get_device_capability(device)
key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
key = repr(key) key = repr(key)
# get cached binary # get cached binary
drv_cache = self.fn.drv_cache drv_cache = self.fn.drv_cache
@@ -616,18 +616,12 @@ class Kernel:
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'], exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
'dbm.dumb': ['.dir', '.dat']}[dbtype] 'dbm.dumb': ['.dir', '.dat']}[dbtype]
db_paths = [bin_cache_path + ext for ext in exts] db_paths = [bin_cache_path + ext for ext in exts]
# check if the cache is stale
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
cache_mtime = max([os.path.getmtime(db) for db in db_paths])
is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime
# check if the cache is corrupted # check if the cache is corrupted
is_corrupted = os.path.exists(bin_mut_path) is_corrupted = os.path.exists(bin_mut_path)
# delete the cache if stale or corrupted # delete the cache if corrupted
if is_stale or is_corrupted: if is_corrupted:
for db in db_paths: for db in db_paths:
os.remove(db) os.remove(db)
if is_corrupted:
os.remove(bin_mut_path) os.remove(bin_mut_path)
# read the cache, creating if needed # read the cache, creating if needed
with shelve.open(bin_cache_path) as db: with shelve.open(bin_cache_path) as db:
@@ -714,7 +708,26 @@ class Autotuner:
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
@functools.lru_cache()
def version_key():
with open(triton.code_gen.__file__, "rb") as f:
frontend_contents = hashlib.md5(f.read()).hexdigest()
with open(triton._C.libtriton.__file__, "rb") as f:
backend_contents = hashlib.md5(f.read()).hexdigest()
try:
nvcc_version = hashlib.md5(subprocess.check_output(["nvcc", "--version"])).hexdigest()
except Exception:
nvcc_version = None
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = None
return (
triton.__version__, frontend_contents, backend_contents,
nvcc_version, ptxas_version
)
class JITFunction: class JITFunction:
@@ -728,26 +741,27 @@ class JITFunction:
# create cache directory # create cache directory
if not os.path.exists(cache_dir): if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
# create md5 hash of src # paths for dbm file in cache_dir
md5 = hashlib.md5() cache_key = (self.version, self.src) + version_key()
md5.update(self.src.encode('utf-8')) cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest()
md5_hash = md5.hexdigest() self.bin_cache_path = os.path.join(cache_dir, cache_key_str)
# load dbm file in cache_dir for md5_hash self.bin_lock_path = self.bin_cache_path + '.lock'
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
self.bin_lock_path = self.bin_cache_path + '.lock'
self.bin_mut_path = self.bin_cache_path + '.mutating' self.bin_mut_path = self.bin_cache_path + '.mutating'
def __init__(self, fn): def __init__(self, fn, version=None):
# information of wrapped function # information of wrapped function
self.fn = fn self.fn = fn
self.module = fn.__module__ self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args self.arg_names = inspect.getfullargspec(fn).args
self.src = textwrap.dedent(inspect.getsource(fn)) self.version = version
self.src = textwrap.dedent(inspect.getsource(fn))
# cache for callable driver objects (e.g. CUkernel) # cache for callable driver objects (e.g. CUkernel)
self.drv_cache = dict() self.drv_cache = dict()
# on-disk paths for the binary cache and corresponding # on-disk paths for the binary cache and corresponding
# file-lock # file-lock
self._init_cache_paths() self._init_cache_paths()
# JITFunction can be instantiated as kernel # JITFunction can be instantiated as kernel
# when called with a grid using __getitem__ # when called with a grid using __getitem__
self.kernel_decorators = [] self.kernel_decorators = []
@@ -804,6 +818,9 @@ class JITFunction:
def __getitem__(self, grid): def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid) return Launcher(self._init_kernel(), grid)
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
class Config: class Config:
""" """
@@ -899,7 +916,7 @@ def heuristics(values):
return decorator return decorator
def jit(fn): def jit(*args, **kwargs):
""" """
Decorator for JIT-compiling a function using the Triton compiler. Decorator for JIT-compiling a function using the Triton compiler.
@@ -911,11 +928,18 @@ def jit(fn):
* objects within the triton.language package, * objects within the triton.language package,
* arguments to this function, * arguments to this function,
* other jit'd functions * other jit'd functions
:param fn: the function to be jit-compiled :param fn: the function to be jit-compiled
:type fn: Callable :type fn: Callable
""" """
return JITFunction(fn) if args:
assert len(args) == 1
assert callable(args[0])
return JITFunction(args[0], **kwargs)
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator
###### ######