diff --git a/master/.buildinfo b/master/.buildinfo index ce8489ef7..76e485747 100644 --- a/master/.buildinfo +++ b/master/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: cac2b435c7b1f5ec0953d5824bf41314 +config: 53dd097a7168e06dae1f9b4d065b8571 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/master/.doctrees/environment.pickle b/master/.doctrees/environment.pickle index 0883f51d7..b1a075bec 100644 Binary files a/master/.doctrees/environment.pickle and b/master/.doctrees/environment.pickle differ diff --git a/master/.doctrees/getting-started/installation.doctree b/master/.doctrees/getting-started/installation.doctree index 535fb5331..4829bc6ed 100644 Binary files a/master/.doctrees/getting-started/installation.doctree and b/master/.doctrees/getting-started/installation.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/01-vector-add.doctree b/master/.doctrees/getting-started/tutorials/01-vector-add.doctree index 786f2d1f6..3465ed88b 100644 Binary files a/master/.doctrees/getting-started/tutorials/01-vector-add.doctree and b/master/.doctrees/getting-started/tutorials/01-vector-add.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree b/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree index 1c47d5145..413e039c4 100644 Binary files a/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree and b/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree b/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree index 0e8c79ed8..0f33e065b 100644 Binary files a/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree and b/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree b/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree index c69d186ff..d76d9ce09 100644 Binary files a/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree and b/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree b/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree index 7809a1e3e..8680279aa 100644 Binary files a/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree and b/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree b/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree index 12c07cf08..85e7c73ae 100644 Binary files a/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree and b/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree b/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree index 01ca88d47..7721156e1 100644 Binary files a/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree and b/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/index.doctree b/master/.doctrees/getting-started/tutorials/index.doctree index cd6d5f47f..e83b9fb2c 100644 Binary files a/master/.doctrees/getting-started/tutorials/index.doctree and b/master/.doctrees/getting-started/tutorials/index.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree b/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree index 48b78639c..e7bbe5f7d 100644 Binary files a/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree and b/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree differ diff --git a/master/.doctrees/index.doctree b/master/.doctrees/index.doctree index 78a0e0f8f..17545f871 100644 Binary files a/master/.doctrees/index.doctree and b/master/.doctrees/index.doctree differ diff --git a/master/.doctrees/programming-guide/chapter-1/introduction.doctree b/master/.doctrees/programming-guide/chapter-1/introduction.doctree index b653b7411..216315eb2 100644 Binary files a/master/.doctrees/programming-guide/chapter-1/introduction.doctree and b/master/.doctrees/programming-guide/chapter-1/introduction.doctree differ diff --git a/master/.doctrees/programming-guide/chapter-2/related-work.doctree b/master/.doctrees/programming-guide/chapter-2/related-work.doctree index 67fab65d9..3683fd77d 100644 Binary files a/master/.doctrees/programming-guide/chapter-2/related-work.doctree and b/master/.doctrees/programming-guide/chapter-2/related-work.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.Config.doctree b/master/.doctrees/python-api/generated/triton.Config.doctree index e47f23145..052b759fa 100644 Binary files a/master/.doctrees/python-api/generated/triton.Config.doctree and b/master/.doctrees/python-api/generated/triton.Config.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.autotune.doctree b/master/.doctrees/python-api/generated/triton.autotune.doctree index 14aa86c38..704e01473 100644 Binary files a/master/.doctrees/python-api/generated/triton.autotune.doctree and b/master/.doctrees/python-api/generated/triton.autotune.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.heuristics.doctree b/master/.doctrees/python-api/generated/triton.heuristics.doctree index 058694d96..518b278ae 100644 Binary files a/master/.doctrees/python-api/generated/triton.heuristics.doctree and b/master/.doctrees/python-api/generated/triton.heuristics.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.jit.doctree b/master/.doctrees/python-api/generated/triton.jit.doctree index 75d0baa76..2d6ccaf47 100644 Binary files a/master/.doctrees/python-api/generated/triton.jit.doctree and b/master/.doctrees/python-api/generated/triton.jit.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.arange.doctree b/master/.doctrees/python-api/generated/triton.language.arange.doctree index e43b09b24..84b569f23 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.arange.doctree and b/master/.doctrees/python-api/generated/triton.language.arange.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree index ac2a189dd..2bbb17f5c 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree index 5297e9b66..b48a4ffc7 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree index 2e1d3de4c..70fa74c88 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree index c5274409c..b062bd582 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree index e11ba87e2..80e9363c0 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree index b52ff719a..c9ecce340 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree index a1c5c5d5b..01f79e3a0 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree index c7e208d15..c4b54b77c 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree b/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree index 66a3ed81d..6cce2bc53 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree and b/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.cos.doctree b/master/.doctrees/python-api/generated/triton.language.cos.doctree index 40840d6ca..af5d2139f 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.cos.doctree and b/master/.doctrees/python-api/generated/triton.language.cos.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.dot.doctree b/master/.doctrees/python-api/generated/triton.language.dot.doctree index db7ea6e7b..de27a0c9d 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.dot.doctree and b/master/.doctrees/python-api/generated/triton.language.dot.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.exp.doctree b/master/.doctrees/python-api/generated/triton.language.exp.doctree index 4252d23f7..fe499dcb7 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.exp.doctree and b/master/.doctrees/python-api/generated/triton.language.exp.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.load.doctree b/master/.doctrees/python-api/generated/triton.language.load.doctree index 87e41175e..0dfaa4144 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.load.doctree and b/master/.doctrees/python-api/generated/triton.language.load.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.log.doctree b/master/.doctrees/python-api/generated/triton.language.log.doctree index 1b0be15f2..e406898ed 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.log.doctree and b/master/.doctrees/python-api/generated/triton.language.log.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.max.doctree b/master/.doctrees/python-api/generated/triton.language.max.doctree index 3214eb4ce..454afc251 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.max.doctree and b/master/.doctrees/python-api/generated/triton.language.max.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.maximum.doctree b/master/.doctrees/python-api/generated/triton.language.maximum.doctree index daefad1c4..eb2b906d8 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.maximum.doctree and b/master/.doctrees/python-api/generated/triton.language.maximum.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.min.doctree b/master/.doctrees/python-api/generated/triton.language.min.doctree index 11ebbd22d..41b943dfa 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.min.doctree and b/master/.doctrees/python-api/generated/triton.language.min.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.minimum.doctree b/master/.doctrees/python-api/generated/triton.language.minimum.doctree index f1c82d49e..9f6329985 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.minimum.doctree and b/master/.doctrees/python-api/generated/triton.language.minimum.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree b/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree index 209e10cd9..8cdf91ea0 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree and b/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.num_programs.doctree b/master/.doctrees/python-api/generated/triton.language.num_programs.doctree index e82fc59fb..25347ce44 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.num_programs.doctree and b/master/.doctrees/python-api/generated/triton.language.num_programs.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.program_id.doctree b/master/.doctrees/python-api/generated/triton.language.program_id.doctree index 770666834..6d5740ac2 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.program_id.doctree and b/master/.doctrees/python-api/generated/triton.language.program_id.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.rand.doctree b/master/.doctrees/python-api/generated/triton.language.rand.doctree index 1fb733682..a0343dc5d 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.rand.doctree and b/master/.doctrees/python-api/generated/triton.language.rand.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.randint.doctree b/master/.doctrees/python-api/generated/triton.language.randint.doctree index ad4a6b874..70370b42a 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.randint.doctree and b/master/.doctrees/python-api/generated/triton.language.randint.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.randint4x.doctree b/master/.doctrees/python-api/generated/triton.language.randint4x.doctree index 35097c6d0..7970adabb 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.randint4x.doctree and b/master/.doctrees/python-api/generated/triton.language.randint4x.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.randn.doctree b/master/.doctrees/python-api/generated/triton.language.randn.doctree index 6ce25c399..1e20b6961 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.randn.doctree and b/master/.doctrees/python-api/generated/triton.language.randn.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.ravel.doctree b/master/.doctrees/python-api/generated/triton.language.ravel.doctree index fe94d8f46..6ef8e8b1e 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.ravel.doctree and b/master/.doctrees/python-api/generated/triton.language.ravel.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.reshape.doctree b/master/.doctrees/python-api/generated/triton.language.reshape.doctree index f465cc38c..7e71b9dc1 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.reshape.doctree and b/master/.doctrees/python-api/generated/triton.language.reshape.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree b/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree index f4790c51a..7897bb169 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree and b/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sin.doctree b/master/.doctrees/python-api/generated/triton.language.sin.doctree index d16aff2be..2a6370ee1 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sin.doctree and b/master/.doctrees/python-api/generated/triton.language.sin.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.softmax.doctree b/master/.doctrees/python-api/generated/triton.language.softmax.doctree index 02c5769ce..99b8ae7fb 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.softmax.doctree and b/master/.doctrees/python-api/generated/triton.language.softmax.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sqrt.doctree b/master/.doctrees/python-api/generated/triton.language.sqrt.doctree index cfc711328..0e3c3edcc 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sqrt.doctree and b/master/.doctrees/python-api/generated/triton.language.sqrt.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.store.doctree b/master/.doctrees/python-api/generated/triton.language.store.doctree index 083f7bdd8..1aea9e2c2 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.store.doctree and b/master/.doctrees/python-api/generated/triton.language.store.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sum.doctree b/master/.doctrees/python-api/generated/triton.language.sum.doctree index b1e2b6548..76dfbe8b7 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sum.doctree and b/master/.doctrees/python-api/generated/triton.language.sum.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.where.doctree b/master/.doctrees/python-api/generated/triton.language.where.doctree index 2fd5b265f..bd557e3e9 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.where.doctree and b/master/.doctrees/python-api/generated/triton.language.where.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.zeros.doctree b/master/.doctrees/python-api/generated/triton.language.zeros.doctree index 35056cd41..ac1e2fca7 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.zeros.doctree and b/master/.doctrees/python-api/generated/triton.language.zeros.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree b/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree index 45a93036f..97375e25f 100644 Binary files a/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree and b/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree b/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree index 210137801..4c83c85d5 100644 Binary files a/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree and b/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree b/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree index a70776736..b30e51938 100644 Binary files a/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree and b/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree differ diff --git a/master/.doctrees/python-api/triton.doctree b/master/.doctrees/python-api/triton.doctree index 2e79abb56..f36e47101 100644 Binary files a/master/.doctrees/python-api/triton.doctree and b/master/.doctrees/python-api/triton.doctree differ diff --git a/master/.doctrees/python-api/triton.language.doctree b/master/.doctrees/python-api/triton.language.doctree index c9c2ee1c3..9a4694bc4 100644 Binary files a/master/.doctrees/python-api/triton.language.doctree and b/master/.doctrees/python-api/triton.language.doctree differ diff --git a/master/.doctrees/python-api/triton.testing.doctree b/master/.doctrees/python-api/triton.testing.doctree index 10709f67b..cd4757724 100644 Binary files a/master/.doctrees/python-api/triton.testing.doctree and b/master/.doctrees/python-api/triton.testing.doctree differ diff --git a/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb b/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb index 4b7bc0039..9ff8abf47 100644 --- a/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb +++ b/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "import pytest\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(tl.float16), do, trans_a=True)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(ds.to(tl.float16), q, trans_a=True)\n # # compute dq\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n # # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n tmp, L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=64, num_warps=4,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = 64\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n num_stages=1,\n )\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n\n\n@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])\ndef test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):\n torch.manual_seed(20)\n q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n sm_scale = 0.3\n dout = torch.randn_like(q)\n # reference implementation\n M = torch.tril(torch.ones((N_CTX, N_CTX), device=\"cuda\"))\n p = torch.matmul(q, k.transpose(2, 3)) * sm_scale\n for z in range(Z):\n for h in range(H):\n p[:, :, M == 0] = float(\"-inf\")\n p = torch.softmax(p.float(), dim=-1).half()\n ref_out = torch.matmul(p, v)\n ref_out.backward(dout)\n ref_dv, v.grad = v.grad.clone(), None\n ref_dk, k.grad = k.grad.clone(), None\n ref_dq, q.grad = q.grad.clone(), None\n # triton implementation\n tri_out = attention(q, k, v, sm_scale)\n tri_out.backward(dout)\n tri_dv, v.grad = v.grad.clone(), None\n tri_dk, k.grad = k.grad.clone(), None\n tri_dq, q.grad = q.grad.clone(), None\n # compare\n triton.testing.assert_almost_equal(ref_out, tri_out)\n triton.testing.assert_almost_equal(ref_dv, tri_dv)\n triton.testing.assert_almost_equal(ref_dk, tri_dk)\n triton.testing.assert_almost_equal(ref_dq, tri_dq)\n\n\ntry:\n from flash_attn.flash_attn_interface import flash_attn_func\n HAS_FLASH = True\nexcept BaseException:\n HAS_FLASH = False\n\nBATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64\n# vary seq length for fixed head and batch=4\nconfigs = [triton.testing.Benchmark(\n x_names=['N_CTX'],\n x_vals=[2**i for i in range(10, 16)],\n line_arg='provider',\n line_vals=['triton'] + (['flash'] if HAS_FLASH else []),\n line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),\n styles=[('red', '-'), ('blue', '-')],\n ylabel='ms',\n plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',\n args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}\n) for mode in ['bwd']]\n\n\n@triton.testing.perf_report(configs)\ndef bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device=\"cuda\"):\n assert mode in ['fwd', 'bwd']\n warmup = 25\n rep = 100\n if provider == \"triton\":\n q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n sm_scale = 1.3\n fn = lambda: attention(q, k, v, sm_scale)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n if provider == \"flash\":\n lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)\n cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)\n cu_seqlens[1:] = lengths.cumsum(0)\n qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)\n fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n\n# only works on A100 at the moment\n# bench_flash_attention.run(save_path='.', print_data=True)" + "import pytest\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(tl.float16), do, trans_a=True)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(ds.to(tl.float16), q, trans_a=True)\n # # compute dq\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n # # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n tmp, L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk, num_warps=num_warps,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n\n num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,\n num_stages=1,\n )\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n\n\n@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])\ndef test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):\n torch.manual_seed(20)\n q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n sm_scale = 0.3\n dout = torch.randn_like(q)\n # reference implementation\n M = torch.tril(torch.ones((N_CTX, N_CTX), device=\"cuda\"))\n p = torch.matmul(q, k.transpose(2, 3)) * sm_scale\n for z in range(Z):\n for h in range(H):\n p[:, :, M == 0] = float(\"-inf\")\n p = torch.softmax(p.float(), dim=-1).half()\n ref_out = torch.matmul(p, v)\n ref_out.backward(dout)\n ref_dv, v.grad = v.grad.clone(), None\n ref_dk, k.grad = k.grad.clone(), None\n ref_dq, q.grad = q.grad.clone(), None\n # triton implementation\n tri_out = attention(q, k, v, sm_scale)\n tri_out.backward(dout)\n tri_dv, v.grad = v.grad.clone(), None\n tri_dk, k.grad = k.grad.clone(), None\n tri_dq, q.grad = q.grad.clone(), None\n # compare\n triton.testing.assert_almost_equal(ref_out, tri_out)\n triton.testing.assert_almost_equal(ref_dv, tri_dv)\n triton.testing.assert_almost_equal(ref_dk, tri_dk)\n triton.testing.assert_almost_equal(ref_dq, tri_dq)\n\n\ntry:\n from flash_attn.flash_attn_interface import flash_attn_func\n HAS_FLASH = True\nexcept BaseException:\n HAS_FLASH = False\n\nBATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64\n# vary seq length for fixed head and batch=4\nconfigs = [triton.testing.Benchmark(\n x_names=['N_CTX'],\n x_vals=[2**i for i in range(10, 16)],\n line_arg='provider',\n line_vals=['triton'] + (['flash'] if HAS_FLASH else []),\n line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),\n styles=[('red', '-'), ('blue', '-')],\n ylabel='ms',\n plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',\n args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}\n) for mode in ['bwd']]\n\n\n@triton.testing.perf_report(configs)\ndef bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device=\"cuda\"):\n assert mode in ['fwd', 'bwd']\n warmup = 25\n rep = 100\n if provider == \"triton\":\n q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n sm_scale = 1.3\n fn = lambda: attention(q, k, v, sm_scale)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n if provider == \"flash\":\n lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)\n cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)\n cu_seqlens[1:] = lengths.cumsum(0)\n qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)\n fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n\n# only works on A100 at the moment\n# bench_flash_attention.run(save_path='.', print_data=True)" ] } ], diff --git a/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py b/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py index fb0f4f958..035514746 100644 --- a/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py +++ b/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py @@ -204,13 +204,16 @@ class _attention(torch.autograd.Function): def forward(ctx, q, k, v, sm_scale): BLOCK = 128 # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( q, k, v, sm_scale, tmp, L, m, @@ -221,14 +224,14 @@ class _attention(torch.autograd.Function): o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=64, num_warps=4, + BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=1, ) ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK ctx.grid = grid ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = 64 + ctx.BLOCK_DMODEL = Lk return o @staticmethod @@ -245,6 +248,8 @@ class _attention(torch.autograd.Function): do_scaled, delta, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) + + num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, o, do_scaled, @@ -257,7 +262,7 @@ class _attention(torch.autograd.Function): q.shape[0], q.shape[1], q.shape[2], ctx.grid[0], BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, num_stages=1, ) return dq, dk, dv, None diff --git a/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip b/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip index 71c9ec1f2..9e6ef2932 100644 Binary files a/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip and b/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip differ diff --git a/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip b/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip index 0dc37655b..61ae15f94 100644 Binary files a/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip and b/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip differ diff --git a/master/_images/sphx_glr_01-vector-add_001.png b/master/_images/sphx_glr_01-vector-add_001.png index 81078eb20..4a566f9b8 100644 Binary files a/master/_images/sphx_glr_01-vector-add_001.png and b/master/_images/sphx_glr_01-vector-add_001.png differ diff --git a/master/_images/sphx_glr_01-vector-add_thumb.png b/master/_images/sphx_glr_01-vector-add_thumb.png index 2ee73199a..9c4738255 100644 Binary files a/master/_images/sphx_glr_01-vector-add_thumb.png and b/master/_images/sphx_glr_01-vector-add_thumb.png differ diff --git a/master/_images/sphx_glr_02-fused-softmax_001.png b/master/_images/sphx_glr_02-fused-softmax_001.png index e24386cc4..162f66a2a 100644 Binary files a/master/_images/sphx_glr_02-fused-softmax_001.png and b/master/_images/sphx_glr_02-fused-softmax_001.png differ diff --git a/master/_images/sphx_glr_02-fused-softmax_thumb.png b/master/_images/sphx_glr_02-fused-softmax_thumb.png index e41b71c7c..12a87d081 100644 Binary files a/master/_images/sphx_glr_02-fused-softmax_thumb.png and b/master/_images/sphx_glr_02-fused-softmax_thumb.png differ diff --git a/master/_images/sphx_glr_03-matrix-multiplication_001.png b/master/_images/sphx_glr_03-matrix-multiplication_001.png index 5945406a0..17fe71fe8 100644 Binary files a/master/_images/sphx_glr_03-matrix-multiplication_001.png and b/master/_images/sphx_glr_03-matrix-multiplication_001.png differ diff --git a/master/_images/sphx_glr_03-matrix-multiplication_thumb.png b/master/_images/sphx_glr_03-matrix-multiplication_thumb.png index 2f15385fd..3bc5ee06b 100644 Binary files a/master/_images/sphx_glr_03-matrix-multiplication_thumb.png and b/master/_images/sphx_glr_03-matrix-multiplication_thumb.png differ diff --git a/master/_images/sphx_glr_05-layer-norm_001.png b/master/_images/sphx_glr_05-layer-norm_001.png index d59837a37..854edf803 100644 Binary files a/master/_images/sphx_glr_05-layer-norm_001.png and b/master/_images/sphx_glr_05-layer-norm_001.png differ diff --git a/master/_images/sphx_glr_05-layer-norm_thumb.png b/master/_images/sphx_glr_05-layer-norm_thumb.png index 05255746a..aadf42620 100644 Binary files a/master/_images/sphx_glr_05-layer-norm_thumb.png and b/master/_images/sphx_glr_05-layer-norm_thumb.png differ diff --git a/master/_sources/getting-started/tutorials/01-vector-add.rst.txt b/master/_sources/getting-started/tutorials/01-vector-add.rst.txt index 49004d98d..d3eb6c50b 100644 --- a/master/_sources/getting-started/tutorials/01-vector-add.rst.txt +++ b/master/_sources/getting-started/tutorials/01-vector-add.rst.txt @@ -235,17 +235,17 @@ We can now run the decorated function above. Pass `print_data=True` to see the p 0 4096.0 9.600000 9.600000 1 8192.0 19.200000 19.200000 2 16384.0 38.400001 38.400001 - 3 32768.0 63.999998 76.800002 + 3 32768.0 76.800002 76.800002 4 65536.0 127.999995 127.999995 5 131072.0 219.428568 219.428568 - 6 262144.0 341.333321 384.000001 + 6 262144.0 384.000001 384.000001 7 524288.0 472.615390 472.615390 8 1048576.0 614.400016 614.400016 9 2097152.0 722.823517 722.823517 10 4194304.0 780.190482 780.190482 11 8388608.0 812.429770 812.429770 12 16777216.0 833.084721 833.084721 - 13 33554432.0 842.004273 843.811163 + 13 33554432.0 842.004273 842.004273 14 67108864.0 847.448255 848.362445 15 134217728.0 849.737435 850.656574 @@ -255,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 1 minutes 49.928 seconds) + **Total running time of the script:** ( 1 minutes 41.917 seconds) .. _sphx_glr_download_getting-started_tutorials_01-vector-add.py: diff --git a/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt b/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt index 020cf3a05..fced1c069 100644 --- a/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt +++ b/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt @@ -278,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t softmax-performance: N Triton Torch (native) Torch (jit) - 0 256.0 546.133347 546.133347 188.321838 + 0 256.0 512.000001 546.133347 190.511628 1 384.0 614.400016 585.142862 153.600004 - 2 512.0 655.360017 606.814814 154.566038 + 2 512.0 655.360017 585.142849 154.566038 3 640.0 706.206879 640.000002 160.000000 4 768.0 722.823517 664.216187 162.754967 .. ... ... ... ... - 93 12160.0 812.359066 406.179533 198.834951 - 94 12288.0 812.429770 415.661740 199.197579 - 95 12416.0 812.498981 412.149375 198.755369 - 96 12544.0 810.925276 412.971190 199.012395 - 97 12672.0 811.007961 412.097543 199.167004 + 93 12160.0 812.359066 406.603966 199.038365 + 94 12288.0 812.429770 416.101597 199.197579 + 95 12416.0 812.498981 413.006241 198.854847 + 96 12544.0 810.925276 412.971190 199.111113 + 97 12672.0 811.007961 412.516771 199.167004 [98 rows x 4 columns] @@ -306,7 +306,7 @@ In the above plot, we can see that: .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 3 minutes 30.792 seconds) + **Total running time of the script:** ( 3 minutes 30.054 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: diff --git a/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt b/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt index bb773de44..11259d3b9 100644 --- a/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt +++ b/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt @@ -459,37 +459,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we matmul-performance: M cuBLAS ... Triton Triton (+ LeakyReLU) - 0 256.0 2.978909 ... 2.978909 3.276800 + 0 256.0 2.730667 ... 2.978909 2.978909 1 384.0 7.372800 ... 8.507077 8.507077 - 2 512.0 14.563555 ... 16.384000 16.384000 + 2 512.0 14.563555 ... 15.420235 16.384000 3 640.0 22.260869 ... 24.380953 24.380953 4 768.0 32.768000 ... 35.389441 34.028308 - 5 896.0 39.025776 ... 40.140799 39.025776 + 5 896.0 37.971025 ... 40.140799 39.025776 6 1024.0 49.932191 ... 53.773130 52.428801 - 7 1152.0 45.242181 ... 48.161033 48.161033 + 7 1152.0 45.242181 ... 48.161033 47.396572 8 1280.0 51.200001 ... 57.690139 57.690139 - 9 1408.0 64.138541 ... 69.009825 67.305878 - 10 1536.0 80.430545 ... 81.355034 79.526831 + 9 1408.0 64.138541 ... 69.009825 68.147202 + 10 1536.0 80.430545 ... 80.430545 78.643199 11 1664.0 62.929456 ... 63.372618 62.492442 - 12 1792.0 72.512412 ... 73.460287 59.467852 - 13 1920.0 68.776119 ... 71.257735 71.257735 + 12 1792.0 72.983276 ... 73.460287 59.467852 + 13 1920.0 69.120002 ... 71.626943 71.257735 14 2048.0 73.908442 ... 78.398206 77.314362 - 15 2176.0 83.500614 ... 87.876193 86.367588 - 16 2304.0 68.446623 ... 77.810656 77.307030 - 17 2432.0 71.305746 ... 86.711310 85.653855 - 18 2560.0 77.833728 ... 82.956960 81.108913 - 19 2688.0 83.369354 ... 90.316801 90.102270 - 20 2816.0 79.587973 ... 84.687779 83.153880 - 21 2944.0 81.967162 ... 83.617504 81.967162 - 22 3072.0 81.707223 ... 90.020831 88.060814 - 23 3200.0 83.879425 ... 95.238096 87.673110 - 24 3328.0 83.226931 ... 82.748617 84.895397 - 25 3456.0 81.353753 ... 88.207407 91.200871 - 26 3584.0 87.296493 ... 99.354022 97.628001 - 27 3712.0 82.421427 ... 89.353616 83.247783 - 28 3840.0 83.339866 ... 91.398346 86.840987 - 29 3968.0 86.849777 ... 92.302520 84.066569 - 30 4096.0 93.077479 ... 83.055527 82.340585 + 15 2176.0 83.155572 ... 87.876193 85.998493 + 16 2304.0 68.446623 ... 78.064941 77.307030 + 17 2432.0 71.305746 ... 86.179335 85.653855 + 18 2560.0 77.833728 ... 82.539044 81.310171 + 19 2688.0 83.737433 ... 90.532356 89.676257 + 20 2816.0 80.767055 ... 83.873477 81.674548 + 21 2944.0 82.237674 ... 83.477440 82.373605 + 22 3072.0 81.707223 ... 89.877939 88.197981 + 23 3200.0 84.544253 ... 96.822991 94.814812 + 24 3328.0 83.226931 ... 85.398926 84.895397 + 25 3456.0 81.766291 ... 91.511426 86.503829 + 26 3584.0 83.876297 ... 95.756542 95.350361 + 27 3712.0 84.159518 ... 88.837126 87.937800 + 28 3840.0 85.070769 ... 93.326587 85.663823 + 29 3968.0 91.198760 ... 87.097744 91.609561 + 30 4096.0 86.204508 ... 93.792965 89.240508 [31 rows x 5 columns] @@ -499,7 +499,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 7 minutes 16.663 seconds) + **Total running time of the script:** ( 6 minutes 38.507 seconds) .. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py: diff --git a/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt b/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt index edbf188ee..d1778b13a 100644 --- a/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt +++ b/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt @@ -240,7 +240,7 @@ References .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 0 minutes 0.279 seconds) + **Total running time of the script:** ( 0 minutes 0.012 seconds) .. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py: diff --git a/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt b/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt index cf401b478..58bd3828e 100644 --- a/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt +++ b/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt @@ -38,36 +38,36 @@ Layer Normalization layer-norm: N Triton Torch Apex - 0 1024.0 585.142849 277.694907 481.882344 + 0 1024.0 585.142849 277.694907 468.114273 1 1536.0 630.153868 323.368435 511.999982 - 2 2048.0 668.734716 334.367358 528.516136 - 3 2560.0 694.237267 365.714281 518.481028 + 2 2048.0 682.666643 337.814445 520.126988 + 3 2560.0 694.237267 362.477870 512.000013 4 3072.0 712.347810 378.092307 501.551037 - 5 3584.0 725.873439 384.859062 458.751978 - 6 4096.0 728.177767 381.023256 458.293714 - 7 4608.0 670.254540 396.387087 426.173427 - 8 5120.0 688.403381 397.669909 426.666652 - 9 5632.0 698.542675 395.228063 413.357796 - 10 6144.0 702.171410 402.885254 411.313806 - 11 6656.0 700.631610 400.360920 398.861429 - 12 7168.0 690.891575 396.844306 387.459443 - 13 7680.0 678.895043 392.587863 387.634072 - 14 8192.0 639.375598 393.609605 373.424507 - 15 8704.0 627.315309 389.005597 380.502740 - 16 9216.0 606.814809 407.337026 383.999986 - 17 9728.0 587.350922 409.599987 382.427505 - 18 10240.0 564.965524 408.578556 382.803739 - 19 10752.0 547.872604 411.559798 381.445676 - 20 11264.0 533.207081 406.826188 373.134567 + 5 3584.0 725.873439 384.859062 451.527536 + 6 4096.0 728.177767 381.023256 451.972420 + 7 4608.0 670.254540 396.387087 428.651163 + 8 5120.0 688.403381 397.669909 420.102563 + 9 5632.0 704.000002 395.228063 413.357796 + 10 6144.0 702.171410 402.885254 413.042029 + 11 6656.0 700.631610 400.360920 400.360920 + 12 7168.0 690.891575 392.767108 382.293315 + 13 7680.0 678.895043 393.846167 386.415087 + 14 8192.0 636.271854 394.795186 377.729113 + 15 8704.0 624.502255 389.005597 379.465939 + 16 9216.0 604.327881 406.214877 382.010363 + 17 9728.0 585.142883 408.524944 383.369452 + 18 10240.0 564.965524 409.600010 382.803739 + 19 10752.0 546.133312 411.559798 380.601764 + 20 11264.0 531.634232 404.997742 371.595879 21 11776.0 520.486200 409.599991 377.587162 - 22 12288.0 514.680630 413.911572 383.251457 - 23 12800.0 504.433489 410.420828 376.470582 - 24 13312.0 494.180982 405.699062 376.310952 - 25 13824.0 481.882350 411.888257 379.389355 - 26 14336.0 470.997935 406.695045 374.185964 - 27 14848.0 460.403127 408.192434 374.712936 - 28 15360.0 454.269882 406.214870 378.092307 - 29 15872.0 447.887117 406.974373 376.225175 + 22 12288.0 516.031509 413.911572 383.251457 + 23 12800.0 504.433489 409.599981 377.163903 + 24 13312.0 494.180982 406.473303 377.645399 + 25 13824.0 482.934503 412.656711 379.389355 + 26 14336.0 471.967074 402.414053 370.558967 + 27 14848.0 461.297068 407.492270 373.534584 + 28 15360.0 454.269882 406.214870 377.511515 + 29 15872.0 447.098578 409.599996 377.343238 @@ -393,7 +393,7 @@ Layer Normalization .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 5 minutes 37.042 seconds) + **Total running time of the script:** ( 5 minutes 37.218 seconds) .. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py: diff --git a/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt b/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt index 8ea794afc..d78c6a14f 100644 --- a/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt +++ b/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt @@ -23,7 +23,7 @@ Fused Attention This is a Triton implementation of the Flash Attention algorithm (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) -.. GENERATED FROM PYTHON SOURCE LINES 7-355 +.. GENERATED FROM PYTHON SOURCE LINES 7-360 @@ -233,13 +233,16 @@ This is a Triton implementation of the Flash Attention algorithm def forward(ctx, q, k, v, sm_scale): BLOCK = 128 # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( q, k, v, sm_scale, tmp, L, m, @@ -250,14 +253,14 @@ This is a Triton implementation of the Flash Attention algorithm o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=64, num_warps=4, + BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=1, ) ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK ctx.grid = grid ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = 64 + ctx.BLOCK_DMODEL = Lk return o @staticmethod @@ -274,6 +277,8 @@ This is a Triton implementation of the Flash Attention algorithm do_scaled, delta, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) + + num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, o, do_scaled, @@ -286,7 +291,7 @@ This is a Triton implementation of the Flash Attention algorithm q.shape[0], q.shape[1], q.shape[2], ctx.grid[0], BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, num_stages=1, ) return dq, dk, dv, None @@ -385,7 +390,7 @@ This is a Triton implementation of the Flash Attention algorithm .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 0 minutes 0.078 seconds) + **Total running time of the script:** ( 0 minutes 0.073 seconds) .. _sphx_glr_download_getting-started_tutorials_06-fused-attention.py: diff --git a/master/_sources/getting-started/tutorials/07-libdevice-function.rst.txt b/master/_sources/getting-started/tutorials/07-libdevice-function.rst.txt index 7db38fb6e..297ebf116 100644 --- a/master/_sources/getting-started/tutorials/07-libdevice-function.rst.txt +++ b/master/_sources/getting-started/tutorials/07-libdevice-function.rst.txt @@ -152,7 +152,7 @@ We can also customize the libdevice library path by passing the path to the `lib .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 0 minutes 0.254 seconds) + **Total running time of the script:** ( 0 minutes 0.010 seconds) .. _sphx_glr_download_getting-started_tutorials_07-libdevice-function.py: diff --git a/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt b/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt index 445f72e13..9f35e906e 100644 --- a/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt +++ b/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt @@ -5,20 +5,20 @@ Computation times ================= -**18:15.037** total execution time for **getting-started_tutorials** files: +**17:27.791** total execution time for **getting-started_tutorials** files: +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 07:16.663 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 06:38.507 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:37.042 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:37.218 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:30.792 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:30.054 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:49.928 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:41.917 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.279 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` (``06-fused-attention.py``) | 00:00.073 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` (``07-libdevice-function.py``) | 00:00.254 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.012 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` (``06-fused-attention.py``) | 00:00.078 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` (``07-libdevice-function.py``) | 00:00.010 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ diff --git a/master/getting-started/tutorials/01-vector-add.html b/master/getting-started/tutorials/01-vector-add.html index d11c7dfd4..05fda03ae 100644 --- a/master/getting-started/tutorials/01-vector-add.html +++ b/master/getting-started/tutorials/01-vector-add.html @@ -327,22 +327,22 @@ for different problem sizes.

0 4096.0 9.600000 9.600000 1 8192.0 19.200000 19.200000 2 16384.0 38.400001 38.400001 -3 32768.0 63.999998 76.800002 +3 32768.0 76.800002 76.800002 4 65536.0 127.999995 127.999995 5 131072.0 219.428568 219.428568 -6 262144.0 341.333321 384.000001 +6 262144.0 384.000001 384.000001 7 524288.0 472.615390 472.615390 8 1048576.0 614.400016 614.400016 9 2097152.0 722.823517 722.823517 10 4194304.0 780.190482 780.190482 11 8388608.0 812.429770 812.429770 12 16777216.0 833.084721 833.084721 -13 33554432.0 842.004273 843.811163 +13 33554432.0 842.004273 842.004273 14 67108864.0 847.448255 848.362445 15 134217728.0 849.737435 850.656574 -

Total running time of the script: ( 1 minutes 49.928 seconds)

+

Total running time of the script: ( 1 minutes 41.917 seconds)