reduced some spilling
This commit is contained in:
@@ -164,16 +164,14 @@ def _bwd_kernel(
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
# m = tl.load(m_ptrs + offs_m_curr)
|
||||
# p = tl.exp(qk * sm_scale - m[:, None])
|
||||
p = tl.exp(qk * sm_scale)
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||
# compute dp = dot(v, do)
|
||||
# Di = tl.load(D_ptrs + offs_m_curr)
|
||||
# dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
@@ -287,7 +285,7 @@ class _attention(torch.autograd.Function):
|
||||
# num_stages=1,
|
||||
# )
|
||||
# print(pgm.asm["ttgir"])
|
||||
# exit(1)
|
||||
# # exit(1)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
@@ -326,8 +324,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
# triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
# triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
|
Reference in New Issue
Block a user