2022-06-05 21:05:02 +00:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Fused Softmax — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script data-url_root = "../../" id = "documentation_options" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script async = "async" src = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
< link rel = "next" title = "Matrix Multiplication" href = "03-matrix-multiplication.html" / >
< link rel = "prev" title = "Vector Addition" href = "01-vector-add.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" role = "heading" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2" > < a class = "reference internal" href = "01-vector-add.html" > Vector Addition< / a > < / li >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Fused Softmax< / a > < ul >
< li class = "toctree-l3" > < a class = "reference internal" href = "#motivations" > Motivations< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#compute-kernel" > Compute Kernel< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#unit-test" > Unit Test< / a > < / li >
< li class = "toctree-l3" > < a class = "reference internal" href = "#benchmark" > Benchmark< / a > < / li >
< / ul >
< / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "04-low-memory-dropout.html" > Low-Memory Dropout< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "05-layer-norm.html" > Layer Normalization< / a > < / li >
2022-07-14 07:22:19 +00:00
< li class = "toctree-l2" > < a class = "reference internal" href = "06-fused-attention.html" > Fused Attention< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "07-libdevice-function.html" > Libdevice function< / a > < / li >
2022-06-05 21:05:02 +00:00
< / ul >
< / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Python API< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.html" > triton< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.language.html" > triton.language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../python-api/triton.testing.html" > triton.testing< / a > < / li >
< / ul >
< p class = "caption" role = "heading" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Fused Softmax< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/02-fused-softmax.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-02-fused-softmax-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "fused-softmax" >
< span id = "sphx-glr-getting-started-tutorials-02-fused-softmax-py" > < / span > < h1 > Fused Softmax< a class = "headerlink" href = "#fused-softmax" title = "Permalink to this headline" > ¶< / a > < / h1 >
< p > In this tutorial, you will write a fused softmax operation that is significantly faster
than PyTorch’ s native op for a particular class of matrices: those whose rows can fit in
the GPU’ s SRAM.
You will learn about:< / p >
< ul class = "simple" >
< li > < p > The benefits of kernel fusion for bandwidth-bound operations.< / p > < / li >
< li > < p > Reduction operators in Triton.< / p > < / li >
< / ul >
< div class = "section" id = "motivations" >
< h2 > Motivations< a class = "headerlink" href = "#motivations" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > Custom GPU kernels for elementwise additions are educationally valuable but won’ t get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton.language< / span > < span class = "k" > as< / span > < span class = "nn" > tl< / span >
< span class = "nd" > @torch< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span > < span class = "o" > .< / span > < span class = "n" > script< / span >
< span class = "k" > def< / span > < span class = "nf" > naive_softmax< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Compute row-wise softmax of X using native pytorch< / span >
< span class = "sd" > We subtract the maximum element in order to avoid overflows. Softmax is invariant to< / span >
< span class = "sd" > this shift.< / span >
< span class = "sd" > " " " < / span >
< span class = "c1" > # read MN elements ; write M elements< / span >
< span class = "n" > x_max< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > max< / span > < span class = "p" > (< / span > < span class = "n" > dim< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )[< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "c1" > # read MN + M elements ; write MN elements< / span >
< span class = "n" > z< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > -< / span > < span class = "n" > x_max< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "c1" > # read MN elements ; write MN elements< / span >
< span class = "n" > numerator< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > z< / span > < span class = "p" > )< / span >
< span class = "c1" > # read MN elements ; write M elements< / span >
< span class = "n" > denominator< / span > < span class = "o" > =< / span > < span class = "n" > numerator< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > dim< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "c1" > # read MN + M elements ; write MN elements< / span >
< span class = "n" > ret< / span > < span class = "o" > =< / span > < span class = "n" > numerator< / span > < span class = "o" > /< / span > < span class = "n" > denominator< / span > < span class = "p" > [:,< / span > < span class = "kc" > None< / span > < span class = "p" > ]< / span >
< span class = "c1" > # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements< / span >
< span class = "k" > return< / span > < span class = "n" > ret< / span >
< / pre > < / div >
< / div >
< p > When implemented naively in PyTorch, computing < code class = "code docutils literal notranslate" > < span class = "pre" > y< / span > < span class = "pre" > =< / span > < span class = "pre" > naive_softmax(x)< / span > < / code > for < span class = "math notranslate nohighlight" > \(x \in R^{M \times N}\)< / span >
requires reading < span class = "math notranslate nohighlight" > \(5MN + 2M\)< / span > elements from DRAM and writing back < span class = "math notranslate nohighlight" > \(3MN + 2M\)< / span > elements.
This is obviously wasteful; we’ d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only < span class = "math notranslate nohighlight" > \(MN\)< / span > bytes, so we could
expect a theoretical speed-up of ~4x (i.e., < span class = "math notranslate nohighlight" > \((8MN + 4M) / 2MN\)< / span > ).
The < cite > torch.jit.script< / cite > flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.< / p >
< / div >
< div class = "section" id = "compute-kernel" >
< h2 > Compute Kernel< a class = "headerlink" href = "#compute-kernel" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > Our softmax kernel works as follows: each program loads a row of the input matrix X,
normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a
power-of-two number of elements, so we need to internally “pad” each row and guard the
memory operations properly if we want to handle any possible input shapes:< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > jit< / span >
< span class = "k" > def< / span > < span class = "nf" > softmax_kernel< / span > < span class = "p" > (< / span >
< span class = "n" > output_ptr< / span > < span class = "p" > ,< / span > < span class = "n" > input_ptr< / span > < span class = "p" > ,< / span > < span class = "n" > input_row_stride< / span > < span class = "p" > ,< / span > < span class = "n" > output_row_stride< / span > < span class = "p" > ,< / span > < span class = "n" > n_cols< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "p" > :< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > constexpr< / span >
< span class = "p" > ):< / span >
< span class = "c1" > # The rows of the softmax are independent, so we parallelize across those< / span >
< span class = "n" > row_idx< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "c1" > # The stride represents how much we need to increase the pointer to advance 1 row< / span >
< span class = "n" > row_start_ptr< / span > < span class = "o" > =< / span > < span class = "n" > input_ptr< / span > < span class = "o" > +< / span > < span class = "n" > row_idx< / span > < span class = "o" > *< / span > < span class = "n" > input_row_stride< / span >
< span class = "c1" > # The block size is the next power of two greater than n_cols, so we can fit each< / span >
< span class = "c1" > # row in a single block< / span >
< span class = "n" > col_offsets< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > arange< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > )< / span >
< span class = "n" > input_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > row_start_ptr< / span > < span class = "o" > +< / span > < span class = "n" > col_offsets< / span >
< span class = "c1" > # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols< / span >
< span class = "n" > row< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > input_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > col_offsets< / span > < span class = "o" > < < / span > < span class = "n" > n_cols< / span > < span class = "p" > ,< / span > < span class = "n" > other< / span > < span class = "o" > =-< / span > < span class = "nb" > float< / span > < span class = "p" > (< / span > < span class = "s1" > ' inf' < / span > < span class = "p" > ))< / span >
< span class = "c1" > # Substract maximum for numerical stability< / span >
< span class = "n" > row_minus_max< / span > < span class = "o" > =< / span > < span class = "n" > row< / span > < span class = "o" > -< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > max< / span > < span class = "p" > (< / span > < span class = "n" > row< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "c1" > # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)< / span >
< span class = "n" > numerator< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > exp< / span > < span class = "p" > (< / span > < span class = "n" > row_minus_max< / span > < span class = "p" > )< / span >
< span class = "n" > denominator< / span > < span class = "o" > =< / span > < span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > numerator< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > softmax_output< / span > < span class = "o" > =< / span > < span class = "n" > numerator< / span > < span class = "o" > /< / span > < span class = "n" > denominator< / span >
< span class = "c1" > # Write back output to DRAM< / span >
< span class = "n" > output_row_start_ptr< / span > < span class = "o" > =< / span > < span class = "n" > output_ptr< / span > < span class = "o" > +< / span > < span class = "n" > row_idx< / span > < span class = "o" > *< / span > < span class = "n" > output_row_stride< / span >
< span class = "n" > output_ptrs< / span > < span class = "o" > =< / span > < span class = "n" > output_row_start_ptr< / span > < span class = "o" > +< / span > < span class = "n" > col_offsets< / span >
< span class = "n" > tl< / span > < span class = "o" > .< / span > < span class = "n" > store< / span > < span class = "p" > (< / span > < span class = "n" > output_ptrs< / span > < span class = "p" > ,< / span > < span class = "n" > softmax_output< / span > < span class = "p" > ,< / span > < span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > col_offsets< / span > < span class = "o" > < < / span > < span class = "n" > n_cols< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< p > We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > def< / span > < span class = "nf" > softmax< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ):< / span >
< span class = "n" > n_rows< / span > < span class = "p" > ,< / span > < span class = "n" > n_cols< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "c1" > # The block size is the smallest power of two greater than the number of columns in `x`< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > next_power_of_2< / span > < span class = "p" > (< / span > < span class = "n" > n_cols< / span > < span class = "p" > )< / span >
< span class = "c1" > # Another trick we can use is to ask the compiler to use more threads per row by< / span >
< span class = "c1" > # increasing the number of warps (`num_warps`) over which each row is distributed.< / span >
< span class = "c1" > # You will see in the next tutorial how to auto-tune this value in a more natural< / span >
< span class = "c1" > # way so you don' t have to come up with manual heuristics yourself.< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span >
< span class = "k" > if< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > > =< / span > < span class = "mi" > 2048< / span > < span class = "p" > :< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "mi" > 8< / span >
< span class = "k" > if< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "o" > > =< / span > < span class = "mi" > 4096< / span > < span class = "p" > :< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "mi" > 16< / span >
< span class = "c1" > # Allocate output< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "c1" > # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o< / span >
< span class = "c1" > # f the input matrix< / span >
< span class = "n" > softmax_kernel< / span > < span class = "p" > [(< / span > < span class = "n" > n_rows< / span > < span class = "p" > ,)](< / span >
< span class = "n" > y< / span > < span class = "p" > ,< / span >
< span class = "n" > x< / span > < span class = "p" > ,< / span >
< span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span >
< span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > stride< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span >
< span class = "n" > n_cols< / span > < span class = "p" > ,< / span >
< span class = "n" > num_warps< / span > < span class = "o" > =< / span > < span class = "n" > num_warps< / span > < span class = "p" > ,< / span >
< span class = "n" > BLOCK_SIZE< / span > < span class = "o" > =< / span > < span class = "n" > BLOCK_SIZE< / span > < span class = "p" > ,< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > y< / span >
< / pre > < / div >
< / div >
< / div >
< div class = "section" id = "unit-test" >
< h2 > Unit Test< a class = "headerlink" href = "#unit-test" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > manual_seed< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 1823< / span > < span class = "p" > ,< / span > < span class = "mi" > 781< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > y_triton< / span > < span class = "o" > =< / span > < span class = "n" > softmax< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "n" > y_torch< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > softmax< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > allclose< / span > < span class = "p" > (< / span > < span class = "n" > y_triton< / span > < span class = "p" > ,< / span > < span class = "n" > y_torch< / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "n" > y_triton< / span > < span class = "p" > ,< / span > < span class = "n" > y_torch< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< p > As expected, the results are identical.< / p >
< / div >
< div class = "section" id = "benchmark" >
< h2 > Benchmark< a class = "headerlink" href = "#benchmark" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) < code class = "code docutils literal notranslate" > < span class = "pre" > torch.softmax< / span > < / code > and (2) the < code class = "code docutils literal notranslate" > < span class = "pre" > naive_softmax< / span > < / code > defined above.< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > perf_report< / span > < span class = "p" > (< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > Benchmark< / span > < span class = "p" > (< / span >
< span class = "n" > x_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' N' < / span > < span class = "p" > ],< / span > < span class = "c1" > # argument names to use as an x-axis for the plot< / span >
< span class = "n" > x_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span >
< span class = "mi" > 128< / span > < span class = "o" > *< / span > < span class = "n" > i< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 100< / span > < span class = "p" > )< / span >
< span class = "p" > ],< / span > < span class = "c1" > # different possible values for `x_name`< / span >
< span class = "n" > line_arg< / span > < span class = "o" > =< / span > < span class = "s1" > ' provider' < / span > < span class = "p" > ,< / span > < span class = "c1" > # argument name whose value corresponds to a different line in the plot< / span >
< span class = "n" > line_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span >
< span class = "s1" > ' triton' < / span > < span class = "p" > ,< / span >
< span class = "s1" > ' torch-native' < / span > < span class = "p" > ,< / span >
< span class = "s1" > ' torch-jit' < / span > < span class = "p" > ,< / span >
< span class = "p" > ],< / span > < span class = "c1" > # possible values for `line_arg``< / span >
< span class = "n" > line_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span >
< span class = "s2" > " Triton" < / span > < span class = "p" > ,< / span >
< span class = "s2" > " Torch (native)" < / span > < span class = "p" > ,< / span >
< span class = "s2" > " Torch (jit)" < / span > < span class = "p" > ,< / span >
< span class = "p" > ],< / span > < span class = "c1" > # label name for the lines< / span >
< span class = "n" > styles< / span > < span class = "o" > =< / span > < span class = "p" > [(< / span > < span class = "s1" > ' blue' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' green' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' -' < / span > < span class = "p" > ),< / span > < span class = "p" > (< / span > < span class = "s1" > ' green' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' --' < / span > < span class = "p" > )],< / span > < span class = "c1" > # line styles< / span >
< span class = "n" > ylabel< / span > < span class = "o" > =< / span > < span class = "s2" > " GB/s" < / span > < span class = "p" > ,< / span > < span class = "c1" > # label name for the y-axis< / span >
< span class = "n" > plot_name< / span > < span class = "o" > =< / span > < span class = "s2" > " softmax-performance" < / span > < span class = "p" > ,< / span > < span class = "c1" > # name for the plot. Used also as a file name for saving the plot.< / span >
< span class = "n" > args< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s1" > ' M' < / span > < span class = "p" > :< / span > < span class = "mi" > 4096< / span > < span class = "p" > },< / span > < span class = "c1" > # values for function arguments not in `x_names` and `y_name`< / span >
< span class = "p" > )< / span >
< span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > benchmark< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > provider< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "n" > M< / span > < span class = "p" > ,< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' torch-native' < / span > < span class = "p" > :< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > softmax< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =-< / span > < span class = "mi" > 1< / span > < span class = "p" > ))< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > :< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > softmax< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ))< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' torch-jit' < / span > < span class = "p" > :< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > naive_softmax< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ))< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > nelement< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > element_size< / span > < span class = "p" > ()< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-9< / span > < span class = "o" > /< / span > < span class = "p" > (< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-3< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > max_ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > min_ms< / span > < span class = "p" > )< / span >
< span class = "n" > benchmark< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span > < span class = "n" > show_plots< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "n" > print_data< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< img alt = "02 fused softmax" class = "sphx-glr-single-img" src = "../../_images/sphx_glr_02-fused-softmax_001.png" / >
< p class = "sphx-glr-script-out" > Out:< / p >
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > softmax-performance:
N Triton Torch (native) Torch (jit)
2022-09-02 00:49:33 +00:00
0 256.0 546.133347 512.000001 188.321838
2022-08-31 00:55:11 +00:00
1 384.0 614.400016 585.142862 153.600004
2022-09-02 00:49:33 +00:00
2 512.0 655.360017 585.142849 154.566038
3 640.0 706.206879 640.000002 160.000000
2022-08-31 00:55:11 +00:00
4 768.0 722.823517 664.216187 162.754967
2022-06-05 21:05:02 +00:00
.. ... ... ... ...
2022-09-02 00:49:33 +00:00
93 12160.0 812.359066 406.179533 198.530610
94 12288.0 812.429770 415.661740 198.895304
95 12416.0 812.498981 412.149375 198.457532
96 12544.0 810.925276 412.971190 198.716830
97 12672.0 811.007961 412.097543 198.776477
2022-06-05 21:05:02 +00:00
[98 rows x 4 columns]
< / pre > < / div >
< / div >
< p > In the above plot, we can see that:< / p >
< blockquote >
< div > < ul class = "simple" >
< li > < p > Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.< / p > < / li >
< li > < p > Triton is noticeably faster than < code class = "code docutils literal notranslate" > < span class = "pre" > torch.softmax< / span > < / code > – in addition to being < strong > easier to read, understand and maintain< / strong > .
Note however that the PyTorch < cite > softmax< / cite > operation is more general and will works on tensors of any shape.< / p > < / li >
< / ul >
< / div > < / blockquote >
2022-09-02 00:49:33 +00:00
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 3 minutes 31.344 seconds)< / p >
2022-06-05 21:05:02 +00:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-02-fused-softmax-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 02-fused-softmax.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/034d953b6214fedce6ea03803c712b89/02-fused-softmax.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 02-fused-softmax.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "03-matrix-multiplication.html" class = "btn btn-neutral float-right" title = "Matrix Multiplication" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
< a href = "01-vector-add.html" class = "btn btn-neutral float-left" title = "Vector Addition" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< div class = "rst-versions" data-toggle = "rst-versions" role = "note" aria-label = "versions" >
< span class = "rst-current-version" data-toggle = "rst-current-version" >
< span class = "fa fa-book" > Other Versions< / span >
v: master
< span class = "fa fa-caret-down" > < / span >
< / span >
< div class = "rst-other-versions" >
< dl >
< dt > Tags< / dt >
< dd > < a href = "../../../v1.1.2/index.html" > v1.1.2< / a > < / dd >
< / dl >
< dl >
< dt > Branches< / dt >
< dd > < a href = "02-fused-softmax.html" > master< / a > < / dd >
< / dl >
< / div >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >