[DRIVER] Add CUDA P2P support (#209)

This commit is contained in:
milesial
2021-08-21 06:00:54 +02:00
committed by GitHub
parent 6aa5720d75
commit 5b29da719d
7 changed files with 147 additions and 21 deletions

View File

@@ -36,6 +36,9 @@ void init_triton_driver(py::module &&m) {
}))
.def("max_shared_memory", [](drv::cu_device *self) {
return self->max_shared_memory();
})
.def("enable_peer_access", [](drv::cu_device *self, unsigned long long int peer_mem_ptr) {
self->enable_peer_access(peer_mem_ptr);
});
// host device
py::class_<drv::host_device, drv::device>(m, "host_device")

96
python/test/test_comm.py Normal file
View File

@@ -0,0 +1,96 @@
import torch
import triton
import pytest
import subprocess
import triton.language as tl
import numpy as np
def get_p2p_matrix():
try:
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
except subprocess.CalledProcessError:
return pytest.skip("No multi-GPU topology", allow_module_level=True)
lines = stdout.split("Legend")[0].split('\n')[1:]
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
if matrix.size <= 1:
return pytest.skip("No multi-GPU topology", allow_module_level=True)
else:
return matrix
def get_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "OK")
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
def get_non_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "NS")
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
p2p_devices = get_p2p_devices()
non_p2p_devices = get_non_p2p_devices()
@triton.jit
def _copy(from_ptr, to_ptr, N, **meta):
pid = tl.program_id(0)
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
values = tl.load(from_ptr + offsets, mask=offsets < N)
tl.store(to_ptr + offsets, values, mask=offsets < N)
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
[(device_kernel, device_from, device_to, stream_from, stream_to)
for device_kernel in p2p_devices
for device_from in p2p_devices
for device_to in p2p_devices
for stream_from in ['default', 'custom']
for stream_to in ['default', 'custom']
])
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
if device_to == device_from:
return pytest.skip()
torch.cuda.set_device(device_kernel)
N = 512
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
_copy[grid](x_from, x_to, N, BLOCK=1024)
assert torch.allclose(x_from, x_to.to(device_from))
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
[(device_kernel, device_from, device_to, stream_from, stream_to)
for device_kernel in non_p2p_devices
for device_from in non_p2p_devices
for device_to in non_p2p_devices
for stream_from in ['default', 'custom']
for stream_to in ['default', 'custom']
])
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
if device_to == device_from:
return pytest.skip()
with pytest.raises(RuntimeError):
torch.cuda.set_device(device_kernel)
N = 512
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
_copy[grid](x_from, x_to, N, BLOCK=1024)

View File

@@ -1,17 +1,16 @@
import inspect
import struct
import enum
import types
import torch
import ast
import builtins
import tempfile
from .tools.disasm import extract
import triton._C.libtriton.triton as _triton
import triton
import inspect
import struct
import sys
import tempfile
import textwrap
import collections
import torch
import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
class CodeGenerator(ast.NodeVisitor):
@@ -454,6 +453,7 @@ class CompilationError(Exception):
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}'\
@@ -530,8 +530,6 @@ class Kernel:
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
# explicitly set device
torch.cuda.set_device(device.index)
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
@@ -549,11 +547,10 @@ class Kernel:
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
if shared_mem > tt_device.max_shared_memory():
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, device, num_warps, num_stages, force_nc_cache)
if shared_mem > device.max_shared_memory():
raise OutOfResources(shared_mem, device.max_shared_memory(), "shared memory")
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
@@ -562,14 +559,30 @@ class Kernel:
if len(tensor_idxs) == 0:
raise ValueError("No Tensor argument found.")
invalid_args = []
device_ids = []
for idx in tensor_idxs:
curr = wargs[idx]
if not curr.is_cuda:
invalid_args += [idx]
invalid_args.append(idx)
else:
device_ids.append(curr.device.index)
if invalid_args:
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
" Only CUDA is supported at the moment")
device = wargs[tensor_idxs[0]].device
device = torch.device('cuda', torch.cuda.current_device())
tt_device = _triton.driver.cu_device(device.index, False)
if len(set(device_ids)) != 1 or device_ids[0] != device.index:
# try to enable P2P communication
for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
if dst_idx != device.index:
try:
tt_device.enable_peer_access(wargs[arg_idx].data_ptr())
except RuntimeError as e:
raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
.format(device.index, dst_idx, str(e)))
# enqueue kernel on the current device
torch.cuda.set_device(device.index)
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
@@ -586,7 +599,7 @@ class Kernel:
if key not in cache:
# compile and cache configuration if necessary
cache[key] = self._compile(
*wargs, device=device, attributes=attributes,
*wargs, device=tt_device, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
constants=constants, **meta
)