567 lines
		
	
	
		
			70 KiB
		
	
	
	
		
			HTML
		
	
	
	
	
	
			
		
		
	
	
			567 lines
		
	
	
		
			70 KiB
		
	
	
	
		
			HTML
		
	
	
	
	
	
 | 
						|
 | 
						|
<!DOCTYPE html>
 | 
						|
<html class="writer-html5" lang="en" >
 | 
						|
<head>
 | 
						|
  <meta charset="utf-8" />
 | 
						|
  
 | 
						|
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
 | 
						|
  
 | 
						|
  <title>Layer Normalization — Triton  documentation</title>
 | 
						|
  
 | 
						|
 | 
						|
  
 | 
						|
  <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
 | 
						|
  <link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
 | 
						|
 | 
						|
  
 | 
						|
  
 | 
						|
 | 
						|
  
 | 
						|
  
 | 
						|
 | 
						|
  
 | 
						|
 | 
						|
  
 | 
						|
  <!--[if lt IE 9]>
 | 
						|
    <script src="../../_static/js/html5shiv.min.js"></script>
 | 
						|
  <![endif]-->
 | 
						|
  
 | 
						|
    
 | 
						|
      <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
 | 
						|
        <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
 | 
						|
        <script src="../../_static/jquery.js"></script>
 | 
						|
        <script src="../../_static/underscore.js"></script>
 | 
						|
        <script src="../../_static/doctools.js"></script>
 | 
						|
    
 | 
						|
    <script type="text/javascript" src="../../_static/js/theme.js"></script>
 | 
						|
 | 
						|
    
 | 
						|
    <link rel="index" title="Index" href="../../genindex.html" />
 | 
						|
    <link rel="search" title="Search" href="../../search.html" />
 | 
						|
    <link rel="next" title="triton" href="../../python-api/triton.html" />
 | 
						|
    <link rel="prev" title="Low-Memory Dropout" href="04-low-memory-dropout.html" /> 
 | 
						|
</head>
 | 
						|
 | 
						|
<body class="wy-body-for-nav">
 | 
						|
 | 
						|
   
 | 
						|
  <div class="wy-grid-for-nav">
 | 
						|
    
 | 
						|
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
 | 
						|
      <div class="wy-side-scroll">
 | 
						|
        <div class="wy-side-nav-search" >
 | 
						|
          
 | 
						|
 | 
						|
          
 | 
						|
            <a href="../../index.html" class="icon icon-home"> Triton
 | 
						|
          
 | 
						|
 | 
						|
          
 | 
						|
          </a>
 | 
						|
 | 
						|
          
 | 
						|
            
 | 
						|
            
 | 
						|
          
 | 
						|
 | 
						|
          
 | 
						|
<div role="search">
 | 
						|
  <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
 | 
						|
    <input type="text" name="q" placeholder="Search docs" />
 | 
						|
    <input type="hidden" name="check_keywords" value="yes" />
 | 
						|
    <input type="hidden" name="area" value="default" />
 | 
						|
  </form>
 | 
						|
</div>
 | 
						|
 | 
						|
          
 | 
						|
        </div>
 | 
						|
 | 
						|
        
 | 
						|
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
 | 
						|
          
 | 
						|
            
 | 
						|
            
 | 
						|
              
 | 
						|
            
 | 
						|
            
 | 
						|
              <p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
 | 
						|
<ul class="current">
 | 
						|
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
 | 
						|
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
 | 
						|
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
 | 
						|
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
 | 
						|
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
 | 
						|
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
 | 
						|
<li class="toctree-l2 current"><a class="current reference internal" href="#">Layer Normalization</a></li>
 | 
						|
</ul>
 | 
						|
</li>
 | 
						|
</ul>
 | 
						|
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
 | 
						|
<ul>
 | 
						|
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
 | 
						|
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
 | 
						|
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
 | 
						|
</ul>
 | 
						|
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
 | 
						|
<ul>
 | 
						|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
 | 
						|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
 | 
						|
