[OPS][BLOCKSPARSE] safeguarding a couple more configurations (#292)
This commit is contained in:
committed by
GitHub
parent
a12827848d
commit
b53f5f3803
@@ -3,15 +3,12 @@ import triton
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
|
||||
[
|
||||
(mode, at, bt, block, dtype) for dtype in ["float16"] for mode in ["sdd", "dsd", "dds"]
|
||||
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@pytest.mark.parametrize("TRANS_A", [False, True])
|
||||
@pytest.mark.parametrize("TRANS_B", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
@@ -27,7 +24,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest)
|
||||
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
|
||||
# torch result
|
||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
@@ -40,11 +37,10 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
triton.testing.assert_almost_equal(rc, tc)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK, WIDTH",
|
||||
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
||||
)
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 2, 4, WIDTH, WIDTH
|
||||
@@ -81,19 +77,22 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
triton.testing.assert_almost_equal(ry, ty)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_attention_fwd_bwd(
|
||||
block,
|
||||
dtype,
|
||||
input_scale=1.0,
|
||||
tol=2e-2,
|
||||
scale=1 / 8.0,
|
||||
n_ctx=256,
|
||||
dtype=torch.float16,
|
||||
batch_size=2,
|
||||
n_heads=2,
|
||||
block=64,
|
||||
):
|
||||
# inputs
|
||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||
qkvs = [torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)]
|
||||
qkvs = [
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||
]
|
||||
attn_mask = torch.tril(
|
||||
torch.ones(
|
||||
[n_ctx, n_ctx],
|
||||
@@ -112,7 +111,7 @@ def test_attention_fwd_bwd(
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out**2).mean()
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss.backward()
|
||||
grads = [query.grad, key.grad, value.grad]
|
||||
|
||||
@@ -127,7 +126,7 @@ def test_attention_fwd_bwd(
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||
# ad hoc loss
|
||||
torch_loss = (torch_attn_out**2).mean()
|
||||
torch_loss = (torch_attn_out ** 2).mean()
|
||||
torch_loss.backward()
|
||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||
|
||||
@@ -138,6 +137,7 @@ def test_attention_fwd_bwd(
|
||||
triton.testing.assert_almost_equal(g1, g2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
|
Reference in New Issue
Block a user