453 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			HTML
		
	
	
	
	
	
			
		
		
	
	
			453 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			HTML
		
	
	
	
	
	
 | 
						||
 | 
						||
<!DOCTYPE html>
 | 
						||
<html class="writer-html5" lang="en" >
 | 
						||
<head>
 | 
						||
  <meta charset="utf-8" />
 | 
						||
  
 | 
						||
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
 | 
						||
  
 | 
						||
  <title>Low-Memory Dropout — Triton  documentation</title>
 | 
						||
  
 | 
						||
 | 
						||
  
 | 
						||
  <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
 | 
						||
  <link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
 | 
						||
 | 
						||
  
 | 
						||
  
 | 
						||
 | 
						||
  
 | 
						||
  
 | 
						||
 | 
						||
  
 | 
						||
 | 
						||
  
 | 
						||
  <!--[if lt IE 9]>
 | 
						||
    <script src="../../_static/js/html5shiv.min.js"></script>
 | 
						||
  <![endif]-->
 | 
						||
  
 | 
						||
    
 | 
						||
      <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
 | 
						||
        <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
 | 
						||
        <script src="../../_static/jquery.js"></script>
 | 
						||
        <script src="../../_static/underscore.js"></script>
 | 
						||
        <script src="../../_static/doctools.js"></script>
 | 
						||
        <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
 | 
						||
    
 | 
						||
    <script type="text/javascript" src="../../_static/js/theme.js"></script>
 | 
						||
 | 
						||
    
 | 
						||
    <link rel="index" title="Index" href="../../genindex.html" />
 | 
						||
    <link rel="search" title="Search" href="../../search.html" />
 | 
						||
    <link rel="next" title="Layer Normalization" href="05-layer-norm.html" />
 | 
						||
    <link rel="prev" title="Matrix Multiplication" href="03-matrix-multiplication.html" /> 
 | 
						||
</head>
 | 
						||
 | 
						||
<body class="wy-body-for-nav">
 | 
						||
 | 
						||
   
 | 
						||
  <div class="wy-grid-for-nav">
 | 
						||
    
 | 
						||
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
 | 
						||
      <div class="wy-side-scroll">
 | 
						||
        <div class="wy-side-nav-search" >
 | 
						||
          
 | 
						||
 | 
						||
          
 | 
						||
            <a href="../../index.html" class="icon icon-home"> Triton
 | 
						||
          
 | 
						||
 | 
						||
          
 | 
						||
          </a>
 | 
						||
 | 
						||
          
 | 
						||
            
 | 
						||
            
 | 
						||
          
 | 
						||
 | 
						||
          
 | 
						||
<div role="search">
 | 
						||
  <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
 | 
						||
    <input type="text" name="q" placeholder="Search docs" />
 | 
						||
    <input type="hidden" name="check_keywords" value="yes" />
 | 
						||
    <input type="hidden" name="area" value="default" />
 | 
						||
  </form>
 | 
						||
</div>
 | 
						||
 | 
						||
          
 | 
						||
        </div>
 | 
						||
 | 
						||
        
 | 
						||
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
 | 
						||
          
 | 
						||
            
 | 
						||
            
 | 
						||
              
 | 
						||
            
 | 
						||
            
 | 
						||
              <p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
 | 
						||
<ul class="current">
 | 
						||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
 | 
						||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
 | 
						||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
 | 
						||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
 | 
						||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
 | 
						||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Low-Memory Dropout</a><ul>
 | 
						||
<li class="toctree-l3"><a class="reference internal" href="#baseline">Baseline</a></li>
 | 
						||
<li class="toctree-l3"><a class="reference internal" href="#seeded-dropout">Seeded dropout</a></li>
 | 
						||
<li class="toctree-l3"><a class="reference internal" href="#exercises">Exercises</a></li>
 | 
						||
<li class="toctree-l3"><a class="reference internal" href="#references">References</a></li>
 | 
						||
</ul>
 | 
						||
</li>
 | 
						||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
 | 
						||
</ul>
 | 
						||
</li>
 | 
						||
</ul>
 | 
						||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
 | 
						||
<ul>
 | 
						||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
 | 
						||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
 | 
						||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
 | 
						||
</ul>
 | 
						||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
 | 
						||
<ul>
 | 
						||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
 | 
						||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
 | 
						||
