[DRIVER] Add CUDA P2P support (#209)
This commit is contained in:
@@ -36,6 +36,9 @@ void init_triton_driver(py::module &&m) {
|
||||
}))
|
||||
.def("max_shared_memory", [](drv::cu_device *self) {
|
||||
return self->max_shared_memory();
|
||||
})
|
||||
.def("enable_peer_access", [](drv::cu_device *self, unsigned long long int peer_mem_ptr) {
|
||||
self->enable_peer_access(peer_mem_ptr);
|
||||
});
|
||||
// host device
|
||||
py::class_<drv::host_device, drv::device>(m, "host_device")
|
||||
|
96
python/test/test_comm.py
Normal file
96
python/test/test_comm.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import subprocess
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_p2p_matrix():
|
||||
try:
|
||||
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
|
||||
except subprocess.CalledProcessError:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
|
||||
lines = stdout.split("Legend")[0].split('\n')[1:]
|
||||
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
|
||||
if matrix.size <= 1:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
else:
|
||||
return matrix
|
||||
|
||||
|
||||
def get_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "OK")
|
||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||
|
||||
|
||||
def get_non_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "NS")
|
||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||
|
||||
|
||||
p2p_devices = get_p2p_devices()
|
||||
non_p2p_devices = get_non_p2p_devices()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy(from_ptr, to_ptr, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
values = tl.load(from_ptr + offsets, mask=offsets < N)
|
||||
tl.store(to_ptr + offsets, values, mask=offsets < N)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in p2p_devices
|
||||
for device_from in p2p_devices
|
||||
for device_to in p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
||||
assert torch.allclose(x_from, x_to.to(device_from))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in non_p2p_devices
|
||||
for device_from in non_p2p_devices
|
||||
for device_to in non_p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
@@ -1,17 +1,16 @@
|
||||
import inspect
|
||||
import struct
|
||||
import enum
|
||||
import types
|
||||
import torch
|
||||
import ast
|
||||
import builtins
|
||||
import tempfile
|
||||
from .tools.disasm import extract
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton
|
||||
import inspect
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -454,6 +453,7 @@ class CompilationError(Exception):
|
||||
self.message += '\n Error: ' + str(err)
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}'\
|
||||
@@ -530,8 +530,6 @@ class Kernel:
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
||||
# explicitly set device
|
||||
torch.cuda.set_device(device.index)
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
@@ -549,11 +547,10 @@ class Kernel:
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
# Compile to machine code
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
|
||||
if shared_mem > tt_device.max_shared_memory():
|
||||
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
if shared_mem > device.max_shared_memory():
|
||||
raise OutOfResources(shared_mem, device.max_shared_memory(), "shared memory")
|
||||
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
@@ -562,14 +559,30 @@ class Kernel:
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
invalid_args = []
|
||||
device_ids = []
|
||||
for idx in tensor_idxs:
|
||||
curr = wargs[idx]
|
||||
if not curr.is_cuda:
|
||||
invalid_args += [idx]
|
||||
invalid_args.append(idx)
|
||||
else:
|
||||
device_ids.append(curr.device.index)
|
||||
if invalid_args:
|
||||
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
||||
" Only CUDA is supported at the moment")
|
||||
device = wargs[tensor_idxs[0]].device
|
||||
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
if len(set(device_ids)) != 1 or device_ids[0] != device.index:
|
||||
# try to enable P2P communication
|
||||
for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||
if dst_idx != device.index:
|
||||
try:
|
||||
tt_device.enable_peer_access(wargs[arg_idx].data_ptr())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||
.format(device.index, dst_idx, str(e)))
|
||||
|
||||
# enqueue kernel on the current device
|
||||
torch.cuda.set_device(device.index)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
@@ -586,7 +599,7 @@ class Kernel:
|
||||
if key not in cache:
|
||||
# compile and cache configuration if necessary
|
||||
cache[key] = self._compile(
|
||||
*wargs, device=device, attributes=attributes,
|
||||
*wargs, device=tt_device, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
constants=constants, **meta
|
||||
)
|
||||
|
Reference in New Issue
Block a user