[FRONTEND] Add cache_version to triton.jit (#301)
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
|
# version
|
||||||
|
__version__ = '1.0.0'
|
||||||
|
|
||||||
# TODO: torch needs to be imported first
|
# TODO: torch needs to be imported first
|
||||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||||
import torch
|
import torch
|
||||||
@@ -8,5 +11,3 @@ from . import language
|
|||||||
from . import code_gen
|
from . import code_gen
|
||||||
from . import testing
|
from . import testing
|
||||||
from . import ops
|
from . import ops
|
||||||
# version
|
|
||||||
__version__ = '1.0.0'
|
|
||||||
|
@@ -1,17 +1,16 @@
|
|||||||
import ast
|
import ast
|
||||||
import builtins
|
import builtins
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import textwrap
|
import textwrap
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import shelve
|
import shelve
|
||||||
import shutil
|
import subprocess
|
||||||
import os
|
import os
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
import tempfile
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
@@ -574,7 +573,6 @@ class Kernel:
|
|||||||
" Only CUDA is supported at the moment")
|
" Only CUDA is supported at the moment")
|
||||||
|
|
||||||
device = torch.device('cuda', torch.cuda.current_device())
|
device = torch.device('cuda', torch.cuda.current_device())
|
||||||
device_ty = device.type
|
|
||||||
device_idx = device.index
|
device_idx = device.index
|
||||||
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||||
# try to enable P2P communication
|
# try to enable P2P communication
|
||||||
@@ -595,10 +593,12 @@ class Kernel:
|
|||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||||
# compute hash for caching this kernel
|
# compute hash for caching this kernel
|
||||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||||
attr_key = frozenset(attributes.items())
|
attr_key = tuple(attributes.items())
|
||||||
meta_key = frozenset(meta.items())
|
meta_key = tuple(sorted(meta.items()))
|
||||||
const_key = frozenset(constants.items())
|
const_key = tuple(constants.items())
|
||||||
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
compute_capability = torch.cuda.get_device_capability(device)
|
||||||
|
|
||||||
|
key = (compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||||
key = repr(key)
|
key = repr(key)
|
||||||
# get cached binary
|
# get cached binary
|
||||||
drv_cache = self.fn.drv_cache
|
drv_cache = self.fn.drv_cache
|
||||||
@@ -616,18 +616,12 @@ class Kernel:
|
|||||||
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
|
exts = {'dbm.gnu': [''], 'dbm.ndbm': ['.db'],
|
||||||
'dbm.dumb': ['.dir', '.dat']}[dbtype]
|
'dbm.dumb': ['.dir', '.dat']}[dbtype]
|
||||||
db_paths = [bin_cache_path + ext for ext in exts]
|
db_paths = [bin_cache_path + ext for ext in exts]
|
||||||
# check if the cache is stale
|
|
||||||
frontend_mtime = os.path.getmtime(triton.code_gen.__file__)
|
|
||||||
backend_mtime = os.path.getmtime(triton._C.libtriton.__file__)
|
|
||||||
cache_mtime = max([os.path.getmtime(db) for db in db_paths])
|
|
||||||
is_stale = frontend_mtime > cache_mtime or backend_mtime > cache_mtime
|
|
||||||
# check if the cache is corrupted
|
# check if the cache is corrupted
|
||||||
is_corrupted = os.path.exists(bin_mut_path)
|
is_corrupted = os.path.exists(bin_mut_path)
|
||||||
# delete the cache if stale or corrupted
|
# delete the cache if corrupted
|
||||||
if is_stale or is_corrupted:
|
if is_corrupted:
|
||||||
for db in db_paths:
|
for db in db_paths:
|
||||||
os.remove(db)
|
os.remove(db)
|
||||||
if is_corrupted:
|
|
||||||
os.remove(bin_mut_path)
|
os.remove(bin_mut_path)
|
||||||
# read the cache, creating if needed
|
# read the cache, creating if needed
|
||||||
with shelve.open(bin_cache_path) as db:
|
with shelve.open(bin_cache_path) as db:
|
||||||
@@ -714,7 +708,26 @@ class Autotuner:
|
|||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def version_key():
|
||||||
|
with open(triton.code_gen.__file__, "rb") as f:
|
||||||
|
frontend_contents = hashlib.md5(f.read()).hexdigest()
|
||||||
|
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||||
|
backend_contents = hashlib.md5(f.read()).hexdigest()
|
||||||
|
|
||||||
|
try:
|
||||||
|
nvcc_version = hashlib.md5(subprocess.check_output(["nvcc", "--version"])).hexdigest()
|
||||||
|
except Exception:
|
||||||
|
nvcc_version = None
|
||||||
|
try:
|
||||||
|
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||||
|
except Exception:
|
||||||
|
ptxas_version = None
|
||||||
|
|
||||||
|
return (
|
||||||
|
triton.__version__, frontend_contents, backend_contents,
|
||||||
|
nvcc_version, ptxas_version
|
||||||
|
)
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
|
||||||
@@ -728,26 +741,27 @@ class JITFunction:
|
|||||||
# create cache directory
|
# create cache directory
|
||||||
if not os.path.exists(cache_dir):
|
if not os.path.exists(cache_dir):
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
# create md5 hash of src
|
# paths for dbm file in cache_dir
|
||||||
md5 = hashlib.md5()
|
cache_key = (self.version, self.src) + version_key()
|
||||||
md5.update(self.src.encode('utf-8'))
|
cache_key_str = hashlib.md5(repr(cache_key).encode("utf-8")).hexdigest()
|
||||||
md5_hash = md5.hexdigest()
|
self.bin_cache_path = os.path.join(cache_dir, cache_key_str)
|
||||||
# load dbm file in cache_dir for md5_hash
|
self.bin_lock_path = self.bin_cache_path + '.lock'
|
||||||
self.bin_cache_path = os.path.join(cache_dir, md5_hash)
|
|
||||||
self.bin_lock_path = self.bin_cache_path + '.lock'
|
|
||||||
self.bin_mut_path = self.bin_cache_path + '.mutating'
|
self.bin_mut_path = self.bin_cache_path + '.mutating'
|
||||||
|
|
||||||
def __init__(self, fn):
|
def __init__(self, fn, version=None):
|
||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.module = fn.__module__
|
self.module = fn.__module__
|
||||||
self.arg_names = inspect.getfullargspec(fn).args
|
self.arg_names = inspect.getfullargspec(fn).args
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.version = version
|
||||||
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
# cache for callable driver objects (e.g. CUkernel)
|
# cache for callable driver objects (e.g. CUkernel)
|
||||||
self.drv_cache = dict()
|
self.drv_cache = dict()
|
||||||
|
|
||||||
# on-disk paths for the binary cache and corresponding
|
# on-disk paths for the binary cache and corresponding
|
||||||
# file-lock
|
# file-lock
|
||||||
self._init_cache_paths()
|
self._init_cache_paths()
|
||||||
|
|
||||||
# JITFunction can be instantiated as kernel
|
# JITFunction can be instantiated as kernel
|
||||||
# when called with a grid using __getitem__
|
# when called with a grid using __getitem__
|
||||||
self.kernel_decorators = []
|
self.kernel_decorators = []
|
||||||
@@ -804,6 +818,9 @@ class JITFunction:
|
|||||||
def __getitem__(self, grid):
|
def __getitem__(self, grid):
|
||||||
return Launcher(self._init_kernel(), grid)
|
return Launcher(self._init_kernel(), grid)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""
|
"""
|
||||||
@@ -899,7 +916,7 @@ def heuristics(values):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def jit(fn):
|
def jit(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Decorator for JIT-compiling a function using the Triton compiler.
|
Decorator for JIT-compiling a function using the Triton compiler.
|
||||||
|
|
||||||
@@ -911,11 +928,18 @@ def jit(fn):
|
|||||||
* objects within the triton.language package,
|
* objects within the triton.language package,
|
||||||
* arguments to this function,
|
* arguments to this function,
|
||||||
* other jit'd functions
|
* other jit'd functions
|
||||||
|
|
||||||
:param fn: the function to be jit-compiled
|
:param fn: the function to be jit-compiled
|
||||||
:type fn: Callable
|
:type fn: Callable
|
||||||
"""
|
"""
|
||||||
return JITFunction(fn)
|
if args:
|
||||||
|
assert len(args) == 1
|
||||||
|
assert callable(args[0])
|
||||||
|
return JITFunction(args[0], **kwargs)
|
||||||
|
else:
|
||||||
|
def decorator(fn):
|
||||||
|
return JITFunction(fn, **kwargs)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
######
|
######
|
||||||
|
Reference in New Issue
Block a user