[DOCS] support for variable head dimensions in flash attention triton tutorial (#623)
This commit is contained in:
@@ -204,13 +204,16 @@ class _attention(torch.autograd.Function):
|
|||||||
def forward(ctx, q, k, v, sm_scale):
|
def forward(ctx, q, k, v, sm_scale):
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
# shape constraints
|
# shape constraints
|
||||||
Lq, Lk = q.shape[-1], k.shape[-1]
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
assert Lq == Lk
|
assert Lq == Lk and Lk == Lv
|
||||||
|
assert Lk in {16, 32, 64, 128}
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||||
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
q, k, v, sm_scale,
|
q, k, v, sm_scale,
|
||||||
tmp, L, m,
|
tmp, L, m,
|
||||||
@@ -221,14 +224,14 @@ class _attention(torch.autograd.Function):
|
|||||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||||
q.shape[0], q.shape[1], q.shape[2],
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||||
BLOCK_DMODEL=64, num_warps=4,
|
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
ctx.save_for_backward(q, k, v, o, L, m)
|
ctx.save_for_backward(q, k, v, o, L, m)
|
||||||
ctx.BLOCK = BLOCK
|
ctx.BLOCK = BLOCK
|
||||||
ctx.grid = grid
|
ctx.grid = grid
|
||||||
ctx.sm_scale = sm_scale
|
ctx.sm_scale = sm_scale
|
||||||
ctx.BLOCK_DMODEL = 64
|
ctx.BLOCK_DMODEL = Lk
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -245,6 +248,8 @@ class _attention(torch.autograd.Function):
|
|||||||
do_scaled, delta,
|
do_scaled, delta,
|
||||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
|
||||||
_bwd_kernel[(ctx.grid[1],)](
|
_bwd_kernel[(ctx.grid[1],)](
|
||||||
q, k, v, ctx.sm_scale,
|
q, k, v, ctx.sm_scale,
|
||||||
o, do_scaled,
|
o, do_scaled,
|
||||||
@@ -257,7 +262,7 @@ class _attention(torch.autograd.Function):
|
|||||||
q.shape[0], q.shape[1], q.shape[2],
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
ctx.grid[0],
|
ctx.grid[0],
|
||||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
return dq, dk, dv, None
|
return dq, dk, dv, None
|
||||||
|
Reference in New Issue
Block a user