[FRONTEND] Add an assert in case we get a CPU tensor. (#478)
This commit is contained in:
@@ -689,6 +689,10 @@ class Kernel:
|
||||
# handle annotations
|
||||
for pos, _type in self.fn.annotations.items():
|
||||
wargs[pos] = _type(wargs[pos])
|
||||
# check that tensors are on GPU.
|
||||
for arg in wargs:
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
assert arg.is_cuda, "All tensors must be on GPU!"
|
||||
# query device index and cuda stream
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
|
Reference in New Issue
Block a user