diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 897ec31e4..04df5d842 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -711,15 +711,15 @@ def cdiv(x, y): ###### - class TensorWrapper: - def __init__(self, data_ptr, dtype): + def __init__(self, data_ptr, dtype, device): self._data_ptr = data_ptr self.dtype = dtype + self.device = device def data_ptr(self): return self._data_ptr def reinterpret(tensor, dtype): - return TensorWrapper(tensor.data_ptr(), dtype) \ No newline at end of file + return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)