[PYTHON] Improved error message for CPU (#167)

This commit is contained in:
Philippe Tillet
2021-07-30 09:47:27 -07:00
committed by GitHub
parent 2293afece7
commit a34c57402f

View File

@@ -557,6 +557,14 @@ class Kernel:
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
raise ValueError("No Tensor argument found.")
invalid_args = []
for idx in tensor_idxs:
curr = wargs[idx]
if not curr.is_cuda:
invalid_args += [idx]
if invalid_args:
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
" Only CUDA is supported at the moment")
device = wargs[tensor_idxs[0]].device
torch.cuda.set_device(device.index)
# attributes