.
This commit is contained in:
@@ -222,8 +222,8 @@ class _attention(torch.autograd.Function):
|
|||||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||||
q.shape[0], q.shape[1], q.shape[2],
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
BLOCK_DMODEL=Lk, num_warps=4,
|
||||||
num_stages=1,
|
num_stages=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.save_for_backward(q, k, v, o, L, m)
|
ctx.save_for_backward(q, k, v, o, L, m)
|
||||||
|
Reference in New Issue
Block a user