From ac57812bdc4bb2c60ca675512867c60cd607309c Mon Sep 17 00:00:00 2001 From: Szymon Sidor Date: Fri, 7 May 2021 23:51:02 -0700 Subject: [PATCH] [PYTHON] Update TensorWrapper with device attribute (#102) bugfix `triton.reinterpret` --- python/triton/code_gen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 897ec31e4..04df5d842 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -711,15 +711,15 @@ def cdiv(x, y): ###### - class TensorWrapper: - def __init__(self, data_ptr, dtype): + def __init__(self, data_ptr, dtype, device): self._data_ptr = data_ptr self.dtype = dtype + self.device = device def data_ptr(self): return self._data_ptr def reinterpret(tensor, dtype): - return TensorWrapper(tensor.data_ptr(), dtype) \ No newline at end of file + return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)