From b244db06da24a87453a40ad35b085ee37dac3705 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 30 Sep 2022 19:30:46 -0700 Subject: [PATCH] [TUTORIALS] Attention tutorial fixup --- python/tutorials/06-fused-attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 035514746..996d9df40 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -249,7 +249,8 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 + # NOTE: kernel currently buggy for other values of `num_warps` + num_warps = 8 _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, o, do_scaled,