diff --git a/python/requirements-test.txt b/python/requirements-test.txt index 48b6d3be3..84893a889 100644 --- a/python/requirements-test.txt +++ b/python/requirements-test.txt @@ -1,2 +1,3 @@ +numpy pytest scipy >= 1.7.1 diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index aa0e7430a..fe33c9c6a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,31 +1,59 @@ +import copy +import itertools +import re +from typing import Optional + +import numpy as np +import pytest import torch +from numpy.random import RandomState + import triton import triton.language as tl -import copy -import pytest -import ast -import itertools - -torch.manual_seed(0) - -# convert from string to torch.dtype -# Necessary because doesn't print torch.dtype properly -cvt = { - 'bool': torch.bool, - 'int8': torch.int8, - 'int16': torch.int16, - 'int32': torch.int32, - 'int64': torch.int64, - 'bfloat16': torch.bfloat16, - 'float16': torch.float16, - 'float32': torch.float32, - 'float64': torch.float64, -} int_dtypes = ['int8', 'int16', 'int32', 'int64'] float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + float_dtypes +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + dtype = getattr(np, dtype_str) + if dtype_str in int_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype) + x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype) + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor: + # For now, this always converts to a torch tensor, but when we add unsigned + # integers, it will also support TensorWrapper, since torch doesn't have + # unsigned support. + return torch.tensor(x, device=device) + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + def patch_kernel(template, to_replace): kernel = copy.deepcopy(template) @@ -34,19 +62,18 @@ def patch_kernel(template, to_replace): return kernel -@pytest.mark.parametrize("dtype_x", [ - (dtype_x) for dtype_x in dtypes -]) +@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 @triton.jit def kernel(X, SIZE: tl.constexpr): pass - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device) kernel[(1, )](x, SIZE=SIZE, num_warps=4) + # generic test functions -def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -58,18 +85,36 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) - if 'log' in expr: x = torch.abs(x) + 0.01 + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 # reference result - z_ref = eval(expr if torch_expr is None else torch_expr) + z_ref = eval(expr if numpy_expr is None else numpy_expr) # triton result - z_tri = torch.empty_like(z_ref) - kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + x_tri = to_triton(x, device=device) + z_tri = to_triton(np.empty_like(z_ref), device=device) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) # compare - triton.testing.assert_almost_equal(z_ref, z_tri) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'): +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -82,27 +127,24 @@ def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y= kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17) - y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144) - if mode_x == 'nan': x[:] = float('nan') - if mode_y == 'nan': y[:] = float('nan') + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') # reference result - z_ref = eval(expr if torch_expr is None else torch_expr) + z_ref = eval(expr if numpy_expr is None else numpy_expr) + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if dtype_z is not None: + z_ref = z_ref.astype(dtype_z) # triton result - z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) - kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) - # compare - triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr) - - -def _fake_fmod(x, y): - """ - Triton % (for both integers and floats) has the same semantics as torch - fmod, but torch fmod doesn't work on integers until torch 1.8. - `_fake_fmod` gives the same semantics but works on all versions of torch. - """ - z = torch.remainder(x, y) - return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: @@ -130,36 +172,38 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes: - # LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders. - torch_expr = '_fake_fmod(x, y)' + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = 'np.fmod(x, y)' elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): # Triton promotes 16-bit floating-point / and % to 32-bit because there # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. - torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)' + numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' else: - torch_expr = None + numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): - with pytest.raises(AssertionError, match='Arrays are not almost equal'): - _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) + with pytest.raises(AssertionError, match='Not equal to tolerance'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) else: - _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) - + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) # --------------- # test bitwise ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ - (dtype_x, dtype_y, f' x {op} y') \ - for op in ['&', '|', '^'] \ - for dtype_x in dtypes \ - for dtype_y in dtypes +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + for dtype_y in dtypes ]) -def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'): +def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): + expr = f'x {op} y' if 'float' in dtype_x + dtype_y: - with pytest.raises(RuntimeError): - _test_binary(dtype_x, dtype_y, expr, device=device) + with pytest.raises(triton.code_gen.CompilationError) as exc_info: + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) + # The CompilationError must have been caused by a C++ exception with this text. + assert re.match('invalid operands of type', str(exc_info.value.__cause__)) else: _test_binary(dtype_x, dtype_y, expr, device=device) @@ -168,23 +212,24 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'): # test compare ops # --------------- ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \ +@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \ # real [ - (dtype_x, dtype_y, f' x {op} y', 'real', 'real') \ + (dtype_x, dtype_y, op, 'real', 'real') \ for op in ops \ for dtype_x in dtypes \ for dtype_y in dtypes ] + \ # NaNs -[('float32', 'float32', f' x {op} y', mode_x, mode_y) \ +[('float32', 'float32', op, mode_x, mode_y) \ for op in ops - for mode_x, mode_y in [('nan' , 'real'), - ('real', 'nan'), + for mode_x, mode_y in [('nan' , 'real'), + ('real', 'nan'), ('nan' , 'nan')] ]) -def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'): +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): + expr = f'x {op} y' _test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device) @@ -192,9 +237,9 @@ def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'): # test unary ops # --------------- @pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, f' -x') for dtype_x in float_dtypes + (dtype_x, ' -x') for dtype_x in dtypes ] + [\ - (dtype_x, f' ~x') for dtype_x in int_dtypes + (dtype_x, ' ~x') for dtype_x in int_dtypes ]) def test_unary_op(dtype_x, expr, device='cuda'): _test_unary(dtype_x, expr, device=device) @@ -210,7 +255,7 @@ def test_unary_op(dtype_x, expr, device='cuda'): 'exp', 'log', 'cos', 'sin' ]) def test_math_op(expr, device='cuda'): - _test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device) + _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) # ---------------- @@ -229,12 +274,11 @@ def make_ptr_str(name, shape): return f"{name} + {' + '.join(offsets)}" -@pytest.mark.parametrize("expr", [f'x[{s}]' for s in - ['None, :', ':, None',\ - 'None, :, :', ':, :, None']\ +@pytest.mark.parametrize("expr, dtype_str", [ + (f'x[{s}]', 'int32') + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] ]) -def test_index1d(expr, device='cuda'): - dtype = torch.int32 +def test_index1d(expr, dtype_str, device='cuda'): rank_x = expr.count(':') rank_y = expr.count(',') + 1 shape_x = [32 for _ in range(rank_x)] @@ -257,14 +301,15 @@ def test_index1d(expr, device='cuda'): kernel = patch_kernel(kernel, to_replace) # torch result - x = triton.testing.random(shape_x, dtype=dtype, device=device) - y = torch.zeros(shape_z, dtype=dtype, device=device) + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) z_ref = eval(expr) + y # triton result - z_tri = torch.empty_like(z_ref) - kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0]) + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x) + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) # compare - triton.testing.assert_almost_equal(z_ref, z_tri) + assert (z_ref == to_numpy(z_tri)).all() # --------------- @@ -316,14 +361,15 @@ def test_tuples(): # --------------- # test atomics # --------------- -@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([ - [('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \ - ('max', 'int32', mode), ('max', 'float32', mode),\ - ('min', 'int32', mode), ('min', 'float32', mode),\ +@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([ + [ + ('add', 'float16', mode), + ('add', 'int32', mode), ('add', 'float32', mode), + ('max', 'int32', mode), ('max', 'float32', mode), + ('min', 'int32', mode), ('min', 'float32', mode), ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) -def test_atomic_rmw(op, dtype_x, mode, device='cuda'): - dtype_x = cvt[dtype_x] +def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): n_programs = 5 # triton kernel @@ -334,52 +380,59 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): old = GENERATE_TEST_HERE kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'}) - torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op] - max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min - min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] # triton result - x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device) + rs = RandomState(17) + x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs) if mode == 'all_neg': - x_tri = -torch.abs(x_tri) + x = -np.abs(x) if mode == 'all_pos': - x_tri = torch.abs(x_tri) + x = np.abs(x) if mode == 'min_neg': - idx = torch.randint(n_programs, size=(1, )).item() - x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1 + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 if mode == 'max_pos': - idx = torch.randint(n_programs, size=(1, )).item() - x_tri[idx] = torch.max(torch.abs(x_tri)) + 1 + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) - z_tri = torch.empty([], dtype=dtype_x, device=device) - z_tri.fill_(neutral) + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) kernel[(n_programs, )](x_tri, z_tri) # torch result - z_ref = torch_op(x_tri).to(dtype_x) + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) # compare exact = op not in ['add'] if exact: - assert z_ref.item() == z_tri.item() + assert z_ref.item() == to_numpy(z_tri).item() else: - triton.testing.assert_almost_equal(z_ref, z_tri) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001) # --------------- # test cast # --------------- @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ - (dtype_x, dtype_z, False) \ - for dtype_x in dtypes\ + (dtype_x, dtype_z, False) + for dtype_x in dtypes for dtype_z in dtypes -] + [ +] + [ ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), - ('float32', 'int32', True) -]) + ('float32', 'int32', True), +] +) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): - x0 = 43 if dtype_x.startswith('int') else 43.5 - x = torch.tensor([x0], dtype=cvt[dtype_x], device=device) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + x0 = 43 if dtype_x in int_dtypes else 43.5 + if dtype_x.startswith('bfloat'): + x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device) + else: + x = np.array([x0], dtype=getattr(np, dtype_x)) + x_tri = to_triton(x) # triton kernel @triton.jit @@ -389,26 +442,31 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): tl.store(Z, z) # triton result - z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device) - kernel[(1, )](x, z_tri, BITCAST=bitcast) - # torch result - if bitcast: - import numpy as np - z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z)) - z_ref = torch.from_numpy(z_ref).to(device) + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device) else: - z_ref = x.to(z_tri.dtype) - assert z_tri == z_ref + z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + assert z_tri == z_ref + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z)) + else: + z_ref = x.astype(getattr(np, dtype_z)) + assert to_numpy(z_tri) == z_ref # --------------- # test reduce # --------------- -@pytest.mark.parametrize("dtype, shape", +@pytest.mark.parametrize("dtype_str, shape", [(dtype, shape) \ for dtype in dtypes\ for shape in [128, 512]]) -def test_reduce1d(dtype, shape, device='cuda'): - dtype = cvt[dtype] +def test_reduce1d(dtype_str, shape, device='cuda'): # triton kernel @triton.jit @@ -416,22 +474,22 @@ def test_reduce1d(dtype, shape, device='cuda'): x = tl.load(X + tl.arange(0, BLOCK)) tl.store(Z, tl.sum(x, axis=0)) - x = triton.testing.random((shape,), dtype=dtype, device=device) + rs = RandomState(17) + x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) + # numpy result + z_ref = np.sum(x).astype(getattr(np, dtype_str)) # triton result - z_tri = triton.testing.random((1,), dtype=dtype, device=device) - kernel[(1,)](x, z_tri, BLOCK=shape) - # torch result - z_ref = torch.sum(x).to(dtype) + x_tri = to_triton(x, device=device) + z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device) + kernel[(1,)](x_tri, z_tri, BLOCK=shape) # compare - triton.testing.assert_almost_equal(z_tri, z_ref) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -@pytest.mark.parametrize("dtype, shape, axis", - [(dtype, shape, 1) \ - for dtype in ['float32']\ - for shape in [(1, 1024)]]) -def test_reduce2d(dtype, shape, axis, device='cuda'): - dtype = cvt[dtype] +@pytest.mark.parametrize("dtype_str, shape, axis", [ + ('float32', (1, 1024), 1) +]) +def test_reduce2d(dtype_str, shape, axis, device='cuda'): # triton kernel @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): @@ -441,29 +499,30 @@ def test_reduce2d(dtype, shape, axis, device='cuda'): z = tl.sum(x, axis=AXIS) tl.store(Z + range_m, z) # input - x = triton.testing.random(shape, dtype=dtype, device=device) + x = numpy_random(shape, dtype_str=dtype_str) # triton result - z_tri = torch.empty((shape[0],), dtype=dtype, device=device) - kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) - # torch result - z_ref = torch.sum(x, axis=axis).to(dtype) + x_tri = to_triton(x) + z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device) + kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + # numpy reference result + z_ref = np.sum(x, axis=axis).astype(x.dtype) # compare - triton.testing.assert_almost_equal(z_tri, z_ref) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # --------------- # test permute # --------------- -@pytest.mark.parametrize("dtype, shape, perm", +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) \ for dtype in ['float32']\ for shape in [(128, 128)]\ - for perm in [(1, 0)]]) -def test_permute(dtype, shape, perm, device='cuda'): - dtype = cvt[dtype] + for perm in [(1, 0)]]) +def test_permute(dtype_str, shape, perm, device='cuda'): + # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, + def kernel(X, stride_xm, stride_xn, + Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) @@ -471,14 +530,15 @@ def test_permute(dtype, shape, perm, device='cuda'): Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn tl.store(Zs, tl.load(Xs)) # input - x = triton.testing.random(shape, dtype=dtype, device=device) + x = numpy_random(shape, dtype_str=dtype_str) # triton result - z_tri = torch.empty_like(x) - pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), - z_tri, z_tri.stride(1), z_tri.stride(0), - BLOCK_M=shape[0], BLOCK_N=shape[1]) + z_tri = to_triton(np.empty_like(x), device=device) + x_tri = to_triton(x, device=device) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), + z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1]) # torch result - z_ref = x.permute(*perm).contiguous() + z_ref = x.transpose(*perm) # compare triton.testing.assert_almost_equal(z_tri, z_ref) # parse ptx to make sure ld/st are vectorized @@ -491,13 +551,12 @@ def test_permute(dtype, shape, perm, device='cuda'): # --------------- @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) -def test_dot(epilogue, dtype=torch.float32, device='cuda'): - torch.manual_seed(0) +def test_dot(epilogue, device='cuda'): # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xk, + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, - Z, stride_zm, stride_zn, + Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): off_m = tl.arange(0, BLOCK_M) @@ -513,36 +572,38 @@ def test_dot(epilogue, dtype=torch.float32, device='cuda'): ZRs = Z + off_m * stride_zm z += tl.load(ZRs)[:, None] if ADD_COLS: - ZCs = Z + off_n * stride_zn + ZCs = Z + off_n * stride_zn z += tl.load(ZCs)[None, :] tl.store(Zs, z) # input M, N, K = 64, 64, 32 - x = triton.testing.random((M, K), dtype=dtype, device=device) - y = triton.testing.random((K, N), dtype=dtype, device=device) + rs = RandomState(17) + x = numpy_random((M, K), dtype_str='float32', rs=rs) + y = numpy_random((K, N), dtype_str='float32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) # triton result - z = triton.testing.random((M, N), dtype=dtype, device=device) - z_tri = z.clone() + z = numpy_random((M, N), dtype_str='float32', rs=rs) + z_tri = to_triton(z, device=device) if epilogue == 'trans': z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) - pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), - y, y.stride(0), y.stride(1), + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), + y_tri, y_tri.stride(0), y_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX = epilogue=='add-matrix', ADD_ROWS = epilogue=='add-rows', ADD_COLS = epilogue=='add-cols') # torch result - z_ref = torch.matmul(x.float(), y.float()) + z_ref = np.matmul(x, y) if epilogue == 'add-matrix': z_ref += z if epilogue == 'add-rows': z_ref += z[:,0][:, None] if epilogue == 'add-cols': z_ref += z[0,:][None, :] - z_ref = z_ref.to(torch.float16) # compare - triton.testing.assert_almost_equal(z_tri, z_ref) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx @@ -558,7 +619,7 @@ def test_dot_without_load(): c = tl.dot(a, b) pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] tl.store(pout, c) - + out = torch.ones((32,32), dtype=torch.float32, device="cuda") kernel[(1,)](out) @@ -571,7 +632,7 @@ def test_arange(start, device='cuda'): BLOCK = 128 z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) @triton.jit - def _kernel(z, BLOCK: tl.constexpr, + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): off = tl.arange(0, BLOCK) val = tl.arange(START, END) @@ -605,8 +666,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'): N_offsets = tl.arange(0, N) K_offsets = tl.arange(0, K) - in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] - in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] # Load inputs. x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) @@ -616,7 +677,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'): o = tl.dot(x, w) # Store output - output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) pgm = _kernel[(1,)](in1, in2, out, @@ -687,7 +748,7 @@ def test_default(): def _kernel(ret0, ret1, value): tl.store(ret0, _impl()) tl.store(ret1, _impl(value)) - + _kernel[(1,)](ret0, ret1, value) assert ret0.item() == 10 assert ret1.item() == value @@ -699,5 +760,5 @@ def test_noop(device='cuda'): @triton.jit def kernel(x): pass - x = triton.testing.random((1,), dtype=torch.int32, device=device) + x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) kernel[(1, )](x) diff --git a/python/triton/testing.py b/python/triton/testing.py index f274e808f..eef7f5be6 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -85,31 +85,6 @@ def allclose(x, y, tol=1e-2): return err <= tol -def assert_allclose(x, y, tol=1e-2): - assert x.dtype == y.dtype - assert allclose(x, y, tol) - - -def random(shape, dtype, device, seed=0): - """ - Override the seed in tests if you're calling this function twice and don't - want the same result for both calls. - """ - torch.manual_seed(seed) - if isinstance(shape, int): - shape = (shape, ) - if dtype == torch.bool: - return torch.randint(0, 2, shape, dtype=dtype, device=device) - if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - iinfo = torch.iinfo(dtype) - x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device) - x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. - return x - if dtype in [torch.float16, torch.float32, torch.float64]: - return torch.normal(0, 1, shape, dtype=dtype, device=device) - raise RuntimeError(f'Unknown dtype {dtype}') - - def nvsmi(attrs): attrs = ','.join(attrs) cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] @@ -203,7 +178,7 @@ class Benchmark: styles=None, ): """ - Constructor + Constructor :param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value. :type x_names: List[str] @@ -344,4 +319,4 @@ def get_max_tensorcore_tflops(backend, device): else: ops_per_sub_core = 512 tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) - return tflops \ No newline at end of file + return tflops