[PYTHON] Improved error message for CPU (#167)
This commit is contained in:
@@ -557,6 +557,14 @@ class Kernel:
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
invalid_args = []
|
||||
for idx in tensor_idxs:
|
||||
curr = wargs[idx]
|
||||
if not curr.is_cuda:
|
||||
invalid_args += [idx]
|
||||
if invalid_args:
|
||||
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
||||
" Only CUDA is supported at the moment")
|
||||
device = wargs[tensor_idxs[0]].device
|
||||
torch.cuda.set_device(device.index)
|
||||
# attributes
|
||||
|
Reference in New Issue
Block a user