[FRONTEND] signed-integer math fixes and testing (#395)

- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway.
- Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now.
- When testing signed integer operations, use random numbers from the full range of the type.
- Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type.
- Fix a bad `CompilationError` invocation.
- Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
This commit is contained in:
Madeleine Thompson
2021-12-21 09:46:05 -08:00
committed by GitHub
parent 4a8953efa3
commit 5cdb948c05
4 changed files with 93 additions and 47 deletions

View File

@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
triton.testing.assert_almost_equal(z_ref, z_tri)
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -82,12 +82,12 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144)
if mode_x == 'nan': x[:] = float('nan')
if mode_y == 'nan': y[:] = float('nan')
# reference result
z_ref = eval(expr)
z_ref = eval(expr if torch_expr is None else torch_expr)
# triton result
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
@@ -95,17 +95,56 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
def _fake_fmod(x, y):
"""
Triton % (for both integers and floats) has the same semantics as torch
fmod, but torch fmod doesn't work on integers until torch 1.8.
`_fake_fmod` gives the same semantics but works on all versions of torch.
"""
z = torch.remainder(x, y)
return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z)
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# The result of x % y is ill-conditioned if x % y is much smaller than x.
# pytorch/CUDA has slightly different (probably better) rounding on
# remainders than stock LLVM. We currently don't expect to match it
# bit-for-bit.
return (dtype_x, dtype_y) in [
('int32', 'float16'),
('int32', 'float32'),
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
]
# ---------------
# test binary ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['+', '-', '*', '/', '%'] \
for dtype_x in dtypes \
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes
for dtype_y in dtypes
])
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device)
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
# LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders.
torch_expr = '_fake_fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)'
else:
torch_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
with pytest.raises(AssertionError, match='Arrays are not almost equal'):
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
else:
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
# ---------------

View File

@@ -482,7 +482,8 @@ class CodeGenerator(ast.NodeVisitor):
with warnings.catch_warnings():
# The ast library added visit_Constant and deprecated some other
# methods but we can't move to that without breaking Python 3.6 and 3.7.
warnings.simplefilter("ignore", DeprecationWarning)
warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
return super().visit(node)
def generic_visit(self, node):
@@ -905,7 +906,7 @@ class JITFunction:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node, e)
raise CompilationError(self.src, node) from e
# - when `.src` attribute is set, cache path needs
# to be reinitialized

View File

@@ -89,14 +89,21 @@ def assert_allclose(x, y, tol=1e-2):
assert allclose(x, y, tol)
def random(shape, dtype, device):
torch.manual_seed(0)
def random(shape, dtype, device, seed=0):
"""
Override the seed in tests if you're calling this function twice and don't
want the same result for both calls.
"""
torch.manual_seed(seed)
if isinstance(shape, int):
shape = (shape, )
if dtype == torch.bool:
return torch.randint(0, 2, shape, dtype=dtype, device=device)
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.randint(1, 32, shape, dtype=dtype, device=device)
iinfo = torch.iinfo(dtype)
x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.normal(0, 1, shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}')