Merge branch 'master' into rcom52_fixes
This commit is contained in:
File diff suppressed because it is too large
Load Diff
261
python/test/unit/language/test_dequantize.py
Normal file
261
python/test/unit/language/test_dequantize.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# flake8: noqa: F821,F841
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int8(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int4(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
|
||||
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int2(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
def test_dequantize_int8() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 128, 4)
|
||||
else:
|
||||
size = random.randrange(132, 1024, 4)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int8 * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 128)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int8[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int8[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int4() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 2,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input_int8_h1 = input_int8 >> 4
|
||||
input_int8_h0 = input_int8 & 15
|
||||
|
||||
input_int4_val = torch.stack(
|
||||
(input_int8_h0, input_int8_h1), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int4_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int4[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int4[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int2() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int16)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 4,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int16 = input_int8.view(torch.int16)
|
||||
|
||||
input_int8_q3 = input_int8 >> 6
|
||||
input_int8_q2 = (input_int8 >> 4) & 3
|
||||
input_int8_q1 = (input_int8 >> 2) & 3
|
||||
input_int8_q0 = input_int8 & 3
|
||||
|
||||
input_int2_val = torch.stack(
|
||||
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int16))
|
||||
expected = (input_int2_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
|
||||
dequantize_kernel_int2[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int2[grid](
|
||||
output,
|
||||
input_int16,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
@@ -1,16 +1,16 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from numpy.random import Philox
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
#####################################
|
||||
## Reference Philox Implementation
|
||||
# Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
@@ -74,9 +74,8 @@ class CustomPhilox4x:
|
||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||
|
||||
def _raise_key(self, key):
|
||||
ret0 = key[0] + self._config.PHILOX_KEY_A
|
||||
ret1 = key[1] + self._config.PHILOX_KEY_B
|
||||
return np.array([ret0, ret1], dtype=self._dtype)
|
||||
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
|
||||
return key + np.array(pk, dtype=self._dtype)
|
||||
|
||||
def random_raw(self):
|
||||
counter = self._counter
|
||||
@@ -104,18 +103,21 @@ class CustomPhilox(CustomPhilox4x):
|
||||
|
||||
|
||||
#####################################
|
||||
## Unit Tests
|
||||
# Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
@@ -133,10 +135,12 @@ def test_randint(size, seed, device='cuda'):
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
@@ -148,13 +152,16 @@ def test_rand(size, seed, device='cuda'):
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
@@ -168,3 +175,24 @@ def test_randn(size, seed, device='cuda'):
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
||||
|
||||
|
||||
# tl.rand() should never produce >=1.0
|
||||
|
||||
def test_rand_limits():
|
||||
@triton.jit
|
||||
def kernel(input, output, n: tl.constexpr):
|
||||
idx = tl.arange(0, n)
|
||||
x = tl.load(input + idx)
|
||||
y = tl.random.uint32_to_uniform_float(x)
|
||||
tl.store(output + idx, y)
|
||||
|
||||
min_max_int32 = torch.tensor([
|
||||
torch.iinfo(torch.int32).min,
|
||||
torch.iinfo(torch.int32).max,
|
||||
], dtype=torch.int32, device='cuda')
|
||||
output = torch.empty(2, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](min_max_int32, output, 2)
|
||||
|
||||
assert output[0] == output[1]
|
||||
assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@@ -9,76 +10,108 @@ import pytest
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed = 0
|
||||
torch.manual_seed(seed)
|
||||
is_sdd = MODE == "sdd"
|
||||
is_dsd = MODE == "dsd"
|
||||
is_dds = MODE == "dds"
|
||||
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
|
||||
# create op
|
||||
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
||||
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
|
||||
c_shape = (Z, H, M, N)
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a.shape[2], a.shape[3]),
|
||||
"dds": (b.shape[2], b.shape[3]),
|
||||
"dsd": (a_shape[2], a_shape[3]),
|
||||
"dds": (b_shape[2], b_shape[3]),
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
b_ref = do_mask(b_ref) if is_dds else b_ref
|
||||
a_ref.retain_grad()
|
||||
b_ref.retain_grad()
|
||||
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
|
||||
b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
||||
c_ref.backward(dc_ref)
|
||||
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
|
||||
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
|
||||
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
|
||||
# torch result
|
||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
||||
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
||||
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
|
||||
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
|
||||
b_tri = do_sparsify(b_tri) if is_dds else b_tri
|
||||
a_tri.retain_grad()
|
||||
b_tri.retain_grad()
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
||||
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
|
||||
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
|
||||
da_tri = a_tri.grad
|
||||
db_tri = b_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(rc, tc)
|
||||
triton.testing.assert_almost_equal(c_ref, c_tri)
|
||||
triton.testing.assert_almost_equal(da_ref, da_tri)
|
||||
triton.testing.assert_almost_equal(db_ref, db_tri)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE):
|
||||
is_causal = True
|
||||
configs = [
|
||||
(16, 256),
|
||||
(32, 576),
|
||||
(64, 1871),
|
||||
(128, 2511),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_dense", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
|
||||
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 1, 1, WIDTH, WIDTH
|
||||
scale = 0.4
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
|
||||
Z, H, M, N = 2, 3, WIDTH, WIDTH
|
||||
# initialize layout
|
||||
# make sure each row has at least one non-zero element
|
||||
torch.diagonal(layout)[:] = 1
|
||||
torch.diagonal(at_mask)[:] = 1
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
|
||||
kp_mask[:] = 0
|
||||
kp_mask[kp_mask == 1.0] = float("-inf")
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
||||
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
ty = op(
|
||||
tx,
|
||||
scale=scale,
|
||||
key_padding_mask=kp_mask,
|
||||
key_padding_mask_mode="add",
|
||||
attn_mask=at_mask.to(DTYPE),
|
||||
attn_mask_mode="mul",
|
||||
is_causal=is_causal,
|
||||
)
|
||||
# torch result
|
||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
||||
# broadcast at_mask to the same shape as rx
|
||||
if is_causal: at_mask = torch.tril(at_mask)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
rx[M == 0] = float("-inf")
|
||||
# rx += kp_mask[:, None, None, :]
|
||||
ry = torch.softmax(rx * scale, -1)
|
||||
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
if is_dense:
|
||||
layout[:] = 1
|
||||
else:
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# initialize data
|
||||
a_shape = (Z, H, M, N)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
||||
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
||||
# compute [torch]
|
||||
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref.retain_grad()
|
||||
at_mask = torch.ones((M, N), device="cuda")
|
||||
if is_causal:
|
||||
at_mask = torch.tril(at_mask)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
|
||||
a_ref[M == 0] = float("-inf")
|
||||
out_ref = torch.softmax(a_ref * scale, -1)
|
||||
out_ref.backward(dout_ref)
|
||||
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
# compute [triton]
|
||||
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri.retain_grad()
|
||||
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
||||
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
||||
out_tri.backward(dout_tri)
|
||||
da_tri = a_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ry, ty)
|
||||
triton.testing.assert_almost_equal(out_tri, out_ref)
|
||||
triton.testing.assert_almost_equal(da_tri, da_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@@ -97,14 +130,6 @@ def test_attention_fwd_bwd(
|
||||
qkvs = [
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||
]
|
||||
attn_mask = torch.tril(
|
||||
torch.ones(
|
||||
[n_ctx, n_ctx],
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
diagonal=0,
|
||||
)
|
||||
|
||||
# Triton:
|
||||
n_blocks = n_ctx // block
|
||||
@@ -113,7 +138,7 @@ def test_attention_fwd_bwd(
|
||||
query.retain_grad()
|
||||
key.retain_grad()
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
||||
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss.backward()
|
||||
@@ -121,6 +146,8 @@ def test_attention_fwd_bwd(
|
||||
|
||||
# Torch version:
|
||||
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
||||
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
|
||||
attn_mask = torch.tril(attn_mask, diagonal=0)
|
||||
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
||||
torch_q.retain_grad()
|
||||
torch_k.retain_grad()
|
||||
@@ -145,20 +172,16 @@ def test_attention_fwd_bwd(
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
attn_mask: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||
layout,
|
||||
block,
|
||||
)
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
|
||||
|
||||
w = sparse_dot_sdd_nt(query, key)
|
||||
w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
|
||||
w = sparse_softmax(w, scale=scale, is_causal=True)
|
||||
a = sparse_dot_dsd_nn(w, value)
|
||||
return a
|
||||
|
@@ -1,17 +1,23 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']\
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and dtype == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
@@ -30,4 +36,4 @@ def test_op(M, N, dtype, mode):
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
||||
|
@@ -1,8 +1,11 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
@@ -46,7 +49,7 @@ import torch
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
@@ -59,31 +62,36 @@ import torch
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and DTYPE == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
||||
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
triton.autotune(configs, [])(kernel)
|
||||
kernel.kernel_decorators += decorators[1:]
|
||||
kernel.configs = configs
|
||||
# kernel.run = kernel.run.run.run
|
||||
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
|
||||
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
||||
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
||||
|
206
python/test/unit/runtime/test_cache.py
Normal file
206
python/test/unit/runtime/test_cache.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.runtime.jit import JITFunction
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
|
||||
@triton.jit
|
||||
def function_1(i):
|
||||
i = i + 1
|
||||
i = function_2(i)
|
||||
return i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def function_2(i):
|
||||
i = i + 1
|
||||
return i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["i"])
|
||||
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
def apply_src_change(target, old, new):
|
||||
kernel.hash = None
|
||||
function_1.hash = None
|
||||
function_2.hash = None
|
||||
function_1.src = function_1.src.replace(old, new)
|
||||
target.src = target.src.replace(old, new)
|
||||
ret = target.cache_key
|
||||
target.src = target.src.replace(new, old)
|
||||
return ret
|
||||
|
||||
|
||||
def test_nochange():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
||||
assert baseline == updated
|
||||
|
||||
|
||||
def test_toplevel_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def test_nested1_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def reset_tmp_dir():
|
||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||
if os.path.exists(tmpdir):
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def test_reuse():
|
||||
counter = 0
|
||||
|
||||
def inc_counter(*args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
for i in range(10):
|
||||
kernel[(1,)](x, 1, BLOCK=1024)
|
||||
assert counter == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||
def test_specialize(mode):
|
||||
counter = 0
|
||||
|
||||
def inc_counter(*args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
||||
target = {'enable': 3, 'disable': 1}[mode]
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
cache_str = None
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs["repr"]
|
||||
triton.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
||||
|
||||
def test_constexpr_not_callable() -> None:
|
||||
@triton.jit
|
||||
def kernel(X, c: tl.constexpr):
|
||||
tl.store(X, 2)
|
||||
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
error = False
|
||||
try:
|
||||
kernel[(1, )](x, c="str")
|
||||
except BaseException:
|
||||
error = True
|
||||
assert error is False
|
||||
# try and catch
|
||||
try:
|
||||
kernel[(1, )](x, c=tl.abs)
|
||||
except BaseException:
|
||||
error = True
|
||||
assert error is True
|
||||
|
||||
|
||||
def test_jit_warmup_cache() -> None:
|
||||
@triton.jit
|
||||
def kernel_add(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
args = [
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
32,
|
||||
]
|
||||
assert len(kernel_add.cache) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = namedtuple("instance_descriptor", [
|
||||
"divisible_by_16", "equal_to_1"])(
|
||||
tuple(range(4)),
|
||||
())
|
||||
|
||||
proc = multiprocessing.Process(
|
||||
target=triton.compile,
|
||||
kwargs=dict(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
device=0,
|
||||
constants={3: 32},
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
@@ -1,96 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import subprocess
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_p2p_matrix():
|
||||
try:
|
||||
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
|
||||
except subprocess.CalledProcessError:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
|
||||
lines = stdout.split("Legend")[0].split('\n')[1:]
|
||||
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
|
||||
if matrix.size <= 1:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
else:
|
||||
return matrix
|
||||
|
||||
|
||||
def get_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "OK")
|
||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||
|
||||
|
||||
def get_non_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "NS")
|
||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||
|
||||
|
||||
p2p_devices = get_p2p_devices()
|
||||
non_p2p_devices = get_non_p2p_devices()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy(from_ptr, to_ptr, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
values = tl.load(from_ptr + offsets, mask=offsets < N)
|
||||
tl.store(to_ptr + offsets, values, mask=offsets < N)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in p2p_devices
|
||||
for device_from in p2p_devices
|
||||
for device_to in p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
||||
assert torch.allclose(x_from, x_to.to(device_from))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in non_p2p_devices
|
||||
for device_from in non_p2p_devices
|
||||
for device_to in non_p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
Reference in New Issue
Block a user