[PYTHON] Added TRITON_DEBUG_MODE which reallocates input tensors outside of the pytorch memory pool to spot out-of-bounds accesses more easily
This commit is contained in:
@@ -14,7 +14,7 @@ __global__ void add(float* z, float* x, float* y, int N) {
|
|||||||
|
|
||||||
bool check[TILE] = offset < N;
|
bool check[TILE] = offset < N;
|
||||||
|
|
||||||
*?(check)pz = *?(check)px + *?(check)py;
|
*pz = *px + *py;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -32,9 +32,8 @@ add = _add.apply
|
|||||||
|
|
||||||
# test
|
# test
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
x = torch.rand(98432).cuda()
|
x = torch.rand(900).cuda()
|
||||||
y = torch.rand(98432).cuda()
|
y = torch.rand(900).cuda()
|
||||||
za = x + y
|
za = x + y
|
||||||
zb = add(x, y)
|
zb = add(x, y)
|
||||||
|
|
||||||
print(torch.allclose(za,zb))
|
print(torch.allclose(za,zb))
|
||||||
|
@@ -57,7 +57,19 @@ void synchronize(int64_t dev_id) {
|
|||||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||||
stream.synchronize();
|
stream.synchronize();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor raw_like(torch::Tensor x){
|
||||||
|
if(x.nbytes() == 0)
|
||||||
|
return torch::empty_like(x);
|
||||||
|
C10_CUDA_CHECK(cudaSetDevice(x.device().index()));
|
||||||
|
auto shape = x.sizes();
|
||||||
|
CUdeviceptr data;
|
||||||
|
triton::driver::dispatch::cuMemAlloc(&data, x.nbytes());
|
||||||
|
auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); };
|
||||||
|
auto ret = torch::from_blob((void*)data, shape, deleter, x.options());
|
||||||
|
ret.copy_(x);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||||
@@ -82,5 +94,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
|||||||
|
|
||||||
static auto registry = torch::RegisterOperators()
|
static auto registry = torch::RegisterOperators()
|
||||||
.op("triton::launch_kernel", &launch_kernel)
|
.op("triton::launch_kernel", &launch_kernel)
|
||||||
|
.op("triton::raw_like", &raw_like)
|
||||||
.op("triton::cdiv_sum", &cdiv_sum)
|
.op("triton::cdiv_sum", &cdiv_sum)
|
||||||
.op("triton::synchronize", &synchronize);
|
.op("triton::synchronize", &synchronize);
|
||||||
|
@@ -89,6 +89,12 @@ class kernel:
|
|||||||
return libtriton.get_fn_ptx((self.op_id, dev_id), opt)
|
return libtriton.get_fn_ptx((self.op_id, dev_id), opt)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||||
|
_args = args
|
||||||
|
args = [x for x in args]
|
||||||
|
for i in range(len(args)):
|
||||||
|
if isinstance(args[i], torch.Tensor):
|
||||||
|
args[i] = torch.ops.triton.raw_like(args[i])
|
||||||
for x in args:
|
for x in args:
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
device = x.device.index
|
device = x.device.index
|
||||||
@@ -109,3 +115,7 @@ class kernel:
|
|||||||
names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
|
names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
|
||||||
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
||||||
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
||||||
|
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||||
|
for i in range(len(args)):
|
||||||
|
if isinstance(args[i], torch.Tensor):
|
||||||
|
_args[i].copy_(args[i])
|
Reference in New Issue
Block a user