[DOCS] support for variable head dimensions in flash attention triton tutorial (#623)
This commit is contained in:
@@ -204,13 +204,16 @@ class _attention(torch.autograd.Function):
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k.shape[-1]
|
||||
assert Lq == Lk
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
@@ -221,14 +224,14 @@ class _attention(torch.autograd.Function):
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=64, num_warps=4,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = 64
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
@@ -245,6 +248,8 @@ class _attention(torch.autograd.Function):
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
@@ -257,7 +262,7 @@ class _attention(torch.autograd.Function):
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
Reference in New Issue
Block a user