diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 8e118fc00..4ae539c07 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,3 +1,6 @@ +# version +__version__ = '1.0.0' + # TODO: torch needs to be imported first # or pybind11 shows `munmap_chunk(): invalid pointer` import torch @@ -8,5 +11,3 @@ from . import language from . import code_gen from . import testing from . import ops -# version -__version__ = '1.0.0' diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index ef3145a40..857bfd4cf 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,17 +1,16 @@ import ast import builtins +import functools import inspect import struct import sys -import tempfile import textwrap import hashlib import os import shelve -import shutil +import subprocess import os from .tools.disasm import extract -import tempfile import torch import triton import triton._C.libtriton.triton as _triton @@ -574,7 +573,6 @@ class Kernel: " Only CUDA is supported at the moment") device = torch.device('cuda', torch.cuda.current_device()) - device_ty = device.type device_idx = device.index if len(set(device_ids)) != 1 or device_ids[0] != device_idx: # try to enable P2P communication @@ -595,10 +593,12 @@ class Kernel: constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} # compute hash for caching this kernel types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) - attr_key = frozenset(attributes.items()) - meta_key = frozenset(meta.items()) - const_key = frozenset(constants.items()) - key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key) + attr_key = tuple(attributes.items()) + meta_key = tuple(sorted(meta.items())) + const_key = tuple(constants.items()) + compute_capability = torch.cuda.get_device_capability(device) + + key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key) key = repr(key) # get cached binary drv_cache = self.fn.drv_cache @@ -616,18 +616,12 @@ class Kernel: exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'], 'dbm.dumb': ['.dir', '.dat']}[dbtype] db_paths = [bin_cache_path + ext for ext in exts] - # check if the cache is stale - frontend_mtime = os.path.getmtime(triton.code_gen.__file__) - backend_mtime = os.path.getmtime(triton._C.libtriton.__file__) - cache_mtime = max([os.path.getmtime(db) for db in db_paths]) - is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime # check if the cache is corrupted is_corrupted = os.path.exists(bin_mut_path) - # delete the cache if stale or corrupted - if is_stale or is_corrupted: + # delete the cache if corrupted + if is_corrupted: for db in db_paths: os.remove(db) - if is_corrupted: os.remove(bin_mut_path) # read the cache, creating if needed with shelve.open(bin_cache_path) as db: @@ -714,7 +708,26 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) +@functools.lru_cache() +def version_key(): + with open(triton.code_gen.__file__, "rb") as f: + frontend_contents = hashlib.md5(f.read()).hexdigest() + with open(triton._C.libtriton.__file__, "rb") as f: + backend_contents = hashlib.md5(f.read()).hexdigest() + try: + nvcc_version = hashlib.md5(subprocess.check_output(["nvcc", "--version"])).hexdigest() + except Exception: + nvcc_version = None + try: + ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() + except Exception: + ptxas_version = None + + return ( + triton.__version__, frontend_contents, backend_contents, + nvcc_version, ptxas_version + ) class JITFunction: @@ -728,26 +741,27 @@ class JITFunction: # create cache directory if not os.path.exists(cache_dir): os.makedirs(cache_dir, exist_ok=True) - # create md5 hash of src - md5 = hashlib.md5() - md5.update(self.src.encode('utf-8')) - md5_hash = md5.hexdigest() - # load dbm file in cache_dir for md5_hash - self.bin_cache_path = os.path.join(cache_dir, md5_hash) - self.bin_lock_path = self.bin_cache_path + '.lock' + # paths for dbm file in cache_dir + cache_key = (self.version, self.src) + version_key() + cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest() + self.bin_cache_path = os.path.join(cache_dir, cache_key_str) + self.bin_lock_path = self.bin_cache_path + '.lock' self.bin_mut_path = self.bin_cache_path + '.mutating' - def __init__(self, fn): + def __init__(self, fn, version=None): # information of wrapped function self.fn = fn self.module = fn.__module__ self.arg_names = inspect.getfullargspec(fn).args - self.src = textwrap.dedent(inspect.getsource(fn)) + self.version = version + self.src = textwrap.dedent(inspect.getsource(fn)) # cache for callable driver objects (e.g. CUkernel) self.drv_cache = dict() + # on-disk paths for the binary cache and corresponding # file-lock self._init_cache_paths() + # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel_decorators = [] @@ -804,6 +818,9 @@ class JITFunction: def __getitem__(self, grid): return Launcher(self._init_kernel(), grid) + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + class Config: """ @@ -899,7 +916,7 @@ def heuristics(values): return decorator -def jit(fn): +def jit(*args, **kwargs): """ Decorator for JIT-compiling a function using the Triton compiler. @@ -911,11 +928,18 @@ def jit(fn): * objects within the triton.language package, * arguments to this function, * other jit'd functions - + :param fn: the function to be jit-compiled :type fn: Callable """ - return JITFunction(fn) + if args: + assert len(args) == 1 + assert callable(args[0]) + return JITFunction(args[0], **kwargs) + else: + def decorator(fn): + return JITFunction(fn, **kwargs) + return decorator ######