From 045ab5d62a64c02becf3cd3336b0a5891e489265 Mon Sep 17 00:00:00 2001 From: Jared Kaplan Date: Fri, 19 Feb 2021 17:46:05 -0500 Subject: [PATCH] [PYTHON] Add Blocksparse Attention Fwd/Bwd Test (#69) Also includes small bugfix for block-sparse softmax --- python/test/test_blocksparse.py | 185 ++++++++++++++++++----- python/triton/ops/blocksparse/matmul.py | 2 +- python/triton/ops/blocksparse/softmax.py | 4 +- 3 files changed, 149 insertions(+), 42 deletions(-) diff --git a/python/test/test_blocksparse.py b/python/test/test_blocksparse.py index 0dd793c3b..6b12371eb 100644 --- a/python/test/test_blocksparse.py +++ b/python/test/test_blocksparse.py @@ -3,44 +3,57 @@ import triton import pytest -@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK", +@pytest.mark.parametrize( + "MODE, TRANS_A, TRANS_B, BLOCK", [ - (mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\ - for at in [False, True]\ - for bt in [False, True]\ - for block in [16, 32, 64] - ] - ) -def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384): + (mode, at, bt, block) + for mode in ["sdd", "dsd", "dds"] + for at in [False, True] + for bt in [False, True] + for block in [16, 32, 64] + ], +) +def test_matmul( + MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384 +): # set seed torch.random.manual_seed(0) # create inputs - a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda') - b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda') - shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE] + a = torch.randn( + (Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda" + ) + b = torch.randn( + (Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda" + ) + shape = { + "sdd": (M, N), + "dsd": (a.shape[2], a.shape[3]), + "dds": (b.shape[2], b.shape[3]), + }[MODE] layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) # triton result - op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) - ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a - rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b + op = triton.ops.blocksparse.matmul( + layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B + ) + ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a + rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rc = op(ra, rb) # torch result - ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a - tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b + ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a + tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b ta = ta.transpose(2, 3) if TRANS_A else ta tb = tb.transpose(2, 3) if TRANS_B else tb tc = torch.matmul(ta, tb) - tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc - tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc + tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc + tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc # compare assert triton.testing.allclose(rc, tc) -@pytest.mark.parametrize("BLOCK, WIDTH", - [ - (block, width) for block in [32]\ - for width in [256, 576, 1024, 1792] - ] - ) + +@pytest.mark.parametrize( + "BLOCK, WIDTH", + [(block, width) for block in [32] for width in [256, 576, 1024, 1792]], +) def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): # set seed torch.random.manual_seed(0) @@ -48,31 +61,125 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): scale = 0.4 # create inputs layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) - x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda') - at_mask = torch.randint(low=0, high=2, size=(N, N), \ - dtype=torch.bool, requires_grad=False, device='cuda') - kp_mask = torch.randint(low=0, high=2, size=(Z, N), \ - dtype=DTYPE, requires_grad=False, device='cuda') - kp_mask[kp_mask == 1.] = float('-inf') + x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") + at_mask = torch.randint( + low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda" + ) + kp_mask = torch.randint( + low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda" + ) + kp_mask[kp_mask == 1.0] = float("-inf") # triton result op = triton.ops.blocksparse.softmax(layout, BLOCK) tx = triton.testing.sparsify_tensor(x, layout, BLOCK) - ty = op(tx, - scale=scale, - key_padding_mask=kp_mask, - key_padding_mask_mode='add', - attn_mask=at_mask.to(DTYPE), - attn_mask_mode='mul') + ty = op( + tx, + scale=scale, + key_padding_mask=kp_mask, + key_padding_mask_mode="add", + attn_mask=at_mask.to(DTYPE), + attn_mask_mode="mul", + ) # torch result - rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float('-inf')) + rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) if at_mask is not None: # broadcast at_mask to the same shape as rx M = at_mask[None, None, :, :] + torch.zeros_like(rx) - rx[M == 0] = float('-inf') + rx[M == 0] = float("-inf") if kp_mask is not None: rx += kp_mask[:, None, None, :] ry = torch.softmax(rx * scale, -1) ry = torch.softmax(rx * scale, -1) ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) # compare - assert triton.testing.allclose(ry, ty) \ No newline at end of file + assert triton.testing.allclose(ry, ty) + + +def test_attention_fwd_bwd( + input_scale=1.0, + tol=2e-2, + scale=1 / 8.0, + n_ctx=256, + dtype=torch.float16, + batch_size=2, + n_heads=2, + block=64, +): + # inputs + qkv_shape = (batch_size, n_heads, n_ctx, 64) + qkvs = [ + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True) + .to(dtype) + .cuda() + for _ in range(3) + ] + attn_mask = torch.tril( + torch.ones( + [n_ctx, n_ctx], + device="cuda", + dtype=dtype, + ), + diagonal=0, + ) + + # Triton: + n_blocks = n_ctx // block + layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + query, key, value = [x.clone() for x in qkvs] + query.retain_grad() + key.retain_grad() + value.retain_grad() + attn_out = triton_attention( + layout, block, attn_mask, query=query, key=key, value=value, scale=scale + ) + # ad hoc loss + loss = (attn_out ** 2).mean() + loss.backward() + grads = [query.grad, key.grad, value.grad] + + # Torch version: + torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) + torch_q.retain_grad() + torch_k.retain_grad() + torch_v.retain_grad() + scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) + # ad hoc loss + torch_loss = (torch_attn_out ** 2).mean() + torch_loss.backward() + torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] + + # comparison + print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") + torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol) + for g1, g2 in zip(grads, torch_grads): + torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol) + + +def triton_attention( + layout, + block: int, + attn_mask: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +): + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul( + layout, block, "sdd", trans_a=False, trans_b=True + ) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul( + layout, block, "dsd", trans_a=False, trans_b=False + ) + sparse_softmax = triton.ops.blocksparse.softmax( + layout, + block, + ) + + w = sparse_dot_sdd_nt(query, key) + w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul") + a = sparse_dot_dsd_nn(w, value) + return a diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 7d5abe948..03dc32f21 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -126,7 +126,7 @@ class _matmul(torch.autograd.Function): num_lock = 1 key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) if key not in _matmul.sdd_cache: - defines = {'TM': block*pack, 'TN': block*pack, + defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block, 'TK': 32, diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index fd3faeade..2b0d904fa 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -110,7 +110,7 @@ class _softmax(torch.autograd.Function): kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) M = x.shape[0] - grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M] + grid = lambda opt: [spdims[0] * spdims[1] * block, M] # run kernel kernel(x.data_ptr(), @@ -151,7 +151,7 @@ class _softmax(torch.autograd.Function): ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask, ctx.kp_mask_mode, ctx.attn_mask_mode) M = x.shape[0] - grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M] + grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid) return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None