[LANG] Fixed semantics of NaN in float comparisons (#281)
This commit is contained in:
@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
||||
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
@@ -84,6 +84,8 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
||||
if mode_x == 'nan': x[:] = float('nan')
|
||||
if mode_y == 'nan': y[:] = float('nan')
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
# triton result
|
||||
@@ -126,14 +128,25 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
# ---------------
|
||||
# test compare ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['==', '!=', '>', '<', '>=', '<='] \
|
||||
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
|
||||
# real
|
||||
[
|
||||
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
|
||||
for op in ops \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
] + \
|
||||
# NaNs
|
||||
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
|
||||
for op in ops
|
||||
for mode_x, mode_y in [('nan' , 'real'),
|
||||
('real', 'nan'),
|
||||
('nan' , 'nan')]
|
||||
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
|
||||
|
||||
# ---------------
|
Reference in New Issue
Block a user