Files
triton/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb

54 lines
15 KiB
Plaintext
Raw Normal View History

2022-07-14 07:22:19 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Fused Attention\nThis is a Triton implementation of the Flash Attention algorithm \n(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import pytest\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L,
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}