[FRONTEND] Add do_not_specialize to triton.jit to prevent specialization of kernel argument (#309)

This commit is contained in:
Philippe Tillet
2021-09-24 20:27:10 -07:00
committed by GitHub
parent 83da3febf2
commit c3756d1c33

View File

@@ -588,7 +588,8 @@ class Kernel:
torch.cuda.set_device(device_idx) torch.cuda.set_device(device_idx)
# attributes # attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)} attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
# compute hash for caching this kernel # compute hash for caching this kernel
@@ -740,13 +741,15 @@ class JITFunction:
def _set_cache_key(self): def _set_cache_key(self):
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version) self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
def __init__(self, fn, version=None): def __init__(self, fn, version=None, do_not_specialize=None):
# information of wrapped function # information of wrapped function
self.fn = fn self.fn = fn
self.module = fn.__module__ self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args self.arg_names = inspect.getfullargspec(fn).args
self.version = version self.version = version
self.src = textwrap.dedent(inspect.getsource(fn)) self.src = textwrap.dedent(inspect.getsource(fn))
self.do_not_specialize = [] if do_not_specialize is None else\
[self.arg_names.index(arg) for arg in do_not_specialize]
# cache for callable driver objects (e.g. CUkernel) # cache for callable driver objects (e.g. CUkernel)
self.drv_cache = dict() self.drv_cache = dict()
# cache for binaries (on-disk) # cache for binaries (on-disk)