[FRONTEND] Disable P2P (#312)
This commit is contained in:
@@ -574,15 +574,15 @@ class Kernel:
|
||||
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
device_idx = device.index
|
||||
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||
# try to enable P2P communication
|
||||
for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||
if dst_idx != device_idx:
|
||||
try:
|
||||
_triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||
.format(device_idx, dst_idx, str(e)))
|
||||
# if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||
# # try to enable P2P communication
|
||||
# for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||
# if dst_idx != device_idx:
|
||||
# try:
|
||||
# _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
||||
# except RuntimeError as e:
|
||||
# raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||
# .format(device_idx, dst_idx, str(e)))
|
||||
|
||||
# enqueue kernel on the current device
|
||||
torch.cuda.set_device(device_idx)
|
||||
|
Reference in New Issue
Block a user