[FRONTEND] updated TensorWrapper (#299)
This commit is contained in:
@@ -936,16 +936,16 @@ def next_power_of_2(n):
|
|||||||
|
|
||||||
######
|
######
|
||||||
|
|
||||||
|
|
||||||
class TensorWrapper:
|
class TensorWrapper:
|
||||||
def __init__(self, data_ptr, dtype, device):
|
def __init__(self, base, dtype):
|
||||||
self._data_ptr = data_ptr
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.base = base
|
||||||
|
self.is_cuda = base.is_cuda
|
||||||
|
self.device = base.device
|
||||||
|
|
||||||
def data_ptr(self):
|
def data_ptr(self):
|
||||||
return self._data_ptr
|
return self.base.data_ptr()
|
||||||
|
|
||||||
|
|
||||||
def reinterpret(tensor, dtype):
|
def reinterpret(tensor, dtype):
|
||||||
return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)
|
return TensorWrapper(tensor, dtype)
|
||||||
|
Reference in New Issue
Block a user