[LANG] Fixed semantics of NaN in float comparisons (#281)

This commit is contained in:
Philippe Tillet
2021-09-13 15:06:29 -07:00
committed by GitHub
parent cecca90bea
commit 3e395bc84e
8 changed files with 46 additions and 17 deletions

View File

@@ -0,0 +1,547 @@
import torch
import triton
import triton.language as tl
import copy
import pytest
import ast
import itertools
torch.manual_seed(0)
# convert from string to torch.dtype
# Necessary because doesn't print torch.dtype properly
cvt = {
'bool': torch.bool,
'int8': torch.int8,
'int16': torch.int16,
'int32': torch.int32,
'int64': torch.int64,
'bfloat16': torch.bfloat16,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
}
int_dtypes = ['int8', 'int16', 'int32', 'int64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes
def patch_kernel(template, to_replace):
kernel = copy.deepcopy(template)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel
@pytest.mark.parametrize("dtype_x", [
(dtype_x) for dtype_x in dtypes
])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, **meta):
pass
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# generic test functions
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
if 'log' in expr: x = torch.abs(x) + 0.01
# reference result
z_ref = eval(expr if torch_expr is None else torch_expr)
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_almost_equal(z_ref, z_tri)
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, **meta):
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
y = tl.load(Y + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
if mode_x == 'nan': x[:] = float('nan')
if mode_y == 'nan': y[:] = float('nan')
# reference result
z_ref = eval(expr)
# triton result
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
# ---------------
# test binary ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['+', '-', '*', '/', '%'] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['&', '|', '^'] \
for dtype_x in dtypes \
for dtype_y in dtypes
])
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
if 'float' in dtype_x + dtype_y:
with pytest.raises(RuntimeError):
_test_binary(dtype_x, dtype_y, expr, device=device)
else:
_test_binary(dtype_x, dtype_y, expr, device=device)
# ---------------
# test compare ops
# ---------------
ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
# real
[
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
for dtype_y in dtypes
] + \
# NaNs
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
for op in ops
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
('nan' , 'nan')]
])
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
# ---------------
# test unary ops
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, f' -x') for dtype_x in float_dtypes
] + [\
(dtype_x, f' ~x') for dtype_x in int_dtypes
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
# ----------------
# test math ops
# ----------------
# @pytest.mark.paramterize("expr", [
# 'exp', 'log', 'cos', 'sin'
# ])
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device)
# ----------------
# test indexing
# ----------------
def make_ptr_str(name, shape):
rank = len(shape)
offsets = []
stride = 1
for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i]
return f"{name} + {' + '.join(offsets)}"
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
['None, :', ':, None',\
'None, :, :', ':, :, None']\
])
def test_index1d(expr, device='cuda'):
dtype = torch.int32
rank_x = expr.count(':')
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
# Triton kernel
@triton.jit
def kernel(Z, X, **meta):
SIZE = meta['SIZE']
m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
tl.store(Z_PTR_EXPR, z)
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
kernel = patch_kernel(kernel, to_replace)
# torch result
x = triton.testing.random(shape_x, dtype=dtype, device=device)
y = torch.zeros(shape_z, dtype=dtype, device=device)
z_ref = eval(expr) + y
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
# compare
triton.testing.assert_almost_equal(z_ref, z_tri)
# ---------------
# test tuples
# ---------------
@triton.jit
def fn(a, b):
return a + b, \
a - b, \
a * b
def test_tuples():
device = 'cuda'
@triton.jit
def with_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = fn(x, y)
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
@triton.jit
def without_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = x + y, x - y, x * y
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
x = torch.tensor([1.3], device=device, dtype=torch.float32)
y = torch.tensor([1.9], device=device, dtype=torch.float32)
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
for kernel in [with_fn, without_fn]:
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
a_ref, b_ref, c_ref = x + y, x - y, x * y
assert a_tri == a_ref
assert b_tri == b_ref
assert c_tri == c_ref
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
('max', 'int32', mode), ('max', 'float32', mode),\
('min', 'int32', mode), ('min', 'float32', mode),\
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
dtype_x = cvt[dtype_x]
n_programs = 5
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
pid = tl.program_id(0)
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# triton result
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
if mode == 'all_neg':
x_tri = -torch.abs(x_tri)
if mode == 'all_pos':
x_tri = torch.abs(x_tri)
if mode == 'min_neg':
idx = torch.randint(n_programs, size=(1, )).item()
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
if mode == 'max_pos':
idx = torch.randint(n_programs, size=(1, )).item()
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
z_tri = torch.empty([], dtype=dtype_x, device=device)
z_tri.fill_(neutral)
kernel[(n_programs, )](x_tri, z_tri)
# torch result
z_ref = torch_op(x_tri).to(dtype_x)
# compare
exact = op not in ['add']
if exact:
assert z_ref.item() == z_tri.item()
else:
triton.testing.assert_almost_equal(z_ref, z_tri)
# ---------------
# test cast
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) \
for dtype_x in dtypes\
for dtype_z in dtypes
] + [
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True)
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
tl.store(Z, z)
# triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, BITCAST=bitcast)
# torch result
if bitcast:
import numpy as np
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
z_ref = torch.from_numpy(z_ref).to(device)
else:
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
# ---------------
# test reduce
# ---------------
@pytest.mark.parametrize("dtype, shape",
[(dtype, shape) \
for dtype in dtypes\
for shape in [128, 512]])
def test_reduce1d(dtype, shape, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X + tl.arange(0, meta['BLOCK']))
tl.store(Z, tl.sum(x, axis=0))
x = triton.testing.random((shape,), dtype=dtype, device=device)
# triton result
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
kernel[(1,)](x, z_tri, BLOCK=shape)
# torch result
z_ref = torch.sum(x).to(dtype)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
@pytest.mark.parametrize("dtype, shape, axis",
[(dtype, shape, 1) \
for dtype in ['float32']\
for shape in [(1, 1024)]])
def test_reduce2d(dtype, shape, axis, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
range_m = tl.arange(0, meta['BLOCK_M'])
range_n = tl.arange(0, meta['BLOCK_N'])
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
z = tl.sum(x, axis=meta['AXIS'])
tl.store(Z + range_m, z)
# input
x = triton.testing.random(shape, dtype=dtype, device=device)
# triton result
z_tri = torch.empty((shape[0],), dtype=dtype, device=device)
kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
# torch result
z_ref = torch.sum(x, axis=axis).to(dtype)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# ---------------
# test permute
# ---------------
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype, shape, perm",
[(dtype, shape, perm) \
for dtype in ['float32']\
for shape in [(128, 128)]\
for perm in [(1, 0)]])
def test_permute(dtype, shape, perm, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn, **meta):
BLOCK_M = meta['BLOCK_M']
BLOCK_N = meta['BLOCK_N']
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
# input
x = triton.testing.random(shape, dtype=dtype, device=device)
# triton result
z_tri = torch.empty_like(x)
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result
z_ref = x.permute(*perm).contiguous()
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
torch.manual_seed(0)
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn, **meta):
BLOCK_M = meta['BLOCK_M']
BLOCK_K = meta['BLOCK_K']
BLOCK_N = meta['BLOCK_N']
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys))
if meta['ADD_MATRIX']:
z += tl.load(Zs)
if meta['ADD_ROWS']:
ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None]
if meta['ADD_COLS']:
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
tl.store(Zs, z)
# input
M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
# triton result
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
z_tri = z.clone()
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
y, y.stride(0), y.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix',
ADD_ROWS = epilogue=='add-rows',
ADD_COLS = epilogue=='add-cols')
# torch result
z_ref = torch.matmul(x.float(), y.float())
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:,0][:, None]
if epilogue == 'add-cols':
z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16)
# compare
ptx = pgm.asm['ptx']
# print(ptx)
triton.testing.assert_almost_equal(z_tri, z_ref)
# make sure ld/st are vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test load
# ---------------
# ---------------
# test store
# ---------------
# ---------------
# test if
# ---------------
# ---------------
# test for
# ---------------
# ---------------
# test while
# ---------------
# ---------------
# test noop
#----------------
def test_noop(device='cuda'):
@triton.jit
def kernel(**meta):
pass
x = triton.testing.random((1,), dtype=torch.int32, device=device)
kernel[(1, )](x)

