[FRONTEND] Add do_not_specialize
to triton.jit to prevent specialization of kernel argument (#309)
This commit is contained in:
@@ -588,7 +588,8 @@ class Kernel:
|
|||||||
torch.cuda.set_device(device_idx)
|
torch.cuda.set_device(device_idx)
|
||||||
# attributes
|
# attributes
|
||||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||||
|
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||||
# compute hash for caching this kernel
|
# compute hash for caching this kernel
|
||||||
@@ -740,13 +741,15 @@ class JITFunction:
|
|||||||
def _set_cache_key(self):
|
def _set_cache_key(self):
|
||||||
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
||||||
|
|
||||||
def __init__(self, fn, version=None):
|
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.module = fn.__module__
|
self.module = fn.__module__
|
||||||
self.arg_names = inspect.getfullargspec(fn).args
|
self.arg_names = inspect.getfullargspec(fn).args
|
||||||
self.version = version
|
self.version = version
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
|
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||||
|
[self.arg_names.index(arg) for arg in do_not_specialize]
|
||||||
# cache for callable driver objects (e.g. CUkernel)
|
# cache for callable driver objects (e.g. CUkernel)
|
||||||
self.drv_cache = dict()
|
self.drv_cache = dict()
|
||||||
# cache for binaries (on-disk)
|
# cache for binaries (on-disk)
|
||||||
|
Reference in New Issue
Block a user