Files
triton/master/getting-started/tutorials/05-layer-norm.html
2022-04-12 00:41:56 +00:00

567 lines
70 KiB
HTML

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Layer Normalization &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="triton" href="../../python-api/triton.html" />
<link rel="prev" title="Low-Memory Dropout" href="04-low-memory-dropout.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Layer Normalization</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/05-layer-norm.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="layer-normalization">
<span id="sphx-glr-getting-started-tutorials-05-layer-norm-py"></span><h1>Layer Normalization<a class="headerlink" href="#layer-normalization" title="Permalink to this headline"></a></h1>
<img alt="05 layer norm" class="sphx-glr-single-img" src="../../_images/sphx_glr_05-layer-norm_001.png" />
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>layer-norm-backward:
N Triton Torch Apex
0 1024.0 361.411758 99.902435 315.076934
1 1536.0 405.098894 134.540150 344.523365
2 2048.0 491.520012 159.584422 323.368435
3 2560.0 458.507457 182.857144 325.079368
4 3072.0 519.211251 191.999993 320.556515
5 3584.0 554.941930 208.271186 310.527060
6 4096.0 568.231237 221.405403 301.546004
7 4608.0 502.690905 232.336141 287.251954
8 5120.0 527.381977 243.809526 286.433562
9 5632.0 542.843364 244.647957 291.939522
10 6144.0 550.208948 251.202731 287.438593
11 6656.0 532.479975 255.590406 286.793541
12 7168.0 510.480705 253.734520 277.470965
13 7680.0 487.619051 266.358392 284.884090
14 8192.0 468.114289 258.354805 278.481578
15 8704.0 415.300208 267.815384 285.377055
16 9216.0 431.157889 272.394084 289.887291
17 9728.0 438.033784 280.278512 289.308559
18 10240.0 442.810829 287.438599 290.496460
19 10752.0 426.525614 246.935876 289.941565
20 11264.0 427.071098 245.536784 286.069848
21 11776.0 419.323436 249.447482 288.686414
22 12288.0 415.954875 254.893699 294.617366
23 12800.0 411.244989 254.094291 289.538159
24 13312.0 410.652963 252.559690 289.391298
25 13824.0 404.604870 257.190689 292.056329
26 14336.0 396.387109 256.190622 289.129416
27 14848.0 386.080180 257.665934 288.777966
28 15360.0 379.649845 258.513318 286.656296
29 15872.0 372.000001 261.806182 290.341468
</pre></div>
</div>
<div class="line-block">
<div class="line"><br /></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="k">try</span><span class="p">:</span>
<span class="c1"># This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it</span>
<span class="c1"># should not be added to extras_require in setup.py.</span>
<span class="kn">import</span> <span class="nn">apex</span>
<span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">except</span> <span class="ne">ModuleNotFoundError</span><span class="p">:</span>
<span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># Forward Pass</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_fwd_fused</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="c1"># position of elements processed by this program</span>
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span>
<span class="c1"># offset data pointers to start at the row of interest</span>
<span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">Y</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="c1"># load data and cast to float32</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># compute mean</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="c1"># compute std</span>
<span class="n">xmean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">var</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xmean</span> <span class="o">*</span> <span class="n">xmean</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">tl</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="n">xmean</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="c1"># write-back mean/rstd</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">mean</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
<span class="c1"># multiply by weight and add bias</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">xhat</span> <span class="o">*</span> <span class="n">w</span> <span class="o">+</span> <span class="n">b</span>
<span class="c1"># write-back</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># Backward pass (DX + partial DW + partial DB)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dx_fused</span><span class="p">(</span><span class="n">DX</span><span class="p">,</span> <span class="n">DY</span><span class="p">,</span> <span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">Lock</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="c1"># position of elements processed by this program</span>
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span>
<span class="c1"># offset data pointers to start at the row of interest</span>
<span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">DY</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">DX</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="c1"># offset locks and weight/bias gradient pointer</span>
<span class="c1"># each kernel instance accumulates partial sums for</span>
<span class="c1"># DW and DB into one of GROUP_SIZE_M independent buffers</span>
<span class="c1"># these buffers stay in the L2, which allow this kernel</span>
<span class="c1"># to be fast</span>
<span class="n">lock_id</span> <span class="o">=</span> <span class="n">row</span> <span class="o">%</span> <span class="n">GROUP_SIZE_M</span>
<span class="n">Lock</span> <span class="o">+=</span> <span class="n">lock_id</span>
<span class="n">Count</span> <span class="o">=</span> <span class="n">Lock</span> <span class="o">+</span> <span class="n">GROUP_SIZE_M</span>
<span class="n">DW</span> <span class="o">=</span> <span class="n">DW</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
<span class="n">DB</span> <span class="o">=</span> <span class="n">DB</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
<span class="c1"># load data to SRAM</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DY</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
<span class="c1"># compute dx</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="n">wdy</span> <span class="o">=</span> <span class="n">w</span> <span class="o">*</span> <span class="n">dy</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">xhat</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">wdy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">wdy</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">mean1</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">mean2</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">dx</span> <span class="o">=</span> <span class="p">(</span><span class="n">wdy</span> <span class="o">-</span> <span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">mean1</span> <span class="o">+</span> <span class="n">mean2</span><span class="p">))</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="c1"># write-back dx</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DX</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">dx</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># accumulate partial sums for dw/db</span>
<span class="n">partial_dw</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span> <span class="o">*</span> <span class="n">xhat</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">partial_db</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">while</span> <span class="n">tl</span><span class="o">.</span><span class="n">atomic_cas</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">pass</span>
<span class="n">count</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Count</span><span class="p">)</span>
<span class="c1"># first store doesn&#39;t accumulate</span>
<span class="k">if</span> <span class="n">count</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Count</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">partial_dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">partial_db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">partial_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">partial_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># release lock</span>
<span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># Backward pass (total DW + total DB)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dwdb</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">FINAL_DW</span><span class="p">,</span> <span class="n">FINAL_DB</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
<span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">sum_dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dw</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">sum_db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DW</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DB</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">LayerNorm</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
<span class="c1"># allocate output</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># reshape input data into 2D tensor</span>
<span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="c1"># Less than 64KB per feature: enqueue fused kernel</span>
<span class="n">MAX_FUSED_SIZE</span> <span class="o">=</span> <span class="mi">65536</span> <span class="o">//</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">MAX_FUSED_SIZE</span><span class="p">,</span> <span class="n">triton</span><span class="o">.</span><span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">))</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&gt;</span> <span class="n">BLOCK_SIZE</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;This layer norm doesn&#39;t support feature dim &gt;= 64KB.&quot;</span><span class="p">)</span>
<span class="c1"># heuristics for number of warps</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="n">BLOCK_SIZE</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">8</span><span class="p">)</span>
<span class="c1"># enqueue kernel</span>
<span class="n">_layer_norm_fwd_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">x_arg</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">,</span>
<span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">BLOCK_SIZE</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span> <span class="o">=</span> <span class="n">num_warps</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="k">return</span> <span class="n">y</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dy</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
<span class="c1"># heuristics for amount of parallel reduction stream for DG/DB</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">64</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">8192</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">96</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">128</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">1024</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">256</span>
<span class="c1"># allocate output</span>
<span class="n">locks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">_dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">_db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">dy</span><span class="p">)</span>
<span class="c1"># enqueue kernel using forward pass heuristics</span>
<span class="c1"># also compute partial sums for DW and DB</span>
<span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
<span class="n">_layer_norm_bwd_dx_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">dx</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">locks</span><span class="p">,</span>
<span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">ctx</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="o">=</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span>
<span class="n">num_warps</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span><span class="p">)</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">])]</span>
<span class="c1"># accumulate partial sums in separate kernel</span>
<span class="n">_layer_norm_bwd_dwdb</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dx</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="kc">None</span>
<span class="n">layer_norm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="o">.</span><span class="n">apply</span>
<span class="k">def</span> <span class="nf">test_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
<span class="c1"># create data</span>
<span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># forward pass</span>
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># backward pass (triton)</span>
<span class="n">y_tri</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dx_tri</span><span class="p">,</span> <span class="n">dw_tri</span><span class="p">,</span> <span class="n">db_tri</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
<span class="n">x</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="c1"># backward pass (torch)</span>
<span class="n">y_ref</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dx_ref</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">db_ref</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
<span class="c1"># compare</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dx_tri</span><span class="p">,</span> <span class="n">dx_ref</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">db_tri</span><span class="p">,</span> <span class="n">db_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dw_tri</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;N&#39;</span><span class="p">],</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">512</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">32</span><span class="p">)],</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="s1">&#39;torch&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;apex&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;Triton&#39;</span><span class="p">,</span> <span class="s1">&#39;Torch&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;Apex&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;orange&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">)],</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s1">&#39;GB/s&#39;</span><span class="p">,</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s1">&#39;layer-norm-backward&#39;</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;M&#39;</span><span class="p">:</span> <span class="mi">4096</span><span class="p">,</span> <span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s1">&#39;mode&#39;</span><span class="p">:</span> <span class="s1">&#39;backward&#39;</span><span class="p">}</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">bench_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">provider</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;backward&#39;</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
<span class="c1"># create data</span>
<span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># utility functions</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch&#39;</span><span class="p">:</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;apex&#39;</span><span class="p">:</span>
<span class="n">apex_layer_norm</span> <span class="o">=</span> <span class="n">apex</span><span class="o">.</span><span class="n">normalization</span><span class="o">.</span><span class="n">FusedLayerNorm</span><span class="p">(</span><span class="n">w_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">apex_layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># forward pass</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;forward&#39;</span><span class="p">:</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">y_fwd</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="c1"># backward pass</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;backward&#39;</span><span class="p">:</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y_fwd</span><span class="p">()</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">y</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="n">grad_to_none</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">],</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
<span class="n">bench_layer_norm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">save_path</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 12.229 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-05-layer-norm-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">05-layer-norm.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">05-layer-norm.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="04-low-memory-dropout.html" class="btn btn-neutral float-left" title="Low-Memory Dropout" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="05-layer-norm.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>