</ul>
 | 
						||
 | 
						||
            
 | 
						||
          
 | 
						||
        </div>
 | 
						||
        
 | 
						||
      </div>
 | 
						||
    </nav>
 | 
						||
 | 
						||
    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
 | 
						||
 | 
						||
      
 | 
						||
      <nav class="wy-nav-top" aria-label="top navigation">
 | 
						||
        
 | 
						||
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
 | 
						||
          <a href="../../index.html">Triton</a>
 | 
						||
        
 | 
						||
      </nav>
 | 
						||
 | 
						||
 | 
						||
      <div class="wy-nav-content">
 | 
						||
        
 | 
						||
        <div class="rst-content">
 | 
						||
        
 | 
						||
          
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
 | 
						||
<div role="navigation" aria-label="breadcrumbs navigation">
 | 
						||
 | 
						||
  <ul class="wy-breadcrumbs">
 | 
						||
    
 | 
						||
      <li><a href="../../index.html" class="icon icon-home"></a> »</li>
 | 
						||
        
 | 
						||
          <li><a href="index.html">Tutorials</a> »</li>
 | 
						||
        
 | 
						||
      <li>Low-Memory Dropout</li>
 | 
						||
    
 | 
						||
    
 | 
						||
      <li class="wy-breadcrumbs-aside">
 | 
						||
        
 | 
						||
          
 | 
						||
            <a href="../../_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt" rel="nofollow"> View page source</a>
 | 
						||
          
 | 
						||
        
 | 
						||
      </li>
 | 
						||
    
 | 
						||
  </ul>
 | 
						||
 | 
						||
  
 | 
						||
  <hr/>
 | 
						||
</div>
 | 
						||
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
 | 
						||
           <div itemprop="articleBody">
 | 
						||
            
 | 
						||
  <div class="sphx-glr-download-link-note admonition note">
 | 
						||
<p class="admonition-title">Note</p>
 | 
						||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">here</span></a>
 | 
						||
to download the full example code</p>
 | 
						||
</div>
 | 
						||
<div class="sphx-glr-example-title section" id="low-memory-dropout">
 | 
						||
<span id="sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"></span><h1>Low-Memory Dropout<a class="headerlink" href="#low-memory-dropout" title="Permalink to this headline">¶</a></h1>
 | 
						||
<p>In this tutorial, you will write a memory-efficient implementation of dropout whose state
 | 
						||
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
 | 
						||
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:</p>
 | 
						||
<ul class="simple">
 | 
						||
<li><p>The limitations of naive implementations of Dropout with PyTorch</p></li>
 | 
						||
<li><p>Parallel pseudo-random number generation in Triton</p></li>
 | 
						||
</ul>
 | 
						||
<div class="section" id="baseline">
 | 
						||
<h2>Baseline<a class="headerlink" href="#baseline" title="Permalink to this headline">¶</a></h2>
 | 
						||
<p>The <em>dropout</em> operator was first introduced in <a class="reference internal" href="#srivastava2014" id="id1"><span>[SRIVASTAVA2014]</span></a> as a way to improve the performance
 | 
						||
of deep neural networks in low-data regime (i.e. regularization).</p>
 | 
						||
<p>It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
 | 
						||
output has a probability <span class="math notranslate nohighlight">\(p\)</span> of being changed to zero and otherwise it is copied from the input.
 | 
						||
This forces the network to perform well even when only <span class="math notranslate nohighlight">\(1 - p\)</span> scalars from the input are available.</p>
 | 
						||
<p>At evaluation time we want to use the full power of the network so we set <span class="math notranslate nohighlight">\(p=0\)</span>. Naively this would
 | 
						||
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
 | 
						||
in the output softmax temperature). To prevent this we multiply the output by <span class="math notranslate nohighlight">\(\frac{1}{1 - p}\)</span>, which
 | 
						||
keeps the norm consistent regardless of the dropout probability.</p>
 | 
						||
<p>Let’s first take a look at the baseline implementation.</p>
 | 
						||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tabulate</span>
 | 
						||
<span class="kn">import</span> <span class="nn">torch</span>
 | 
						||
 | 
						||
<span class="kn">import</span> <span class="nn">triton</span>
 | 
						||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
 | 
						||
 | 
						||
 | 
						||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
 | 
						||
