[PYTHON] Update TensorWrapper with device attribute (#102)
bugfix `triton.reinterpret`
This commit is contained in:
committed by
Philippe Tillet
parent
840140bf26
commit
ac57812bdc
@@ -711,15 +711,15 @@ def cdiv(x, y):
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, data_ptr, dtype):
|
||||
def __init__(self, data_ptr, dtype, device):
|
||||
self._data_ptr = data_ptr
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def data_ptr(self):
|
||||
return self._data_ptr
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor.data_ptr(), dtype)
|
||||
return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)
|
||||
|
Reference in New Issue
Block a user