diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 6353ddbd0..ef3145a40 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -936,16 +936,16 @@ def next_power_of_2(n): ###### - class TensorWrapper: - def __init__(self, data_ptr, dtype, device): - self._data_ptr = data_ptr + def __init__(self, base, dtype): self.dtype = dtype - self.device = device - + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + def data_ptr(self): - return self._data_ptr + return self.base.data_ptr() def reinterpret(tensor, dtype): - return TensorWrapper(tensor.data_ptr(), dtype, tensor.device) + return TensorWrapper(tensor, dtype)