</ul>
 | 
						|
 | 
						|
            
 | 
						|
          
 | 
						|
        </div>
 | 
						|
        
 | 
						|
      </div>
 | 
						|
    </nav>
 | 
						|
 | 
						|
    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
 | 
						|
 | 
						|
      
 | 
						|
      <nav class="wy-nav-top" aria-label="top navigation">
 | 
						|
        
 | 
						|
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
 | 
						|
          <a href="../../index.html">Triton</a>
 | 
						|
        
 | 
						|
      </nav>
 | 
						|
 | 
						|
 | 
						|
      <div class="wy-nav-content">
 | 
						|
        
 | 
						|
        <div class="rst-content">
 | 
						|
        
 | 
						|
          
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
<div role="navigation" aria-label="breadcrumbs navigation">
 | 
						|
 | 
						|
  <ul class="wy-breadcrumbs">
 | 
						|
    
 | 
						|
      <li><a href="../../index.html" class="icon icon-home"></a> »</li>
 | 
						|
        
 | 
						|
          <li><a href="index.html">Tutorials</a> »</li>
 | 
						|
        
 | 
						|
      <li>Layer Normalization</li>
 | 
						|
    
 | 
						|
    
 | 
						|
      <li class="wy-breadcrumbs-aside">
 | 
						|
        
 | 
						|
          
 | 
						|
            <a href="../../_sources/getting-started/tutorials/05-layer-norm.rst.txt" rel="nofollow"> View page source</a>
 | 
						|
          
 | 
						|
        
 | 
						|
      </li>
 | 
						|
    
 | 
						|
  </ul>
 | 
						|
 | 
						|
  
 | 
						|
  <hr/>
 | 
						|
</div>
 | 
						|
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
 | 
						|
           <div itemprop="articleBody">
 | 
						|
            
 | 
						|
  <div class="sphx-glr-download-link-note admonition note">
 | 
						|
<p class="admonition-title">Note</p>
 | 
						|
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">here</span></a>
 | 
						|
to download the full example code</p>
 | 
						|
</div>
 | 
						|
<div class="sphx-glr-example-title section" id="layer-normalization">
 | 
						|
<span id="sphx-glr-getting-started-tutorials-05-layer-norm-py"></span><h1>Layer Normalization<a class="headerlink" href="#layer-normalization" title="Permalink to this headline">¶</a></h1>
 | 
						|
<img alt="05 layer norm" class="sphx-glr-single-img" src="../../_images/sphx_glr_05-layer-norm_001.png" />
 | 
						|
<p class="sphx-glr-script-out">Out:</p>
 | 
						|
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>layer-norm-backward:
 | 
						|
          N      Triton       Torch        Apex
 | 
						|
0    1024.0  356.173905   99.902435  315.076934
 | 
						|
1    1536.0  405.098894  134.050910  344.523365
 | 
						|
2    2048.0  491.520012  159.067963  334.367350
 | 
						|
3    2560.0  458.507457  182.314537  330.322572
 | 
						|
4    3072.0  519.211251  191.501303  321.956335
 | 
						|
5    3584.0  554.941930  207.768111  309.410081
 | 
						|
6    4096.0  568.231237  220.907859  299.707322
 | 
						|
7    4608.0  502.690905  232.336141  287.251954
 | 
						|
8    5120.0  527.381977  243.809526  286.433562
 | 
						|
9    5632.0  540.671974  244.426754  291.939522
 | 
						|
10   6144.0  550.208948  251.202731  288.000001
 | 
						|
11   6656.0  530.710976  255.590406  286.793541
 | 
						|
12   7168.0  510.480705  253.734520  277.470965
 | 
						|
13   7680.0  487.619051  266.358392  284.884090
 | 
						|
14   8192.0  468.114289  258.354805  278.481578
 | 
						|
15   8704.0  414.476194  267.472468  285.377055
 | 
						|
16   9216.0  431.157889  272.394084  289.887291
 | 
						|
