[TEST] use numpy for reference results in test_core.py (#409)
Since numpy supports unsigned integers, and pytorch doesn't, this will make it easier to test unsigned integer support. This adds an explicit requirement for numpy in tests, but we already required scipy, so it was already an implicit dependency.
This commit is contained in:
committed by
GitHub
parent
03f1256f60
commit
d8db0308cb
@@ -1,2 +1,3 @@
|
|||||||
|
numpy
|
||||||
pytest
|
pytest
|
||||||
scipy >= 1.7.1
|
scipy >= 1.7.1
|
||||||
|
@@ -1,31 +1,59 @@
|
|||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from numpy.random import RandomState
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import copy
|
|
||||||
import pytest
|
|
||||||
import ast
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
# convert from string to torch.dtype
|
|
||||||
# Necessary because doesn't print torch.dtype properly
|
|
||||||
cvt = {
|
|
||||||
'bool': torch.bool,
|
|
||||||
'int8': torch.int8,
|
|
||||||
'int16': torch.int16,
|
|
||||||
'int32': torch.int32,
|
|
||||||
'int64': torch.int64,
|
|
||||||
'bfloat16': torch.bfloat16,
|
|
||||||
'float16': torch.float16,
|
|
||||||
'float32': torch.float32,
|
|
||||||
'float64': torch.float64,
|
|
||||||
}
|
|
||||||
|
|
||||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||||
float_dtypes = ['float16', 'float32', 'float64']
|
float_dtypes = ['float16', 'float32', 'float64']
|
||||||
dtypes = int_dtypes + float_dtypes
|
dtypes = int_dtypes + float_dtypes
|
||||||
|
|
||||||
|
def _bitwidth(dtype: str) -> int:
|
||||||
|
# ex.: "int64" -> 64
|
||||||
|
return int(re.search(r'(\d+)$', dtype).group(1))
|
||||||
|
|
||||||
|
|
||||||
|
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||||
|
"""
|
||||||
|
Override `rs` if you're calling this function twice and don't want the same
|
||||||
|
result for both calls.
|
||||||
|
"""
|
||||||
|
if isinstance(shape, int):
|
||||||
|
shape = (shape, )
|
||||||
|
if rs is None:
|
||||||
|
rs = RandomState(seed=17)
|
||||||
|
dtype = getattr(np, dtype_str)
|
||||||
|
if dtype_str in int_dtypes:
|
||||||
|
iinfo = np.iinfo(getattr(np, dtype_str))
|
||||||
|
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
|
||||||
|
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||||
|
return x
|
||||||
|
elif dtype_str in float_dtypes:
|
||||||
|
return rs.normal(0, 1, shape).astype(dtype)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||||
|
|
||||||
|
|
||||||
|
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
|
||||||
|
# For now, this always converts to a torch tensor, but when we add unsigned
|
||||||
|
# integers, it will also support TensorWrapper, since torch doesn't have
|
||||||
|
# unsigned support.
|
||||||
|
return torch.tensor(x, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.cpu().numpy()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
||||||
|
|
||||||
|
|
||||||
def patch_kernel(template, to_replace):
|
def patch_kernel(template, to_replace):
|
||||||
kernel = copy.deepcopy(template)
|
kernel = copy.deepcopy(template)
|
||||||
@@ -34,19 +62,18 @@ def patch_kernel(template, to_replace):
|
|||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_x", [
|
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||||
(dtype_x) for dtype_x in dtypes
|
|
||||||
])
|
|
||||||
def test_empty_kernel(dtype_x, device='cuda'):
|
def test_empty_kernel(dtype_x, device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, SIZE: tl.constexpr):
|
def kernel(X, SIZE: tl.constexpr):
|
||||||
pass
|
pass
|
||||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device)
|
||||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||||
|
|
||||||
|
|
||||||
# generic test functions
|
# generic test functions
|
||||||
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -58,18 +85,36 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
|||||||
|
|
||||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||||
# inputs
|
# inputs
|
||||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
x = numpy_random(SIZE, dtype_str=dtype_x)
|
||||||
if 'log' in expr: x = torch.abs(x) + 0.01
|
if 'log' in expr:
|
||||||
|
x = np.abs(x) + 0.01
|
||||||
# reference result
|
# reference result
|
||||||
z_ref = eval(expr if torch_expr is None else torch_expr)
|
z_ref = eval(expr if numpy_expr is None else numpy_expr)
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = torch.empty_like(z_ref)
|
x_tri = to_triton(x, device=device)
|
||||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||||
|
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||||
|
"""
|
||||||
|
Given two dtype strings, returns the numpy dtype Triton thinks binary
|
||||||
|
operations on the two types should return. Returns None if the return value
|
||||||
|
matches numpy. This is generally needed because Triton and pytorch return
|
||||||
|
narrower floating point types than numpy in mixed operations.
|
||||||
|
"""
|
||||||
|
overrides = {
|
||||||
|
('float16', 'int16'): np.float16,
|
||||||
|
('float16', 'int32'): np.float16,
|
||||||
|
('float16', 'int64'): np.float16,
|
||||||
|
}
|
||||||
|
key = (a, b) if a < b else (b, a)
|
||||||
|
return overrides.get(key)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -82,27 +127,24 @@ def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y=
|
|||||||
|
|
||||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||||
# inputs
|
# inputs
|
||||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17)
|
rs = RandomState(17)
|
||||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144)
|
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||||
if mode_x == 'nan': x[:] = float('nan')
|
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
||||||
if mode_y == 'nan': y[:] = float('nan')
|
if mode_x == 'nan':
|
||||||
|
x[:] = float('nan')
|
||||||
|
if mode_y == 'nan':
|
||||||
|
y[:] = float('nan')
|
||||||
# reference result
|
# reference result
|
||||||
z_ref = eval(expr if torch_expr is None else torch_expr)
|
z_ref = eval(expr if numpy_expr is None else numpy_expr)
|
||||||
|
dtype_z = _binary_op_dtype_override(dtype_x, dtype_y)
|
||||||
|
if dtype_z is not None:
|
||||||
|
z_ref = z_ref.astype(dtype_z)
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
x_tri = to_triton(x, device=device)
|
||||||
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
y_tri = to_triton(y, device=device)
|
||||||
# compare
|
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
|
||||||
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
|
kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
||||||
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
|
||||||
|
|
||||||
def _fake_fmod(x, y):
|
|
||||||
"""
|
|
||||||
Triton % (for both integers and floats) has the same semantics as torch
|
|
||||||
fmod, but torch fmod doesn't work on integers until torch 1.8.
|
|
||||||
`_fake_fmod` gives the same semantics but works on all versions of torch.
|
|
||||||
"""
|
|
||||||
z = torch.remainder(x, y)
|
|
||||||
return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z)
|
|
||||||
|
|
||||||
|
|
||||||
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||||
@@ -130,36 +172,38 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
|||||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
expr = f' x {op} y'
|
expr = f' x {op} y'
|
||||||
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
||||||
# LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders.
|
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||||
torch_expr = '_fake_fmod(x, y)'
|
numpy_expr = 'np.fmod(x, y)'
|
||||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||||
# Triton promotes 16-bit floating-point / and % to 32-bit because there
|
# Triton promotes 16-bit floating-point / and % to 32-bit because there
|
||||||
# are no native div or FRem operations on float16. Since we have to
|
# are no native div or FRem operations on float16. Since we have to
|
||||||
# convert anyway, we may as well take the accuracy bump.
|
# convert anyway, we may as well take the accuracy bump.
|
||||||
torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)'
|
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
|
||||||
else:
|
else:
|
||||||
torch_expr = None
|
numpy_expr = None
|
||||||
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
||||||
with pytest.raises(AssertionError, match='Arrays are not almost equal'):
|
with pytest.raises(AssertionError, match='Not equal to tolerance'):
|
||||||
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
else:
|
else:
|
||||||
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test bitwise ops
|
# test bitwise ops
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
(dtype_x, dtype_y, f' x {op} y') \
|
(dtype_x, dtype_y, op)
|
||||||
for op in ['&', '|', '^'] \
|
for op in ['&', '|', '^']
|
||||||
for dtype_x in dtypes \
|
for dtype_x in dtypes
|
||||||
for dtype_y in dtypes
|
for dtype_y in dtypes
|
||||||
])
|
])
|
||||||
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
|
expr = f'x {op} y'
|
||||||
if 'float' in dtype_x + dtype_y:
|
if 'float' in dtype_x + dtype_y:
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||||
|
# The CompilationError must have been caused by a C++ exception with this text.
|
||||||
|
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||||
else:
|
else:
|
||||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||||
|
|
||||||
@@ -168,23 +212,24 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
|||||||
# test compare ops
|
# test compare ops
|
||||||
# ---------------
|
# ---------------
|
||||||
ops = ['==', '!=', '>', '<', '>=', '<=']
|
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
|
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
|
||||||
# real
|
# real
|
||||||
[
|
[
|
||||||
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
|
(dtype_x, dtype_y, op, 'real', 'real') \
|
||||||
for op in ops \
|
for op in ops \
|
||||||
for dtype_x in dtypes \
|
for dtype_x in dtypes \
|
||||||
for dtype_y in dtypes
|
for dtype_y in dtypes
|
||||||
] + \
|
] + \
|
||||||
# NaNs
|
# NaNs
|
||||||
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
|
[('float32', 'float32', op, mode_x, mode_y) \
|
||||||
for op in ops
|
for op in ops
|
||||||
for mode_x, mode_y in [('nan' , 'real'),
|
for mode_x, mode_y in [('nan' , 'real'),
|
||||||
('real', 'nan'),
|
('real', 'nan'),
|
||||||
('nan' , 'nan')]
|
('nan' , 'nan')]
|
||||||
|
|
||||||
])
|
])
|
||||||
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
|
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||||
|
expr = f'x {op} y'
|
||||||
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||||
|
|
||||||
|
|
||||||
@@ -192,9 +237,9 @@ def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
|
|||||||
# test unary ops
|
# test unary ops
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, expr", [
|
@pytest.mark.parametrize("dtype_x, expr", [
|
||||||
(dtype_x, f' -x') for dtype_x in float_dtypes
|
(dtype_x, ' -x') for dtype_x in dtypes
|
||||||
] + [\
|
] + [\
|
||||||
(dtype_x, f' ~x') for dtype_x in int_dtypes
|
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||||
])
|
])
|
||||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||||
_test_unary(dtype_x, expr, device=device)
|
_test_unary(dtype_x, expr, device=device)
|
||||||
@@ -210,7 +255,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
|||||||
'exp', 'log', 'cos', 'sin'
|
'exp', 'log', 'cos', 'sin'
|
||||||
])
|
])
|
||||||
def test_math_op(expr, device='cuda'):
|
def test_math_op(expr, device='cuda'):
|
||||||
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device)
|
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||||
|
|
||||||
|
|
||||||
# ----------------
|
# ----------------
|
||||||
@@ -229,12 +274,11 @@ def make_ptr_str(name, shape):
|
|||||||
return f"{name} + {' + '.join(offsets)}"
|
return f"{name} + {' + '.join(offsets)}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
|
@pytest.mark.parametrize("expr, dtype_str", [
|
||||||
['None, :', ':, None',\
|
(f'x[{s}]', 'int32')
|
||||||
'None, :, :', ':, :, None']\
|
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||||
])
|
])
|
||||||
def test_index1d(expr, device='cuda'):
|
def test_index1d(expr, dtype_str, device='cuda'):
|
||||||
dtype = torch.int32
|
|
||||||
rank_x = expr.count(':')
|
rank_x = expr.count(':')
|
||||||
rank_y = expr.count(',') + 1
|
rank_y = expr.count(',') + 1
|
||||||
shape_x = [32 for _ in range(rank_x)]
|
shape_x = [32 for _ in range(rank_x)]
|
||||||
@@ -257,14 +301,15 @@ def test_index1d(expr, device='cuda'):
|
|||||||
kernel = patch_kernel(kernel, to_replace)
|
kernel = patch_kernel(kernel, to_replace)
|
||||||
|
|
||||||
# torch result
|
# torch result
|
||||||
x = triton.testing.random(shape_x, dtype=dtype, device=device)
|
x = numpy_random(shape_x, dtype_str=dtype_str)
|
||||||
y = torch.zeros(shape_z, dtype=dtype, device=device)
|
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
||||||
z_ref = eval(expr) + y
|
z_ref = eval(expr) + y
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = torch.empty_like(z_ref)
|
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||||
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
|
x_tri = to_triton(x)
|
||||||
|
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
assert (z_ref == to_numpy(z_tri)).all()
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
@@ -316,14 +361,15 @@ def test_tuples():
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test atomics
|
# test atomics
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
|
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||||
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
|
[
|
||||||
('max', 'int32', mode), ('max', 'float32', mode),\
|
('add', 'float16', mode),
|
||||||
('min', 'int32', mode), ('min', 'float32', mode),\
|
('add', 'int32', mode), ('add', 'float32', mode),
|
||||||
|
('max', 'int32', mode), ('max', 'float32', mode),
|
||||||
|
('min', 'int32', mode), ('min', 'float32', mode),
|
||||||
]
|
]
|
||||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||||
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||||
dtype_x = cvt[dtype_x]
|
|
||||||
n_programs = 5
|
n_programs = 5
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@@ -334,52 +380,59 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
|||||||
old = GENERATE_TEST_HERE
|
old = GENERATE_TEST_HERE
|
||||||
|
|
||||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||||
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
|
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||||
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
|
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||||
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
|
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||||
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||||
|
|
||||||
# triton result
|
# triton result
|
||||||
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
|
rs = RandomState(17)
|
||||||
|
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
|
||||||
if mode == 'all_neg':
|
if mode == 'all_neg':
|
||||||
x_tri = -torch.abs(x_tri)
|
x = -np.abs(x)
|
||||||
if mode == 'all_pos':
|
if mode == 'all_pos':
|
||||||
x_tri = torch.abs(x_tri)
|
x = np.abs(x)
|
||||||
if mode == 'min_neg':
|
if mode == 'min_neg':
|
||||||
idx = torch.randint(n_programs, size=(1, )).item()
|
idx = rs.randint(n_programs, size=(1, )).item()
|
||||||
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
|
x[idx] = -np.max(np.abs(x)) - 1
|
||||||
if mode == 'max_pos':
|
if mode == 'max_pos':
|
||||||
idx = torch.randint(n_programs, size=(1, )).item()
|
idx = rs.randint(n_programs, size=(1, )).item()
|
||||||
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
|
x[idx] = np.max(np.abs(x)) + 1
|
||||||
|
x_tri = to_triton(x, device=device)
|
||||||
|
|
||||||
z_tri = torch.empty([], dtype=dtype_x, device=device)
|
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||||
z_tri.fill_(neutral)
|
|
||||||
kernel[(n_programs, )](x_tri, z_tri)
|
kernel[(n_programs, )](x_tri, z_tri)
|
||||||
# torch result
|
# torch result
|
||||||
z_ref = torch_op(x_tri).to(dtype_x)
|
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||||
# compare
|
# compare
|
||||||
exact = op not in ['add']
|
exact = op not in ['add']
|
||||||
if exact:
|
if exact:
|
||||||
assert z_ref.item() == z_tri.item()
|
assert z_ref.item() == to_numpy(z_tri).item()
|
||||||
else:
|
else:
|
||||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test cast
|
# test cast
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||||
(dtype_x, dtype_z, False) \
|
(dtype_x, dtype_z, False)
|
||||||
for dtype_x in dtypes\
|
for dtype_x in dtypes
|
||||||
for dtype_z in dtypes
|
for dtype_z in dtypes
|
||||||
] + [
|
] + [
|
||||||
('float32', 'bfloat16', False),
|
('float32', 'bfloat16', False),
|
||||||
('bfloat16', 'float32', False),
|
('bfloat16', 'float32', False),
|
||||||
('float32', 'int32', True)
|
('float32', 'int32', True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||||
x0 = 43 if dtype_x.startswith('int') else 43.5
|
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||||
x = torch.tensor([x0], dtype=cvt[dtype_x], device=device)
|
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||||
|
if dtype_x.startswith('bfloat'):
|
||||||
|
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
||||||
|
else:
|
||||||
|
x = np.array([x0], dtype=getattr(np, dtype_x))
|
||||||
|
x_tri = to_triton(x)
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -389,26 +442,31 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
tl.store(Z, z)
|
tl.store(Z, z)
|
||||||
|
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
|
if dtype_z.startswith('bfloat'):
|
||||||
kernel[(1, )](x, z_tri, BITCAST=bitcast)
|
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
||||||
# torch result
|
|
||||||
if bitcast:
|
|
||||||
import numpy as np
|
|
||||||
z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z))
|
|
||||||
z_ref = torch.from_numpy(z_ref).to(device)
|
|
||||||
else:
|
else:
|
||||||
z_ref = x.to(z_tri.dtype)
|
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device)
|
||||||
assert z_tri == z_ref
|
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
||||||
|
# torch result
|
||||||
|
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||||
|
assert bitcast is False
|
||||||
|
z_ref = x_tri.to(z_tri.dtype)
|
||||||
|
assert z_tri == z_ref
|
||||||
|
else:
|
||||||
|
if bitcast:
|
||||||
|
z_ref = x.view(getattr(np, dtype_z))
|
||||||
|
else:
|
||||||
|
z_ref = x.astype(getattr(np, dtype_z))
|
||||||
|
assert to_numpy(z_tri) == z_ref
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test reduce
|
# test reduce
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype, shape",
|
@pytest.mark.parametrize("dtype_str, shape",
|
||||||
[(dtype, shape) \
|
[(dtype, shape) \
|
||||||
for dtype in dtypes\
|
for dtype in dtypes\
|
||||||
for shape in [128, 512]])
|
for shape in [128, 512]])
|
||||||
def test_reduce1d(dtype, shape, device='cuda'):
|
def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||||
dtype = cvt[dtype]
|
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -416,22 +474,22 @@ def test_reduce1d(dtype, shape, device='cuda'):
|
|||||||
x = tl.load(X + tl.arange(0, BLOCK))
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
tl.store(Z, tl.sum(x, axis=0))
|
tl.store(Z, tl.sum(x, axis=0))
|
||||||
|
|
||||||
x = triton.testing.random((shape,), dtype=dtype, device=device)
|
rs = RandomState(17)
|
||||||
|
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||||
|
# numpy result
|
||||||
|
z_ref = np.sum(x).astype(getattr(np, dtype_str))
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
|
x_tri = to_triton(x, device=device)
|
||||||
kernel[(1,)](x, z_tri, BLOCK=shape)
|
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
|
||||||
# torch result
|
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||||
z_ref = torch.sum(x).to(dtype)
|
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype, shape, axis",
|
@pytest.mark.parametrize("dtype_str, shape, axis", [
|
||||||
[(dtype, shape, 1) \
|
('float32', (1, 1024), 1)
|
||||||
for dtype in ['float32']\
|
])
|
||||||
for shape in [(1, 1024)]])
|
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||||
def test_reduce2d(dtype, shape, axis, device='cuda'):
|
|
||||||
dtype = cvt[dtype]
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||||
@@ -441,29 +499,30 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
|
|||||||
z = tl.sum(x, axis=AXIS)
|
z = tl.sum(x, axis=AXIS)
|
||||||
tl.store(Z + range_m, z)
|
tl.store(Z + range_m, z)
|
||||||
# input
|
# input
|
||||||
x = triton.testing.random(shape, dtype=dtype, device=device)
|
x = numpy_random(shape, dtype_str=dtype_str)
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = torch.empty((shape[0],), dtype=dtype, device=device)
|
x_tri = to_triton(x)
|
||||||
kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device)
|
||||||
# torch result
|
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||||
z_ref = torch.sum(x, axis=axis).to(dtype)
|
# numpy reference result
|
||||||
|
z_ref = np.sum(x, axis=axis).astype(x.dtype)
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test permute
|
# test permute
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype, shape, perm",
|
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||||
[(dtype, shape, perm) \
|
[(dtype, shape, perm) \
|
||||||
for dtype in ['float32']\
|
for dtype in ['float32']\
|
||||||
for shape in [(128, 128)]\
|
for shape in [(128, 128)]\
|
||||||
for perm in [(1, 0)]])
|
for perm in [(1, 0)]])
|
||||||
def test_permute(dtype, shape, perm, device='cuda'):
|
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||||
dtype = cvt[dtype]
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, stride_xm, stride_xn,
|
def kernel(X, stride_xm, stride_xn,
|
||||||
Z, stride_zm, stride_zn,
|
Z, stride_zm, stride_zn,
|
||||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||||
off_m = tl.arange(0, BLOCK_M)
|
off_m = tl.arange(0, BLOCK_M)
|
||||||
off_n = tl.arange(0, BLOCK_N)
|
off_n = tl.arange(0, BLOCK_N)
|
||||||
@@ -471,14 +530,15 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
|||||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||||
tl.store(Zs, tl.load(Xs))
|
tl.store(Zs, tl.load(Xs))
|
||||||
# input
|
# input
|
||||||
x = triton.testing.random(shape, dtype=dtype, device=device)
|
x = numpy_random(shape, dtype_str=dtype_str)
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = torch.empty_like(x)
|
z_tri = to_triton(np.empty_like(x), device=device)
|
||||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
x_tri = to_triton(x, device=device)
|
||||||
z_tri, z_tri.stride(1), z_tri.stride(0),
|
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||||
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||||
# torch result
|
# torch result
|
||||||
z_ref = x.permute(*perm).contiguous()
|
z_ref = x.transpose(*perm)
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||||
# parse ptx to make sure ld/st are vectorized
|
# parse ptx to make sure ld/st are vectorized
|
||||||
@@ -491,13 +551,12 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
||||||
def test_dot(epilogue, dtype=torch.float32, device='cuda'):
|
def test_dot(epilogue, device='cuda'):
|
||||||
torch.manual_seed(0)
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, stride_xm, stride_xk,
|
def kernel(X, stride_xm, stride_xk,
|
||||||
Y, stride_yk, stride_yn,
|
Y, stride_yk, stride_yn,
|
||||||
Z, stride_zm, stride_zn,
|
Z, stride_zm, stride_zn,
|
||||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
|
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
|
||||||
off_m = tl.arange(0, BLOCK_M)
|
off_m = tl.arange(0, BLOCK_M)
|
||||||
@@ -513,36 +572,38 @@ def test_dot(epilogue, dtype=torch.float32, device='cuda'):
|
|||||||
ZRs = Z + off_m * stride_zm
|
ZRs = Z + off_m * stride_zm
|
||||||
z += tl.load(ZRs)[:, None]
|
z += tl.load(ZRs)[:, None]
|
||||||
if ADD_COLS:
|
if ADD_COLS:
|
||||||
ZCs = Z + off_n * stride_zn
|
ZCs = Z + off_n * stride_zn
|
||||||
z += tl.load(ZCs)[None, :]
|
z += tl.load(ZCs)[None, :]
|
||||||
tl.store(Zs, z)
|
tl.store(Zs, z)
|
||||||
# input
|
# input
|
||||||
M, N, K = 64, 64, 32
|
M, N, K = 64, 64, 32
|
||||||
x = triton.testing.random((M, K), dtype=dtype, device=device)
|
rs = RandomState(17)
|
||||||
y = triton.testing.random((K, N), dtype=dtype, device=device)
|
x = numpy_random((M, K), dtype_str='float32', rs=rs)
|
||||||
|
y = numpy_random((K, N), dtype_str='float32', rs=rs)
|
||||||
|
x_tri = to_triton(x, device=device)
|
||||||
|
y_tri = to_triton(y, device=device)
|
||||||
# triton result
|
# triton result
|
||||||
z = triton.testing.random((M, N), dtype=dtype, device=device)
|
z = numpy_random((M, N), dtype_str='float32', rs=rs)
|
||||||
z_tri = z.clone()
|
z_tri = to_triton(z, device=device)
|
||||||
if epilogue == 'trans':
|
if epilogue == 'trans':
|
||||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||||
y, y.stride(0), y.stride(1),
|
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||||
ADD_MATRIX = epilogue=='add-matrix',
|
ADD_MATRIX = epilogue=='add-matrix',
|
||||||
ADD_ROWS = epilogue=='add-rows',
|
ADD_ROWS = epilogue=='add-rows',
|
||||||
ADD_COLS = epilogue=='add-cols')
|
ADD_COLS = epilogue=='add-cols')
|
||||||
# torch result
|
# torch result
|
||||||
z_ref = torch.matmul(x.float(), y.float())
|
z_ref = np.matmul(x, y)
|
||||||
if epilogue == 'add-matrix':
|
if epilogue == 'add-matrix':
|
||||||
z_ref += z
|
z_ref += z
|
||||||
if epilogue == 'add-rows':
|
if epilogue == 'add-rows':
|
||||||
z_ref += z[:,0][:, None]
|
z_ref += z[:,0][:, None]
|
||||||
if epilogue == 'add-cols':
|
if epilogue == 'add-cols':
|
||||||
z_ref += z[0,:][None, :]
|
z_ref += z[0,:][None, :]
|
||||||
z_ref = z_ref.to(torch.float16)
|
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
# make sure ld/st are vectorized
|
# make sure ld/st are vectorized
|
||||||
ptx = pgm.asm['ptx']
|
ptx = pgm.asm['ptx']
|
||||||
assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
@@ -558,7 +619,7 @@ def test_dot_without_load():
|
|||||||
c = tl.dot(a, b)
|
c = tl.dot(a, b)
|
||||||
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
|
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
|
||||||
tl.store(pout, c)
|
tl.store(pout, c)
|
||||||
|
|
||||||
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
|
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
|
||||||
kernel[(1,)](out)
|
kernel[(1,)](out)
|
||||||
|
|
||||||
@@ -571,7 +632,7 @@ def test_arange(start, device='cuda'):
|
|||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(z, BLOCK: tl.constexpr,
|
def _kernel(z, BLOCK: tl.constexpr,
|
||||||
START: tl.constexpr, END: tl.constexpr):
|
START: tl.constexpr, END: tl.constexpr):
|
||||||
off = tl.arange(0, BLOCK)
|
off = tl.arange(0, BLOCK)
|
||||||
val = tl.arange(START, END)
|
val = tl.arange(START, END)
|
||||||
@@ -605,8 +666,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
|||||||
N_offsets = tl.arange(0, N)
|
N_offsets = tl.arange(0, N)
|
||||||
K_offsets = tl.arange(0, K)
|
K_offsets = tl.arange(0, K)
|
||||||
|
|
||||||
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
|
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
|
||||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
|
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
|
||||||
|
|
||||||
# Load inputs.
|
# Load inputs.
|
||||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||||
@@ -616,7 +677,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
|||||||
o = tl.dot(x, w)
|
o = tl.dot(x, w)
|
||||||
|
|
||||||
# Store output
|
# Store output
|
||||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
|
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
|
||||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||||
|
|
||||||
pgm = _kernel[(1,)](in1, in2, out,
|
pgm = _kernel[(1,)](in1, in2, out,
|
||||||
@@ -687,7 +748,7 @@ def test_default():
|
|||||||
def _kernel(ret0, ret1, value):
|
def _kernel(ret0, ret1, value):
|
||||||
tl.store(ret0, _impl())
|
tl.store(ret0, _impl())
|
||||||
tl.store(ret1, _impl(value))
|
tl.store(ret1, _impl(value))
|
||||||
|
|
||||||
_kernel[(1,)](ret0, ret1, value)
|
_kernel[(1,)](ret0, ret1, value)
|
||||||
assert ret0.item() == 10
|
assert ret0.item() == 10
|
||||||
assert ret1.item() == value
|
assert ret1.item() == value
|
||||||
@@ -699,5 +760,5 @@ def test_noop(device='cuda'):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(x):
|
def kernel(x):
|
||||||
pass
|
pass
|
||||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||||
kernel[(1, )](x)
|
kernel[(1, )](x)
|
||||||
|
@@ -85,31 +85,6 @@ def allclose(x, y, tol=1e-2):
|
|||||||
return err <= tol
|
return err <= tol
|
||||||
|
|
||||||
|
|
||||||
def assert_allclose(x, y, tol=1e-2):
|
|
||||||
assert x.dtype == y.dtype
|
|
||||||
assert allclose(x, y, tol)
|
|
||||||
|
|
||||||
|
|
||||||
def random(shape, dtype, device, seed=0):
|
|
||||||
"""
|
|
||||||
Override the seed in tests if you're calling this function twice and don't
|
|
||||||
want the same result for both calls.
|
|
||||||
"""
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if isinstance(shape, int):
|
|
||||||
shape = (shape, )
|
|
||||||
if dtype == torch.bool:
|
|
||||||
return torch.randint(0, 2, shape, dtype=dtype, device=device)
|
|
||||||
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
|
||||||
iinfo = torch.iinfo(dtype)
|
|
||||||
x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device)
|
|
||||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
|
||||||
return x
|
|
||||||
if dtype in [torch.float16, torch.float32, torch.float64]:
|
|
||||||
return torch.normal(0, 1, shape, dtype=dtype, device=device)
|
|
||||||
raise RuntimeError(f'Unknown dtype {dtype}')
|
|
||||||
|
|
||||||
|
|
||||||
def nvsmi(attrs):
|
def nvsmi(attrs):
|
||||||
attrs = ','.join(attrs)
|
attrs = ','.join(attrs)
|
||||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||||
@@ -203,7 +178,7 @@ class Benchmark:
|
|||||||
styles=None,
|
styles=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
|
|
||||||
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
|
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
|
||||||
:type x_names: List[str]
|
:type x_names: List[str]
|
||||||
@@ -344,4 +319,4 @@ def get_max_tensorcore_tflops(backend, device):
|
|||||||
else:
|
else:
|
||||||
ops_per_sub_core = 512
|
ops_per_sub_core = 512
|
||||||
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
|
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
|
||||||
return tflops
|
return tflops
|
||||||
|
Reference in New Issue
Block a user