[FRONTEND] updated TensorWrapper (#299)

This commit is contained in:
Philippe Tillet
2021-09-22 13:53:27 -07:00
committed by GitHub
parent 2849e7a773
commit 5211f23a63

View File

@@ -936,16 +936,16 @@ def next_power_of_2(n):
######
class TensorWrapper:
def __init__(self, data_ptr, dtype, device):
self._data_ptr = data_ptr
def __init__(self, base, dtype):
self.dtype = dtype
self.device = device
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self._data_ptr
return self.base.data_ptr()
def reinterpret(tensor, dtype):
return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)
return TensorWrapper(tensor, dtype)