Files
triton/master/getting-started/tutorials/06-fused-attention.html
2022-09-12 00:51:39 +00:00

628 lines
95 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Fused Attention &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Libdevice function" href="07-libdevice-function.html" />
<link rel="prev" title="Layer Normalization" href="05-layer-norm.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Fused Attention</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/06-fused-attention.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-06-fused-attention-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="fused-attention">
<span id="sphx-glr-getting-started-tutorials-06-fused-attention-py"></span><h1>Fused Attention<a class="headerlink" href="#fused-attention" title="Permalink to this headline"></a></h1>
<p>This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., <a class="reference external" href="https://arxiv.org/pdf/2205.14135v2.pdf">https://arxiv.org/pdf/2205.14135v2.pdf</a>; Rabe and Staats <a class="reference external" href="https://arxiv.org/pdf/2112.05682v2.pdf">https://arxiv.org/pdf/2112.05682v2.pdf</a>)</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">pytest</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_fwd_kernel</span><span class="p">(</span>
<span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span>
<span class="n">TMP</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="c1"># NOTE: TMP is a scratchpad buffer to workaround a compiler bug</span>
<span class="n">Out</span><span class="p">,</span>
<span class="n">stride_qz</span><span class="p">,</span> <span class="n">stride_qh</span><span class="p">,</span> <span class="n">stride_qm</span><span class="p">,</span> <span class="n">stride_qk</span><span class="p">,</span>
<span class="n">stride_kz</span><span class="p">,</span> <span class="n">stride_kh</span><span class="p">,</span> <span class="n">stride_kn</span><span class="p">,</span> <span class="n">stride_kk</span><span class="p">,</span>
<span class="n">stride_vz</span><span class="p">,</span> <span class="n">stride_vh</span><span class="p">,</span> <span class="n">stride_vk</span><span class="p">,</span> <span class="n">stride_vn</span><span class="p">,</span>
<span class="n">stride_oz</span><span class="p">,</span> <span class="n">stride_oh</span><span class="p">,</span> <span class="n">stride_om</span><span class="p">,</span> <span class="n">stride_on</span><span class="p">,</span>
<span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">BLOCK_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">start_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">off_hz</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># initialize offsets</span>
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
<span class="n">offs_d</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
<span class="n">off_q</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span>
<span class="n">off_k</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span>
<span class="n">off_v</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_qh</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_d</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span>
<span class="c1"># Initialize pointers to Q, K, V</span>
<span class="n">q_ptrs</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">+</span> <span class="n">off_q</span>
<span class="n">k_ptrs</span> <span class="o">=</span> <span class="n">K</span> <span class="o">+</span> <span class="n">off_k</span>
<span class="n">v_ptrs</span> <span class="o">=</span> <span class="n">V</span> <span class="o">+</span> <span class="n">off_v</span>
<span class="c1"># initialize pointer to m and l</span>
<span class="n">t_ptrs</span> <span class="o">=</span> <span class="n">TMP</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="n">m_i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;inf&quot;</span><span class="p">)</span>
<span class="n">l_i</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># load q: it will stay in SRAM throughout</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">q_ptrs</span><span class="p">)</span>
<span class="c1"># loop over k, v and update accumulator</span>
<span class="k">for</span> <span class="n">start_n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="p">(</span><span class="n">start_m</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">):</span>
<span class="n">start_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">multiple_of</span><span class="p">(</span><span class="n">start_n</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
<span class="c1"># -- compute qk ----</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">k_ptrs</span> <span class="o">+</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">stride_kn</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">*=</span> <span class="n">sm_scale</span>
<span class="n">qk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="p">(</span><span class="n">start_n</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]),</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;-inf&quot;</span><span class="p">))</span>
<span class="c1"># -- compute m_ij, p, l_ij</span>
<span class="n">m_ij</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">qk</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">qk</span> <span class="o">-</span> <span class="n">m_ij</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span>
<span class="n">l_ij</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># -- update m_i and l_i</span>
<span class="n">m_i_new</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">m_i</span><span class="p">,</span> <span class="n">m_ij</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_i</span> <span class="o">-</span> <span class="n">m_i_new</span><span class="p">)</span>
<span class="n">beta</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_ij</span> <span class="o">-</span> <span class="n">m_i_new</span><span class="p">)</span>
<span class="n">l_i_new</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">l_i</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">l_ij</span>
<span class="c1"># -- update output accumulator --</span>
<span class="c1"># scale p</span>
<span class="n">p_scale</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">/</span> <span class="n">l_i_new</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="n">p_scale</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># scale acc</span>
<span class="n">acc_scale</span> <span class="o">=</span> <span class="n">l_i</span> <span class="o">/</span> <span class="n">l_i_new</span> <span class="o">*</span> <span class="n">alpha</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">t_ptrs</span><span class="p">,</span> <span class="n">acc_scale</span><span class="p">)</span>
<span class="n">acc_scale</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">t_ptrs</span><span class="p">)</span> <span class="c1"># BUG: have to store and immediately load</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">acc</span> <span class="o">*</span> <span class="n">acc_scale</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># update acc</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">v_ptrs</span> <span class="o">+</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">stride_vk</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="c1"># update m_i and l_i</span>
<span class="n">l_i</span> <span class="o">=</span> <span class="n">l_i_new</span>
<span class="n">m_i</span> <span class="o">=</span> <span class="n">m_i_new</span>
<span class="c1"># rematerialize offsets to save registers</span>
<span class="n">start_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="c1"># write back l and m</span>
<span class="n">l_ptrs</span> <span class="o">=</span> <span class="n">L</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="n">m_ptrs</span> <span class="o">=</span> <span class="n">M</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">l_ptrs</span><span class="p">,</span> <span class="n">l_i</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">m_ptrs</span><span class="p">,</span> <span class="n">m_i</span><span class="p">)</span>
<span class="c1"># initialize pointers to output</span>
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
<span class="n">off_o</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">stride_oh</span> <span class="o">+</span> <span class="n">offs_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_om</span> <span class="o">+</span> <span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_on</span>
<span class="n">out_ptrs</span> <span class="o">=</span> <span class="n">Out</span> <span class="o">+</span> <span class="n">off_o</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">out_ptrs</span><span class="p">,</span> <span class="n">acc</span><span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_bwd_preprocess</span><span class="p">(</span>
<span class="n">Out</span><span class="p">,</span> <span class="n">DO</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span>
<span class="n">NewDO</span><span class="p">,</span> <span class="n">Delta</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">off_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">off_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">)</span>
<span class="c1"># load</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Out</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DO</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">L</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># compute</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">do</span> <span class="o">/</span> <span class="n">denom</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">delta</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">o</span> <span class="o">*</span> <span class="n">do</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># write-back</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">NewDO</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">D_HEAD</span> <span class="o">+</span> <span class="n">off_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">do</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Delta</span> <span class="o">+</span> <span class="n">off_m</span><span class="p">,</span> <span class="n">delta</span><span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_bwd_kernel</span><span class="p">(</span>
<span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span> <span class="n">Out</span><span class="p">,</span> <span class="n">DO</span><span class="p">,</span>
<span class="n">DQ</span><span class="p">,</span> <span class="n">DK</span><span class="p">,</span> <span class="n">DV</span><span class="p">,</span>
<span class="n">L</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span>
<span class="n">D</span><span class="p">,</span>
<span class="n">stride_qz</span><span class="p">,</span> <span class="n">stride_qh</span><span class="p">,</span> <span class="n">stride_qm</span><span class="p">,</span> <span class="n">stride_qk</span><span class="p">,</span>
<span class="n">stride_kz</span><span class="p">,</span> <span class="n">stride_kh</span><span class="p">,</span> <span class="n">stride_kn</span><span class="p">,</span> <span class="n">stride_kk</span><span class="p">,</span>
<span class="n">stride_vz</span><span class="p">,</span> <span class="n">stride_vh</span><span class="p">,</span> <span class="n">stride_vk</span><span class="p">,</span> <span class="n">stride_vn</span><span class="p">,</span>
<span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span>
<span class="n">num_block</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">BLOCK_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">off_hz</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">off_z</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">//</span> <span class="n">H</span>
<span class="n">off_h</span> <span class="o">=</span> <span class="n">off_hz</span> <span class="o">%</span> <span class="n">H</span>
<span class="c1"># offset pointers for batch/head</span>
<span class="n">Q</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">K</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">V</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DO</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DQ</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DK</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="n">DV</span> <span class="o">+=</span> <span class="n">off_z</span> <span class="o">*</span> <span class="n">stride_qz</span> <span class="o">+</span> <span class="n">off_h</span> <span class="o">*</span> <span class="n">stride_qh</span>
<span class="k">for</span> <span class="n">start_n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_block</span><span class="p">):</span>
<span class="n">lo</span> <span class="o">=</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">BLOCK_M</span>
<span class="c1"># initialize row/col offsets</span>
<span class="n">offs_qm</span> <span class="o">=</span> <span class="n">lo</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">offs_n</span> <span class="o">=</span> <span class="n">start_n</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
<span class="n">offs_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">)</span>
<span class="c1"># initialize pointers to value-like data</span>
<span class="n">q_ptrs</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">k_ptrs</span> <span class="o">=</span> <span class="n">K</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span><span class="p">)</span>
<span class="n">v_ptrs</span> <span class="o">=</span> <span class="n">V</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">do_ptrs</span> <span class="o">=</span> <span class="n">DO</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">dq_ptrs</span> <span class="o">=</span> <span class="n">DQ</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_qm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="c1"># pointer to row-wise quantities in value-like data</span>
<span class="n">D_ptrs</span> <span class="o">=</span> <span class="n">D</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span>
<span class="n">m_ptrs</span> <span class="o">=</span> <span class="n">M</span> <span class="o">+</span> <span class="n">off_hz</span> <span class="o">*</span> <span class="n">N_CTX</span>
<span class="c1"># initialize dv amd dk</span>
<span class="n">dv</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_DMODEL</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># k and v stay in SRAM throughout</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">k_ptrs</span><span class="p">)</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">v_ptrs</span><span class="p">)</span>
<span class="c1"># loop over rows</span>
<span class="k">for</span> <span class="n">start_m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">lo</span><span class="p">,</span> <span class="n">num_block</span> <span class="o">*</span> <span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">):</span>
<span class="n">offs_m_curr</span> <span class="o">=</span> <span class="n">start_m</span> <span class="o">+</span> <span class="n">offs_m</span>
<span class="c1"># load q, k, v, do on-chip</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">q_ptrs</span><span class="p">)</span>
<span class="c1"># recompute p = softmax(qk, dim=-1).T</span>
<span class="c1"># NOTE: `do` is pre-divided by `l`; no normalization here</span>
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">qk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">offs_m_curr</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]),</span> <span class="n">qk</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;-inf&quot;</span><span class="p">))</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">m_ptrs</span> <span class="o">+</span> <span class="n">offs_m_curr</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">qk</span> <span class="o">*</span> <span class="n">sm_scale</span> <span class="o">-</span> <span class="n">m</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])</span>
<span class="c1"># compute dv</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">do_ptrs</span><span class="p">)</span>
<span class="n">dv</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">do</span><span class="p">,</span> <span class="n">trans_a</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># compute dp = dot(v, do)</span>
<span class="n">Di</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">D_ptrs</span> <span class="o">+</span> <span class="n">offs_m_curr</span><span class="p">)</span>
<span class="n">dp</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="n">Di</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">dp</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">trans_b</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># compute ds = p * (dp - delta[:, None])</span>
<span class="n">ds</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="n">dp</span> <span class="o">*</span> <span class="n">sm_scale</span>
<span class="c1"># compute dk = dot(ds.T, q)</span>
<span class="n">dk</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">q</span><span class="p">,</span> <span class="n">trans_a</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># # compute dq</span>
<span class="n">dq</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">dq_ptrs</span><span class="p">,</span> <span class="n">eviction_policy</span><span class="o">=</span><span class="s2">&quot;evict_last&quot;</span><span class="p">)</span>
<span class="n">dq</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">),</span> <span class="n">k</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dq_ptrs</span><span class="p">,</span> <span class="n">dq</span><span class="p">,</span> <span class="n">eviction_policy</span><span class="o">=</span><span class="s2">&quot;evict_last&quot;</span><span class="p">)</span>
<span class="c1"># # increment pointers</span>
<span class="n">dq_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
<span class="n">q_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
<span class="n">do_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_M</span> <span class="o">*</span> <span class="n">stride_qm</span>
<span class="c1"># write-back</span>
<span class="n">dv_ptrs</span> <span class="o">=</span> <span class="n">DV</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_qm</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_qk</span><span class="p">)</span>
<span class="n">dk_ptrs</span> <span class="o">=</span> <span class="n">DK</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_kn</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_kk</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dv_ptrs</span><span class="p">,</span> <span class="n">dv</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">dk_ptrs</span><span class="p">,</span> <span class="n">dk</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">_attention</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">):</span>
<span class="n">BLOCK</span> <span class="o">=</span> <span class="mi">128</span>
<span class="c1"># shape constraints</span>
<span class="n">Lq</span><span class="p">,</span> <span class="n">Lk</span><span class="p">,</span> <span class="n">Lv</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">assert</span> <span class="n">Lq</span> <span class="o">==</span> <span class="n">Lk</span> <span class="ow">and</span> <span class="n">Lk</span> <span class="o">==</span> <span class="n">Lv</span>
<span class="k">assert</span> <span class="n">Lk</span> <span class="ow">in</span> <span class="p">{</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">}</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="n">grid</span> <span class="o">=</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">BLOCK</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">tmp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">L</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="n">device</span><span class="o">=</span><span class="n">q</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span> <span class="k">if</span> <span class="n">Lk</span> <span class="o">&lt;=</span> <span class="mi">64</span> <span class="k">else</span> <span class="mi">8</span>
<span class="n">_fwd_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">,</span>
<span class="n">tmp</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span>
<span class="n">o</span><span class="p">,</span>
<span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">o</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">,</span>
<span class="n">BLOCK_DMODEL</span><span class="o">=</span><span class="n">Lk</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">,</span>
<span class="n">num_stages</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span> <span class="o">=</span> <span class="n">BLOCK</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span> <span class="o">=</span> <span class="n">sm_scale</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span> <span class="o">=</span> <span class="n">Lk</span>
<span class="k">return</span> <span class="n">o</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">do</span><span class="p">):</span>
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span> <span class="n">m</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">do</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">dq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dk</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="n">dv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
<span class="n">do_scaled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">do</span><span class="p">)</span>
<span class="n">delta</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">l</span><span class="p">)</span>
<span class="n">_bwd_preprocess</span><span class="p">[(</span><span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">)](</span>
<span class="n">o</span><span class="p">,</span> <span class="n">do</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span>
<span class="n">do_scaled</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span>
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span> <span class="k">if</span> <span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span> <span class="o">&lt;=</span> <span class="mi">64</span> <span class="k">else</span> <span class="mi">8</span>
<span class="n">_bwd_kernel</span><span class="p">[(</span><span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">1</span><span class="p">],)](</span>
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">.</span><span class="n">sm_scale</span><span class="p">,</span>
<span class="n">o</span><span class="p">,</span> <span class="n">do_scaled</span><span class="p">,</span>
<span class="n">dq</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span>
<span class="n">l</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span>
<span class="n">delta</span><span class="p">,</span>
<span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">q</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">k</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">v</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span>
<span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">grid</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">BLOCK_M</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">,</span>
<span class="n">BLOCK_DMODEL</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_DMODEL</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">,</span>
<span class="n">num_stages</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">dq</span><span class="p">,</span> <span class="n">dk</span><span class="p">,</span> <span class="n">dv</span><span class="p">,</span> <span class="kc">None</span>
<span class="n">attention</span> <span class="o">=</span> <span class="n">_attention</span><span class="o">.</span><span class="n">apply</span>
<span class="nd">@pytest</span><span class="o">.</span><span class="n">mark</span><span class="o">.</span><span class="n">parametrize</span><span class="p">(</span><span class="s1">&#39;Z, H, N_CTX, D_HEAD&#39;</span><span class="p">,</span> <span class="p">[(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">64</span><span class="p">)])</span>
<span class="k">def</span> <span class="nf">test_op</span><span class="p">(</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">):</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">Z</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
<span class="n">sm_scale</span> <span class="o">=</span> <span class="mf">0.3</span>
<span class="n">dout</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="c1"># reference implementation</span>
<span class="n">M</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">N_CTX</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">))</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="o">*</span> <span class="n">sm_scale</span>
<span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">Z</span><span class="p">):</span>
<span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">H</span><span class="p">):</span>
<span class="n">p</span><span class="p">[:,</span> <span class="p">:,</span> <span class="n">M</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;-inf&quot;</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">half</span><span class="p">()</span>
<span class="n">ref_out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="n">ref_out</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dout</span><span class="p">)</span>
<span class="n">ref_dv</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">ref_dk</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">ref_dq</span><span class="p">,</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="c1"># triton implementation</span>
<span class="n">tri_out</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span>
<span class="n">tri_out</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dout</span><span class="p">)</span>
<span class="n">tri_dv</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">tri_dk</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="n">tri_dq</span><span class="p">,</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span>
<span class="c1"># compare</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_out</span><span class="p">,</span> <span class="n">tri_out</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dv</span><span class="p">,</span> <span class="n">tri_dv</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dk</span><span class="p">,</span> <span class="n">tri_dk</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">ref_dq</span><span class="p">,</span> <span class="n">tri_dq</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">flash_attn.flash_attn_interface</span> <span class="kn">import</span> <span class="n">flash_attn_func</span>
<span class="n">HAS_FLASH</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">except</span> <span class="ne">BaseException</span><span class="p">:</span>
<span class="n">HAS_FLASH</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">BATCH</span><span class="p">,</span> <span class="n">N_HEADS</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">48</span><span class="p">,</span> <span class="mi">4096</span><span class="p">,</span> <span class="mi">64</span>
<span class="c1"># vary seq length for fixed head and batch=4</span>
<span class="n">configs</span> <span class="o">=</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;N_CTX&#39;</span><span class="p">],</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">16</span><span class="p">)],</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;triton&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;flash&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_FLASH</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;Triton&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;Flash&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_FLASH</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;red&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">)],</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s1">&#39;ms&#39;</span><span class="p">,</span>
<span class="n">plot_name</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;fused-attention-batch</span><span class="si">{</span><span class="n">BATCH</span><span class="si">}</span><span class="s1">-head</span><span class="si">{</span><span class="n">N_HEADS</span><span class="si">}</span><span class="s1">-d</span><span class="si">{</span><span class="n">D_HEAD</span><span class="si">}</span><span class="s1">-</span><span class="si">{</span><span class="n">mode</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;H&#39;</span><span class="p">:</span> <span class="n">N_HEADS</span><span class="p">,</span> <span class="s1">&#39;BATCH&#39;</span><span class="p">:</span> <span class="n">BATCH</span><span class="p">,</span> <span class="s1">&#39;D_HEAD&#39;</span><span class="p">:</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s1">&#39;mode&#39;</span><span class="p">:</span> <span class="n">mode</span><span class="p">}</span>
<span class="p">)</span> <span class="k">for</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;bwd&#39;</span><span class="p">]]</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span><span class="n">configs</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">bench_flash_attention</span><span class="p">(</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">,</span> <span class="n">mode</span><span class="p">,</span> <span class="n">provider</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;fwd&#39;</span><span class="p">,</span> <span class="s1">&#39;bwd&#39;</span><span class="p">]</span>
<span class="n">warmup</span> <span class="o">=</span> <span class="mi">25</span>
<span class="n">rep</span> <span class="o">=</span> <span class="mi">100</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s2">&quot;triton&quot;</span><span class="p">:</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">sm_scale</span> <span class="o">=</span> <span class="mf">1.3</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;bwd&#39;</span><span class="p">:</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">percentiles</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ms</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s2">&quot;flash&quot;</span><span class="p">:</span>
<span class="n">lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,),</span> <span class="n">fill_value</span><span class="o">=</span><span class="n">N_CTX</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="n">cu_seqlens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BATCH</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">cu_seqlens</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">lengths</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">BATCH</span> <span class="o">*</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">D_HEAD</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">flash_attn_func</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">cu_seqlens</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">N_CTX</span><span class="p">,</span> <span class="n">causal</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;bwd&#39;</span><span class="p">:</span>
<span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">percentiles</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ms</span>
<span class="c1"># only works on A100 at the moment</span>
<span class="c1"># bench_flash_attention.run(save_path=&#39;.&#39;, print_data=True)</span>
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.072 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-06-fused-attention-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">06-fused-attention.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">06-fused-attention.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="07-libdevice-function.html" class="btn btn-neutral float-right" title="Libdevice function" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="05-layer-norm.html" class="btn btn-neutral float-left" title="Layer Normalization" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="06-fused-attention.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>