View File

@@ -0,0 +1,198 @@
import torch
import triton
import triton.language as tl
import pytest
import scipy.stats
import numpy as np
from numpy.random import Philox
#####################################
## Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
self.DTYPE = DTYPE
# This is better for GPU
PHILOX_32 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B9,
PHILOX_KEY_B=0xBB67AE85,
PHILOX_ROUND_A=0xD2511F53,
PHILOX_ROUND_B=0xCD9E8D57,
DTYPE=np.uint32,
)
# This is what numpy implements
PHILOX_64 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B97F4A7C15,
PHILOX_KEY_B=0xBB67AE8584CAA73B,
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
PHILOX_ROUND_B=0xCA5A826395121157,
DTYPE=np.uint64,
)
class CustomPhilox4x:
def __init__(self, seed, config):
self._config = config
seed = self._into_pieces(seed)
self._key = np.array(seed[:2], dtype=self._dtype)
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
@property
def _dtype(self):
return self._config.DTYPE
def _into_pieces(self, n, pad=4):
res = []
while len(res) < pad:
res.append(np.array(n, dtype=self._dtype))
n >>= (np.dtype(self._dtype).itemsize * 8)
assert n == 0
return tuple(res)
def _multiply_low_high(self, a, b):
low = a * b
high = int(a) * int(b)
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
return low, high
def _single_round(self, counter, key):
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
ret0 = hi1 ^ counter[1] ^ key[0]
ret1 = lo1
ret2 = hi0 ^ counter[3] ^ key[1]
ret3 = lo0
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key):
ret0 = key[0] + self._config.PHILOX_KEY_A
ret1 = key[1] + self._config.PHILOX_KEY_B
return np.array([ret0, ret1], dtype=self._dtype)
def random_raw(self):
counter = self._counter
key = self._key
for _ in range(10):
counter = self._single_round(counter, key)
key = self._raise_key(key)
self.advance(1)
return counter
def advance(self, n_steps):
self._counter[0] += n_steps
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
class CustomPhilox(CustomPhilox4x):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.buffer = []
def random_raw(self):
if len(self.buffer) == 0:
self.buffer = list(super().random_raw())[::-1]
return int(self.buffer.pop())
#####################################
## Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\
for seed in [0, 42, 124, 54]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.int32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
# reference result
gen = CustomPhilox4x(seed, config=PHILOX_32)
out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref
# test conversion of random uint32 into random float in [0, 1]
def test_uint32_to_uniform_float():
@triton.jit
def kernel(SRC, TGT, N, **meta):
pid = tl.program_id(0)
offset = pid * BLOCK + tl.arange(0, BLOCK)
src = tl.load(SRC + offset)
tgt = tl.random.uint32_to_uniform_float(src)
tl.store(TGT + offset, tgt, mask=offset < N)
def run(source):
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
N = source.numel()
grid = lambda meta: (triton.cdiv(N, BLOCK),)
kernel[grid](source, target, N)
return target
# check range of edge values
n = 100
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
target = run(source).tolist()
assert target == sorted(target)
assert all(0.0 <= num < 1.0 for num in target)
# check distribution is uniform
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
target = run(source).tolist()
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2

