[FRONTEND] updated TensorWrapper (#299)
This commit is contained in:
@@ -936,16 +936,16 @@ def next_power_of_2(n):
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, data_ptr, dtype, device):
|
||||
self._data_ptr = data_ptr
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
def data_ptr(self):
|
||||
return self._data_ptr
|
||||
return self.base.data_ptr()
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor.data_ptr(), dtype, tensor.device)
|
||||
return TensorWrapper(tensor, dtype)
|
||||
|
Reference in New Issue
Block a user