[TEST] use numpy for reference results in test_core.py (#409)

Since numpy supports unsigned integers, and pytorch doesn't, this will make it easier to test unsigned integer support.

This adds an explicit requirement for numpy in tests, but we already required scipy, so it was already an implicit dependency.
This commit is contained in:
Madeleine Thompson
2022-01-04 13:07:29 -08:00
committed by GitHub
parent 03f1256f60
commit d8db0308cb
3 changed files with 241 additions and 204 deletions

View File

@@ -1,2 +1,3 @@
numpy
pytest
scipy >= 1.7.1

View File

@@ -1,31 +1,59 @@
import copy
import itertools
import re
from typing import Optional
import numpy as np
import pytest
import torch
from numpy.random import RandomState
import triton
import triton.language as tl
import copy
import pytest
import ast
import itertools
torch.manual_seed(0)
# convert from string to torch.dtype
# Necessary because doesn't print torch.dtype properly
cvt = {
'bool': torch.bool,
'int8': torch.int8,
'int16': torch.int16,
'int32': torch.int32,
'int64': torch.int64,
'bfloat16': torch.bfloat16,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
}
int_dtypes = ['int8', 'int16', 'int32', 'int64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes
def _bitwidth(dtype: str) -> int:
# ex.: "int64" -> 64
return int(re.search(r'(\d+)$', dtype).group(1))
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
"""
Override `rs` if you're calling this function twice and don't want the same
result for both calls.
"""
if isinstance(shape, int):
shape = (shape, )
if rs is None:
rs = RandomState(seed=17)
dtype = getattr(np, dtype_str)
if dtype_str in int_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str))
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
elif dtype_str in float_dtypes:
return rs.normal(0, 1, shape).astype(dtype)
else:
raise RuntimeError(f'Unknown dtype {dtype_str}')
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
# For now, this always converts to a torch tensor, but when we add unsigned
# integers, it will also support TensorWrapper, since torch doesn't have
# unsigned support.
return torch.tensor(x, device=device)
def to_numpy(x):
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
else:
raise ValueError(f"Not a triton-compatible tensor: {x}")
def patch_kernel(template, to_replace):
kernel = copy.deepcopy(template)
@@ -34,19 +62,18 @@ def patch_kernel(template, to_replace):
return kernel
@pytest.mark.parametrize("dtype_x", [
(dtype_x) for dtype_x in dtypes
])
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, SIZE: tl.constexpr):
pass
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# generic test functions
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -58,18 +85,36 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
if 'log' in expr: x = torch.abs(x) + 0.01
x = numpy_random(SIZE, dtype_str=dtype_x)
if 'log' in expr:
x = np.abs(x) + 0.01
# reference result
z_ref = eval(expr if torch_expr is None else torch_expr)
z_ref = eval(expr if numpy_expr is None else numpy_expr)
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.empty_like(z_ref), device=device)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_almost_equal(z_ref, z_tri)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'):
def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
"""
Given two dtype strings, returns the numpy dtype Triton thinks binary
operations on the two types should return. Returns None if the return value
matches numpy. This is generally needed because Triton and pytorch return
narrower floating point types than numpy in mixed operations.
"""
overrides = {
('float16', 'int16'): np.float16,
('float16', 'int32'): np.float16,
('float16', 'int64'): np.float16,
}
key = (a, b) if a < b else (b, a)
return overrides.get(key)
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -82,27 +127,24 @@ def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y=
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144)
if mode_x == 'nan': x[:] = float('nan')
if mode_y == 'nan': y[:] = float('nan')
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
if mode_x == 'nan':
x[:] = float('nan')
if mode_y == 'nan':
y[:] = float('nan')
# reference result
z_ref = eval(expr if torch_expr is None else torch_expr)
z_ref = eval(expr if numpy_expr is None else numpy_expr)
dtype_z = _binary_op_dtype_override(dtype_x, dtype_y)
if dtype_z is not None:
z_ref = z_ref.astype(dtype_z)
# triton result
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
def _fake_fmod(x, y):
"""
Triton % (for both integers and floats) has the same semantics as torch
fmod, but torch fmod doesn't work on integers until torch 1.8.
`_fake_fmod` gives the same semantics but works on all versions of torch.
"""
z = torch.remainder(x, y)
return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z)
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
@@ -130,36 +172,38 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
# LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders.
torch_expr = '_fake_fmod(x, y)'
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)'
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
else:
torch_expr = None
numpy_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
with pytest.raises(AssertionError, match='Arrays are not almost equal'):
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
with pytest.raises(AssertionError, match='Not equal to tolerance'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
else:
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['&', '|', '^'] \
for dtype_x in dtypes \
for dtype_y in dtypes
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['&', '|', '^']
for dtype_x in dtypes
for dtype_y in dtypes
])
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
if 'float' in dtype_x + dtype_y:
with pytest.raises(RuntimeError):
_test_binary(dtype_x, dtype_y, expr, device=device)
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, device=device)
@@ -168,23 +212,24 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
# test compare ops
# ---------------
ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
# real
[
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
(dtype_x, dtype_y, op, 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
for dtype_y in dtypes
] + \
# NaNs
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
[('float32', 'float32', op, mode_x, mode_y) \
for op in ops
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
('nan' , 'nan')]
])
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
@@ -192,9 +237,9 @@ def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
# test unary ops
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, f' -x') for dtype_x in float_dtypes
(dtype_x, ' -x') for dtype_x in dtypes
] + [\
(dtype_x, f' ~x') for dtype_x in int_dtypes
(dtype_x, ' ~x') for dtype_x in int_dtypes
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
@@ -210,7 +255,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
'exp', 'log', 'cos', 'sin'
])
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device)
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# ----------------
@@ -229,12 +274,11 @@ def make_ptr_str(name, shape):
return f"{name} + {' + '.join(offsets)}"
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
['None, :', ':, None',\
'None, :, :', ':, :, None']\
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', 'int32')
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
])
def test_index1d(expr, device='cuda'):
dtype = torch.int32
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
@@ -257,14 +301,15 @@ def test_index1d(expr, device='cuda'):
kernel = patch_kernel(kernel, to_replace)
# torch result
x = triton.testing.random(shape_x, dtype=dtype, device=device)
y = torch.zeros(shape_z, dtype=dtype, device=device)
x = numpy_random(shape_x, dtype_str=dtype_str)
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
z_ref = eval(expr) + y
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
z_tri = to_triton(np.empty_like(z_ref), device=device)
x_tri = to_triton(x)
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# compare
triton.testing.assert_almost_equal(z_ref, z_tri)
assert (z_ref == to_numpy(z_tri)).all()
# ---------------
@@ -316,14 +361,15 @@ def test_tuples():
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
('max', 'int32', mode), ('max', 'float32', mode),\
('min', 'int32', mode), ('min', 'float32', mode),\
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'int32', mode), ('add', 'float32', mode),
('max', 'int32', mode), ('max', 'float32', mode),
('min', 'int32', mode), ('min', 'float32', mode),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
dtype_x = cvt[dtype_x]
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
n_programs = 5
# triton kernel
@@ -334,52 +380,59 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# triton result
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
rs = RandomState(17)
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
if mode == 'all_neg':
x_tri = -torch.abs(x_tri)
x = -np.abs(x)
if mode == 'all_pos':
x_tri = torch.abs(x_tri)
x = np.abs(x)
if mode == 'min_neg':
idx = torch.randint(n_programs, size=(1, )).item()
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = -np.max(np.abs(x)) - 1
if mode == 'max_pos':
idx = torch.randint(n_programs, size=(1, )).item()
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = np.max(np.abs(x)) + 1
x_tri = to_triton(x, device=device)
z_tri = torch.empty([], dtype=dtype_x, device=device)
z_tri.fill_(neutral)
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
kernel[(n_programs, )](x_tri, z_tri)
# torch result
z_ref = torch_op(x_tri).to(dtype_x)
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# compare
exact = op not in ['add']
if exact:
assert z_ref.item() == z_tri.item()
assert z_ref.item() == to_numpy(z_tri).item()
else:
triton.testing.assert_almost_equal(z_ref, z_tri)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
# ---------------
# test cast
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) \
for dtype_x in dtypes\
(dtype_x, dtype_z, False)
for dtype_x in dtypes
for dtype_z in dtypes
] + [
] + [
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True)
])
('float32', 'int32', True),
]
)
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x0 = 43 if dtype_x.startswith('int') else 43.5
x = torch.tensor([x0], dtype=cvt[dtype_x], device=device)
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
if dtype_x.startswith('bfloat'):
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
else:
x = np.array([x0], dtype=getattr(np, dtype_x))
x_tri = to_triton(x)
# triton kernel
@triton.jit
@@ -389,26 +442,31 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
tl.store(Z, z)
# triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, BITCAST=bitcast)
# torch result
if bitcast:
import numpy as np
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
z_ref = torch.from_numpy(z_ref).to(device)
if dtype_z.startswith('bfloat'):
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
else:
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device)
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
# torch result
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
assert bitcast is False
z_ref = x_tri.to(z_tri.dtype)
assert z_tri == z_ref
else:
if bitcast:
z_ref = x.view(getattr(np, dtype_z))
else:
z_ref = x.astype(getattr(np, dtype_z))
assert to_numpy(z_tri) == z_ref
# ---------------
# test reduce
# ---------------
@pytest.mark.parametrize("dtype, shape",
@pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \
for dtype in dtypes\
for shape in [128, 512]])
def test_reduce1d(dtype, shape, device='cuda'):
dtype = cvt[dtype]
def test_reduce1d(dtype_str, shape, device='cuda'):
# triton kernel
@triton.jit
@@ -416,22 +474,22 @@ def test_reduce1d(dtype, shape, device='cuda'):
x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Z, tl.sum(x, axis=0))
x = triton.testing.random((shape,), dtype=dtype, device=device)
rs = RandomState(17)
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
# numpy result
z_ref = np.sum(x).astype(getattr(np, dtype_str))
# triton result
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
kernel[(1,)](x, z_tri, BLOCK=shape)
# torch result
z_ref = torch.sum(x).to(dtype)
x_tri = to_triton(x, device=device)
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@pytest.mark.parametrize("dtype, shape, axis",
[(dtype, shape, 1) \
for dtype in ['float32']\
for shape in [(1, 1024)]])
def test_reduce2d(dtype, shape, axis, device='cuda'):
dtype = cvt[dtype]
@pytest.mark.parametrize("dtype_str, shape, axis", [
('float32', (1, 1024), 1)
])
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
@@ -441,29 +499,30 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z)
# input
x = triton.testing.random(shape, dtype=dtype, device=device)
x = numpy_random(shape, dtype_str=dtype_str)
# triton result
z_tri = torch.empty((shape[0],), dtype=dtype, device=device)
kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
# torch result
z_ref = torch.sum(x, axis=axis).to(dtype)
x_tri = to_triton(x)
z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
# numpy reference result
z_ref = np.sum(x, axis=axis).astype(x.dtype)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype, shape, perm",
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \
for dtype in ['float32']\
for shape in [(128, 128)]\
for perm in [(1, 0)]])
def test_permute(dtype, shape, perm, device='cuda'):
dtype = cvt[dtype]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn,
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
@@ -471,14 +530,15 @@ def test_permute(dtype, shape, perm, device='cuda'):
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
# input
x = triton.testing.random(shape, dtype=dtype, device=device)
x = numpy_random(shape, dtype_str=dtype_str)
# triton result
z_tri = torch.empty_like(x)
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
z_tri = to_triton(np.empty_like(x), device=device)
x_tri = to_triton(x, device=device)
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result
z_ref = x.permute(*perm).contiguous()
z_ref = x.transpose(*perm)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# parse ptx to make sure ld/st are vectorized
@@ -491,13 +551,12 @@ def test_permute(dtype, shape, perm, device='cuda'):
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, dtype=torch.float32, device='cuda'):
torch.manual_seed(0)
def test_dot(epilogue, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
@@ -513,36 +572,38 @@ def test_dot(epilogue, dtype=torch.float32, device='cuda'):
ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None]
if ADD_COLS:
ZCs = Z + off_n * stride_zn
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
tl.store(Zs, z)
# input
M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=dtype, device=device)
y = triton.testing.random((K, N), dtype=dtype, device=device)
rs = RandomState(17)
x = numpy_random((M, K), dtype_str='float32', rs=rs)
y = numpy_random((K, N), dtype_str='float32', rs=rs)
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
# triton result
z = triton.testing.random((M, N), dtype=dtype, device=device)
z_tri = z.clone()
z = numpy_random((M, N), dtype_str='float32', rs=rs)
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
y, y.stride(0), y.stride(1),
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix',
ADD_ROWS = epilogue=='add-rows',
ADD_COLS = epilogue=='add-cols')
# torch result
z_ref = torch.matmul(x.float(), y.float())
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:,0][:, None]
if epilogue == 'add-cols':
z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
@@ -558,7 +619,7 @@ def test_dot_without_load():
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
@@ -571,7 +632,7 @@ def test_arange(start, device='cuda'):
BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit
def _kernel(z, BLOCK: tl.constexpr,
def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK)
val = tl.arange(START, END)
@@ -605,8 +666,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
@@ -616,7 +677,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
o = tl.dot(x, w)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out,
@@ -687,7 +748,7 @@ def test_default():
def _kernel(ret0, ret1, value):
tl.store(ret0, _impl())
tl.store(ret1, _impl(value))
_kernel[(1,)](ret0, ret1, value)
assert ret0.item() == 10
assert ret1.item() == value
@@ -699,5 +760,5 @@ def test_noop(device='cuda'):
@triton.jit
def kernel(x):
pass
x = triton.testing.random((1,), dtype=torch.int32, device=device)
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
kernel[(1, )](x)

View File

@@ -85,31 +85,6 @@ def allclose(x, y, tol=1e-2):
return err <= tol
def assert_allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
assert allclose(x, y, tol)
def random(shape, dtype, device, seed=0):
"""
Override the seed in tests if you're calling this function twice and don't
want the same result for both calls.
"""
torch.manual_seed(seed)
if isinstance(shape, int):
shape = (shape, )
if dtype == torch.bool:
return torch.randint(0, 2, shape, dtype=dtype, device=device)
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
iinfo = torch.iinfo(dtype)
x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.normal(0, 1, shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}')
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -203,7 +178,7 @@ class Benchmark:
styles=None,
):
"""
Constructor
Constructor
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
:type x_names: List[str]
@@ -344,4 +319,4 @@ def get_max_tensorcore_tflops(backend, device):
else:
ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
return tflops
return tflops