diff --git a/python/examples/tutorials/vec_add.py b/python/examples/tutorials/vec_add.py index acce14062..542efa57b 100644 --- a/python/examples/tutorials/vec_add.py +++ b/python/examples/tutorials/vec_add.py @@ -14,7 +14,7 @@ __global__ void add(float* z, float* x, float* y, int N) { bool check[TILE] = offset < N; - *?(check)pz = *?(check)px + *?(check)py; + *pz = *px + *py; } """ @@ -32,9 +32,8 @@ add = _add.apply # test torch.manual_seed(0) -x = torch.rand(98432).cuda() -y = torch.rand(98432).cuda() +x = torch.rand(900).cuda() +y = torch.rand(900).cuda() za = x + y zb = add(x, y) - print(torch.allclose(za,zb)) diff --git a/python/src/launch.cc b/python/src/launch.cc index 1f6a8988f..ad0cac7e9 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -57,7 +57,19 @@ void synchronize(int64_t dev_id) { triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); stream.synchronize(); } +} +torch::Tensor raw_like(torch::Tensor x){ + if(x.nbytes() == 0) + return torch::empty_like(x); + C10_CUDA_CHECK(cudaSetDevice(x.device().index())); + auto shape = x.sizes(); + CUdeviceptr data; + triton::driver::dispatch::cuMemAlloc(&data, x.nbytes()); + auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); }; + auto ret = torch::from_blob((void*)data, shape, deleter, x.options()); + ret.copy_(x); + return ret; } void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, @@ -82,5 +94,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, static auto registry = torch::RegisterOperators() .op("triton::launch_kernel", &launch_kernel) + .op("triton::raw_like", &raw_like) .op("triton::cdiv_sum", &cdiv_sum) .op("triton::synchronize", &synchronize); diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 5bab8f6ea..c002c803e 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -89,6 +89,12 @@ class kernel: return libtriton.get_fn_ptx((self.op_id, dev_id), opt) def __call__(self, *args, **kwargs): + if 'TRITON_DEBUG_MODE' in os.environ: + _args = args + args = [x for x in args] + for i in range(len(args)): + if isinstance(args[i], torch.Tensor): + args[i] = torch.ops.triton.raw_like(args[i]) for x in args: if isinstance(x, torch.Tensor): device = x.device.index @@ -108,4 +114,8 @@ class kernel: params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) names = list(kwargs['constants'].keys()) if 'constants' in kwargs else [] constants = list(kwargs['constants'].values()) if 'constants' in kwargs else [] - torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants) \ No newline at end of file + torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants) + if 'TRITON_DEBUG_MODE' in os.environ: + for i in range(len(args)): + if isinstance(args[i], torch.Tensor): + _args[i].copy_(args[i]) \ No newline at end of file