<span class="k">def</span> <span class="nf">_dropout</span><span class="p">(</span>
 | 
						||
        <span class="n">x_ptr</span><span class="p">,</span>  <span class="c1"># pointer to the input</span>
 | 
						||
        <span class="n">x_keep_ptr</span><span class="p">,</span>  <span class="c1"># pointer to a mask of 0s and 1s</span>
 | 
						||
        <span class="n">output_ptr</span><span class="p">,</span>  <span class="c1"># pointer to the output</span>
 | 
						||
        <span class="n">n_elements</span><span class="p">,</span>  <span class="c1"># number of elements in the `x` tensor</span>
 | 
						||
        <span class="n">p</span><span class="p">,</span>  <span class="c1"># probability that an element of `x` is changed to zero</span>
 | 
						||
        <span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
 | 
						||
<span class="p">):</span>
 | 
						||
    <span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
 | 
						||
    <span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
 | 
						||
    <span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
 | 
						||
    <span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">n_elements</span>
 | 
						||
    <span class="c1"># Load data</span>
 | 
						||
    <span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						||
    <span class="n">x_keep</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_keep_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						||
    <span class="c1"># The line below is the crucial part, described in the paragraph above!</span>
 | 
						||
    <span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
 | 
						||
    <span class="c1"># Write-back output</span>
 | 
						||
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						||
 | 
						||
 | 
						||
<span class="k">def</span> <span class="nf">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span>
 | 
						||
    <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
						||
    <span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
 | 
						||
    <span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
 | 
						||
    <span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]),)</span>
 | 
						||
    <span class="n">_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
 | 
						||
    <span class="k">return</span> <span class="n">output</span>
 | 
						||
 | 
						||
 | 
						||
<span class="c1"># Input tensor</span>
 | 
						||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
 | 
						||
<span class="c1"># Dropout mask</span>
 | 
						||
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span>
 | 
						||
<span class="n">x_keep</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span> <span class="o">></span> <span class="n">p</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
 | 
						||
<span class="c1">#</span>
 | 
						||
<span class="n">output</span> <span class="o">=</span> <span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="o">=</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">p</span><span class="p">)</span>
 | 
						||
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
 | 
						||
    <span class="p">[</span><span class="s2">"input"</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
 | 
						||
    <span class="p">[</span><span class="s2">"keep mask"</span><span class="p">]</span> <span class="o">+</span> <span class="n">x_keep</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
 | 
						||
    <span class="p">[</span><span class="s2">"output"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
 | 
						||
<span class="p">]))</span>
 | 
						||
</pre></div>
 | 
						||
</div>
 | 
						||
<p class="sphx-glr-script-out">Out:</p>
 | 
						||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>---------  -------  ---------  --------  --------  --------  --------  --------  --------  ---------  ---------
 | 
						||
input      1.541    -0.293429  -2.17879  0.568431  -1.08452  -1.3986   0.403347  0.838026  -0.719258  -0.403344
 | 
						||
keep mask  1         1          0        1          0         1        1         0          0          0
 | 
						||
output     3.08199  -0.586858   0        1.13686    0        -2.79719  0.806694  0          0          0
 | 
						||
---------  -------  ---------  --------  --------  --------  --------  --------  --------  ---------  ---------
 | 
						||
</pre></div>
 | 
						||
</div>
 | 
						||
</div>
 | 
						||
<div class="section" id="seeded-dropout">
 | 
						||
<h2>Seeded dropout<a class="headerlink" href="#seeded-dropout" title="Permalink to this headline">¶</a></h2>
 | 
						||
<p>Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
 | 
						||
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
 | 
						||
