[FRONTEND] Add do_not_specialize to triton.jit to prevent specialization of kernel argument (#309)
				
					
				
			This commit is contained in:
		@@ -588,7 +588,8 @@ class Kernel:
 | 
			
		||||
        torch.cuda.set_device(device_idx)
 | 
			
		||||
        # attributes
 | 
			
		||||
        args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
 | 
			
		||||
        attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
 | 
			
		||||
        attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
 | 
			
		||||
                      if isinstance(a, int) and i not in self.fn.do_not_specialize}
 | 
			
		||||
        # transforms ints whose value is one into constants for just-in-time compilation
 | 
			
		||||
        constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
 | 
			
		||||
        # compute hash for caching this kernel
 | 
			
		||||
@@ -740,13 +741,15 @@ class JITFunction:
 | 
			
		||||
    def _set_cache_key(self):
 | 
			
		||||
        self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, fn, version=None):
 | 
			
		||||
    def __init__(self, fn, version=None, do_not_specialize=None):
 | 
			
		||||
        # information of wrapped function
 | 
			
		||||
        self.fn = fn
 | 
			
		||||
        self.module = fn.__module__
 | 
			
		||||
        self.arg_names = inspect.getfullargspec(fn).args
 | 
			
		||||
        self.version = version
 | 
			
		||||
        self.src = textwrap.dedent(inspect.getsource(fn))
 | 
			
		||||
        self.do_not_specialize = [] if do_not_specialize is None else\
 | 
			
		||||
                                 [self.arg_names.index(arg) for arg in do_not_specialize]
 | 
			
		||||
        # cache for callable driver objects (e.g. CUkernel)
 | 
			
		||||
        self.drv_cache = dict()
 | 
			
		||||
        # cache for binaries (on-disk)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user