[FRONTEND] Add do_not_specialize to triton.jit to prevent specialization of kernel argument (#309)
				
					
				
			This commit is contained in:
		@@ -588,7 +588,8 @@ class Kernel:
 | 
				
			|||||||
        torch.cuda.set_device(device_idx)
 | 
					        torch.cuda.set_device(device_idx)
 | 
				
			||||||
        # attributes
 | 
					        # attributes
 | 
				
			||||||
        args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
 | 
					        args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
 | 
				
			||||||
        attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
 | 
					        attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
 | 
				
			||||||
 | 
					                      if isinstance(a, int) and i not in self.fn.do_not_specialize}
 | 
				
			||||||
        # transforms ints whose value is one into constants for just-in-time compilation
 | 
					        # transforms ints whose value is one into constants for just-in-time compilation
 | 
				
			||||||
        constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
 | 
					        constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
 | 
				
			||||||
        # compute hash for caching this kernel
 | 
					        # compute hash for caching this kernel
 | 
				
			||||||
@@ -740,13 +741,15 @@ class JITFunction:
 | 
				
			|||||||
    def _set_cache_key(self):
 | 
					    def _set_cache_key(self):
 | 
				
			||||||
        self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
 | 
					        self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, fn, version=None):
 | 
					    def __init__(self, fn, version=None, do_not_specialize=None):
 | 
				
			||||||
        # information of wrapped function
 | 
					        # information of wrapped function
 | 
				
			||||||
        self.fn = fn
 | 
					        self.fn = fn
 | 
				
			||||||
        self.module = fn.__module__
 | 
					        self.module = fn.__module__
 | 
				
			||||||
        self.arg_names = inspect.getfullargspec(fn).args
 | 
					        self.arg_names = inspect.getfullargspec(fn).args
 | 
				
			||||||
        self.version = version
 | 
					        self.version = version
 | 
				
			||||||
        self.src = textwrap.dedent(inspect.getsource(fn))
 | 
					        self.src = textwrap.dedent(inspect.getsource(fn))
 | 
				
			||||||
 | 
					        self.do_not_specialize = [] if do_not_specialize is None else\
 | 
				
			||||||
 | 
					                                 [self.arg_names.index(arg) for arg in do_not_specialize]
 | 
				
			||||||
        # cache for callable driver objects (e.g. CUkernel)
 | 
					        # cache for callable driver objects (e.g. CUkernel)
 | 
				
			||||||
        self.drv_cache = dict()
 | 
					        self.drv_cache = dict()
 | 
				
			||||||
        # cache for binaries (on-disk)
 | 
					        # cache for binaries (on-disk)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user