[FRONTEND] Add do_not_specialize
to triton.jit to prevent specialization of kernel argument (#309)
This commit is contained in:
@@ -588,7 +588,8 @@ class Kernel:
|
||||
torch.cuda.set_device(device_idx)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
# compute hash for caching this kernel
|
||||
@@ -740,13 +741,15 @@ class JITFunction:
|
||||
def _set_cache_key(self):
|
||||
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
||||
|
||||
def __init__(self, fn, version=None):
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.arg_names = inspect.getfullargspec(fn).args
|
||||
self.version = version
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||
[self.arg_names.index(arg) for arg in do_not_specialize]
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.drv_cache = dict()
|
||||
# cache for binaries (on-disk)
|
||||
|
Reference in New Issue
Block a user