This commit is contained in:
Philippe Tillet
2022-12-21 14:02:10 -08:00
parent 88e572e54d
commit 033e82060d

View File

@@ -222,8 +222,8 @@ class _attention(torch.autograd.Function):
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
BLOCK_DMODEL=Lk, num_warps=4,
num_stages=2,
)
ctx.save_for_backward(q, k, v, o, L, m)