[PYTHON] Update TensorWrapper with device attribute (#102)
bugfix `triton.reinterpret`
This commit is contained in:
committed by
Philippe Tillet
parent
840140bf26
commit
ac57812bdc
@@ -711,15 +711,15 @@ def cdiv(x, y):
|
|||||||
|
|
||||||
######
|
######
|
||||||
|
|
||||||
|
|
||||||
class TensorWrapper:
|
class TensorWrapper:
|
||||||
def __init__(self, data_ptr, dtype):
|
def __init__(self, data_ptr, dtype, device):
|
||||||
self._data_ptr = data_ptr
|
self._data_ptr = data_ptr
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def data_ptr(self):
|
def data_ptr(self):
|
||||||
return self._data_ptr
|
return self._data_ptr
|
||||||
|
|
||||||
|
|
||||||
def reinterpret(tensor, dtype):
|
def reinterpret(tensor, dtype):
|
||||||
return TensorWrapper(tensor.data_ptr(), dtype)
|
return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)
|
||||||
|
Reference in New Issue
Block a user