[CI] Moved from assert_allclose to assert_almost_equal (#200)

This commit is contained in:
Philippe Tillet
2021-08-12 12:00:30 -07:00
committed by GitHub
parent 70e28ff380
commit b120d70a0a
5 changed files with 24 additions and 16 deletions

View File

@@ -55,7 +55,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_allclose(z_ref, z_tri)
triton.testing.assert_almost_equal(z_ref, z_tri)
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
@@ -79,7 +79,7 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_allclose(z_ref, z_tri)
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
# ---------------
@@ -202,7 +202,7 @@ def test_index1d(expr, device='cuda'):
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
# compare
triton.testing.assert_allclose(z_ref, z_tri)
triton.testing.assert_almost_equal(z_ref, z_tri)
# ---------------
@@ -262,7 +262,7 @@ def test_tuples():
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
dtype_x = cvt[dtype_x]
n_programs = 37
n_programs = 5
# triton kernel
@triton.jit
@@ -300,7 +300,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
if exact:
assert z_ref.item() == z_tri.item()
else:
triton.testing.assert_allclose(z_ref, z_tri)
triton.testing.assert_almost_equal(z_ref, z_tri)
# ---------------