[FRONTEND] Alignment fix-up (#428)
This commit is contained in:
@@ -674,9 +674,17 @@ class Kernel:
|
||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if i in self.fn.do_not_specialize:
|
||||
continue
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = Kernel.pow2_divisor(arg)
|
||||
elif i in tensor_idxs:
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||
Kernel.pow2_divisor(range_size))
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
|
Reference in New Issue
Block a user