17   9728.0  438.033784  279.942444  288.950501
 | 
						|
18  10240.0  442.810829  287.102804  290.153487
 | 
						|
19  10752.0  427.231788  246.699797  289.941565
 | 
						|
20  11264.0  427.071098  245.760001  286.069848
 | 
						|
21  11776.0  419.946507  249.447482  288.686414
 | 
						|
22  12288.0  415.369018  254.673582  294.617366
 | 
						|
23  12800.0  410.695192  253.884294  287.910035
 | 
						|
24  13312.0  410.125805  252.559690  289.129403
 | 
						|
25  13824.0  404.112047  257.190689  292.056329
 | 
						|
26  14336.0  396.844280  256.000002  289.129416
 | 
						|
27  14848.0  385.662341  257.479779  288.777966
 | 
						|
28  15360.0  378.869469  258.332158  288.225185
 | 
						|
29  15872.0  372.000001  261.806182  290.562936
 | 
						|
</pre></div>
 | 
						|
</div>
 | 
						|
<div class="line-block">
 | 
						|
<div class="line"><br /></div>
 | 
						|
</div>
 | 
						|
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
 | 
						|
 | 
						|
<span class="kn">import</span> <span class="nn">triton</span>
 | 
						|
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
 | 
						|
 | 
						|
<span class="k">try</span><span class="p">:</span>
 | 
						|
    <span class="c1"># This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it</span>
 | 
						|
    <span class="c1"># should not be added to extras_require in setup.py.</span>
 | 
						|
    <span class="kn">import</span> <span class="nn">apex</span>
 | 
						|
    <span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">True</span>
 | 
						|
<span class="k">except</span> <span class="ne">ModuleNotFoundError</span><span class="p">:</span>
 | 
						|
    <span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">False</span>
 | 
						|
 | 
						|
 | 
						|
<span class="c1"># Forward Pass</span>
 | 
						|
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
 | 
						|
<span class="k">def</span> <span class="nf">_layer_norm_fwd_fused</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
 | 
						|
                          <span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
 | 
						|
    <span class="c1"># position of elements processed by this program</span>
 | 
						|
    <span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
 | 
						|
    <span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
 | 
						|
    <span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o"><</span> <span class="n">N</span>
 | 
						|
    <span class="c1"># offset data pointers to start at the row of interest</span>
 | 
						|
    <span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
 | 
						|
    <span class="n">Y</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
 | 
						|
    <span class="c1"># load data and cast to float32</span>
 | 
						|
    <span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
 | 
						|
    <span class="c1"># compute mean</span>
 | 
						|
    <span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
 | 
						|
    <span class="c1"># compute std</span>
 | 
						|
    <span class="n">xmean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
 | 
						|
    <span class="n">var</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xmean</span> <span class="o">*</span> <span class="n">xmean</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
 | 
						|
    <span class="n">rstd</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">tl</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
 | 
						|
    <span class="n">xhat</span> <span class="o">=</span> <span class="n">xmean</span> <span class="o">*</span> <span class="n">rstd</span>
 | 
						|
    <span class="c1"># write-back mean/rstd</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">mean</span><span class="p">)</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
 | 
						|
    <span class="c1"># multiply by weight and add bias</span>
 | 
						|
    <span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
    <span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
    <span class="n">y</span> <span class="o">=</span> <span class="n">xhat</span> <span class="o">*</span> <span class="n">w</span> <span class="o">+</span> <span class="n">b</span>
 | 
						|
    <span class="c1"># write-back</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
 | 
						|
 | 
						|