View File

@@ -0,0 +1,160 @@
import torch
import triton
import pytest
@pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
[
(mode, at, bt, block, dtype) for dtype in ["float16"] for mode in ["sdd", "dsd", "dds"]
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]
],
)
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
"dds": (b.shape[2], b.shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest)
# torch result
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
# compare
triton.testing.assert_almost_equal(rc, tc)
@pytest.mark.parametrize(
"BLOCK, WIDTH",
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
)
def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 4, WIDTH, WIDTH
scale = 0.4
# create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
kp_mask[kp_mask == 1.0] = float("-inf")
# triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK)
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
ty = op(
tx,
scale=scale,
key_padding_mask=kp_mask,
key_padding_mask_mode="add",
attn_mask=at_mask.to(DTYPE),
attn_mask_mode="mul",
)
# torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
if at_mask is not None:
# broadcast at_mask to the same shape as rx
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf")
if kp_mask is not None:
rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx * scale, -1)
ry = torch.softmax(rx * scale, -1)
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
# compare
triton.testing.assert_almost_equal(ry, ty)
def test_attention_fwd_bwd(
input_scale=1.0,
tol=2e-2,
scale=1 / 8.0,
n_ctx=256,
dtype=torch.float16,
batch_size=2,
n_heads=2,
block=64,
):
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)]
attn_mask = torch.tril(
torch.ones(
[n_ctx, n_ctx],
device="cuda",
dtype=dtype,
),
diagonal=0,
)
# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out**2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out**2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
triton.testing.assert_almost_equal(loss, torch_loss)
for g1, g2 in zip(grads, torch_grads):
triton.testing.assert_almost_equal(g1, g2)
def triton_attention(
layout,
block: int,
attn_mask: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,
)
w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
a = sparse_dot_dsd_nn(w, value)
return a