very tricky when using recompute/checkpointing (e.g. see all the notes about <cite>preserve_rng_state</cite> in
 | 
						||
<a class="reference external" href="https://pytorch.org/docs/1.9.0/checkpoint.html">https://pytorch.org/docs/1.9.0/checkpoint.html</a>). In this tutorial we’ll describe an alternative implementation
 | 
						||
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
 | 
						||
of persisting randomness across multiple invocations of the kernel.</p>
 | 
						||
<p>Pseudorandom number generation in Triton is simple! In this tutorial we will use the
 | 
						||
<code class="code docutils literal notranslate"><span class="pre">triton.language.rand</span></code> function which generates a block of uniformly distributed <code class="code docutils literal notranslate"><span class="pre">float32</span></code>
 | 
						||
values in [0, 1), given a seed and a block of <code class="code docutils literal notranslate"><span class="pre">int32</span></code> offsets. But if you need it, Triton also provides
 | 
						||
other <a class="reference internal" href="../../python-api/triton.language.html#random-number-generation"><span class="std std-ref">random number generation strategies</span></a>.</p>
 | 
						||
<div class="admonition note">
 | 
						||
<p class="admonition-title">Note</p>
 | 
						||
<p>Triton’s implementation of PRNG is based on the Philox algorithm (described on <a class="reference internal" href="#salmon2011" id="id2"><span>[SALMON2011]</span></a>).</p>
 | 
						||
</div>
 | 
						||
<p>Let’s put it all together.</p>
 | 
						||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
 | 
						||
<span class="k">def</span> <span class="nf">_seeded_dropout</span><span class="p">(</span>
 | 
						||
        <span class="n">x_ptr</span><span class="p">,</span>
 | 
						||
        <span class="n">output_ptr</span><span class="p">,</span>
 | 
						||
        <span class="n">n_elements</span><span class="p">,</span>
 | 
						||
        <span class="n">p</span><span class="p">,</span>
 | 
						||
        <span class="n">seed</span><span class="p">,</span>
 | 
						||
        <span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
 | 
						||
<span class="p">):</span>
 | 
						||
    <span class="c1"># compute memory offsets of elements handled by this instance</span>
 | 
						||
    <span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
 | 
						||
    <span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
 | 
						||
    <span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
 | 
						||
    <span class="c1"># load data from x</span>
 | 
						||
    <span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">n_elements</span>
 | 
						||
    <span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						||
    <span class="c1"># randomly prune it</span>
 | 
						||
    <span class="n">random</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)</span>
 | 
						||
    <span class="n">x_keep</span> <span class="o">=</span> <span class="n">random</span> <span class="o">></span> <span class="n">p</span>
 | 
						||
    <span class="c1"># write-back</span>
 | 
						||
    <span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
 | 
						||
    <span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
 | 
						||
 | 
						||
 | 
						||
<span class="k">def</span> <span class="nf">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
 | 
						||
    <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
						||
    <span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
 | 
						||
    <span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
 | 
						||
    <span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]),)</span>
 | 
						||
    <span class="n">_seeded_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
 | 
						||
    <span class="k">return</span> <span class="n">output</span>
 | 
						||
 | 
						||
 | 
						||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
 | 
						||
<span class="c1"># Compare this to the baseline - dropout mask is never instantiated!</span>
 | 
						||
<span class="n">output</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
 | 
						||
<span class="n">output2</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
 | 
						||
