diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index aedb2051f..699487232 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -588,7 +588,8 @@ class Kernel: torch.cuda.set_device(device_idx) # attributes args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)} + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + if isinstance(a, int) and i not in self.fn.do_not_specialize} # transforms ints whose value is one into constants for just-in-time compilation constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} # compute hash for caching this kernel @@ -740,13 +741,15 @@ class JITFunction: def _set_cache_key(self): self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version) - def __init__(self, fn, version=None): + def __init__(self, fn, version=None, do_not_specialize=None): # information of wrapped function self.fn = fn self.module = fn.__module__ self.arg_names = inspect.getfullargspec(fn).args self.version = version self.src = textwrap.dedent(inspect.getsource(fn)) + self.do_not_specialize = [] if do_not_specialize is None else\ + [self.arg_names.index(arg) for arg in do_not_specialize] # cache for callable driver objects (e.g. CUkernel) self.drv_cache = dict() # cache for binaries (on-disk)