[LANG] Fixed semantics of NaN in float comparisons (#281)
This commit is contained in:
547
python/test/unit/language/test_core.py
Normal file
547
python/test/unit/language/test_core.py
Normal file
@@ -0,0 +1,547 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import copy
|
||||
import pytest
|
||||
import ast
|
||||
import itertools
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# convert from string to torch.dtype
|
||||
# Necessary because doesn't print torch.dtype properly
|
||||
cvt = {
|
||||
'bool': torch.bool,
|
||||
'int8': torch.int8,
|
||||
'int16': torch.int16,
|
||||
'int32': torch.int32,
|
||||
'int64': torch.int64,
|
||||
'bfloat16': torch.bfloat16,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
}
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + float_dtypes
|
||||
|
||||
|
||||
def patch_kernel(template, to_replace):
|
||||
kernel = copy.deepcopy(template)
|
||||
for key, value in to_replace.items():
|
||||
kernel.src = kernel.src.replace(key, value)
|
||||
return kernel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [
|
||||
(dtype_x) for dtype_x in dtypes
|
||||
])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
@triton.jit
|
||||
def kernel(X, **meta):
|
||||
pass
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
if 'log' in expr: x = torch.abs(x) + 0.01
|
||||
# reference result
|
||||
z_ref = eval(expr if torch_expr is None else torch_expr)
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y, **meta):
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X + off)
|
||||
y = tl.load(Y + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
||||
if mode_x == 'nan': x[:] = float('nan')
|
||||
if mode_y == 'nan': y[:] = float('nan')
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
# triton result
|
||||
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
||||
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test binary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['+', '-', '*', '/', '%'] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['&', '|', '^'] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(RuntimeError):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test compare ops
|
||||
# ---------------
|
||||
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
|
||||
# real
|
||||
[
|
||||
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
|
||||
for op in ops \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
] + \
|
||||
# NaNs
|
||||
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
|
||||
for op in ops
|
||||
for mode_x, mode_y in [('nan' , 'real'),
|
||||
('real', 'nan'),
|
||||
('nan' , 'nan')]
|
||||
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test unary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, f' -x') for dtype_x in float_dtypes
|
||||
] + [\
|
||||
(dtype_x, f' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
# ----------------
|
||||
# test math ops
|
||||
# ----------------
|
||||
# @pytest.mark.paramterize("expr", [
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
# ----------------
|
||||
|
||||
|
||||
def make_ptr_str(name, shape):
|
||||
rank = len(shape)
|
||||
offsets = []
|
||||
stride = 1
|
||||
for i in reversed(range(rank)):
|
||||
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
||||
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
|
||||
stride *= shape[i]
|
||||
return f"{name} + {' + '.join(offsets)}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
|
||||
['None, :', ':, None',\
|
||||
'None, :, :', ':, :, None']\
|
||||
])
|
||||
def test_index1d(expr, device='cuda'):
|
||||
dtype = torch.int32
|
||||
rank_x = expr.count(':')
|
||||
rank_y = expr.count(',') + 1
|
||||
shape_x = [32 for _ in range(rank_x)]
|
||||
shape_z = [32 for _ in range(rank_y)]
|
||||
|
||||
# Triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
SIZE = meta['SIZE']
|
||||
m = tl.arange(0, SIZE)
|
||||
n = tl.arange(0, SIZE)
|
||||
x = tl.load(X_PTR_EXPR)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z_PTR_EXPR, z)
|
||||
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
'GENERATE_TEST_HERE': expr,
|
||||
}
|
||||
kernel = patch_kernel(kernel, to_replace)
|
||||
|
||||
# torch result
|
||||
x = triton.testing.random(shape_x, dtype=dtype, device=device)
|
||||
y = torch.zeros(shape_z, dtype=dtype, device=device)
|
||||
z_ref = eval(expr) + y
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test tuples
|
||||
# ---------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fn(a, b):
|
||||
return a + b, \
|
||||
a - b, \
|
||||
a * b
|
||||
|
||||
|
||||
def test_tuples():
|
||||
device = 'cuda'
|
||||
|
||||
@triton.jit
|
||||
def with_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = fn(x, y)
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
@triton.jit
|
||||
def without_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = x + y, x - y, x * y
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
||||
y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
||||
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
for kernel in [with_fn, without_fn]:
|
||||
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
||||
a_ref, b_ref, c_ref = x + y, x - y, x * y
|
||||
assert a_tri == a_ref
|
||||
assert b_tri == b_ref
|
||||
assert c_tri == c_ref
|
||||
|
||||
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
|
||||
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
|
||||
('max', 'int32', mode), ('max', 'float32', mode),\
|
||||
('min', 'int32', mode), ('min', 'float32', mode),\
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
dtype_x = cvt[dtype_x]
|
||||
n_programs = 5
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
pid = tl.program_id(0)
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
|
||||
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
|
||||
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# triton result
|
||||
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
|
||||
if mode == 'all_neg':
|
||||
x_tri = -torch.abs(x_tri)
|
||||
if mode == 'all_pos':
|
||||
x_tri = torch.abs(x_tri)
|
||||
if mode == 'min_neg':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
|
||||
if mode == 'max_pos':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
|
||||
|
||||
z_tri = torch.empty([], dtype=dtype_x, device=device)
|
||||
z_tri.fill_(neutral)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
# torch result
|
||||
z_ref = torch_op(x_tri).to(dtype_x)
|
||||
# compare
|
||||
exact = op not in ['add']
|
||||
if exact:
|
||||
assert z_ref.item() == z_tri.item()
|
||||
else:
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test cast
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False) \
|
||||
for dtype_x in dtypes\
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
|
||||
tl.store(Z, z)
|
||||
|
||||
# triton result
|
||||
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
|
||||
kernel[(1, )](x, z_tri, BITCAST=bitcast)
|
||||
# torch result
|
||||
if bitcast:
|
||||
import numpy as np
|
||||
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
|
||||
z_ref = torch.from_numpy(z_ref).to(device)
|
||||
else:
|
||||
z_ref = x.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype, shape",
|
||||
[(dtype, shape) \
|
||||
for dtype in dtypes\
|
||||
for shape in [128, 512]])
|
||||
def test_reduce1d(dtype, shape, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X + tl.arange(0, meta['BLOCK']))
|
||||
tl.store(Z, tl.sum(x, axis=0))
|
||||
|
||||
x = triton.testing.random((shape,), dtype=dtype, device=device)
|
||||
# triton result
|
||||
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
|
||||
kernel[(1,)](x, z_tri, BLOCK=shape)
|
||||
# torch result
|
||||
z_ref = torch.sum(x).to(dtype)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype, shape, axis",
|
||||
[(dtype, shape, 1) \
|
||||
for dtype in ['float32']\
|
||||
for shape in [(1, 1024)]])
|
||||
def test_reduce2d(dtype, shape, axis, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
range_m = tl.arange(0, meta['BLOCK_M'])
|
||||
range_n = tl.arange(0, meta['BLOCK_N'])
|
||||
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
|
||||
z = tl.sum(x, axis=meta['AXIS'])
|
||||
tl.store(Z + range_m, z)
|
||||
# input
|
||||
x = triton.testing.random(shape, dtype=dtype, device=device)
|
||||
# triton result
|
||||
z_tri = torch.empty((shape[0],), dtype=dtype, device=device)
|
||||
kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
# torch result
|
||||
z_ref = torch.sum(x, axis=axis).to(dtype)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("dtype, shape, perm",
|
||||
[(dtype, shape, perm) \
|
||||
for dtype in ['float32']\
|
||||
for shape in [(128, 128)]\
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype, shape, perm, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn, **meta):
|
||||
BLOCK_M = meta['BLOCK_M']
|
||||
BLOCK_N = meta['BLOCK_N']
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
# input
|
||||
x = triton.testing.random(shape, dtype=dtype, device=device)
|
||||
# triton result
|
||||
z_tri = torch.empty_like(x)
|
||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
||||
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# torch result
|
||||
z_ref = x.permute(*perm).contiguous()
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# ---------------
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
Y, stride_yk, stride_yn,
|
||||
Z, stride_zm, stride_zn, **meta):
|
||||
BLOCK_M = meta['BLOCK_M']
|
||||
BLOCK_K = meta['BLOCK_K']
|
||||
BLOCK_N = meta['BLOCK_N']
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_k = tl.arange(0, BLOCK_K)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys))
|
||||
if meta['ADD_MATRIX']:
|
||||
z += tl.load(Zs)
|
||||
if meta['ADD_ROWS']:
|
||||
ZRs = Z + off_m * stride_zm
|
||||
z += tl.load(ZRs)[:, None]
|
||||
if meta['ADD_COLS']:
|
||||
ZCs = Z + off_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
|
||||
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
|
||||
# triton result
|
||||
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
|
||||
z_tri = z.clone()
|
||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
||||
y, y.stride(0), y.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX = epilogue=='add-matrix',
|
||||
ADD_ROWS = epilogue=='add-rows',
|
||||
ADD_COLS = epilogue=='add-cols')
|
||||
# torch result
|
||||
z_ref = torch.matmul(x.float(), y.float())
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
z_ref += z[:,0][:, None]
|
||||
if epilogue == 'add-cols':
|
||||
z_ref += z[0,:][None, :]
|
||||
z_ref = z_ref.to(torch.float16)
|
||||
# compare
|
||||
ptx = pgm.asm['ptx']
|
||||
# print(ptx)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# make sure ld/st are vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test if
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test for
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test while
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test noop
|
||||
#----------------
|
||||
def test_noop(device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(**meta):
|
||||
pass
|
||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
||||
kernel[(1, )](x)
|
198
python/test/unit/language/test_random.py
Normal file
198
python/test/unit/language/test_random.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
|
||||
from numpy.random import Philox
|
||||
|
||||
#####################################
|
||||
## Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
|
||||
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
|
||||
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
|
||||
self.DTYPE = DTYPE
|
||||
|
||||
|
||||
# This is better for GPU
|
||||
PHILOX_32 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B9,
|
||||
PHILOX_KEY_B=0xBB67AE85,
|
||||
PHILOX_ROUND_A=0xD2511F53,
|
||||
PHILOX_ROUND_B=0xCD9E8D57,
|
||||
DTYPE=np.uint32,
|
||||
)
|
||||
|
||||
# This is what numpy implements
|
||||
PHILOX_64 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B97F4A7C15,
|
||||
PHILOX_KEY_B=0xBB67AE8584CAA73B,
|
||||
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
|
||||
PHILOX_ROUND_B=0xCA5A826395121157,
|
||||
DTYPE=np.uint64,
|
||||
)
|
||||
|
||||
|
||||
class CustomPhilox4x:
|
||||
def __init__(self, seed, config):
|
||||
self._config = config
|
||||
seed = self._into_pieces(seed)
|
||||
self._key = np.array(seed[:2], dtype=self._dtype)
|
||||
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
|
||||
|
||||
@property
|
||||
def _dtype(self):
|
||||
return self._config.DTYPE
|
||||
|
||||
def _into_pieces(self, n, pad=4):
|
||||
res = []
|
||||
while len(res) < pad:
|
||||
res.append(np.array(n, dtype=self._dtype))
|
||||
n >>= (np.dtype(self._dtype).itemsize * 8)
|
||||
assert n == 0
|
||||
return tuple(res)
|
||||
|
||||
def _multiply_low_high(self, a, b):
|
||||
low = a * b
|
||||
high = int(a) * int(b)
|
||||
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
|
||||
return low, high
|
||||
|
||||
def _single_round(self, counter, key):
|
||||
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
|
||||
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
|
||||
ret0 = hi1 ^ counter[1] ^ key[0]
|
||||
ret1 = lo1
|
||||
ret2 = hi0 ^ counter[3] ^ key[1]
|
||||
ret3 = lo0
|
||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||
|
||||
def _raise_key(self, key):
|
||||
ret0 = key[0] + self._config.PHILOX_KEY_A
|
||||
ret1 = key[1] + self._config.PHILOX_KEY_B
|
||||
return np.array([ret0, ret1], dtype=self._dtype)
|
||||
|
||||
def random_raw(self):
|
||||
counter = self._counter
|
||||
key = self._key
|
||||
for _ in range(10):
|
||||
counter = self._single_round(counter, key)
|
||||
key = self._raise_key(key)
|
||||
self.advance(1)
|
||||
return counter
|
||||
|
||||
def advance(self, n_steps):
|
||||
self._counter[0] += n_steps
|
||||
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
|
||||
|
||||
|
||||
class CustomPhilox(CustomPhilox4x):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.buffer = []
|
||||
|
||||
def random_raw(self):
|
||||
if len(self.buffer) == 0:
|
||||
self.buffer = list(super().random_raw())[::-1]
|
||||
return int(self.buffer.pop())
|
||||
|
||||
|
||||
#####################################
|
||||
## Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randint(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.int32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
|
||||
# reference result
|
||||
gen = CustomPhilox4x(seed, config=PHILOX_32)
|
||||
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test conversion of random uint32 into random float in [0, 1]
|
||||
def test_uint32_to_uniform_float():
|
||||
@triton.jit
|
||||
def kernel(SRC, TGT, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
offset = pid * BLOCK + tl.arange(0, BLOCK)
|
||||
src = tl.load(SRC + offset)
|
||||
tgt = tl.random.uint32_to_uniform_float(src)
|
||||
tl.store(TGT + offset, tgt, mask=offset < N)
|
||||
|
||||
def run(source):
|
||||
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
|
||||
N = source.numel()
|
||||
grid = lambda meta: (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](source, target, N)
|
||||
return target
|
||||
|
||||
# check range of edge values
|
||||
n = 100
|
||||
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
|
||||
target = run(source).tolist()
|
||||
assert target == sorted(target)
|
||||
assert all(0.0 <= num < 1.0 for num in target)
|
||||
# check distribution is uniform
|
||||
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
|
||||
target = run(source).tolist()
|
||||
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test uniform PRNG
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.rand(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randn(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
160
python/test/unit/operators/test_blocksparse.py
Normal file
160
python/test/unit/operators/test_blocksparse.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
|
||||
[
|
||||
(mode, at, bt, block, dtype) for dtype in ["float16"] for mode in ["sdd", "dsd", "dds"]
|
||||
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]
|
||||
],
|
||||
)
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a.shape[2], a.shape[3]),
|
||||
"dds": (b.shape[2], b.shape[3]),
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest)
|
||||
# torch result
|
||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
||||
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(rc, tc)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK, WIDTH",
|
||||
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
||||
)
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 2, 4, WIDTH, WIDTH
|
||||
scale = 0.4
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
|
||||
kp_mask[kp_mask == 1.0] = float("-inf")
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
||||
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
ty = op(
|
||||
tx,
|
||||
scale=scale,
|
||||
key_padding_mask=kp_mask,
|
||||
key_padding_mask_mode="add",
|
||||
attn_mask=at_mask.to(DTYPE),
|
||||
attn_mask_mode="mul",
|
||||
)
|
||||
# torch result
|
||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
||||
if at_mask is not None:
|
||||
# broadcast at_mask to the same shape as rx
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
rx[M == 0] = float("-inf")
|
||||
if kp_mask is not None:
|
||||
rx += kp_mask[:, None, None, :]
|
||||
ry = torch.softmax(rx * scale, -1)
|
||||
ry = torch.softmax(rx * scale, -1)
|
||||
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ry, ty)
|
||||
|
||||
|
||||
def test_attention_fwd_bwd(
|
||||
input_scale=1.0,
|
||||
tol=2e-2,
|
||||
scale=1 / 8.0,
|
||||
n_ctx=256,
|
||||
dtype=torch.float16,
|
||||
batch_size=2,
|
||||
n_heads=2,
|
||||
block=64,
|
||||
):
|
||||
# inputs
|
||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||
qkvs = [torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)]
|
||||
attn_mask = torch.tril(
|
||||
torch.ones(
|
||||
[n_ctx, n_ctx],
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
diagonal=0,
|
||||
)
|
||||
|
||||
# Triton:
|
||||
n_blocks = n_ctx // block
|
||||
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
|
||||
query, key, value = [x.clone() for x in qkvs]
|
||||
query.retain_grad()
|
||||
key.retain_grad()
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out**2).mean()
|
||||
loss.backward()
|
||||
grads = [query.grad, key.grad, value.grad]
|
||||
|
||||
# Torch version:
|
||||
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
||||
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
||||
torch_q.retain_grad()
|
||||
torch_k.retain_grad()
|
||||
torch_v.retain_grad()
|
||||
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
|
||||
scores = scores + attn_mask
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||
# ad hoc loss
|
||||
torch_loss = (torch_attn_out**2).mean()
|
||||
torch_loss.backward()
|
||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||
|
||||
# comparison
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
triton.testing.assert_almost_equal(loss, torch_loss)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
triton.testing.assert_almost_equal(g1, g2)
|
||||
|
||||
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
attn_mask: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||
layout,
|
||||
block,
|
||||
)
|
||||
|
||||
w = sparse_dot_sdd_nt(query, key)
|
||||
w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
|
||||
a = sparse_dot_dsd_nn(w, value)
|
||||
return a
|
33
python/test/unit/operators/test_cross_entropy.py
Normal file
33
python/test/unit/operators/test_cross_entropy.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']\
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
# forward pass
|
||||
tt_y = triton.ops.cross_entropy(x, idx)
|
||||
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||
if mode == 'forward':
|
||||
triton.testing.assert_almost_equal(th_y, tt_y)
|
||||
# backward pass
|
||||
elif mode == 'backward':
|
||||
dy = torch.randn_like(tt_y)
|
||||
# triton backward
|
||||
tt_y.backward(dy)
|
||||
tt_dx = x.grad.clone()
|
||||
# torch backward
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
89
python/test/unit/operators/test_matmul.py
Normal file
89
python/test/unit/operators/test_matmul.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
triton.autotune(configs, [])(kernel)
|
||||
kernel.kernel_decorators += decorators[1:]
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
96
python/test/unit/runtime/test_comm.py
Normal file
96
python/test/unit/runtime/test_comm.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import subprocess
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_p2p_matrix():
|
||||
try:
|
||||
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
|
||||
except subprocess.CalledProcessError:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
|
||||
lines = stdout.split("Legend")[0].split('\n')[1:]
|
||||
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
|
||||
if matrix.size <= 1:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
else:
|
||||
return matrix
|
||||
|
||||
|
||||
def get_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "OK")
|
||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||
|
||||
|
||||
def get_non_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "NS")
|
||||
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
|
||||
|
||||
|
||||
p2p_devices = get_p2p_devices()
|
||||
non_p2p_devices = get_non_p2p_devices()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy(from_ptr, to_ptr, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
values = tl.load(from_ptr + offsets, mask=offsets < N)
|
||||
tl.store(to_ptr + offsets, values, mask=offsets < N)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in p2p_devices
|
||||
for device_from in p2p_devices
|
||||
for device_to in p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
||||
assert torch.allclose(x_from, x_to.to(device_from))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in non_p2p_devices
|
||||
for device_from in non_p2p_devices
|
||||
for device_to in non_p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
Reference in New Issue
Block a user