<span class="c1"># Backward pass (DX + partial DW + partial DB)</span>
 | 
						|
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
 | 
						|
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dx_fused</span><span class="p">(</span><span class="n">DX</span><span class="p">,</span> <span class="n">DY</span><span class="p">,</span> <span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">Lock</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
 | 
						|
                             <span class="n">GROUP_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
 | 
						|
    <span class="c1"># position of elements processed by this program</span>
 | 
						|
    <span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
 | 
						|
    <span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
 | 
						|
    <span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o"><</span> <span class="n">N</span>
 | 
						|
    <span class="c1"># offset data pointers to start at the row of interest</span>
 | 
						|
    <span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
 | 
						|
    <span class="n">DY</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
 | 
						|
    <span class="n">DX</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
 | 
						|
    <span class="c1"># offset locks and weight/bias gradient pointer</span>
 | 
						|
    <span class="c1"># each kernel instance accumulates partial sums for</span>
 | 
						|
    <span class="c1"># DW and DB into one of GROUP_SIZE_M independent buffers</span>
 | 
						|
    <span class="c1"># these buffers stay in the L2, which allow this kernel</span>
 | 
						|
    <span class="c1"># to be fast</span>
 | 
						|
    <span class="n">lock_id</span> <span class="o">=</span> <span class="n">row</span> <span class="o">%</span> <span class="n">GROUP_SIZE_M</span>
 | 
						|
    <span class="n">Lock</span> <span class="o">+=</span> <span class="n">lock_id</span>
 | 
						|
    <span class="n">Count</span> <span class="o">=</span> <span class="n">Lock</span> <span class="o">+</span> <span class="n">GROUP_SIZE_M</span>
 | 
						|
    <span class="n">DW</span> <span class="o">=</span> <span class="n">DW</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
 | 
						|
    <span class="n">DB</span> <span class="o">=</span> <span class="n">DB</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
 | 
						|
    <span class="c1"># load data to SRAM</span>
 | 
						|
    <span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
 | 
						|
    <span class="n">dy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DY</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
 | 
						|
    <span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
 | 
						|
    <span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
 | 
						|
    <span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
 | 
						|
    <span class="c1"># compute dx</span>
 | 
						|
    <span class="n">xhat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">*</span> <span class="n">rstd</span>
 | 
						|
    <span class="n">wdy</span> <span class="o">=</span> <span class="n">w</span> <span class="o">*</span> <span class="n">dy</span>
 | 
						|
    <span class="n">xhat</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">xhat</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
 | 
						|
    <span class="n">wdy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">wdy</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
 | 
						|
    <span class="n">mean1</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
 | 
						|
    <span class="n">mean2</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
 | 
						|
    <span class="n">dx</span> <span class="o">=</span> <span class="p">(</span><span class="n">wdy</span> <span class="o">-</span> <span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">mean1</span> <span class="o">+</span> <span class="n">mean2</span><span class="p">))</span> <span class="o">*</span> <span class="n">rstd</span>
 | 
						|
    <span class="c1"># write-back dx</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DX</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">dx</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
    <span class="c1"># accumulate partial sums for dw/db</span>
 | 
						|
    <span class="n">partial_dw</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span> <span class="o">*</span> <span class="n">xhat</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
 | 
						|
    <span class="n">partial_db</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
 | 
						|
    <span class="k">while</span> <span class="n">tl</span><span class="o">.</span><span class="n">atomic_cas</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
 | 
						|
        <span class="k">pass</span>
 | 
						|
    <span class="n">count</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Count</span><span class="p">)</span>
 | 
						|
    <span class="c1"># first store doesn't accumulate</span>
 | 
						|
    <span class="k">if</span> <span class="n">count</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
 | 
						|
        <span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Count</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
 | 
						|
    <span class="k">else</span><span class="p">:</span>
 | 
						|
        <span class="n">partial_dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
        <span class="n">partial_db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">partial_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">partial_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						|
    <span class="c1"># release lock</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
 | 
						|
 | 
						|
<span class="c1"># Backward pass (total DW + total DB)</span>
 | 
						|
 | 
						|
 | 
						|
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
 | 
						|
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dwdb</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">FINAL_DW</span><span class="p">,</span> <span class="n">FINAL_DB</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
 | 
						|
                         <span class="n">BLOCK_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
 | 
						|
    <span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
 | 
						|
    <span class="n">cols</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
 | 
						|
    <span class="n">dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
 | 
						|
    <span class="n">db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
 | 
						|
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
 | 
						|
        <span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
 | 
						|
        <span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
 | 
						|
        <span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
 | 
						|
        <span class="n">dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
 | 
						|
        <span class="n">db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
 | 
						|
    <span class="n">sum_dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dw</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
 | 
						|
    <span class="n">sum_db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DW</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
 | 
						|
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DB</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
 | 
						|
 | 
						|
 | 
						|
<span class="k">class</span> <span class="nc">LayerNorm</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
 | 
						|
 | 
						|
    <span class="nd">@staticmethod</span>
 | 
						|
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
 | 
						|
        <span class="c1"># allocate output</span>
 | 
						|
        <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
						|
        <span class="c1"># reshape input data into 2D tensor</span>
 | 
						|
        <span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
 | 
						|
        <span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
 | 
						|
        <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
 | 
						|
        <span class="n">rstd</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
 | 
						|
        <span class="c1"># Less than 64KB per feature: enqueue fused kernel</span>
 | 
						|
        <span class="n">MAX_FUSED_SIZE</span> <span class="o">=</span> <span class="mi">65536</span> <span class="o">//</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span>
 | 
						|
        <span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">MAX_FUSED_SIZE</span><span class="p">,</span> <span class="n">triton</span><span class="o">.</span><span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">))</span>
 | 
						|
        <span class="k">if</span> <span class="n">N</span> <span class="o">></span> <span class="n">BLOCK_SIZE</span><span class="p">:</span>
 | 
						|
            <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"This layer norm doesn't support feature dim >= 64KB."</span><span class="p">)</span>
 | 
						|
        <span class="c1"># heuristics for number of warps</span>
 | 
						|
        <span class="n">num_warps</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="n">BLOCK_SIZE</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">8</span><span class="p">)</span>
 | 
						|
        <span class="c1"># enqueue kernel</span>
 | 
						|
        <span class="n">_layer_norm_fwd_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">x_arg</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">,</span>
 | 
						|
                                    <span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
 | 
						|
                                    <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">)</span>
 | 
						|
        <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
 | 
						|
        <span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">BLOCK_SIZE</span>
 | 
						|
        <span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span> <span class="o">=</span> <span class="n">num_warps</span>
 | 
						|
        <span class="n">ctx</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
 | 
						|
        <span class="k">return</span> <span class="n">y</span>
 | 
						|
 | 
						|
    <span class="nd">@staticmethod</span>
 | 
						|
    <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dy</span><span class="p">):</span>
 | 
						|
        <span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
 | 
						|
        <span class="c1"># heuristics for amount of parallel reduction stream for DG/DB</span>
 | 
						|
        <span class="n">N</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
 | 
						|
        <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">64</span>
 | 
						|
        <span class="k">if</span> <span class="n">N</span> <span class="o"><=</span> <span class="mi">8192</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">96</span>
 | 
						|
        <span class="k">if</span> <span class="n">N</span> <span class="o"><=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">128</span>
 | 
						|
        <span class="k">if</span> <span class="n">N</span> <span class="o"><=</span> <span class="mi">1024</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">256</span>
 | 
						|
        <span class="c1"># allocate output</span>
 | 
						|
        <span class="n">locks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
 | 
						|
        <span class="n">_dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
 | 
						|
        <span class="n">_db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
 | 
						|
        <span class="n">dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
 | 
						|
        <span class="n">db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
 | 
						|
        <span class="n">dx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">dy</span><span class="p">)</span>
 | 
						|
        <span class="c1"># enqueue kernel using forward pass heuristics</span>
 | 
						|
        <span class="c1"># also compute partial sums for DW and DB</span>
 | 
						|
        <span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
 | 
						|
        <span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
 | 
						|
        <span class="n">_layer_norm_bwd_dx_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">dx</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">locks</span><span class="p">,</span>
 | 
						|
                                       <span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">ctx</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
 | 
						|
                                       <span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
 | 
						|
                                       <span class="n">GROUP_SIZE_M</span><span class="o">=</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span>
 | 
						|
                                       <span class="n">num_warps</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span><span class="p">)</span>
 | 
						|
        <span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">])]</span>
 | 
						|
        <span class="c1"># accumulate partial sums in separate kernel</span>
 | 
						|
        <span class="n">_layer_norm_bwd_dwdb</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
 | 
						|
                                   <span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
 | 
						|
                                   <span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span>
 | 
						|
        <span class="k">return</span> <span class="n">dx</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="kc">None</span>
 | 
						|
 | 
						|
 | 
						|
