2022-06-05 21:05:02 +00:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Low-Memory Dropout — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script data-url_root = "../../" id = "documentation_options" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script async = "async" src = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
< link rel = "next" title = "Layer Normalization" href = "05-layer-norm.html" / >
< link rel = "prev" title = "Matrix Multiplication" href = "03-matrix-multiplication.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" role = "heading" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2" > < a class = "reference internal" href = "01-vector-add.html" > Vector Addition< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "02-fused-softmax.html" > Fused Softmax< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Low-Memory Dropout< / a > < ul >
< li class = "toctree-l3" > < a class = "reference internal" href = "#baseline" > Baseline< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#seeded-dropout" > Seeded dropout< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#exercises" > Exercises< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#references" > References< / a > < / li >
< / ul >
< / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "05-layer-norm.html" > Layer Normalization< / a > < / li >
< / ul >
< / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Python API< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.html" > triton< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.language.html" > triton.language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Low-Memory Dropout< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "low-memory-dropout" >
< span id = "sphx-glr-getting-started-tutorials-04-low-memory-dropout-py" > < / span > < h1 > Low-Memory Dropout< a class = "headerlink" href = "#low-memory-dropout" title = "Permalink to this headline" > ¶< / a > < / h1 >
< p > In this tutorial, you will write a memory-efficient implementation of dropout whose state
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:< / p >
< ul class = "simple" >
< li > < p > The limitations of naive implementations of Dropout with PyTorch< / p > < / li >
< li > < p > Parallel pseudo-random number generation in Triton< / p > < / li >
< / ul >
< div class = "section" id = "baseline" >
< h2 > Baseline< a class = "headerlink" href = "#baseline" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > The < em > dropout< / em > operator was first introduced in < a class = "reference internal" href = "#srivastava2014" id = "id1" > < span > [SRIVASTAVA2014]< / span > < / a > as a way to improve the performance
of deep neural networks in low-data regime (i.e. regularization).< / p >
< p > It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
output has a probability < span class = "math notranslate nohighlight" > \(p\)< / span > of being changed to zero and otherwise it is copied from the input.
This forces the network to perform well even when only < span class = "math notranslate nohighlight" > \(1 - p\)< / span > scalars from the input are available.< / p >
< p > At evaluation time we want to use the full power of the network so we set < span class = "math notranslate nohighlight" > \(p=0\)< / span > . Naively this would
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
in the output softmax temperature). To prevent this we multiply the output by < span class = "math notranslate nohighlight" > \(\frac{1}{1 - p}\)< / span > , which
keeps the norm consistent regardless of the dropout probability.< / p >
< p > Let’ s first take a look at the baseline implementation.< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > tabulate< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton.language< / span > < span class = "k" > as< / span > < span class = "nn" > tl< / span >
< span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _dropout< / span > < span class = "p" > (< / span >
< span class = "n" > x_ptr< / span > < span class = "p" > ,< / span > < span class = "c1" > # pointer to the input< / span >
< span class = "n" > x_keep_ptr< / span > < span class = "p" > ,< / span > < span class = "c1" > # pointer to a mask of 0s and 1s< / span >
< span class = "n" > output_ptr< / span > < span class = "p" > ,< / span > < span class = "c1" > # pointer to the output< / span >
< span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "c1" > # number of elements in the `x` tensor< / span >
< span class = "n" > p< / span > < span class = "p" > ,< / span > < span class = "c1" > # probability that an element of `x` is changed to zero< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > block_start< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_SIZE< / span >
< span class = "n" > offsets< / span > < span class = "o" > =< / span > < span class = "n" > block_start< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > offsets< / span > < span class = "o" > < < / span > < span class = "n" > n_elements< / span >
< span class = "c1" > # Load data< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > x_ptr< / span > < span class = "o" > +< / span > < span class = "n" > offsets< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "n" > x_keep< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > x_keep_ptr< / span > < span class = "o" > +< / span > < span class = "n" > offsets< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # The line below is the crucial part, described in the paragraph above!< / span >
< span class = "n" > output< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > x_keep< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > /< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "o" > -< / span > < span class = "n" > p< / span > < span class = "p" > ),< / span > < span class = "mf" > 0.0< / span > < span class = "p" > )< / span >
< span class = "c1" > # Write-back output< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > output_ptr< / span > < span class = "o" > +< / span > < span class = "n" > offsets< / span > < span class = "p" > ,< / span > < span class = "n" > output< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > dropout< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > x_keep< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "p" > ):< / span >
< span class = "n" > output< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > is_contiguous< / span > < span class = "p" > ()< / span >
< span class = "n" > n_elements< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > meta< / span > < span class = "p" > :< / span > < span class = "p" > (< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE' < / span > < span class = "p" > ]),)< / span >
< span class = "n" > _dropout< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > x_keep< / span > < span class = "p" > ,< / span > < span class = "n" > output< / span > < span class = "p" > ,< / span > < span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "mi" > 1024< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > output< / span >
< span class = "c1" > # Input tensor< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 10< / span > < span class = "p" > ,))< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "c1" > # Dropout mask< / span >
< span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "mf" > 0.5< / span >
< span class = "n" > x_keep< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 10< / span > < span class = "p" > ,))< / span > < span class = "o" > > < / span > < span class = "n" > p< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > to< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > int32< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "c1" > #< / span >
< span class = "n" > output< / span > < span class = "o" > =< / span > < span class = "n" > dropout< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > x_keep< / span > < span class = "o" > =< / span > < span class = "n" > x_keep< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "n" > p< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > tabulate< / span > < span class = "o" > .< / span > < span class = "n" > tabulate< / span > < span class = "p" > ([< / span >
< span class = "p" > [< / span > < span class = "s2" > " input" < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > tolist< / span > < span class = "p" > (),< / span >
< span class = "p" > [< / span > < span class = "s2" > " keep mask" < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "n" > x_keep< / span > < span class = "o" > .< / span > < span class = "n" > tolist< / span > < span class = "p" > (),< / span >
< span class = "p" > [< / span > < span class = "s2" > " output" < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "n" > output< / span > < span class = "o" > .< / span > < span class = "n" > tolist< / span > < span class = "p" > ()< / span >
< span class = "p" > ]))< / span >
< / pre > < / div >
< / div >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > --------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344
keep mask 1 1 0 1 0 1 1 0 0 0
output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
< / pre > < / div >
< / div >
< / div >
< div class = "section" id = "seeded-dropout" >
< h2 > Seeded dropout< a class = "headerlink" href = "#seeded-dropout" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
very tricky when using recompute/checkpointing (e.g. see all the notes about < cite > preserve_rng_state< / cite > in
< a class = "reference external" href = "https://pytorch.org/docs/1.9.0/checkpoint.html" > https://pytorch.org/docs/1.9.0/checkpoint.html< / a > ). In this tutorial we’ ll describe an alternative implementation
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
of persisting randomness across multiple invocations of the kernel.< / p >
< p > Pseudorandom number generation in Triton is simple! In this tutorial we will use the
< code class = "code docutils literal notranslate" > < span class = "pre" > triton.language.rand< / span > < / code > function which generates a block of uniformly distributed < code class = "code docutils literal notranslate" > < span class = "pre" > float32< / span > < / code >
values in [0, 1), given a seed and a block of < code class = "code docutils literal notranslate" > < span class = "pre" > int32< / span > < / code > offsets. But if you need it, Triton also provides
other < a class = "reference internal" href = "../../python-api/triton.language.html#random-number-generation" > < span class = "std std-ref" > random number generation strategies< / span > < / a > .< / p >
< div class = "admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Triton’ s implementation of PRNG is based on the Philox algorithm (described on < a class = "reference internal" href = "#salmon2011" id = "id2" > < span > [SALMON2011]< / span > < / a > ).< / p >
< / div >
< p > Let’ s put it all together.< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > _seeded_dropout< / span > < span class = "p" > (< / span >
< span class = "n" > x_ptr< / span > < span class = "p" > ,< / span >
< span class = "n" > output_ptr< / span > < span class = "p" > ,< / span >
< span class = "n" > n_elements< / span > < span class = "p" > ,< / span >
< span class = "n" > p< / span > < span class = "p" > ,< / span >
< span class = "n" > seed< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span > < span class = "p" > ,< / span >
< span class = "p" > ):< / span >
< span class = "c1" > # compute memory offsets of elements handled by this instance< / span >
< span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > block_start< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK_SIZE< / span >
< span class = "n" > offsets< / span > < span class = "o" > =< / span > < span class = "n" > block_start< / span > < span class = "o" > +< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "c1" > # load data from x< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > offsets< / span > < span class = "o" > < < / span > < span class = "n" > n_elements< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > x_ptr< / span > < span class = "o" > +< / span > < span class = "n" > offsets< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "c1" > # randomly prune it< / span >
< span class = "n" > random< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > seed< / span > < span class = "p" > ,< / span > < span class = "n" > offsets< / span > < span class = "p" > )< / span >
< span class = "n" > x_keep< / span > < span class = "o" > =< / span > < span class = "n" > random< / span > < span class = "o" > > < / span > < span class = "n" > p< / span >
< span class = "c1" > # write-back< / span >
< span class = "n" > output< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > where< / span > < span class = "p" > (< / span > < span class = "n" > x_keep< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "o" > /< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "o" > -< / span > < span class = "n" > p< / span > < span class = "p" > ),< / span > < span class = "mf" > 0.0< / span > < span class = "p" > )< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > output_ptr< / span > < span class = "o" > +< / span > < span class = "n" > offsets< / span > < span class = "p" > ,< / span > < span class = "n" > output< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > mask< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > seeded_dropout< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "p" > ,< / span > < span class = "n" > seed< / span > < span class = "p" > ):< / span >
< span class = "n" > output< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > is_contiguous< / span > < span class = "p" > ()< / span >
< span class = "n" > n_elements< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > meta< / span > < span class = "p" > :< / span > < span class = "p" > (< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "n" > meta< / span > < span class = "p" > [< / span > < span class = "s1" > ' BLOCK_SIZE' < / span > < span class = "p" > ]),)< / span >
< span class = "n" > _seeded_dropout< / span > < span class = "p" > [< / span > < span class = "n" > grid< / span > < span class = "p" > ](< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > output< / span > < span class = "p" > ,< / span > < span class = "n" > n_elements< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "p" > ,< / span > < span class = "n" > seed< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "mi" > 1024< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > output< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 10< / span > < span class = "p" > ,))< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "c1" > # Compare this to the baseline - dropout mask is never instantiated!< / span >
< span class = "n" > output< / span > < span class = "o" > =< / span > < span class = "n" > seeded_dropout< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "mf" > 0.5< / span > < span class = "p" > ,< / span > < span class = "n" > seed< / span > < span class = "o" > =< / span > < span class = "mi" > 123< / span > < span class = "p" > )< / span >
< span class = "n" > output2< / span > < span class = "o" > =< / span > < span class = "n" > seeded_dropout< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "mf" > 0.5< / span > < span class = "p" > ,< / span > < span class = "n" > seed< / span > < span class = "o" > =< / span > < span class = "mi" > 123< / span > < span class = "p" > )< / span >
< span class = "n" > output3< / span > < span class = "o" > =< / span > < span class = "n" > seeded_dropout< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > p< / span > < span class = "o" > =< / span > < span class = "mf" > 0.5< / span > < span class = "p" > ,< / span > < span class = "n" > seed< / span > < span class = "o" > =< / span > < span class = "mi" > 512< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > tabulate< / span > < span class = "o" > .< / span > < span class = "n" > tabulate< / span > < span class = "p" > ([< / span >
< span class = "p" > [< / span > < span class = "s2" > " input" < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > tolist< / span > < span class = "p" > (),< / span >
< span class = "p" > [< / span > < span class = "s2" > " output (seed = 123)" < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "n" > output< / span > < span class = "o" > .< / span > < span class = "n" > tolist< / span > < span class = "p" > (),< / span >
< span class = "p" > [< / span > < span class = "s2" > " output (seed = 123)" < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "n" > output2< / span > < span class = "o" > .< / span > < span class = "n" > tolist< / span > < span class = "p" > (),< / span >
< span class = "p" > [< / span > < span class = "s2" > " output (seed = 512)" < / span > < span class = "p" > ]< / span > < span class = "o" > +< / span > < span class = "n" > output3< / span > < span class = "o" > .< / span > < span class = "n" > tolist< / span > < span class = "p" > ()< / span >
< span class = "p" > ]))< / span >
< / pre > < / div >
< / div >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > ------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
< / pre > < / div >
< / div >
< p > Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
If you’ d like explore further applications of pseudorandomness in GPU programming, we encourage you
to explore the < cite > triton/language/random< / cite > folder!< / p >
< / div >
< div class = "section" id = "exercises" >
< h2 > Exercises< a class = "headerlink" href = "#exercises" title = "Permalink to this headline" > ¶< / a > < / h2 >
< ol class = "arabic simple" >
< li > < p > Extend the kernel to operate over a matrix and use a vector of seeds - one per row.< / p > < / li >
< li > < p > Add support for striding.< / p > < / li >
< li > < p > (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.< / p > < / li >
< / ol >
< / div >
< div class = "section" id = "references" >
< h2 > References< a class = "headerlink" href = "#references" title = "Permalink to this headline" > ¶< / a > < / h2 >
< dl class = "citation" >
< dt class = "label" id = "salmon2011" > < span class = "brackets" > < a class = "fn-backref" href = "#id2" > SALMON2011< / a > < / span > < / dt >
< dd > < p > John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011< / p >
< / dd >
< dt class = "label" id = "srivastava2014" > < span class = "brackets" > < a class = "fn-backref" href = "#id1" > SRIVASTAVA2014< / a > < / span > < / dt >
< dd > < p > Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014< / p >
< / dd >
< / dl >
2022-06-14 00:49:31 +00:00
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 0 minutes 0.500 seconds)< / p >
2022-06-05 21:05:02 +00:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 04-low-memory-dropout.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/bc847dec325798bdc436c4ef5ac8b78a/04-low-memory-dropout.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 04-low-memory-dropout.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "05-layer-norm.html" class = "btn btn-neutral float-right" title = "Layer Normalization" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
< a href = "03-matrix-multiplication.html" class = "btn btn-neutral float-left" title = "Matrix Multiplication" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< div class = "rst-versions" data-toggle = "rst-versions" role = "note" aria-label = "versions" >
< span class = "rst-current-version" data-toggle = "rst-current-version" >
< span class = "fa fa-book" > Other Versions< / span >
v: master
< span class = "fa fa-caret-down" > < / span >
< / span >
< div class = "rst-other-versions" >
< dl >
< dt > Tags< / dt >
< dd > < a href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd >
< / dl >
< dl >
< dt > Branches< / dt >
< dd > < a href = "04-low-memory-dropout.html" > master< / a > < / dd >
< / dl >
< / div >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >