[FRONTEND] signed-integer math fixes and testing (#395)
- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway. - Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now. - When testing signed integer operations, use random numbers from the full range of the type. - Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type. - Fix a bad `CompilationError` invocation. - Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
This commit is contained in:
committed by
GitHub
parent
4a8953efa3
commit
5cdb948c05
@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
|
||||
def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
@@ -82,12 +82,12 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17)
|
||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144)
|
||||
if mode_x == 'nan': x[:] = float('nan')
|
||||
if mode_y == 'nan': y[:] = float('nan')
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
z_ref = eval(expr if torch_expr is None else torch_expr)
|
||||
# triton result
|
||||
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
||||
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
||||
@@ -95,17 +95,56 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
|
||||
|
||||
|
||||
def _fake_fmod(x, y):
|
||||
"""
|
||||
Triton % (for both integers and floats) has the same semantics as torch
|
||||
fmod, but torch fmod doesn't work on integers until torch 1.8.
|
||||
`_fake_fmod` gives the same semantics but works on all versions of torch.
|
||||
"""
|
||||
z = torch.remainder(x, y)
|
||||
return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z)
|
||||
|
||||
|
||||
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
# The result of x % y is ill-conditioned if x % y is much smaller than x.
|
||||
# pytorch/CUDA has slightly different (probably better) rounding on
|
||||
# remainders than stock LLVM. We currently don't expect to match it
|
||||
# bit-for-bit.
|
||||
return (dtype_x, dtype_y) in [
|
||||
('int32', 'float16'),
|
||||
('int32', 'float32'),
|
||||
('int64', 'float16'),
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
]
|
||||
|
||||
# ---------------
|
||||
# test binary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['+', '-', '*', '/', '%'] \
|
||||
for dtype_x in dtypes \
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['+', '-', '*', '/', '%']
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
||||
# LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders.
|
||||
torch_expr = '_fake_fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||
# Triton promotes 16-bit floating-point / and % to 32-bit because there
|
||||
# are no native div or FRem operations on float16. Since we have to
|
||||
# convert anyway, we may as well take the accuracy bump.
|
||||
torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)'
|
||||
else:
|
||||
torch_expr = None
|
||||
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
||||
with pytest.raises(AssertionError, match='Arrays are not almost equal'):
|
||||
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
|
||||
|
||||
|
||||
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user