[CODEGEN] Major performance improvements on A100 (#70)
Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
committed by
Philippe Tillet
parent
045ab5d62a
commit
5b83259592
@@ -2,29 +2,17 @@ import torch
|
||||
import triton
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block)
|
||||
for mode in ["sdd", "dsd", "dds"]
|
||||
for at in [False, True]
|
||||
for bt in [False, True]
|
||||
for block in [16, 32, 64]
|
||||
],
|
||||
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
|
||||
for block in [16, 32, 64]],
|
||||
)
|
||||
def test_matmul(
|
||||
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
|
||||
):
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn(
|
||||
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
|
||||
)
|
||||
b = torch.randn(
|
||||
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
|
||||
)
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a.shape[2], a.shape[3]),
|
||||
@@ -32,9 +20,7 @@ def test_matmul(
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.matmul(
|
||||
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
|
||||
)
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = op(ra, rb)
|
||||
@@ -49,7 +35,6 @@ def test_matmul(
|
||||
# compare
|
||||
assert triton.testing.allclose(rc, tc)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK, WIDTH",
|
||||
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
||||
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
||||
at_mask = torch.randint(
|
||||
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda"
|
||||
)
|
||||
kp_mask = torch.randint(
|
||||
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
|
||||
)
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
|
||||
kp_mask[kp_mask == 1.0] = float("-inf")
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
||||
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
# compare
|
||||
assert triton.testing.allclose(ry, ty)
|
||||
|
||||
|
||||
def test_attention_fwd_bwd(
|
||||
input_scale=1.0,
|
||||
tol=2e-2,
|
||||
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
|
||||
# inputs
|
||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||
qkvs = [
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
|
||||
.to(dtype)
|
||||
.cuda()
|
||||
for _ in range(3)
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||
]
|
||||
attn_mask = torch.tril(
|
||||
torch.ones(
|
||||
@@ -129,11 +106,9 @@ def test_attention_fwd_bwd(
|
||||
query.retain_grad()
|
||||
key.retain_grad()
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(
|
||||
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
|
||||
)
|
||||
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss = (attn_out**2).mean()
|
||||
loss.backward()
|
||||
grads = [query.grad, key.grad, value.grad]
|
||||
|
||||
@@ -148,17 +123,16 @@ def test_attention_fwd_bwd(
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||
# ad hoc loss
|
||||
torch_loss = (torch_attn_out ** 2).mean()
|
||||
torch_loss = (torch_attn_out**2).mean()
|
||||
torch_loss.backward()
|
||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||
|
||||
# comparison
|
||||
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
@@ -168,12 +142,8 @@ def triton_attention(
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
|
||||
layout, block, "sdd", trans_a=False, trans_b=True
|
||||
)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
|
||||
layout, block, "dsd", trans_a=False, trans_b=False
|
||||
)
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||
layout,
|
||||
block,
|
||||
|
@@ -4,7 +4,7 @@ import triton
|
||||
import torch
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
@@ -17,14 +17,14 @@ import torch
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
# # 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
# # 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
@@ -40,24 +40,28 @@ import torch
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
|
||||
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
|
||||
]),
|
||||
)
|
||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
torch.manual_seed(0)
|
||||
triton.ops._matmul._kernels = dict()
|
||||
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK * TZ
|
||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
|
||||
if M is None:
|
||||
M = TM
|
||||
if N is None:
|
||||
N = TN
|
||||
if K is None:
|
||||
K = TK * SPLITK
|
||||
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
||||
|
Reference in New Issue
Block a user