455 lines
32 KiB
HTML
455 lines
32 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Low-Memory Dropout — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||
<script src="../../_static/jquery.js"></script>
|
||
<script src="../../_static/underscore.js"></script>
|
||
<script src="../../_static/doctools.js"></script>
|
||
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||
|
||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
<link rel="next" title="Layer Normalization" href="05-layer-norm.html" />
|
||
<link rel="prev" title="Matrix Multiplication" href="03-matrix-multiplication.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Low-Memory Dropout</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#baseline">Baseline</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#seeded-dropout">Seeded dropout</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#exercises">Exercises</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#references">References</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li><a href="index.html">Tutorials</a> »</li>
|
||
|
||
<li>Low-Memory Dropout</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../../_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<div class="sphx-glr-download-link-note admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">here</span></a>
|
||
to download the full example code</p>
|
||
</div>
|
||
<div class="sphx-glr-example-title section" id="low-memory-dropout">
|
||
<span id="sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"></span><h1>Low-Memory Dropout<a class="headerlink" href="#low-memory-dropout" title="Permalink to this headline">¶</a></h1>
|
||
<p>In this tutorial, you will write a memory-efficient implementation of dropout whose state
|
||
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
|
||
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:</p>
|
||
<ul class="simple">
|
||
<li><p>The limitations of naive implementations of Dropout with PyTorch</p></li>
|
||
<li><p>Parallel pseudo-random number generation in Triton</p></li>
|
||
</ul>
|
||
<div class="section" id="baseline">
|
||
<h2>Baseline<a class="headerlink" href="#baseline" title="Permalink to this headline">¶</a></h2>
|
||
<p>The <em>dropout</em> operator was first introduced in <a class="reference internal" href="#srivastava2014" id="id1"><span>[SRIVASTAVA2014]</span></a> as a way to improve the performance
|
||
of deep neural networks in low-data regime (i.e. regularization).</p>
|
||
<p>It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
|
||
output has a probability <span class="math notranslate nohighlight">\(p\)</span> of being changed to zero and otherwise it is copied from the input.
|
||
This forces the network to perform well even when only <span class="math notranslate nohighlight">\(1 - p\)</span> scalars from the input are available.</p>
|
||
<p>At evaluation time we want to use the full power of the network so we set <span class="math notranslate nohighlight">\(p=0\)</span>. Naively this would
|
||
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
|
||
in the output softmax temperature). To prevent this we multiply the output by <span class="math notranslate nohighlight">\(\frac{1}{1 - p}\)</span>, which
|
||
keeps the norm consistent regardless of the dropout probability.</p>
|
||
<p>Let’s first take a look at the baseline implementation.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tabulate</span>
|
||
<span class="kn">import</span> <span class="nn">torch</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||
|
||
|
||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">_dropout</span><span class="p">(</span>
|
||
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># pointer to the input</span>
|
||
<span class="n">x_keep_ptr</span><span class="p">,</span> <span class="c1"># pointer to a mask of 0s and 1s</span>
|
||
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># pointer to the output</span>
|
||
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># number of elements in the `x` tensor</span>
|
||
<span class="n">p</span><span class="p">,</span> <span class="c1"># probability that an element of `x` is changed to zero</span>
|
||
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
|
||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">n_elements</span>
|
||
<span class="c1"># Load data</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
<span class="n">x_keep</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_keep_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
<span class="c1"># The line below is the crucial part, described in the paragraph above!</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="c1"># Write-back output</span>
|
||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
|
||
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]),)</span>
|
||
<span class="n">_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
|
||
|
||
<span class="c1"># Input tensor</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="c1"># Dropout mask</span>
|
||
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span>
|
||
<span class="n">x_keep</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span> <span class="o">></span> <span class="n">p</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="c1">#</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="o">=</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">p</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
|
||
<span class="p">[</span><span class="s2">"input"</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
|
||
<span class="p">[</span><span class="s2">"keep mask"</span><span class="p">]</span> <span class="o">+</span> <span class="n">x_keep</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
|
||
<span class="p">[</span><span class="s2">"output"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
|
||
<span class="p">]))</span>
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
|
||
input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344
|
||
keep mask 1 1 0 1 0 1 1 0 0 0
|
||
output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0
|
||
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="seeded-dropout">
|
||
<h2>Seeded dropout<a class="headerlink" href="#seeded-dropout" title="Permalink to this headline">¶</a></h2>
|
||
<p>Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
|
||
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
|
||
very tricky when using recompute/checkpointing (e.g. see all the notes about <cite>preserve_rng_state</cite> in
|
||
<a class="reference external" href="https://pytorch.org/docs/1.9.0/checkpoint.html">https://pytorch.org/docs/1.9.0/checkpoint.html</a>). In this tutorial we’ll describe an alternative implementation
|
||
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
|
||
of persisting randomness across multiple invocations of the kernel.</p>
|
||
<p>Pseudorandom number generation in Triton is simple! In this tutorial we will use the
|
||
<code class="code docutils literal notranslate"><span class="pre">triton.language.rand</span></code> function which generates a block of uniformly distributed <code class="code docutils literal notranslate"><span class="pre">float32</span></code>
|
||
values in [0, 1), given a seed and a block of <code class="code docutils literal notranslate"><span class="pre">int32</span></code> offsets. But if you need it, Triton also provides
|
||
other <a class="reference internal" href="../../python-api/triton.language.html#random-number-generation"><span class="std std-ref">random number generation strategies</span></a>.</p>
|
||
<div class="admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Triton’s implementation of PRNG is based on the Philox algorithm (described on <a class="reference internal" href="#salmon2011" id="id2"><span>[SALMON2011]</span></a>).</p>
|
||
</div>
|
||
<p>Let’s put it all together.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">_seeded_dropout</span><span class="p">(</span>
|
||
<span class="n">x_ptr</span><span class="p">,</span>
|
||
<span class="n">output_ptr</span><span class="p">,</span>
|
||
<span class="n">n_elements</span><span class="p">,</span>
|
||
<span class="n">p</span><span class="p">,</span>
|
||
<span class="n">seed</span><span class="p">,</span>
|
||
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="c1"># compute memory offsets of elements handled by this instance</span>
|
||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
|
||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
|
||
<span class="c1"># load data from x</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">n_elements</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
<span class="c1"># randomly prune it</span>
|
||
<span class="n">random</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)</span>
|
||
<span class="n">x_keep</span> <span class="o">=</span> <span class="n">random</span> <span class="o">></span> <span class="n">p</span>
|
||
<span class="c1"># write-back</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">def</span> <span class="nf">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
|
||
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]),)</span>
|
||
<span class="n">_seeded_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
|
||
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||
<span class="c1"># Compare this to the baseline - dropout mask is never instantiated!</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
|
||
<span class="n">output2</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
|
||
<span class="n">output3</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span>
|
||
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
|
||
<span class="p">[</span><span class="s2">"input"</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
|
||
<span class="p">[</span><span class="s2">"output (seed = 123)"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
|
||
<span class="p">[</span><span class="s2">"output (seed = 123)"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output2</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
|
||
<span class="p">[</span><span class="s2">"output (seed = 512)"</span><span class="p">]</span> <span class="o">+</span> <span class="n">output3</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
|
||
<span class="p">]))</span>
|
||
</pre></div>
|
||
</div>
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
|
||
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
|
||
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
|
||
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
|
||
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
|
||
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
|
||
</pre></div>
|
||
</div>
|
||
<p>Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
|
||
If you’d like explore further applications of pseudorandomness in GPU programming, we encourage you
|
||
to explore the <cite>triton/language/random</cite> folder!</p>
|
||
</div>
|
||
<div class="section" id="exercises">
|
||
<h2>Exercises<a class="headerlink" href="#exercises" title="Permalink to this headline">¶</a></h2>
|
||
<ol class="arabic simple">
|
||
<li><p>Extend the kernel to operate over a matrix and use a vector of seeds - one per row.</p></li>
|
||
<li><p>Add support for striding.</p></li>
|
||
<li><p>(challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.</p></li>
|
||
</ol>
|
||
</div>
|
||
<div class="section" id="references">
|
||
<h2>References<a class="headerlink" href="#references" title="Permalink to this headline">¶</a></h2>
|
||
<dl class="citation">
|
||
<dt class="label" id="salmon2011"><span class="brackets"><a class="fn-backref" href="#id2">SALMON2011</a></span></dt>
|
||
<dd><p>John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011</p>
|
||
</dd>
|
||
<dt class="label" id="srivastava2014"><span class="brackets"><a class="fn-backref" href="#id1">SRIVASTAVA2014</a></span></dt>
|
||
<dd><p>Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014</p>
|
||
</dd>
|
||
</dl>
|
||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.280 seconds)</p>
|
||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py">
|
||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">04-low-memory-dropout.py</span></code></a></p>
|
||
</div>
|
||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/bc847dec325798bdc436c4ef5ac8b78a/04-low-memory-dropout.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">04-low-memory-dropout.ipynb</span></code></a></p>
|
||
</div>
|
||
</div>
|
||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="05-layer-norm.html" class="btn btn-neutral float-right" title="Layer Normalization" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
<a href="03-matrix-multiplication.html" class="btn btn-neutral float-left" title="Matrix Multiplication" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||
<span class="fa fa-book"> Other Versions</span>
|
||
v: master
|
||
<span class="fa fa-caret-down"></span>
|
||
</span>
|
||
<div class="rst-other-versions">
|
||
<dl>
|
||
<dt>Tags</dt>
|
||
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
|
||
</dl>
|
||
<dl>
|
||
<dt>Branches</dt>
|
||
<dd><a href="04-low-memory-dropout.html">master</a></dd>
|
||
</dl>
|
||
</div>
|
||
</div>
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |