This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
364 lines
14 KiB
Python
364 lines
14 KiB
Python
"""
|
|
Fused Attention
|
|
===============
|
|
This is a Triton implementation of the Flash Attention algorithm
|
|
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def _fwd_kernel(
|
|
Q, K, V, sm_scale,
|
|
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
|
|
Out,
|
|
stride_qz, stride_qh, stride_qm, stride_qk,
|
|
stride_kz, stride_kh, stride_kn, stride_kk,
|
|
stride_vz, stride_vh, stride_vk, stride_vn,
|
|
stride_oz, stride_oh, stride_om, stride_on,
|
|
Z, H, N_CTX,
|
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
start_m = tl.program_id(0)
|
|
off_hz = tl.program_id(1)
|
|
# initialize offsets
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
|
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
|
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
|
# Initialize pointers to Q, K, V
|
|
q_ptrs = Q + off_q
|
|
k_ptrs = K + off_k
|
|
v_ptrs = V + off_v
|
|
# initialize pointer to m and l
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
# load q: it will stay in SRAM throughout
|
|
q = tl.load(q_ptrs)
|
|
# loop over k, v and update accumulator
|
|
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
|
# start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
# -- compute qk ----
|
|
k = tl.load(k_ptrs + start_n * stride_kn)
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
qk += tl.dot(q, tl.trans(k))
|
|
qk *= sm_scale
|
|
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
|
# -- compute m_ij, p, l_ij
|
|
m_ij = tl.max(qk, 1)
|
|
p = tl.exp(qk - m_ij[:, None])
|
|
l_ij = tl.sum(p, 1)
|
|
# -- update m_i and l_i
|
|
m_i_new = tl.maximum(m_i, m_ij)
|
|
alpha = tl.exp(m_i - m_i_new)
|
|
beta = tl.exp(m_ij - m_i_new)
|
|
l_i_new = alpha * l_i + beta * l_ij
|
|
# -- update output accumulator --
|
|
# scale p
|
|
p_scale = beta / l_i_new
|
|
p = p * p_scale[:, None]
|
|
# scale acc
|
|
acc_scale = l_i / l_i_new * alpha
|
|
acc = acc * acc_scale[:, None]
|
|
# update acc
|
|
v = tl.load(v_ptrs + start_n * stride_vk)
|
|
p = p.to(tl.float16)
|
|
acc += tl.dot(p, v)
|
|
# update m_i and l_i
|
|
l_i = l_i_new
|
|
m_i = m_i_new
|
|
# rematerialize offsets to save registers
|
|
start_m = tl.program_id(0)
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
# write back l and m
|
|
l_ptrs = L + off_hz * N_CTX + offs_m
|
|
m_ptrs = M + off_hz * N_CTX + offs_m
|
|
tl.store(l_ptrs, l_i)
|
|
tl.store(m_ptrs, m_i)
|
|
# initialize pointers to output
|
|
offs_n = tl.arange(0, BLOCK_DMODEL)
|
|
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
|
out_ptrs = Out + off_o
|
|
tl.store(out_ptrs, acc)
|
|
|
|
|
|
@triton.jit
|
|
def _bwd_preprocess(
|
|
Out, DO, L,
|
|
NewDO, Delta,
|
|
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
|
):
|
|
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
off_n = tl.arange(0, D_HEAD)
|
|
# load
|
|
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
|
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
|
denom = tl.load(L + off_m).to(tl.float32)
|
|
# compute
|
|
do = do / denom[:, None]
|
|
delta = tl.sum(o * do, axis=1)
|
|
# write-back
|
|
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
|
tl.store(Delta + off_m, delta)
|
|
|
|
|
|
@triton.jit
|
|
def _bwd_kernel(
|
|
Q, K, V, sm_scale, Out, DO,
|
|
DQ, DK, DV,
|
|
L, M,
|
|
D,
|
|
stride_qz, stride_qh, stride_qm, stride_qk,
|
|
stride_kz, stride_kh, stride_kn, stride_kk,
|
|
stride_vz, stride_vh, stride_vk, stride_vn,
|
|
Z, H, N_CTX,
|
|
num_block,
|
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
off_hz = tl.program_id(0)
|
|
off_z = off_hz // H
|
|
off_h = off_hz % H
|
|
# offset pointers for batch/head
|
|
Q += off_z * stride_qz + off_h * stride_qh
|
|
K += off_z * stride_qz + off_h * stride_qh
|
|
V += off_z * stride_qz + off_h * stride_qh
|
|
DO += off_z * stride_qz + off_h * stride_qh
|
|
DQ += off_z * stride_qz + off_h * stride_qh
|
|
DK += off_z * stride_qz + off_h * stride_qh
|
|
DV += off_z * stride_qz + off_h * stride_qh
|
|
for start_n in range(0, num_block):
|
|
lo = start_n * BLOCK_M
|
|
# initialize row/col offsets
|
|
offs_qm = lo + tl.arange(0, BLOCK_M)
|
|
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_m = tl.arange(0, BLOCK_N)
|
|
offs_k = tl.arange(0, BLOCK_DMODEL)
|
|
# initialize pointers to value-like data
|
|
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
|
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
|
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
|
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
|
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
|
# pointer to row-wise quantities in value-like data
|
|
D_ptrs = D + off_hz * N_CTX
|
|
m_ptrs = M + off_hz * N_CTX
|
|
# initialize dv amd dk
|
|
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
# k and v stay in SRAM throughout
|
|
k = tl.load(k_ptrs)
|
|
v = tl.load(v_ptrs)
|
|
# loop over rows
|
|
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
|
offs_m_curr = start_m + offs_m
|
|
# load q, k, v, do on-chip
|
|
q = tl.load(q_ptrs)
|
|
# recompute p = softmax(qk, dim=-1).T
|
|
# NOTE: `do` is pre-divided by `l`; no normalization here
|
|
qk = tl.dot(q, tl.trans(k))
|
|
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
|
m = tl.load(m_ptrs + offs_m_curr)
|
|
p = tl.exp(qk * sm_scale - m[:, None])
|
|
# compute dv
|
|
do = tl.load(do_ptrs)
|
|
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
|
# compute dp = dot(v, do)
|
|
Di = tl.load(D_ptrs + offs_m_curr)
|
|
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
|
dp += tl.dot(do, tl.trans(v))
|
|
# compute ds = p * (dp - delta[:, None])
|
|
ds = p * dp * sm_scale
|
|
# compute dk = dot(ds.T, q)
|
|
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
|
# compute dq
|
|
dq = tl.load(dq_ptrs)
|
|
dq += tl.dot(ds.to(tl.float16), k)
|
|
tl.store(dq_ptrs, dq)
|
|
# increment pointers
|
|
dq_ptrs += BLOCK_M * stride_qm
|
|
q_ptrs += BLOCK_M * stride_qm
|
|
do_ptrs += BLOCK_M * stride_qm
|
|
# write-back
|
|
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
|
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
|
tl.store(dv_ptrs, dv)
|
|
tl.store(dk_ptrs, dk)
|
|
|
|
|
|
empty = torch.empty(128, device="cuda")
|
|
|
|
|
|
class _attention(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, q, k, v, sm_scale):
|
|
BLOCK = 128
|
|
# shape constraints
|
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
assert Lq == Lk and Lk == Lv
|
|
assert Lk in {16, 32, 64, 128}
|
|
o = torch.empty_like(q)
|
|
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
|
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
|
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
|
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
|
num_warps = 4 if Lk <= 64 else 8
|
|
|
|
_fwd_kernel[grid](
|
|
q, k, v, sm_scale,
|
|
tmp, L, m,
|
|
o,
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
|
q.shape[0], q.shape[1], q.shape[2],
|
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
|
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
|
num_stages=1,
|
|
)
|
|
|
|
ctx.save_for_backward(q, k, v, o, L, m)
|
|
ctx.BLOCK = BLOCK
|
|
ctx.grid = grid
|
|
ctx.sm_scale = sm_scale
|
|
ctx.BLOCK_DMODEL = Lk
|
|
return o
|
|
|
|
@staticmethod
|
|
def backward(ctx, do):
|
|
q, k, v, o, l, m = ctx.saved_tensors
|
|
do = do.contiguous()
|
|
dq = torch.zeros_like(q, dtype=torch.float32)
|
|
dk = torch.empty_like(k)
|
|
dv = torch.empty_like(v)
|
|
do_scaled = torch.empty_like(do)
|
|
delta = torch.empty_like(l)
|
|
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
|
o, do, l,
|
|
do_scaled, delta,
|
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
|
)
|
|
|
|
# NOTE: kernel currently buggy for other values of `num_warps`
|
|
num_warps = 8
|
|
_bwd_kernel[(ctx.grid[1],)](
|
|
q, k, v, ctx.sm_scale,
|
|
o, do_scaled,
|
|
dq, dk, dv,
|
|
l, m,
|
|
delta,
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
q.shape[0], q.shape[1], q.shape[2],
|
|
ctx.grid[0],
|
|
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
|
num_stages=1,
|
|
)
|
|
return dq, dk, dv, None
|
|
|
|
|
|
attention = _attention.apply
|
|
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
|
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
torch.manual_seed(20)
|
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
|
sm_scale = 0.2
|
|
dout = torch.randn_like(q)
|
|
# reference implementation
|
|
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
|
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
|
for z in range(Z):
|
|
for h in range(H):
|
|
p[:, :, M == 0] = float("-inf")
|
|
p = torch.softmax(p.float(), dim=-1).half()
|
|
# p = torch.exp(p)
|
|
ref_out = torch.matmul(p, v)
|
|
ref_out.backward(dout)
|
|
ref_dv, v.grad = v.grad.clone(), None
|
|
ref_dk, k.grad = k.grad.clone(), None
|
|
ref_dq, q.grad = q.grad.clone(), None
|
|
# # triton implementation
|
|
tri_out = attention(q, k, v, sm_scale)
|
|
# print(ref_out)
|
|
# print(tri_out)
|
|
tri_out.backward(dout)
|
|
tri_dv, v.grad = v.grad.clone(), None
|
|
tri_dk, k.grad = k.grad.clone(), None
|
|
tri_dq, q.grad = q.grad.clone(), None
|
|
# compare
|
|
triton.testing.assert_almost_equal(ref_out, tri_out)
|
|
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
|
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
|
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
|
|
|
|
|
try:
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
HAS_FLASH = True
|
|
except BaseException:
|
|
HAS_FLASH = False
|
|
|
|
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
|
# vary seq length for fixed head and batch=4
|
|
configs = [triton.testing.Benchmark(
|
|
x_names=['N_CTX'],
|
|
x_vals=[2**i for i in range(10, 16)],
|
|
line_arg='provider',
|
|
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
|
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
|
styles=[('red', '-'), ('blue', '-')],
|
|
ylabel='ms',
|
|
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
|
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
|
) for mode in ['fwd']]
|
|
|
|
|
|
@triton.testing.perf_report(configs)
|
|
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
|
assert mode in ['fwd', 'bwd']
|
|
warmup = 25
|
|
rep = 100
|
|
if provider == "triton":
|
|
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
sm_scale = 1.3
|
|
fn = lambda: attention(q, k, v, sm_scale)
|
|
if mode == 'bwd':
|
|
o = fn()
|
|
do = torch.randn_like(o)
|
|
fn = lambda: o.backward(do, retain_graph=True)
|
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
|
return ms
|
|
if provider == "flash":
|
|
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
|
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
|
cu_seqlens[1:] = lengths.cumsum(0)
|
|
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
|
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
|
if mode == 'bwd':
|
|
o = fn()
|
|
do = torch.randn_like(o)
|
|
fn = lambda: o.backward(do, retain_graph=True)
|
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
|
return ms
|
|
|
|
# bench_flash_attention.run(save_path='.', print_data=True)
|