diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index e4bc9cb82..a185724f7 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -222,8 +222,8 @@ class _attention(torch.autograd.Function): o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, num_warps=num_warps, - num_stages=1, + BLOCK_DMODEL=Lk, num_warps=4, + num_stages=2, ) ctx.save_for_backward(q, k, v, o, L, m)