From b53f5f3803370d80bf06a1fa22f71087d9fb70a1 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Mon, 20 Sep 2021 17:15:31 -0700 Subject: [PATCH] [OPS][BLOCKSPARSE] safeguarding a couple more configurations (#292) --- .../test/unit/operators/test_blocksparse.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 0f2232a7b..86006671f 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -3,15 +3,12 @@ import triton import pytest -@pytest.mark.parametrize( - "MODE, TRANS_A, TRANS_B, BLOCK, DTYPE", - [ - (mode, at, bt, block, dtype) for dtype in ["float16"] for mode in ["sdd", "dsd", "dds"] - for at in [False, True] for bt in [False, True] for block in [16, 32, 64] - ], -) +@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) +@pytest.mark.parametrize("TRANS_A", [False, True]) +@pytest.mark.parametrize("TRANS_B", [False, True]) +@pytest.mark.parametrize("BLOCK", [16, 32, 64]) +@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): - DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] # set seed torch.random.manual_seed(0) # create inputs @@ -27,7 +24,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b - rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest) + rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest) # torch result ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b @@ -40,11 +37,10 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= triton.testing.assert_almost_equal(rc, tc) -@pytest.mark.parametrize( - "BLOCK, WIDTH", - [(block, width) for block in [32] for width in [256, 576, 1024, 1792]], -) -def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): +@pytest.mark.parametrize("BLOCK", [16, 32, 64]) +@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792]) +@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) +def test_softmax(BLOCK, WIDTH, DTYPE): # set seed torch.random.manual_seed(0) Z, H, M, N = 2, 4, WIDTH, WIDTH @@ -81,19 +77,22 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): triton.testing.assert_almost_equal(ry, ty) +@pytest.mark.parametrize("block", [16, 32, 64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_attention_fwd_bwd( + block, + dtype, input_scale=1.0, - tol=2e-2, scale=1 / 8.0, n_ctx=256, - dtype=torch.float16, batch_size=2, n_heads=2, - block=64, ): # inputs qkv_shape = (batch_size, n_heads, n_ctx, 64) - qkvs = [torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)] + qkvs = [ + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) + ] attn_mask = torch.tril( torch.ones( [n_ctx, n_ctx], @@ -112,7 +111,7 @@ def test_attention_fwd_bwd( value.retain_grad() attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale) # ad hoc loss - loss = (attn_out**2).mean() + loss = (attn_out ** 2).mean() loss.backward() grads = [query.grad, key.grad, value.grad] @@ -127,7 +126,7 @@ def test_attention_fwd_bwd( probs = torch.softmax(scores, dim=-1) torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) # ad hoc loss - torch_loss = (torch_attn_out**2).mean() + torch_loss = (torch_attn_out ** 2).mean() torch_loss.backward() torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] @@ -138,6 +137,7 @@ def test_attention_fwd_bwd( triton.testing.assert_almost_equal(g1, g2) +@pytest.mark.parametrize("block", [16, 32, 64]) def triton_attention( layout, block: int,