From 7394d732adcec5a67034926a28080af2a7853217 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 15 Aug 2022 19:16:49 -0700 Subject: [PATCH] [DOCS] support for variable head dimensions in flash attention triton tutorial (#623) --- python/tutorials/06-fused-attention.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index fb0f4f958..035514746 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -204,13 +204,16 @@ class _attention(torch.autograd.Function): def forward(ctx, q, k, v, sm_scale): BLOCK = 128 # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( q, k, v, sm_scale, tmp, L, m, @@ -221,14 +224,14 @@ class _attention(torch.autograd.Function): o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=64, num_warps=4, + BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=1, ) ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK ctx.grid = grid ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = 64 + ctx.BLOCK_DMODEL = Lk return o @staticmethod @@ -245,6 +248,8 @@ class _attention(torch.autograd.Function): do_scaled, delta, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) + + num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, o, do_scaled, @@ -257,7 +262,7 @@ class _attention(torch.autograd.Function): q.shape[0], q.shape[1], q.shape[2], ctx.grid[0], BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, num_stages=1, ) return dq, dk, dv, None