This commit is contained in:
Phil Tillet
2022-12-21 13:54:30 -08:00
parent 20100a7254
commit 88e572e54d

View File

@@ -32,7 +32,7 @@ def _fwd_kernel(
offs_n = tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL) offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V # Initialize pointers to Q, K, V
q_ptrs = Q + off_q q_ptrs = Q + off_q
@@ -50,7 +50,7 @@ def _fwd_kernel(
# -- compute qk ---- # -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn) k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k)) qk += tl.dot(q, k)
qk *= sm_scale qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij # -- compute m_ij, p, l_ij
@@ -165,26 +165,26 @@ def _bwd_kernel(
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T # recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here # NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k)) qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr) m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None]) p = tl.exp(qk * sm_scale - m[:, None])
# compute dv # compute dv
do = tl.load(do_ptrs) do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(tl.float16)), do) dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do) # compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr) Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v)) dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None]) # compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale ds = p * dp * sm_scale
# compute dk = dot(ds.T, q) # compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(tl.float16)), q) dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# compute dq # # compute dq
dq = tl.load(dq_ptrs) dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(tl.float16), k) dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq) tl.store(dq_ptrs, dq)
# increment pointers # # increment pointers
dq_ptrs += BLOCK_M * stride_qm dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm
@@ -196,8 +196,6 @@ def _bwd_kernel(
empty = torch.empty(128, device="cuda") empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
@@ -250,8 +248,7 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
) )
# NOTE: kernel currently buggy for other values of `num_warps` num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
num_warps = 8
_bwd_kernel[(ctx.grid[1],)]( _bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale, q, k, v, ctx.sm_scale,
o, do_scaled, o, do_scaled,
@@ -276,7 +273,7 @@ attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)]) @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20) torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2 sm_scale = 0.2
@@ -290,30 +287,23 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
p = torch.softmax(p.float(), dim=-1).half() p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p) # p = torch.exp(p)
ref_out = torch.matmul(p, v) ref_out = torch.matmul(p, v)
ref_out.backward(dout) # ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None # ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None # ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None # ref_dq, q.grad = q.grad.clone(), None
# # triton implementation # # triton implementation
tri_out = attention(q, k, v, sm_scale) tri_out = attention(q, k, v, sm_scale)
# print(ref_out) # print(ref_out)
# print(tri_out) # print(tri_out)
tri_out.backward(dout) # tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None # tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None # tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None # tri_dq, q.grad = q.grad.clone(), None
# compare # compare
triton.testing.assert_almost_equal(ref_out, tri_out) triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv) # triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk) # triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq) # triton.testing.assert_almost_equal(ref_dq, tri_dq)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4 # vary seq length for fixed head and batch=4
@@ -321,8 +311,8 @@ configs = [triton.testing.Benchmark(
x_names=['N_CTX'], x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)], x_vals=[2**i for i in range(10, 16)],
line_arg='provider', line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []), line_vals=['triton'],
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), line_names=['Triton'],
styles=[('red', '-'), ('blue', '-')], styles=[('red', '-'), ('blue', '-')],
ylabel='ms', ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
@@ -360,4 +350,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms return ms
# bench_flash_attention.run(save_path='.', print_data=True) bench_flash_attention.run(save_path='.', print_data=True)