diff --git a/include/triton/driver/device.h b/include/triton/driver/device.h index c39c768ca..c408415c9 100755 --- a/include/triton/driver/device.h +++ b/include/triton/driver/device.h @@ -67,6 +67,7 @@ public: size_t max_sm_clock() const; size_t max_mem_clock() const; void set_max_clock(); + void enable_peer_access(CUdeviceptr peer_mem_ptr) const; // Target std::unique_ptr make_target() const; diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index a6a5d3309..ad6574f44 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -108,8 +108,9 @@ public: static CUresult cuCtxGetDevice(CUdevice* result); static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream); static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc); - static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); - static CUresult cuFuncSetCacheConfig (CUfunction hfunc, CUfunc_cache config); + static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); + static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); + static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); // NVML static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device); static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); @@ -178,6 +179,7 @@ private: static void* cuFuncGetAttribute_; static void* cuFuncSetAttribute_; static void* cuFuncSetCacheConfig_; + static void* cuCtxEnablePeerAccess_; // NVML static void* nvmlInit_v2_; static void* nvmlDeviceGetHandleByPciBusId_v2_; diff --git a/lib/driver/device.cc b/lib/driver/device.cc index dc5912d2d..9128e3fe1 100755 --- a/lib/driver/device.cc +++ b/lib/driver/device.cc @@ -27,6 +27,7 @@ #include #include "triton/driver/device.h" #include "triton/driver/context.h" +#include "triton/driver/error.h" #include "triton/codegen/target.h" namespace triton @@ -159,6 +160,14 @@ void cu_device::set_max_clock() { dispatch::nvmlDeviceSetApplicationsClocks(nvml_device(), max_mem_clock(), max_sm_clock()); } +void cu_device::enable_peer_access(CUdeviceptr peer_mem_ptr) const{ + CUcontext context; + dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, peer_mem_ptr); + try { + dispatch::cuCtxEnablePeerAccess(context, 0); + } catch (exception::cuda::peer_access_already_enabled) {} +} + // print infos std::string cu_device::infos() const{ std::ostringstream oss; diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index d8350881b..f2a2c519f 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -178,6 +178,7 @@ CUDA_DEFINE1(CUresult, cuCtxPopCurrent_v2, CUcontext*) CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction) CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int) CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache) +CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int) NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*) NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*) @@ -252,6 +253,7 @@ void* dispatch::cuCtxPopCurrent_v2_; void* dispatch::cuFuncGetAttribute_; void* dispatch::cuFuncSetAttribute_; void* dispatch::cuFuncSetCacheConfig_; +void* dispatch::cuCtxEnablePeerAccess_; void* dispatch::nvmlInit_v2_; void* dispatch::nvmlDeviceGetHandleByPciBusId_v2_; diff --git a/python/src/triton.cc b/python/src/triton.cc index 63214fa9c..1e7fe255d 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -36,6 +36,9 @@ void init_triton_driver(py::module &&m) { })) .def("max_shared_memory", [](drv::cu_device *self) { return self->max_shared_memory(); + }) + .def("enable_peer_access", [](drv::cu_device *self, unsigned long long int peer_mem_ptr) { + self->enable_peer_access(peer_mem_ptr); }); // host device py::class_(m, "host_device") diff --git a/python/test/test_comm.py b/python/test/test_comm.py new file mode 100644 index 000000000..ae843a15f --- /dev/null +++ b/python/test/test_comm.py @@ -0,0 +1,96 @@ +import torch +import triton +import pytest +import subprocess +import triton.language as tl +import numpy as np + + +def get_p2p_matrix(): + try: + stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii") + except subprocess.CalledProcessError: + return pytest.skip("No multi-GPU topology", allow_module_level=True) + + lines = stdout.split("Legend")[0].split('\n')[1:] + matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2]) + if matrix.size <= 1: + return pytest.skip("No multi-GPU topology", allow_module_level=True) + else: + return matrix + + +def get_p2p_devices(): + matrix = get_p2p_matrix() + idx = np.where(matrix == "OK") + return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}" + + +def get_non_p2p_devices(): + matrix = get_p2p_matrix() + idx = np.where(matrix == "NS") + return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}" + + +p2p_devices = get_p2p_devices() +non_p2p_devices = get_non_p2p_devices() + + +@triton.jit +def _copy(from_ptr, to_ptr, N, **meta): + pid = tl.program_id(0) + offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) + values = tl.load(from_ptr + offsets, mask=offsets < N) + tl.store(to_ptr + offsets, values, mask=offsets < N) + + +@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support") +@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", + [(device_kernel, device_from, device_to, stream_from, stream_to) + for device_kernel in p2p_devices + for device_from in p2p_devices + for device_to in p2p_devices + for stream_from in ['default', 'custom'] + for stream_to in ['default', 'custom'] + ]) +def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to): + if device_to == device_from: + return pytest.skip() + + torch.cuda.set_device(device_kernel) + N = 512 + grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) + + with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): + x_from = torch.randn(N, dtype=torch.float32, device=device_from) + with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): + x_to = torch.empty(N, dtype=torch.float32, device=device_to) + + _copy[grid](x_from, x_to, N, BLOCK=1024) + assert torch.allclose(x_from, x_to.to(device_from)) + + +@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support") +@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", + [(device_kernel, device_from, device_to, stream_from, stream_to) + for device_kernel in non_p2p_devices + for device_from in non_p2p_devices + for device_to in non_p2p_devices + for stream_from in ['default', 'custom'] + for stream_to in ['default', 'custom'] + ]) +def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to): + if device_to == device_from: + return pytest.skip() + + with pytest.raises(RuntimeError): + torch.cuda.set_device(device_kernel) + N = 512 + grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) + + with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): + x_from = torch.randn(N, dtype=torch.float32, device=device_from) + with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): + x_to = torch.empty(N, dtype=torch.float32, device=device_to) + + _copy[grid](x_from, x_to, N, BLOCK=1024) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 0172d5220..8b765a5c8 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,17 +1,16 @@ -import inspect -import struct -import enum -import types -import torch import ast import builtins -import tempfile -from .tools.disasm import extract -import triton._C.libtriton.triton as _triton -import triton +import inspect +import struct import sys +import tempfile import textwrap -import collections + +import torch +import triton +import triton._C.libtriton.triton as _triton + +from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): @@ -454,6 +453,7 @@ class CompilationError(Exception): self.message += '\n Error: ' + str(err) super().__init__(self.message) + class OutOfResources(Exception): def __init__(self, required, limit, name): self.message = f'out of resource: {name}'\ @@ -530,8 +530,6 @@ class Kernel: self.fn = fn def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta): - # explicitly set device - torch.cuda.set_device(device.index) # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel @@ -549,11 +547,10 @@ class Kernel: if node is None or isinstance(e, (NotImplementedError, CompilationError)): raise e raise CompilationError(self.fn.src, node, e) - tt_device = _triton.driver.cu_device(device.index, False) # Compile to machine code - mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache) - if shared_mem > tt_device.max_shared_memory(): - raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory") + mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, device, num_warps, num_stages, force_nc_cache) + if shared_mem > device.max_shared_memory(): + raise OutOfResources(shared_mem, device.max_shared_memory(), "shared memory") return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm) def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): @@ -562,14 +559,30 @@ class Kernel: if len(tensor_idxs) == 0: raise ValueError("No Tensor argument found.") invalid_args = [] + device_ids = [] for idx in tensor_idxs: curr = wargs[idx] if not curr.is_cuda: - invalid_args += [idx] + invalid_args.append(idx) + else: + device_ids.append(curr.device.index) if invalid_args: raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + " Only CUDA is supported at the moment") - device = wargs[tensor_idxs[0]].device + + device = torch.device('cuda', torch.cuda.current_device()) + tt_device = _triton.driver.cu_device(device.index, False) + if len(set(device_ids)) != 1 or device_ids[0] != device.index: + # try to enable P2P communication + for arg_idx, dst_idx in zip(tensor_idxs, device_ids): + if dst_idx != device.index: + try: + tt_device.enable_peer_access(wargs[arg_idx].data_ptr()) + except RuntimeError as e: + raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" + .format(device.index, dst_idx, str(e))) + + # enqueue kernel on the current device torch.cuda.set_device(device.index) # attributes args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] @@ -586,7 +599,7 @@ class Kernel: if key not in cache: # compile and cache configuration if necessary cache[key] = self._compile( - *wargs, device=device, attributes=attributes, + *wargs, device=tt_device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, constants=constants, **meta )