Files
triton/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb
2022-07-14 07:22:19 +00:00

54 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Fused Attention\nThis is a Triton implementation of the Flash Attention algorithm \n(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import pytest\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(tl.float16), do, trans_a=True)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(ds.to(tl.float16), q, trans_a=True)\n # # compute dq\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n # # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n tmp, L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=64, num_warps=4,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = 64\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n num_stages=1,\n )\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n\n\n@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])\ndef test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):\n torch.manual_seed(20)\n q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n sm_scale = 0.3\n dout = torch.randn_like(q)\n # reference implementation\n M = torch.tril(torch.ones((N_CTX, N_CTX), device=\"cuda\"))\n p = torch.matmul(q, k.transpose(2, 3)) * sm_scale\n for z in range(Z):\n for h in range(H):\n p[:, :, M == 0] = float(\"-inf\")\n p = torch.softmax(p.float(), dim=-1).half()\n ref_out = torch.matmul(p, v)\n ref_out.backward(dout)\n ref_dv, v.grad = v.grad.clone(), None\n ref_dk, k.grad = k.grad.clone(), None\n ref_dq, q.grad = q.grad.clone(), None\n # triton implementation\n tri_out = attention(q, k, v, sm_scale)\n tri_out.backward(dout)\n tri_dv, v.grad = v.grad.clone(), None\n tri_dk, k.grad = k.grad.clone(), None\n tri_dq, q.grad = q.grad.clone(), None\n # compare\n triton.testing.assert_almost_equal(ref_out, tri_out)\n triton.testing.assert_almost_equal(ref_dv, tri_dv)\n triton.testing.assert_almost_equal(ref_dk, tri_dk)\n triton.testing.assert_almost_equal(ref_dq, tri_dq)\n\n\ntry:\n from flash_attn.flash_attn_interface import flash_attn_func\n HAS_FLASH = True\nexcept BaseException:\n HAS_FLASH = False\n\nBATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64\n# vary seq length for fixed head and batch=4\nconfigs = [triton.testing.Benchmark(\n x_names=['N_CTX'],\n x_vals=[2**i for i in range(10, 16)],\n line_arg='provider',\n line_vals=['triton'] + (['flash'] if HAS_FLASH else []),\n line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),\n styles=[('red', '-'), ('blue', '-')],\n ylabel='ms',\n plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',\n args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}\n) for mode in ['bwd']]\n\n\n@triton.testing.perf_report(configs)\ndef bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device=\"cuda\"):\n assert mode in ['fwd', 'bwd']\n warmup = 25\n rep = 100\n if provider == \"triton\":\n q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n sm_scale = 1.3\n fn = lambda: attention(q, k, v, sm_scale)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n if provider == \"flash\":\n lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)\n cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)\n cu_seqlens[1:] = lengths.cumsum(0)\n qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)\n fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n\n# only works on A100 at the moment\n# bench_flash_attention.run(save_path='.', print_data=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}