<span class="n">layer_norm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="o">.</span><span class="n">apply</span>
 | 
						|
 | 
						|
 | 
						|
<span class="k">def</span> <span class="nf">test_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">):</span>
 | 
						|
    <span class="c1"># create data</span>
 | 
						|
    <span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
 | 
						|
    <span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
 | 
						|
    <span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
 | 
						|
    <span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
						|
    <span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="c1"># forward pass</span>
 | 
						|
    <span class="n">y_tri</span> <span class="o">=</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
 | 
						|
    <span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
 | 
						|
    <span class="c1"># backward pass (triton)</span>
 | 
						|
    <span class="n">y_tri</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="n">dx_tri</span><span class="p">,</span> <span class="n">dw_tri</span><span class="p">,</span> <span class="n">db_tri</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
 | 
						|
    <span class="n">x</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
 | 
						|
    <span class="c1"># backward pass (torch)</span>
 | 
						|
    <span class="n">y_ref</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="n">dx_ref</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">db_ref</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
 | 
						|
    <span class="c1"># compare</span>
 | 
						|
    <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">)</span>
 | 
						|
    <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dx_tri</span><span class="p">,</span> <span class="n">dx_ref</span><span class="p">)</span>
 | 
						|
    <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">db_tri</span><span class="p">,</span> <span class="n">db_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
 | 
						|
    <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dw_tri</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
 | 
						|
 | 
						|
 | 
						|
