[TUTORIALS] Attention tutorial fixup
This commit is contained in:
@@ -249,7 +249,8 @@ class _attention(torch.autograd.Function):
|
|||||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
|
# NOTE: kernel currently buggy for other values of `num_warps`
|
||||||
|
num_warps = 8
|
||||||
_bwd_kernel[(ctx.grid[1],)](
|
_bwd_kernel[(ctx.grid[1],)](
|
||||||
q, k, v, ctx.sm_scale,
|
q, k, v, ctx.sm_scale,
|
||||||
o, do_scaled,
|
o, do_scaled,
|
||||||
|
Reference in New Issue
Block a user