diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 3f170098b..09254c967 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -689,6 +689,10 @@ class Kernel: # handle annotations for pos, _type in self.fn.annotations.items(): wargs[pos] = _type(wargs[pos]) + # check that tensors are on GPU. + for arg in wargs: + if hasattr(arg, 'data_ptr'): + assert arg.is_cuda, "All tensors must be on GPU!" # query device index and cuda stream device = torch.cuda.current_device() torch.cuda.set_device(device)