From 0dd2ec2e3a8dfb54092577c6418bc083684fca88 Mon Sep 17 00:00:00 2001 From: Yongjik Kim Date: Wed, 16 Mar 2022 14:38:56 -0700 Subject: [PATCH] [FRONTEND] Add an assert in case we get a CPU tensor. (#478) --- python/triton/code_gen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 3f170098b..09254c967 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -689,6 +689,10 @@ class Kernel: # handle annotations for pos, _type in self.fn.annotations.items(): wargs[pos] = _type(wargs[pos]) + # check that tensors are on GPU. + for arg in wargs: + if hasattr(arg, 'data_ptr'): + assert arg.is_cuda, "All tensors must be on GPU!" # query device index and cuda stream device = torch.cuda.current_device() torch.cuda.set_device(device)