diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 699487232..cd919cc6c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -574,15 +574,15 @@ class Kernel: device = torch.device('cuda', torch.cuda.current_device()) device_idx = device.index - if len(set(device_ids)) != 1 or device_ids[0] != device_idx: - # try to enable P2P communication - for arg_idx, dst_idx in zip(tensor_idxs, device_ids): - if dst_idx != device_idx: - try: - _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) - except RuntimeError as e: - raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" - .format(device_idx, dst_idx, str(e))) + # if len(set(device_ids)) != 1 or device_ids[0] != device_idx: + # # try to enable P2P communication + # for arg_idx, dst_idx in zip(tensor_idxs, device_ids): + # if dst_idx != device_idx: + # try: + # _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) + # except RuntimeError as e: + # raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" + # .format(device_idx, dst_idx, str(e))) # enqueue kernel on the current device torch.cuda.set_device(device_idx)