<span class="n">output3</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span>
 | 
						||
 | 
						||
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
 | 
						||
    <span class="p">[</span><span class="s2">"input"</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
 | 
						||
    <span class="p">[</span><span class="s2">"output (seed = 123)"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
 | 
						||
    <span class="p">[</span><span class="s2">"output (seed = 123)"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output2</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
 | 
						||
    <span class="p">[</span><span class="s2">"output (seed = 512)"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output3</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
 | 
						||
<span class="p">]))</span>
 | 
						||
</pre></div>
 | 
						||
</div>
 | 
						||
<p class="sphx-glr-script-out">Out:</p>
 | 
						||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>-------------------  ---------  --------  --------  -------  --------  --------  ---------  ---------  ---------  ---------
 | 
						||
input                -0.952835  0.371721  0.408716  1.42142  0.149397  -0.67086  -0.214186  -0.431969  -0.707878  -0.106434
 | 
						||
output (seed = 123)   0         0.743443  0         0        0         -1.34172   0          0         -1.41576   -0.212868
 | 
						||
output (seed = 123)   0         0.743443  0         0        0         -1.34172   0          0         -1.41576   -0.212868
 | 
						||
output (seed = 512)   0         0         0.817432  2.84284  0         -1.34172  -0.428372   0          0          0
 | 
						||
-------------------  ---------  --------  --------  -------  --------  --------  ---------  ---------  ---------  ---------
 | 
						||
</pre></div>
 | 
						||
</div>
 | 
						||
<p>Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
 | 
						||
If you’d like explore further applications of pseudorandomness in GPU programming, we encourage you
 | 
						||
to explore the <cite>triton/language/random</cite> folder!</p>
 | 
						||
</div>
 | 
						||
<div class="section" id="exercises">
 | 
						||
<h2>Exercises<a class="headerlink" href="#exercises" title="Permalink to this headline">¶</a></h2>
 | 
						||
<ol class="arabic simple">
 | 
						||
<li><p>Extend the kernel to operate over a matrix and use a vector of seeds - one per row.</p></li>
 | 
						||
<li><p>Add support for striding.</p></li>
 | 
						||
<li><p>(challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.</p></li>
 | 
						||
</ol>
 | 
						||
</div>
 | 
						||
<div class="section" id="references">
 | 
						||
<h2>References<a class="headerlink" href="#references" title="Permalink to this headline">¶</a></h2>
 | 
						||
<dl class="citation">
 | 
						||
<dt class="label" id="salmon2011"><span class="brackets"><a class="fn-backref" href="#id2">SALMON2011</a></span></dt>
 | 
						||
<dd><p>John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011</p>
 | 
						||
</dd>
 | 
						||
<dt class="label" id="srivastava2014"><span class="brackets"><a class="fn-backref" href="#id1">SRIVASTAVA2014</a></span></dt>
 | 
						||
<dd><p>Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014</p>
 | 
						||
</dd>
 | 
						||
</dl>
 | 
						||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes  0.014 seconds)</p>
 | 
						||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py">
 | 
						||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
 | 
						||
<p><a class="reference download internal" download="" href="../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">04-low-memory-dropout.py</span></code></a></p>
 | 
						||
</div>
 | 
						||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
 | 
						||
<p><a class="reference download internal" download="" href="../../_downloads/bc847dec325798bdc436c4ef5ac8b78a/04-low-memory-dropout.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">04-low-memory-dropout.ipynb</span></code></a></p>
 | 
						||
</div>
 | 
						||
</div>
 | 
						||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
 | 
						||
</div>
 | 
						||
</div>
 | 
						||
 | 
						||
 | 
						||
           </div>
 | 
						||
           
 | 
						||
          </div>
 | 
						||
          <footer>
 | 
						||
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
 | 
						||
        <a href="05-layer-norm.html" class="btn btn-neutral float-right" title="Layer Normalization" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
 | 
						||
        <a href="03-matrix-multiplication.html" class="btn btn-neutral float-left" title="Matrix Multiplication" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
 | 
						||
    </div>
 | 
						||
 | 
						||
  <hr/>
 | 
						||
 | 
						||
  <div role="contentinfo">
 | 
						||
    <p>
 | 
						||
        © Copyright 2020, Philippe Tillet.
 | 
						||
 | 
						||
    </p>
 | 
						||
  </div>
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
    Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
 | 
						||
    
 | 
						||
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
 | 
						||
    
 | 
						||
    provided by <a href="https://readthedocs.org">Read the Docs</a>. 
 | 
						||
 | 
						||
</footer>
 | 
						||
        </div>
 | 
						||
      </div>
 | 
						||
 | 
						||
    </section>
 | 
						||
 | 
						||
  </div>
 | 
						||
  
 | 
						||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
 | 
						||
    <span class="rst-current-version" data-toggle="rst-current-version">
 | 
						||
        <span class="fa fa-book"> Other Versions</span>
 | 
						||
        v: master
 | 
						||
        <span class="fa fa-caret-down"></span>
 | 
						||
    </span>
 | 
						||
    <div class="rst-other-versions">
 | 
						||
        <dl>
 | 
						||
            <dt>Tags</dt>
 | 
						||
            <dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
 | 
						||
        </dl>
 | 
						||
        <dl>
 | 
						||
            <dt>Branches</dt>
 | 
						||
            <dd><a href="04-low-memory-dropout.html">master</a></dd>
 | 
						||
        </dl>
 | 
						||
    </div>
 | 
						||
</div>
 | 
						||
 | 
						||
  <script type="text/javascript">
 | 
						||
      jQuery(function () {
 | 
						||
          SphinxRtdTheme.Navigation.enable(true);
 | 
						||
      });
 | 
						||
  </script>
 | 
						||
 | 
						||
  
 | 
						||
  
 | 
						||
    
 | 
						||
   
 | 
						||
 | 
						||
</body>
 | 
						||
</html> |