reduced some spilling

This commit is contained in:
Phil Tillet
2023-01-02 19:28:54 -08:00
parent c11fe351e1
commit 05920e0b8b
3 changed files with 152 additions and 148 deletions

View File

@@ -164,16 +164,14 @@ def _bwd_kernel(
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
# m = tl.load(m_ptrs + offs_m_curr)
# p = tl.exp(qk * sm_scale - m[:, None])
p = tl.exp(qk * sm_scale)
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
# compute dp = dot(v, do)
# Di = tl.load(D_ptrs + offs_m_curr)
# dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
@@ -287,7 +285,7 @@ class _attention(torch.autograd.Function):
# num_stages=1,
# )
# print(pgm.asm["ttgir"])
# exit(1)
# # exit(1)
return dq, dk, dv, None
@@ -326,8 +324,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
# triton.testing.assert_almost_equal(ref_dk, tri_dk)
# triton.testing.assert_almost_equal(ref_dq, tri_dq)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4