From 5211f23a63158fafebb23ff4ae2d59ff9dab3f7b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 22 Sep 2021 13:53:27 -0700 Subject: [PATCH] [FRONTEND] updated TensorWrapper (#299) --- python/triton/code_gen.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 6353ddbd0..ef3145a40 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -936,16 +936,16 @@ def next_power_of_2(n): ###### - class TensorWrapper: - def __init__(self, data_ptr, dtype, device): - self._data_ptr = data_ptr + def __init__(self, base, dtype): self.dtype = dtype - self.device = device - + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + def data_ptr(self): - return self._data_ptr + return self.base.data_ptr() def reinterpret(tensor, dtype): - return TensorWrapper(tensor.data_ptr(), dtype, tensor.device) + return TensorWrapper(tensor, dtype)