[DRIVER] Add CUDA P2P support (#209)
This commit is contained in:
@@ -67,6 +67,7 @@ public:
|
|||||||
size_t max_sm_clock() const;
|
size_t max_sm_clock() const;
|
||||||
size_t max_mem_clock() const;
|
size_t max_mem_clock() const;
|
||||||
void set_max_clock();
|
void set_max_clock();
|
||||||
|
void enable_peer_access(CUdeviceptr peer_mem_ptr) const;
|
||||||
// Target
|
// Target
|
||||||
std::unique_ptr<codegen::target> make_target() const;
|
std::unique_ptr<codegen::target> make_target() const;
|
||||||
|
|
||||||
|
@@ -108,8 +108,9 @@ public:
|
|||||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
||||||
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
|
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
|
||||||
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
|
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
|
||||||
static CUresult cuFuncSetCacheConfig (CUfunction hfunc, CUfunc_cache config);
|
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
|
||||||
|
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
|
||||||
// NVML
|
// NVML
|
||||||
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
||||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||||
@@ -178,6 +179,7 @@ private:
|
|||||||
static void* cuFuncGetAttribute_;
|
static void* cuFuncGetAttribute_;
|
||||||
static void* cuFuncSetAttribute_;
|
static void* cuFuncSetAttribute_;
|
||||||
static void* cuFuncSetCacheConfig_;
|
static void* cuFuncSetCacheConfig_;
|
||||||
|
static void* cuCtxEnablePeerAccess_;
|
||||||
// NVML
|
// NVML
|
||||||
static void* nvmlInit_v2_;
|
static void* nvmlInit_v2_;
|
||||||
static void* nvmlDeviceGetHandleByPciBusId_v2_;
|
static void* nvmlDeviceGetHandleByPciBusId_v2_;
|
||||||
|
@@ -27,6 +27,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include "triton/driver/device.h"
|
#include "triton/driver/device.h"
|
||||||
#include "triton/driver/context.h"
|
#include "triton/driver/context.h"
|
||||||
|
#include "triton/driver/error.h"
|
||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
|
|
||||||
namespace triton
|
namespace triton
|
||||||
@@ -159,6 +160,14 @@ void cu_device::set_max_clock() {
|
|||||||
dispatch::nvmlDeviceSetApplicationsClocks(nvml_device(), max_mem_clock(), max_sm_clock());
|
dispatch::nvmlDeviceSetApplicationsClocks(nvml_device(), max_mem_clock(), max_sm_clock());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cu_device::enable_peer_access(CUdeviceptr peer_mem_ptr) const{
|
||||||
|
CUcontext context;
|
||||||
|
dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, peer_mem_ptr);
|
||||||
|
try {
|
||||||
|
dispatch::cuCtxEnablePeerAccess(context, 0);
|
||||||
|
} catch (exception::cuda::peer_access_already_enabled) {}
|
||||||
|
}
|
||||||
|
|
||||||
// print infos
|
// print infos
|
||||||
std::string cu_device::infos() const{
|
std::string cu_device::infos() const{
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
@@ -178,6 +178,7 @@ CUDA_DEFINE1(CUresult, cuCtxPopCurrent_v2, CUcontext*)
|
|||||||
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
|
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
|
||||||
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
|
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
|
||||||
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
|
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
|
||||||
|
CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
|
||||||
|
|
||||||
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
|
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
|
||||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||||
@@ -252,6 +253,7 @@ void* dispatch::cuCtxPopCurrent_v2_;
|
|||||||
void* dispatch::cuFuncGetAttribute_;
|
void* dispatch::cuFuncGetAttribute_;
|
||||||
void* dispatch::cuFuncSetAttribute_;
|
void* dispatch::cuFuncSetAttribute_;
|
||||||
void* dispatch::cuFuncSetCacheConfig_;
|
void* dispatch::cuFuncSetCacheConfig_;
|
||||||
|
void* dispatch::cuCtxEnablePeerAccess_;
|
||||||
|
|
||||||
void* dispatch::nvmlInit_v2_;
|
void* dispatch::nvmlInit_v2_;
|
||||||
void* dispatch::nvmlDeviceGetHandleByPciBusId_v2_;
|
void* dispatch::nvmlDeviceGetHandleByPciBusId_v2_;
|
||||||
|
@@ -36,6 +36,9 @@ void init_triton_driver(py::module &&m) {
|
|||||||
}))
|
}))
|
||||||
.def("max_shared_memory", [](drv::cu_device *self) {
|
.def("max_shared_memory", [](drv::cu_device *self) {
|
||||||
return self->max_shared_memory();
|
return self->max_shared_memory();
|
||||||
|
})
|
||||||
|
.def("enable_peer_access", [](drv::cu_device *self, unsigned long long int peer_mem_ptr) {
|
||||||
|
self->enable_peer_access(peer_mem_ptr);
|
||||||
});
|
});
|
||||||
// host device
|
// host device
|
||||||
py::class_<drv::host_device, drv::device>(m, "host_device")
|
py::class_<drv::host_device, drv::device>(m, "host_device")
|
||||||
|
96
python/test/test_comm.py
Normal file
96
python/test/test_comm.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import pytest
|
||||||
|
import subprocess
|
||||||
|
import triton.language as tl
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_p2p_matrix():
|
||||||
|
try:
|
||||||
|
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||||
|
|
||||||
|
lines = stdout.split("Legend")[0].split('\n')[1:]
|
||||||
|
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
|
||||||
|
if matrix.size <= 1:
|
||||||
|
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||||
|
else:
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
|
def get_p2p_devices():
|
||||||
|
matrix = get_p2p_matrix()
|
||||||
|
idx = np.where(matrix == "OK")
|
||||||
|
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_non_p2p_devices():
|
||||||
|
matrix = get_p2p_matrix()
|
||||||
|
idx = np.where(matrix == "NS")
|
||||||
|
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||||
|
|
||||||
|
|
||||||
|
p2p_devices = get_p2p_devices()
|
||||||
|
non_p2p_devices = get_non_p2p_devices()
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _copy(from_ptr, to_ptr, N, **meta):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||||
|
values = tl.load(from_ptr + offsets, mask=offsets < N)
|
||||||
|
tl.store(to_ptr + offsets, values, mask=offsets < N)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
|
||||||
|
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||||
|
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||||
|
for device_kernel in p2p_devices
|
||||||
|
for device_from in p2p_devices
|
||||||
|
for device_to in p2p_devices
|
||||||
|
for stream_from in ['default', 'custom']
|
||||||
|
for stream_to in ['default', 'custom']
|
||||||
|
])
|
||||||
|
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||||
|
if device_to == device_from:
|
||||||
|
return pytest.skip()
|
||||||
|
|
||||||
|
torch.cuda.set_device(device_kernel)
|
||||||
|
N = 512
|
||||||
|
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||||
|
|
||||||
|
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||||
|
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||||
|
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||||
|
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||||
|
|
||||||
|
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
||||||
|
assert torch.allclose(x_from, x_to.to(device_from))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
|
||||||
|
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||||
|
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||||
|
for device_kernel in non_p2p_devices
|
||||||
|
for device_from in non_p2p_devices
|
||||||
|
for device_to in non_p2p_devices
|
||||||
|
for stream_from in ['default', 'custom']
|
||||||
|
for stream_to in ['default', 'custom']
|
||||||
|
])
|
||||||
|
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||||
|
if device_to == device_from:
|
||||||
|
return pytest.skip()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
torch.cuda.set_device(device_kernel)
|
||||||
|
N = 512
|
||||||
|
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||||
|
|
||||||
|
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||||
|
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||||
|
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||||
|
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||||
|
|
||||||
|
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
@@ -1,17 +1,16 @@
|
|||||||
import inspect
|
|
||||||
import struct
|
|
||||||
import enum
|
|
||||||
import types
|
|
||||||
import torch
|
|
||||||
import ast
|
import ast
|
||||||
import builtins
|
import builtins
|
||||||
import tempfile
|
import inspect
|
||||||
from .tools.disasm import extract
|
import struct
|
||||||
import triton._C.libtriton.triton as _triton
|
|
||||||
import triton
|
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
import collections
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
@@ -454,6 +453,7 @@ class CompilationError(Exception):
|
|||||||
self.message += '\n Error: ' + str(err)
|
self.message += '\n Error: ' + str(err)
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
class OutOfResources(Exception):
|
class OutOfResources(Exception):
|
||||||
def __init__(self, required, limit, name):
|
def __init__(self, required, limit, name):
|
||||||
self.message = f'out of resource: {name}'\
|
self.message = f'out of resource: {name}'\
|
||||||
@@ -530,8 +530,6 @@ class Kernel:
|
|||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
||||||
# explicitly set device
|
|
||||||
torch.cuda.set_device(device.index)
|
|
||||||
# create IR module
|
# create IR module
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
# get just-in-time proto-type of kernel
|
# get just-in-time proto-type of kernel
|
||||||
@@ -549,11 +547,10 @@ class Kernel:
|
|||||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||||
raise e
|
raise e
|
||||||
raise CompilationError(self.fn.src, node, e)
|
raise CompilationError(self.fn.src, node, e)
|
||||||
tt_device = _triton.driver.cu_device(device.index, False)
|
|
||||||
# Compile to machine code
|
# Compile to machine code
|
||||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
|
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||||
if shared_mem > tt_device.max_shared_memory():
|
if shared_mem > device.max_shared_memory():
|
||||||
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
|
raise OutOfResources(shared_mem, device.max_shared_memory(), "shared memory")
|
||||||
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
|
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||||
@@ -562,14 +559,30 @@ class Kernel:
|
|||||||
if len(tensor_idxs) == 0:
|
if len(tensor_idxs) == 0:
|
||||||
raise ValueError("No Tensor argument found.")
|
raise ValueError("No Tensor argument found.")
|
||||||
invalid_args = []
|
invalid_args = []
|
||||||
|
device_ids = []
|
||||||
for idx in tensor_idxs:
|
for idx in tensor_idxs:
|
||||||
curr = wargs[idx]
|
curr = wargs[idx]
|
||||||
if not curr.is_cuda:
|
if not curr.is_cuda:
|
||||||
invalid_args += [idx]
|
invalid_args.append(idx)
|
||||||
|
else:
|
||||||
|
device_ids.append(curr.device.index)
|
||||||
if invalid_args:
|
if invalid_args:
|
||||||
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
||||||
" Only CUDA is supported at the moment")
|
" Only CUDA is supported at the moment")
|
||||||
device = wargs[tensor_idxs[0]].device
|
|
||||||
|
device = torch.device('cuda', torch.cuda.current_device())
|
||||||
|
tt_device = _triton.driver.cu_device(device.index, False)
|
||||||
|
if len(set(device_ids)) != 1 or device_ids[0] != device.index:
|
||||||
|
# try to enable P2P communication
|
||||||
|
for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||||
|
if dst_idx != device.index:
|
||||||
|
try:
|
||||||
|
tt_device.enable_peer_access(wargs[arg_idx].data_ptr())
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||||
|
.format(device.index, dst_idx, str(e)))
|
||||||
|
|
||||||
|
# enqueue kernel on the current device
|
||||||
torch.cuda.set_device(device.index)
|
torch.cuda.set_device(device.index)
|
||||||
# attributes
|
# attributes
|
||||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||||
@@ -586,7 +599,7 @@ class Kernel:
|
|||||||
if key not in cache:
|
if key not in cache:
|
||||||
# compile and cache configuration if necessary
|
# compile and cache configuration if necessary
|
||||||
cache[key] = self._compile(
|
cache[key] = self._compile(
|
||||||
*wargs, device=device, attributes=attributes,
|
*wargs, device=tt_device, attributes=attributes,
|
||||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||||
constants=constants, **meta
|
constants=constants, **meta
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user