[PYTHON] Update TensorWrapper with device attribute (#102)

bugfix `triton.reinterpret`
This commit is contained in:
Szymon Sidor
2021-05-07 23:51:02 -07:00
committed by Philippe Tillet
parent 840140bf26
commit ac57812bdc

View File

@@ -711,15 +711,15 @@ def cdiv(x, y):
######
class TensorWrapper:
def __init__(self, data_ptr, dtype):
def __init__(self, data_ptr, dtype, device):
self._data_ptr = data_ptr
self.dtype = dtype
self.device = device
def data_ptr(self):
return self._data_ptr
def reinterpret(tensor, dtype):
return TensorWrapper(tensor.data_ptr(), dtype)
return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)