View File

@@ -0,0 +1,33 @@
import torch
import triton
import pytest
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
# forward pass
tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == 'forward':
triton.testing.assert_almost_equal(th_y, tt_y)
# backward pass
elif mode == 'backward':
dy = torch.randn_like(tt_y)
# triton backward
tt_y.backward(dy)
tt_dx = x.grad.clone()
# torch backward
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx)

View File

@@ -0,0 +1,89 @@
import pytest
import itertools
import triton
import torch
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
triton.autotune(configs, [])(kernel)
kernel.kernel_decorators += decorators[1:]
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -0,0 +1,96 @@
import torch
import triton
import pytest
import subprocess
import triton.language as tl
import numpy as np
def get_p2p_matrix():
try:
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
except subprocess.CalledProcessError:
return pytest.skip("No multi-GPU topology", allow_module_level=True)
lines = stdout.split("Legend")[0].split('\n')[1:]
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
if matrix.size <= 1:
return pytest.skip("No multi-GPU topology", allow_module_level=True)
else:
return matrix
def get_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "OK")
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
def get_non_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "NS")
return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"
p2p_devices = get_p2p_devices()
non_p2p_devices = get_non_p2p_devices()
@triton.jit
def _copy(from_ptr, to_ptr, N, **meta):
pid = tl.program_id(0)
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
values = tl.load(from_ptr + offsets, mask=offsets < N)
tl.store(to_ptr + offsets, values, mask=offsets < N)
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
[(device_kernel, device_from, device_to, stream_from, stream_to)
for device_kernel in p2p_devices
for device_from in p2p_devices
for device_to in p2p_devices
for stream_from in ['default', 'custom']
for stream_to in ['default', 'custom']
])
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
if device_to == device_from:
return pytest.skip()
torch.cuda.set_device(device_kernel)
N = 512
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
_copy[grid](x_from, x_to, N, BLOCK=1024)
assert torch.allclose(x_from, x_to.to(device_from))
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
[(device_kernel, device_from, device_to, stream_from, stream_to)
for device_kernel in non_p2p_devices
for device_from in non_p2p_devices
for device_to in non_p2p_devices
for stream_from in ['default', 'custom']
for stream_to in ['default', 'custom']
])
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
if device_to == device_from:
return pytest.skip()
with pytest.raises(RuntimeError):
torch.cuda.set_device(device_kernel)
N = 512
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
_copy[grid](x_from, x_to, N, BLOCK=1024)