<span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
 | 
						|
    <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
 | 
						|
        <span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'N'</span><span class="p">],</span>
 | 
						|
        <span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">512</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">32</span><span class="p">)],</span>
 | 
						|
        <span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span>
 | 
						|
        <span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'torch'</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">'apex'</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
 | 
						|
        <span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'Triton'</span><span class="p">,</span> <span class="s1">'Torch'</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">'Apex'</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
 | 
						|
        <span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'orange'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">)],</span>
 | 
						|
        <span class="n">ylabel</span><span class="o">=</span><span class="s1">'GB/s'</span><span class="p">,</span>
 | 
						|
        <span class="n">plot_name</span><span class="o">=</span><span class="s1">'layer-norm-backward'</span><span class="p">,</span>
 | 
						|
        <span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">'M'</span><span class="p">:</span> <span class="mi">4096</span><span class="p">,</span> <span class="s1">'dtype'</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s1">'mode'</span><span class="p">:</span> <span class="s1">'backward'</span><span class="p">}</span>
 | 
						|
    <span class="p">)</span>
 | 
						|
<span class="p">)</span>
 | 
						|
<span class="k">def</span> <span class="nf">bench_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">provider</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">'backward'</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">):</span>
 | 
						|
    <span class="c1"># create data</span>
 | 
						|
    <span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
 | 
						|
    <span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
 | 
						|
    <span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
 | 
						|
    <span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
						|
    <span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
    <span class="c1"># utility functions</span>
 | 
						|
    <span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
 | 
						|
        <span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
 | 
						|
    <span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch'</span><span class="p">:</span>
 | 
						|
        <span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
 | 
						|
    <span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'apex'</span><span class="p">:</span>
 | 
						|
        <span class="n">apex_layer_norm</span> <span class="o">=</span> <span class="n">apex</span><span class="o">.</span><span class="n">normalization</span><span class="o">.</span><span class="n">FusedLayerNorm</span><span class="p">(</span><span class="n">w_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
 | 
						|
        <span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">apex_layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
						|
    <span class="c1"># forward pass</span>
 | 
						|
    <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'forward'</span><span class="p">:</span>
 | 
						|
        <span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
 | 
						|
        <span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">y_fwd</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
 | 
						|
    <span class="c1"># backward pass</span>
 | 
						|
    <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'backward'</span><span class="p">:</span>
 | 
						|
        <span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
 | 
						|
        <span class="n">y</span> <span class="o">=</span> <span class="n">y_fwd</span><span class="p">()</span>
 | 
						|
        <span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">y</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
 | 
						|
                                                     <span class="n">grad_to_none</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">],</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
 | 
						|
    <span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
 | 
						|
 | 
						|
 | 
						|
<span class="n">bench_layer_norm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">save_path</span><span class="o">=</span><span class="s1">'.'</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 | 
						|
</pre></div>
 | 
						|
</div>
 | 
						|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes  12.415 seconds)</p>
 | 
						|
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-05-layer-norm-py">
 | 
						|
