[DOCS] support for variable head dimensions in flash attention triton tutorial (#623)

This commit is contained in:
Phil Wang
2022-08-15 19:16:49 -07:00
committed by GitHub
parent 3e2953f357
commit 7394d732ad

View File

@@ -204,13 +204,16 @@ class _attention(torch.autograd.Function):
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
@@ -221,14 +224,14 @@ class _attention(torch.autograd.Function):
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=64, num_warps=4,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = 64
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
@@ -245,6 +248,8 @@ class _attention(torch.autograd.Function):
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
@@ -257,7 +262,7 @@ class _attention(torch.autograd.Function):
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None