Files
triton/master/getting-started/tutorials/04-low-memory-dropout.html
2022-09-12 00:51:39 +00:00

455 lines
32 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Low-Memory Dropout &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Layer Normalization" href="05-layer-norm.html" />
<link rel="prev" title="Matrix Multiplication" href="03-matrix-multiplication.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Low-Memory Dropout</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#baseline">Baseline</a></li>
<li class="toctree-l3"><a class="reference internal" href="#seeded-dropout">Seeded dropout</a></li>
<li class="toctree-l3"><a class="reference internal" href="#exercises">Exercises</a></li>
<li class="toctree-l3"><a class="reference internal" href="#references">References</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Low-Memory Dropout</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="low-memory-dropout">
<span id="sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"></span><h1>Low-Memory Dropout<a class="headerlink" href="#low-memory-dropout" title="Permalink to this headline"></a></h1>
<p>In this tutorial, you will write a memory-efficient implementation of dropout whose state
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:</p>
<ul class="simple">
<li><p>The limitations of naive implementations of Dropout with PyTorch</p></li>
<li><p>Parallel pseudo-random number generation in Triton</p></li>
</ul>
<div class="section" id="baseline">
<h2>Baseline<a class="headerlink" href="#baseline" title="Permalink to this headline"></a></h2>
<p>The <em>dropout</em> operator was first introduced in <a class="reference internal" href="#srivastava2014" id="id1"><span>[SRIVASTAVA2014]</span></a> as a way to improve the performance
of deep neural networks in low-data regime (i.e. regularization).</p>
<p>It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
output has a probability <span class="math notranslate nohighlight">\(p\)</span> of being changed to zero and otherwise it is copied from the input.
This forces the network to perform well even when only <span class="math notranslate nohighlight">\(1 - p\)</span> scalars from the input are available.</p>
<p>At evaluation time we want to use the full power of the network so we set <span class="math notranslate nohighlight">\(p=0\)</span>. Naively this would
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
in the output softmax temperature). To prevent this we multiply the output by <span class="math notranslate nohighlight">\(\frac{1}{1 - p}\)</span>, which
keeps the norm consistent regardless of the dropout probability.</p>
<p>Lets first take a look at the baseline implementation.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tabulate</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_dropout</span><span class="p">(</span>
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># pointer to the input</span>
<span class="n">x_keep_ptr</span><span class="p">,</span> <span class="c1"># pointer to a mask of 0s and 1s</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># pointer to the output</span>
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># number of elements in the `x` tensor</span>
<span class="n">p</span><span class="p">,</span> <span class="c1"># probability that an element of `x` is changed to zero</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
<span class="c1"># Load data</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">x_keep</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_keep_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># The line below is the crucial part, described in the paragraph above!</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="c1"># Write-back output</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]),)</span>
<span class="n">_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="c1"># Input tensor</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1"># Dropout mask</span>
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="n">x_keep</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span> <span class="o">&gt;</span> <span class="n">p</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1">#</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="o">=</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">p</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
<span class="p">[</span><span class="s2">&quot;input&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;keep mask&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">x_keep</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
<span class="p">]))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344
keep mask 1 1 0 1 0 1 1 0 0 0
output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
</pre></div>
</div>
</div>
<div class="section" id="seeded-dropout">
<h2>Seeded dropout<a class="headerlink" href="#seeded-dropout" title="Permalink to this headline"></a></h2>
<p>Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
very tricky when using recompute/checkpointing (e.g. see all the notes about <cite>preserve_rng_state</cite> in
<a class="reference external" href="https://pytorch.org/docs/1.9.0/checkpoint.html">https://pytorch.org/docs/1.9.0/checkpoint.html</a>). In this tutorial well describe an alternative implementation
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
of persisting randomness across multiple invocations of the kernel.</p>
<p>Pseudorandom number generation in Triton is simple! In this tutorial we will use the
<code class="code docutils literal notranslate"><span class="pre">triton.language.rand</span></code> function which generates a block of uniformly distributed <code class="code docutils literal notranslate"><span class="pre">float32</span></code>
values in [0, 1), given a seed and a block of <code class="code docutils literal notranslate"><span class="pre">int32</span></code> offsets. But if you need it, Triton also provides
other <a class="reference internal" href="../../python-api/triton.language.html#random-number-generation"><span class="std std-ref">random number generation strategies</span></a>.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Tritons implementation of PRNG is based on the Philox algorithm (described on <a class="reference internal" href="#salmon2011" id="id2"><span>[SALMON2011]</span></a>).</p>
</div>
<p>Lets put it all together.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_seeded_dropout</span><span class="p">(</span>
<span class="n">x_ptr</span><span class="p">,</span>
<span class="n">output_ptr</span><span class="p">,</span>
<span class="n">n_elements</span><span class="p">,</span>
<span class="n">p</span><span class="p">,</span>
<span class="n">seed</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="c1"># compute memory offsets of elements handled by this instance</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="c1"># load data from x</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># randomly prune it</span>
<span class="n">random</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)</span>
<span class="n">x_keep</span> <span class="o">=</span> <span class="n">random</span> <span class="o">&gt;</span> <span class="n">p</span>
<span class="c1"># write-back</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]),)</span>
<span class="n">_seeded_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1"># Compare this to the baseline - dropout mask is never instantiated!</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
<span class="n">output2</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
<span class="n">output3</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
<span class="p">[</span><span class="s2">&quot;input&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output (seed = 123)&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output (seed = 123)&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output2</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output (seed = 512)&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output3</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
<span class="p">]))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
</pre></div>
</div>
<p>Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
If youd like explore further applications of pseudorandomness in GPU programming, we encourage you
to explore the <cite>triton/language/random</cite> folder!</p>
</div>
<div class="section" id="exercises">
<h2>Exercises<a class="headerlink" href="#exercises" title="Permalink to this headline"></a></h2>
<ol class="arabic simple">
<li><p>Extend the kernel to operate over a matrix and use a vector of seeds - one per row.</p></li>
<li><p>Add support for striding.</p></li>
<li><p>(challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.</p></li>
</ol>
</div>
<div class="section" id="references">
<h2>References<a class="headerlink" href="#references" title="Permalink to this headline"></a></h2>
<dl class="citation">
<dt class="label" id="salmon2011"><span class="brackets"><a class="fn-backref" href="#id2">SALMON2011</a></span></dt>
<dd><p>John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011</p>
</dd>
<dt class="label" id="srivastava2014"><span class="brackets"><a class="fn-backref" href="#id1">SRIVASTAVA2014</a></span></dt>
<dd><p>Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014</p>
</dd>
</dl>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.280 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">04-low-memory-dropout.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/bc847dec325798bdc436c4ef5ac8b78a/04-low-memory-dropout.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">04-low-memory-dropout.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="05-layer-norm.html" class="btn btn-neutral float-right" title="Layer Normalization" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="03-matrix-multiplication.html" class="btn btn-neutral float-left" title="Matrix Multiplication" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="04-low-memory-dropout.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>