<div class="sphx-glr-download sphx-glr-download-python docutils container">
 | 
						|
<p><a class="reference download internal" download="" href="../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">05-layer-norm.py</span></code></a></p>
 | 
						|
</div>
 | 
						|
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
 | 
						|
<p><a class="reference download internal" download="" href="../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">05-layer-norm.ipynb</span></code></a></p>
 | 
						|
</div>
 | 
						|
</div>
 | 
						|
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
 | 
						|
</div>
 | 
						|
 | 
						|
 | 
						|
           </div>
 | 
						|
           
 | 
						|
          </div>
 | 
						|
          <footer>
 | 
						|
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
 | 
						|
        <a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
 | 
						|
        <a href="04-low-memory-dropout.html" class="btn btn-neutral float-left" title="Low-Memory Dropout" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
 | 
						|
    </div>
 | 
						|
 | 
						|
  <hr/>
 | 
						|
 | 
						|
  <div role="contentinfo">
 | 
						|
    <p>
 | 
						|
        © Copyright 2020, Philippe Tillet.
 | 
						|
 | 
						|
    </p>
 | 
						|
  </div>
 | 
						|
    
 | 
						|
    
 | 
						|
    
 | 
						|
    Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
 | 
						|
    
 | 
						|
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
 | 
						|
    
 | 
						|
    provided by <a href="https://readthedocs.org">Read the Docs</a>. 
 | 
						|
 | 
						|
</footer>
 | 
						|
        </div>
 | 
						|
      </div>
 | 
						|
 | 
						|
    </section>
 | 
						|
 | 
						|
  </div>
 | 
						|
  
 | 
						|
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
 | 
						|
    <span class="rst-current-version" data-toggle="rst-current-version">
 | 
						|
        <span class="fa fa-book"> Other Versions</span>
 | 
						|
        v: master
 | 
						|
        <span class="fa fa-caret-down"></span>
 | 
						|
    </span>
 | 
						|
    <div class="rst-other-versions">
 | 
						|
        <dl>
 | 
						|
            <dt>Tags</dt>
 | 
						|
            <dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
 | 
						|
        </dl>
 | 
						|
        <dl>
 | 
						|
            <dt>Branches</dt>
 | 
						|
            <dd><a href="05-layer-norm.html">master</a></dd>
 | 
						|
        </dl>
 | 
						|
    </div>
 | 
						|
</div>
 | 
						|
 | 
						|
  <script type="text/javascript">
 | 
						|
      jQuery(function () {
 | 
						|
          SphinxRtdTheme.Navigation.enable(true);
 | 
						|
      });
 | 
						|
  </script>
 | 
						|
 | 
						|
  
 | 
						|
  
 | 
						|
    
 | 
						|
   
 | 
						|
 | 
						|
</body>
 | 
						|
</html> |