[TUTORIALS] Attention tutorial fixup

This commit is contained in:
Phil Tillet
2022-09-30 19:30:46 -07:00
parent 7b61303ea1
commit b244db06da

View File

@@ -249,7 +249,8 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
) )
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 # NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)]( _bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale, q, k, v, ctx.sm_scale,
o, do_scaled, o, do_scaled,