[TEST] use numpy for reference results in test_core.py (#409)

Since numpy supports unsigned integers, and pytorch doesn't, this will make it easier to test unsigned integer support.

This adds an explicit requirement for numpy in tests, but we already required scipy, so it was already an implicit dependency.
This commit is contained in:
Madeleine Thompson
2022-01-04 13:07:29 -08:00
committed by GitHub
parent 03f1256f60
commit d8db0308cb
3 changed files with 241 additions and 204 deletions

View File

@@ -1,2 +1,3 @@
numpy
pytest pytest
scipy >= 1.7.1 scipy >= 1.7.1

View File

@@ -1,31 +1,59 @@
import copy
import itertools
import re
from typing import Optional
import numpy as np
import pytest
import torch import torch
from numpy.random import RandomState
import triton import triton
import triton.language as tl import triton.language as tl
import copy
import pytest
import ast
import itertools
torch.manual_seed(0)
# convert from string to torch.dtype
# Necessary because doesn't print torch.dtype properly
cvt = {
'bool': torch.bool,
'int8': torch.int8,
'int16': torch.int16,
'int32': torch.int32,
'int64': torch.int64,
'bfloat16': torch.bfloat16,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
}
int_dtypes = ['int8', 'int16', 'int32', 'int64'] int_dtypes = ['int8', 'int16', 'int32', 'int64']
float_dtypes = ['float16', 'float32', 'float64'] float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes dtypes = int_dtypes + float_dtypes
def _bitwidth(dtype: str) -> int:
# ex.: "int64" -> 64
return int(re.search(r'(\d+)$', dtype).group(1))
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
"""
Override `rs` if you're calling this function twice and don't want the same
result for both calls.
"""
if isinstance(shape, int):
shape = (shape, )
if rs is None:
rs = RandomState(seed=17)
dtype = getattr(np, dtype_str)
if dtype_str in int_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str))
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
elif dtype_str in float_dtypes:
return rs.normal(0, 1, shape).astype(dtype)
else:
raise RuntimeError(f'Unknown dtype {dtype_str}')
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
# For now, this always converts to a torch tensor, but when we add unsigned
# integers, it will also support TensorWrapper, since torch doesn't have
# unsigned support.
return torch.tensor(x, device=device)
def to_numpy(x):
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
else:
raise ValueError(f"Not a triton-compatible tensor: {x}")
def patch_kernel(template, to_replace): def patch_kernel(template, to_replace):
kernel = copy.deepcopy(template) kernel = copy.deepcopy(template)
@@ -34,19 +62,18 @@ def patch_kernel(template, to_replace):
return kernel return kernel
@pytest.mark.parametrize("dtype_x", [ @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
(dtype_x) for dtype_x in dtypes
])
def test_empty_kernel(dtype_x, device='cuda'): def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128 SIZE = 128
@triton.jit @triton.jit
def kernel(X, SIZE: tl.constexpr): def kernel(X, SIZE: tl.constexpr):
pass pass
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device)
kernel[(1, )](x, SIZE=SIZE, num_warps=4) kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# generic test functions # generic test functions
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
@@ -58,18 +85,36 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs # inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) x = numpy_random(SIZE, dtype_str=dtype_x)
if 'log' in expr: x = torch.abs(x) + 0.01 if 'log' in expr:
x = np.abs(x) + 0.01
# reference result # reference result
z_ref = eval(expr if torch_expr is None else torch_expr) z_ref = eval(expr if numpy_expr is None else numpy_expr)
# triton result # triton result
z_tri = torch.empty_like(z_ref) x_tri = to_triton(x, device=device)
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) z_tri = to_triton(np.empty_like(z_ref), device=device)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
# compare # compare
triton.testing.assert_almost_equal(z_ref, z_tri) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'): def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
"""
Given two dtype strings, returns the numpy dtype Triton thinks binary
operations on the two types should return. Returns None if the return value
matches numpy. This is generally needed because Triton and pytorch return
narrower floating point types than numpy in mixed operations.
"""
overrides = {
('float16', 'int16'): np.float16,
('float16', 'int32'): np.float16,
('float16', 'int64'): np.float16,
}
key = (a, b) if a < b else (b, a)
return overrides.get(key)
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
@@ -82,27 +127,24 @@ def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y=
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs # inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17) rs = RandomState(17)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144) x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
if mode_x == 'nan': x[:] = float('nan') y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
if mode_y == 'nan': y[:] = float('nan') if mode_x == 'nan':
x[:] = float('nan')
if mode_y == 'nan':
y[:] = float('nan')
# reference result # reference result
z_ref = eval(expr if torch_expr is None else torch_expr) z_ref = eval(expr if numpy_expr is None else numpy_expr)
dtype_z = _binary_op_dtype_override(dtype_x, dtype_y)
if dtype_z is not None:
z_ref = z_ref.astype(dtype_z)
# triton result # triton result
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) x_tri = to_triton(x, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) y_tri = to_triton(y, device=device)
# compare z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr) kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
def _fake_fmod(x, y):
"""
Triton % (for both integers and floats) has the same semantics as torch
fmod, but torch fmod doesn't work on integers until torch 1.8.
`_fake_fmod` gives the same semantics but works on all versions of torch.
"""
z = torch.remainder(x, y)
return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z)
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
@@ -130,36 +172,38 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
def test_bin_op(dtype_x, dtype_y, op, device='cuda'): def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y' expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes: if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
# LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders. # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
torch_expr = '_fake_fmod(x, y)' numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there # Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to # are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump. # convert anyway, we may as well take the accuracy bump.
torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)' numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
else: else:
torch_expr = None numpy_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
with pytest.raises(AssertionError, match='Arrays are not almost equal'): with pytest.raises(AssertionError, match='Not equal to tolerance'):
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
else: else:
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# --------------- # ---------------
# test bitwise ops # test bitwise ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, f' x {op} y') \ (dtype_x, dtype_y, op)
for op in ['&', '|', '^'] \ for op in ['&', '|', '^']
for dtype_x in dtypes \ for dtype_x in dtypes
for dtype_y in dtypes for dtype_y in dtypes
]) ])
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'): def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
if 'float' in dtype_x + dtype_y: if 'float' in dtype_x + dtype_y:
with pytest.raises(RuntimeError): with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, device=device) _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
else: else:
_test_binary(dtype_x, dtype_y, expr, device=device) _test_binary(dtype_x, dtype_y, expr, device=device)
@@ -168,23 +212,24 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
# test compare ops # test compare ops
# --------------- # ---------------
ops = ['==', '!=', '>', '<', '>=', '<='] ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \ @pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
# real # real
[ [
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \ (dtype_x, dtype_y, op, 'real', 'real') \
for op in ops \ for op in ops \
for dtype_x in dtypes \ for dtype_x in dtypes \
for dtype_y in dtypes for dtype_y in dtypes
] + \ ] + \
# NaNs # NaNs
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \ [('float32', 'float32', op, mode_x, mode_y) \
for op in ops for op in ops
for mode_x, mode_y in [('nan' , 'real'), for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'), ('real', 'nan'),
('nan' , 'nan')] ('nan' , 'nan')]
]) ])
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'): def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device) _test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
@@ -192,9 +237,9 @@ def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
# test unary ops # test unary ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, expr", [ @pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, f' -x') for dtype_x in float_dtypes (dtype_x, ' -x') for dtype_x in dtypes
] + [\ ] + [\
(dtype_x, f' ~x') for dtype_x in int_dtypes (dtype_x, ' ~x') for dtype_x in int_dtypes
]) ])
def test_unary_op(dtype_x, expr, device='cuda'): def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device) _test_unary(dtype_x, expr, device=device)
@@ -210,7 +255,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
'exp', 'log', 'cos', 'sin' 'exp', 'log', 'cos', 'sin'
]) ])
def test_math_op(expr, device='cuda'): def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device) _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# ---------------- # ----------------
@@ -229,12 +274,11 @@ def make_ptr_str(name, shape):
return f"{name} + {' + '.join(offsets)}" return f"{name} + {' + '.join(offsets)}"
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in @pytest.mark.parametrize("expr, dtype_str", [
['None, :', ':, None',\ (f'x[{s}]', 'int32')
'None, :, :', ':, :, None']\ for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
]) ])
def test_index1d(expr, device='cuda'): def test_index1d(expr, dtype_str, device='cuda'):
dtype = torch.int32
rank_x = expr.count(':') rank_x = expr.count(':')
rank_y = expr.count(',') + 1 rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)] shape_x = [32 for _ in range(rank_x)]
@@ -257,14 +301,15 @@ def test_index1d(expr, device='cuda'):
kernel = patch_kernel(kernel, to_replace) kernel = patch_kernel(kernel, to_replace)
# torch result # torch result
x = triton.testing.random(shape_x, dtype=dtype, device=device) x = numpy_random(shape_x, dtype_str=dtype_str)
y = torch.zeros(shape_z, dtype=dtype, device=device) y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
z_ref = eval(expr) + y z_ref = eval(expr) + y
# triton result # triton result
z_tri = torch.empty_like(z_ref) z_tri = to_triton(np.empty_like(z_ref), device=device)
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0]) x_tri = to_triton(x)
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# compare # compare
triton.testing.assert_almost_equal(z_ref, z_tri) assert (z_ref == to_numpy(z_tri)).all()
# --------------- # ---------------
@@ -316,14 +361,15 @@ def test_tuples():
# --------------- # ---------------
# test atomics # test atomics
# --------------- # ---------------
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([ @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \ [
('max', 'int32', mode), ('max', 'float32', mode),\ ('add', 'float16', mode),
('min', 'int32', mode), ('min', 'float32', mode),\ ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'int32', mode), ('max', 'float32', mode),
('min', 'int32', mode), ('min', 'float32', mode),
] ]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x, mode, device='cuda'): def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
dtype_x = cvt[dtype_x]
n_programs = 5 n_programs = 5
# triton kernel # triton kernel
@@ -334,52 +380,59 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
old = GENERATE_TEST_HERE old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'}) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op] numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# triton result # triton result
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device) rs = RandomState(17)
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
if mode == 'all_neg': if mode == 'all_neg':
x_tri = -torch.abs(x_tri) x = -np.abs(x)
if mode == 'all_pos': if mode == 'all_pos':
x_tri = torch.abs(x_tri) x = np.abs(x)
if mode == 'min_neg': if mode == 'min_neg':
idx = torch.randint(n_programs, size=(1, )).item() idx = rs.randint(n_programs, size=(1, )).item()
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1 x[idx] = -np.max(np.abs(x)) - 1
if mode == 'max_pos': if mode == 'max_pos':
idx = torch.randint(n_programs, size=(1, )).item() idx = rs.randint(n_programs, size=(1, )).item()
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1 x[idx] = np.max(np.abs(x)) + 1
x_tri = to_triton(x, device=device)
z_tri = torch.empty([], dtype=dtype_x, device=device) z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
z_tri.fill_(neutral)
kernel[(n_programs, )](x_tri, z_tri) kernel[(n_programs, )](x_tri, z_tri)
# torch result # torch result
z_ref = torch_op(x_tri).to(dtype_x) z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# compare # compare
exact = op not in ['add'] exact = op not in ['add']
if exact: if exact:
assert z_ref.item() == z_tri.item() assert z_ref.item() == to_numpy(z_tri).item()
else: else:
triton.testing.assert_almost_equal(z_ref, z_tri) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
# --------------- # ---------------
# test cast # test cast
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) \ (dtype_x, dtype_z, False)
for dtype_x in dtypes\ for dtype_x in dtypes
for dtype_z in dtypes for dtype_z in dtypes
] + [ ] + [
('float32', 'bfloat16', False), ('float32', 'bfloat16', False),
('bfloat16', 'float32', False), ('bfloat16', 'float32', False),
('float32', 'int32', True) ('float32', 'int32', True),
]) ]
)
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x0 = 43 if dtype_x.startswith('int') else 43.5 # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x = torch.tensor([x0], dtype=cvt[dtype_x], device=device) x0 = 43 if dtype_x in int_dtypes else 43.5
if dtype_x.startswith('bfloat'):
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
else:
x = np.array([x0], dtype=getattr(np, dtype_x))
x_tri = to_triton(x)
# triton kernel # triton kernel
@triton.jit @triton.jit
@@ -389,26 +442,31 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
tl.store(Z, z) tl.store(Z, z)
# triton result # triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device) if dtype_z.startswith('bfloat'):
kernel[(1, )](x, z_tri, BITCAST=bitcast) z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
# torch result
if bitcast:
import numpy as np
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
z_ref = torch.from_numpy(z_ref).to(device)
else: else:
z_ref = x.to(z_tri.dtype) z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device)
assert z_tri == z_ref kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
# torch result
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
assert bitcast is False
z_ref = x_tri.to(z_tri.dtype)
assert z_tri == z_ref
else:
if bitcast:
z_ref = x.view(getattr(np, dtype_z))
else:
z_ref = x.astype(getattr(np, dtype_z))
assert to_numpy(z_tri) == z_ref
# --------------- # ---------------
# test reduce # test reduce
# --------------- # ---------------
@pytest.mark.parametrize("dtype, shape", @pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \ [(dtype, shape) \
for dtype in dtypes\ for dtype in dtypes\
for shape in [128, 512]]) for shape in [128, 512]])
def test_reduce1d(dtype, shape, device='cuda'): def test_reduce1d(dtype_str, shape, device='cuda'):
dtype = cvt[dtype]
# triton kernel # triton kernel
@triton.jit @triton.jit
@@ -416,22 +474,22 @@ def test_reduce1d(dtype, shape, device='cuda'):
x = tl.load(X + tl.arange(0, BLOCK)) x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Z, tl.sum(x, axis=0)) tl.store(Z, tl.sum(x, axis=0))
x = triton.testing.random((shape,), dtype=dtype, device=device) rs = RandomState(17)
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
# numpy result
z_ref = np.sum(x).astype(getattr(np, dtype_str))
# triton result # triton result
z_tri = triton.testing.random((1,), dtype=dtype, device=device) x_tri = to_triton(x, device=device)
kernel[(1,)](x, z_tri, BLOCK=shape) z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
# torch result kernel[(1,)](x_tri, z_tri, BLOCK=shape)
z_ref = torch.sum(x).to(dtype)
# compare # compare
triton.testing.assert_almost_equal(z_tri, z_ref) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@pytest.mark.parametrize("dtype, shape, axis", @pytest.mark.parametrize("dtype_str, shape, axis", [
[(dtype, shape, 1) \ ('float32', (1, 1024), 1)
for dtype in ['float32']\ ])
for shape in [(1, 1024)]]) def test_reduce2d(dtype_str, shape, axis, device='cuda'):
def test_reduce2d(dtype, shape, axis, device='cuda'):
dtype = cvt[dtype]
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
@@ -441,29 +499,30 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
z = tl.sum(x, axis=AXIS) z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z) tl.store(Z + range_m, z)
# input # input
x = triton.testing.random(shape, dtype=dtype, device=device) x = numpy_random(shape, dtype_str=dtype_str)
# triton result # triton result
z_tri = torch.empty((shape[0],), dtype=dtype, device=device) x_tri = to_triton(x)
kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device)
# torch result kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_ref = torch.sum(x, axis=axis).to(dtype) # numpy reference result
z_ref = np.sum(x, axis=axis).astype(x.dtype)
# compare # compare
triton.testing.assert_almost_equal(z_tri, z_ref) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# --------------- # ---------------
# test permute # test permute
# --------------- # ---------------
@pytest.mark.parametrize("dtype, shape, perm", @pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \ [(dtype, shape, perm) \
for dtype in ['float32']\ for dtype in ['float32']\
for shape in [(128, 128)]\ for shape in [(128, 128)]\
for perm in [(1, 0)]]) for perm in [(1, 0)]])
def test_permute(dtype, shape, perm, device='cuda'): def test_permute(dtype_str, shape, perm, device='cuda'):
dtype = cvt[dtype]
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, stride_xm, stride_xn, def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn, Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M) off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N) off_n = tl.arange(0, BLOCK_N)
@@ -471,14 +530,15 @@ def test_permute(dtype, shape, perm, device='cuda'):
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs)) tl.store(Zs, tl.load(Xs))
# input # input
x = triton.testing.random(shape, dtype=dtype, device=device) x = numpy_random(shape, dtype_str=dtype_str)
# triton result # triton result
z_tri = torch.empty_like(x) z_tri = to_triton(np.empty_like(x), device=device)
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), x_tri = to_triton(x, device=device)
z_tri, z_tri.stride(1), z_tri.stride(0), pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
BLOCK_M=shape[0], BLOCK_N=shape[1]) z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result # torch result
z_ref = x.permute(*perm).contiguous() z_ref = x.transpose(*perm)
# compare # compare
triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
# parse ptx to make sure ld/st are vectorized # parse ptx to make sure ld/st are vectorized
@@ -491,13 +551,12 @@ def test_permute(dtype, shape, perm, device='cuda'):
# --------------- # ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, dtype=torch.float32, device='cuda'): def test_dot(epilogue, device='cuda'):
torch.manual_seed(0)
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, stride_xm, stride_xk, def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn, Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn, Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
off_m = tl.arange(0, BLOCK_M) off_m = tl.arange(0, BLOCK_M)
@@ -513,36 +572,38 @@ def test_dot(epilogue, dtype=torch.float32, device='cuda'):
ZRs = Z + off_m * stride_zm ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None] z += tl.load(ZRs)[:, None]
if ADD_COLS: if ADD_COLS:
ZCs = Z + off_n * stride_zn ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :] z += tl.load(ZCs)[None, :]
tl.store(Zs, z) tl.store(Zs, z)
# input # input
M, N, K = 64, 64, 32 M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=dtype, device=device) rs = RandomState(17)
y = triton.testing.random((K, N), dtype=dtype, device=device) x = numpy_random((M, K), dtype_str='float32', rs=rs)
y = numpy_random((K, N), dtype_str='float32', rs=rs)
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
# triton result # triton result
z = triton.testing.random((M, N), dtype=dtype, device=device) z = numpy_random((M, N), dtype_str='float32', rs=rs)
z_tri = z.clone() z_tri = to_triton(z, device=device)
if epilogue == 'trans': if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y, y.stride(0), y.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix', ADD_MATRIX = epilogue=='add-matrix',
ADD_ROWS = epilogue=='add-rows', ADD_ROWS = epilogue=='add-rows',
ADD_COLS = epilogue=='add-cols') ADD_COLS = epilogue=='add-cols')
# torch result # torch result
z_ref = torch.matmul(x.float(), y.float()) z_ref = np.matmul(x, y)
if epilogue == 'add-matrix': if epilogue == 'add-matrix':
z_ref += z z_ref += z
if epilogue == 'add-rows': if epilogue == 'add-rows':
z_ref += z[:,0][:, None] z_ref += z[:,0][:, None]
if epilogue == 'add-cols': if epilogue == 'add-cols':
z_ref += z[0,:][None, :] z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16)
# compare # compare
triton.testing.assert_almost_equal(z_tri, z_ref) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized # make sure ld/st are vectorized
ptx = pgm.asm['ptx'] ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
@@ -558,7 +619,7 @@ def test_dot_without_load():
c = tl.dot(a, b) c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
tl.store(pout, c) tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda") out = torch.ones((32,32), dtype=torch.float32, device="cuda")
kernel[(1,)](out) kernel[(1,)](out)
@@ -571,7 +632,7 @@ def test_arange(start, device='cuda'):
BLOCK = 128 BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit @triton.jit
def _kernel(z, BLOCK: tl.constexpr, def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr): START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK) off = tl.arange(0, BLOCK)
val = tl.arange(START, END) val = tl.arange(START, END)
@@ -605,8 +666,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
N_offsets = tl.arange(0, N) N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K) K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
# Load inputs. # Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
@@ -616,7 +677,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
o = tl.dot(x, w) o = tl.dot(x, w)
# Store output # Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out, pgm = _kernel[(1,)](in1, in2, out,
@@ -687,7 +748,7 @@ def test_default():
def _kernel(ret0, ret1, value): def _kernel(ret0, ret1, value):
tl.store(ret0, _impl()) tl.store(ret0, _impl())
tl.store(ret1, _impl(value)) tl.store(ret1, _impl(value))
_kernel[(1,)](ret0, ret1, value) _kernel[(1,)](ret0, ret1, value)
assert ret0.item() == 10 assert ret0.item() == 10
assert ret1.item() == value assert ret1.item() == value
@@ -699,5 +760,5 @@ def test_noop(device='cuda'):
@triton.jit @triton.jit
def kernel(x): def kernel(x):
pass pass
x = triton.testing.random((1,), dtype=torch.int32, device=device) x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
kernel[(1, )](x) kernel[(1, )](x)

View File

@@ -85,31 +85,6 @@ def allclose(x, y, tol=1e-2):
return err <= tol return err <= tol
def assert_allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
assert allclose(x, y, tol)
def random(shape, dtype, device, seed=0):
"""
Override the seed in tests if you're calling this function twice and don't
want the same result for both calls.
"""
torch.manual_seed(seed)
if isinstance(shape, int):
shape = (shape, )
if dtype == torch.bool:
return torch.randint(0, 2, shape, dtype=dtype, device=device)
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
iinfo = torch.iinfo(dtype)
x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.normal(0, 1, shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}')
def nvsmi(attrs): def nvsmi(attrs):
attrs = ','.join(attrs) attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -203,7 +178,7 @@ class Benchmark:
styles=None, styles=None,
): ):
""" """
Constructor Constructor
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value. :param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
:type x_names: List[str] :type x_names: List[str]
@@ -344,4 +319,4 @@ def get_max_tensorcore_tflops(backend, device):
else: else:
ops_per_sub_core = 512 ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
return tflops return tflops