diff --git a/python/test/test_blocksparse.py b/python/test/test_blocksparse.py index 991a34c26..0f2232a7b 100644 --- a/python/test/test_blocksparse.py +++ b/python/test/test_blocksparse.py @@ -6,7 +6,7 @@ import pytest @pytest.mark.parametrize( "MODE, TRANS_A, TRANS_B, BLOCK, DTYPE", [ - (mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"] + (mode, at, bt, block, dtype) for dtype in ["float16"] for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True] for block in [16, 32, 64] ], ) @@ -37,7 +37,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc # compare - assert triton.testing.allclose(rc, tc) + triton.testing.assert_almost_equal(rc, tc) @pytest.mark.parametrize( @@ -78,7 +78,7 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): ry = torch.softmax(rx * scale, -1) ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) # compare - assert triton.testing.allclose(ry, ty) + triton.testing.assert_almost_equal(ry, ty) def test_attention_fwd_bwd( @@ -133,9 +133,9 @@ def test_attention_fwd_bwd( # comparison # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") - torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol) + triton.testing.assert_almost_equal(loss, torch_loss) for g1, g2 in zip(grads, torch_grads): - torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol) + triton.testing.assert_almost_equal(g1, g2) def triton_attention( diff --git a/python/test/test_cross_entropy.py b/python/test/test_cross_entropy.py index 969774a11..48cb303bb 100644 --- a/python/test/test_cross_entropy.py +++ b/python/test/test_cross_entropy.py @@ -19,7 +19,7 @@ def test_op(M, N, dtype, mode): tt_y = triton.ops.cross_entropy(x, idx) th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) if mode == 'forward': - assert torch.allclose(th_y, tt_y, atol=1e-3, rtol=1e-2) + triton.testing.assert_almost_equal(th_y, tt_y) # backward pass elif mode == 'backward': dy = torch.randn_like(tt_y) @@ -30,4 +30,4 @@ def test_op(M, N, dtype, mode): x.grad.zero_() th_y.backward(dy) th_dx = x.grad.clone() - assert torch.allclose(th_dx, tt_dx, atol=1e-3, rtol=1e-2) \ No newline at end of file + triton.testing.assert_almost_equal(th_dx, tt_dx) \ No newline at end of file diff --git a/python/test/test_language.py b/python/test/test_language.py index 30ff75bfb..71a151f47 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -55,7 +55,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): z_tri = torch.empty_like(z_ref) kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) # compare - triton.testing.assert_allclose(z_ref, z_tri) + triton.testing.assert_almost_equal(z_ref, z_tri) def _test_binary(dtype_x, dtype_y, expr, device='cuda'): @@ -79,7 +79,7 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'): z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) # compare - triton.testing.assert_allclose(z_ref, z_tri) + triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr) # --------------- @@ -202,7 +202,7 @@ def test_index1d(expr, device='cuda'): z_tri = torch.empty_like(z_ref) kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0]) # compare - triton.testing.assert_allclose(z_ref, z_tri) + triton.testing.assert_almost_equal(z_ref, z_tri) # --------------- @@ -262,7 +262,7 @@ def test_tuples(): for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) def test_atomic_rmw(op, dtype_x, mode, device='cuda'): dtype_x = cvt[dtype_x] - n_programs = 37 + n_programs = 5 # triton kernel @triton.jit @@ -300,7 +300,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): if exact: assert z_ref.item() == z_tri.item() else: - triton.testing.assert_allclose(z_ref, z_tri) + triton.testing.assert_almost_equal(z_ref, z_tri) # --------------- diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index b5f60cf4a..e5c1961ef 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -79,11 +79,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, K = BLOCK_K * SPLIT_K if K is None else K # allocate/transpose inputs DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] - a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) - b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) + a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) + b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) a = a.t() if AT else a b = b.t() if BT else b # run test th_c = torch.matmul(a, b) tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest) - assert triton.testing.allclose(th_c, tt_c) + triton.testing.assert_almost_equal(th_c, tt_c) diff --git a/python/triton/testing.py b/python/triton/testing.py index f604ecd89..5e8236a9e 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -55,6 +55,14 @@ def mask_tensor(x, mask, block, value=0): ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value return ret +def assert_almost_equal(x, y, decimal=2, err_msg=''): + import numpy.testing as npt + if isinstance(x, torch.Tensor): + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + y = y.cpu().detach().numpy() + npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) + def allclose(x, y, tol=1e-2): if x.dtype != y.dtype: @@ -86,7 +94,7 @@ def random(shape, dtype, device): if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: return torch.randint(1, 32, shape, dtype=dtype, device=device) if dtype in [torch.float16, torch.float32, torch.float64]: - return torch.normal(0, 10, shape, dtype=dtype, device=device) + return torch.normal(0, 1, shape, dtype=dtype, device=device) raise RuntimeError(f'Unknown dtype {dtype}')