[OPS][BLOCKSPARSE] safeguarding a couple more configurations (#292)

This commit is contained in:
Benjamin Lefaudeux
2021-09-20 17:15:31 -07:00
committed by GitHub
parent a12827848d
commit b53f5f3803

View File

@@ -3,15 +3,12 @@ import triton
import pytest import pytest
@pytest.mark.parametrize( @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE", @pytest.mark.parametrize("TRANS_A", [False, True])
[ @pytest.mark.parametrize("TRANS_B", [False, True])
(mode, at, bt, block, dtype) for dtype in ["float16"] for mode in ["sdd", "dsd", "dds"] @pytest.mark.parametrize("BLOCK", [16, 32, 64])
for at in [False, True] for bt in [False, True] for block in [16, 32, 64] @pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
],
)
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
# create inputs # create inputs
@@ -40,11 +37,10 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
triton.testing.assert_almost_equal(rc, tc) triton.testing.assert_almost_equal(rc, tc)
@pytest.mark.parametrize( @pytest.mark.parametrize("BLOCK", [16, 32, 64])
"BLOCK, WIDTH", @pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]], @pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
) def test_softmax(BLOCK, WIDTH, DTYPE):
def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
Z, H, M, N = 2, 4, WIDTH, WIDTH Z, H, M, N = 2, 4, WIDTH, WIDTH
@@ -81,19 +77,22 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
triton.testing.assert_almost_equal(ry, ty) triton.testing.assert_almost_equal(ry, ty)
@pytest.mark.parametrize("block", [16, 32, 64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_attention_fwd_bwd( def test_attention_fwd_bwd(
block,
dtype,
input_scale=1.0, input_scale=1.0,
tol=2e-2,
scale=1 / 8.0, scale=1 / 8.0,
n_ctx=256, n_ctx=256,
dtype=torch.float16,
batch_size=2, batch_size=2,
n_heads=2, n_heads=2,
block=64,
): ):
# inputs # inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64) qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)] qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
attn_mask = torch.tril( attn_mask = torch.tril(
torch.ones( torch.ones(
[n_ctx, n_ctx], [n_ctx, n_ctx],
@@ -138,6 +137,7 @@ def test_attention_fwd_bwd(
triton.testing.assert_almost_equal(g1, g2) triton.testing.assert_almost_equal(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])
def triton_attention( def triton_attention(
layout, layout,
block: int, block: int,