diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index ae69f876c..70799861d 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -557,6 +557,14 @@ class Kernel: tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] if len(tensor_idxs) == 0: raise ValueError("No Tensor argument found.") + invalid_args = [] + for idx in tensor_idxs: + curr = wargs[idx] + if not curr.is_cuda: + invalid_args += [idx] + if invalid_args: + raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + + " Only CUDA is supported at the moment") device = wargs[tensor_idxs[0]].device torch.cuda.set_device(device.index) # attributes