475 lines
34 KiB
HTML
475 lines
34 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Fused Softmax — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||
<script src="../../_static/jquery.js"></script>
|
||
<script src="../../_static/underscore.js"></script>
|
||
<script src="../../_static/doctools.js"></script>
|
||
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||
|
||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
<link rel="next" title="Matrix Multiplication" href="03-matrix-multiplication.html" />
|
||
<link rel="prev" title="Vector Addition" href="01-vector-add.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="06-fused-attention.html">Fused Attention</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="07-libdevice-function.html">Libdevice function</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li><a href="index.html">Tutorials</a> »</li>
|
||
|
||
<li>Fused Softmax</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../../_sources/getting-started/tutorials/02-fused-softmax.rst.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
<div class="sphx-glr-download-link-note admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">here</span></a>
|
||
to download the full example code</p>
|
||
</div>
|
||
<div class="sphx-glr-example-title section" id="fused-softmax">
|
||
<span id="sphx-glr-getting-started-tutorials-02-fused-softmax-py"></span><h1>Fused Softmax<a class="headerlink" href="#fused-softmax" title="Permalink to this headline">¶</a></h1>
|
||
<p>In this tutorial, you will write a fused softmax operation that is significantly faster
|
||
than PyTorch’s native op for a particular class of matrices: those whose rows can fit in
|
||
the GPU’s SRAM.
|
||
You will learn about:</p>
|
||
<ul class="simple">
|
||
<li><p>The benefits of kernel fusion for bandwidth-bound operations.</p></li>
|
||
<li><p>Reduction operators in Triton.</p></li>
|
||
</ul>
|
||
<div class="section" id="motivations">
|
||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||
<p>Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice.
|
||
Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||
|
||
|
||
<span class="nd">@torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">script</span>
|
||
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||
<span class="sd">"""Compute row-wise softmax of X using native pytorch</span>
|
||
|
||
<span class="sd"> We subtract the maximum element in order to avoid overflows. Softmax is invariant to</span>
|
||
<span class="sd"> this shift.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="c1"># read MN elements ; write M elements</span>
|
||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="c1"># read MN + M elements ; write MN elements</span>
|
||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||
<span class="c1"># read MN elements ; write MN elements</span>
|
||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||
<span class="c1"># read MN elements ; write M elements</span>
|
||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">numerator</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># read MN + M elements ; write MN elements</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||
<span class="c1"># in total: read 5MN + 2M elements ; wrote 3MN + 2M elements</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>When implemented naively in PyTorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span>
|
||
requires reading <span class="math notranslate nohighlight">\(5MN + 2M\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
|
||
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
|
||
X once and does all the necessary computations on-chip.
|
||
Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could
|
||
expect a theoretical speed-up of ~4x (i.e., <span class="math notranslate nohighlight">\((8MN + 4M) / 2MN\)</span>).
|
||
The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically
|
||
but, as we will see later, it is still far from ideal.</p>
|
||
</div>
|
||
<div class="section" id="compute-kernel">
|
||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||
<p>Our softmax kernel works as follows: each program loads a row of the input matrix X,
|
||
normalizes it and writes back the result to the output Y.
|
||
Note that one important limitation of Triton is that each block must have a
|
||
power-of-two number of elements, so we need to internally “pad” each row and guard the
|
||
memory operations properly if we want to handle any possible input shapes:</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||
<span class="k">def</span> <span class="nf">softmax_kernel</span><span class="p">(</span>
|
||
<span class="n">output_ptr</span><span class="p">,</span> <span class="n">input_ptr</span><span class="p">,</span> <span class="n">input_row_stride</span><span class="p">,</span> <span class="n">output_row_stride</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span>
|
||
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span>
|
||
<span class="p">):</span>
|
||
<span class="c1"># The rows of the softmax are independent, so we parallelize across those</span>
|
||
<span class="n">row_idx</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="c1"># The stride represents how much we need to increase the pointer to advance 1 row</span>
|
||
<span class="n">row_start_ptr</span> <span class="o">=</span> <span class="n">input_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">input_row_stride</span>
|
||
<span class="c1"># The block size is the next power of two greater than n_cols, so we can fit each</span>
|
||
<span class="c1"># row in a single block</span>
|
||
<span class="n">col_offsets</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
|
||
<span class="n">input_ptrs</span> <span class="o">=</span> <span class="n">row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
|
||
<span class="c1"># Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols</span>
|
||
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">input_ptrs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o"><</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">'inf'</span><span class="p">))</span>
|
||
<span class="c1"># Substract maximum for numerical stability</span>
|
||
<span class="n">row_minus_max</span> <span class="o">=</span> <span class="n">row</span> <span class="o">-</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">row</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="c1"># Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)</span>
|
||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">row_minus_max</span><span class="p">)</span>
|
||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">softmax_output</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span>
|
||
<span class="c1"># Write back output to DRAM</span>
|
||
<span class="n">output_row_start_ptr</span> <span class="o">=</span> <span class="n">output_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">output_row_stride</span>
|
||
<span class="n">output_ptrs</span> <span class="o">=</span> <span class="n">output_row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
|
||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptrs</span><span class="p">,</span> <span class="n">softmax_output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o"><</span> <span class="n">n_cols</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||
<span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="c1"># The block size is the smallest power of two greater than the number of columns in `x`</span>
|
||
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">next_power_of_2</span><span class="p">(</span><span class="n">n_cols</span><span class="p">)</span>
|
||
<span class="c1"># Another trick we can use is to ask the compiler to use more threads per row by</span>
|
||
<span class="c1"># increasing the number of warps (`num_warps`) over which each row is distributed.</span>
|
||
<span class="c1"># You will see in the next tutorial how to auto-tune this value in a more natural</span>
|
||
<span class="c1"># way so you don't have to come up with manual heuristics yourself.</span>
|
||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">>=</span> <span class="mi">2048</span><span class="p">:</span>
|
||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
|
||
<span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">>=</span> <span class="mi">4096</span><span class="p">:</span>
|
||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
|
||
<span class="c1"># Allocate output</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="c1"># Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o</span>
|
||
<span class="c1"># f the input matrix</span>
|
||
<span class="n">softmax_kernel</span><span class="p">[(</span><span class="n">n_rows</span><span class="p">,)](</span>
|
||
<span class="n">y</span><span class="p">,</span>
|
||
<span class="n">x</span><span class="p">,</span>
|
||
<span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">n_cols</span><span class="p">,</span>
|
||
<span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">,</span>
|
||
<span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">y</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="unit-test">
|
||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||
<p>We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
|
||
This will allow us to verify that our padding mechanism works.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">y_triton</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="n">y_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">),</span> <span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>As expected, the results are identical.</p>
|
||
</div>
|
||
<div class="section" id="benchmark">
|
||
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline">¶</a></h2>
|
||
<p>Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
|
||
We will then compare its performance against (1) <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> and (2) the <code class="code docutils literal notranslate"><span class="pre">naive_softmax</span></code> defined above.</p>
|
||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'N'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
|
||
<span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
|
||
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
|
||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span>
|
||
<span class="s1">'triton'</span><span class="p">,</span>
|
||
<span class="s1">'torch-native'</span><span class="p">,</span>
|
||
<span class="s1">'torch-jit'</span><span class="p">,</span>
|
||
<span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span>
|
||
<span class="s2">"Triton"</span><span class="p">,</span>
|
||
<span class="s2">"Torch (native)"</span><span class="p">,</span>
|
||
<span class="s2">"Torch (jit)"</span><span class="p">,</span>
|
||
<span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"GB/s"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"softmax-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">'M'</span><span class="p">:</span> <span class="mi">4096</span><span class="p">},</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||
<span class="p">)</span>
|
||
<span class="p">)</span>
|
||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch-native'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch-jit'</span><span class="p">:</span>
|
||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||
|
||
|
||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<img alt="02 fused softmax" class="sphx-glr-single-img" src="../../_images/sphx_glr_02-fused-softmax_001.png" />
|
||
<p class="sphx-glr-script-out">Out:</p>
|
||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>softmax-performance:
|
||
N Triton Torch (native) Torch (jit)
|
||
0 256.0 546.133347 546.133347 188.321838
|
||
1 384.0 614.400016 585.142862 153.600004
|
||
2 512.0 655.360017 606.814814 154.566038
|
||
3 640.0 706.206879 640.000002 160.000000
|
||
4 768.0 722.823517 664.216187 162.754967
|
||
.. ... ... ... ...
|
||
93 12160.0 812.359066 406.179533 198.733401
|
||
94 12288.0 812.429770 415.661740 198.995960
|
||
95 12416.0 812.498981 412.149375 198.655991
|
||
96 12544.0 810.925276 412.546756 198.864492
|
||
97 12672.0 811.007961 412.097543 198.971549
|
||
|
||
[98 rows x 4 columns]
|
||
</pre></div>
|
||
</div>
|
||
<p>In the above plot, we can see that:</p>
|
||
<blockquote>
|
||
<div><ul class="simple">
|
||
<li><p>Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.</p></li>
|
||
<li><p>Triton is noticeably faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> – in addition to being <strong>easier to read, understand and maintain</strong>.
|
||
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
|
||
</ul>
|
||
</div></blockquote>
|
||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 30.087 seconds)</p>
|
||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||
</div>
|
||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||
<p><a class="reference download internal" download="" href="../../_downloads/034d953b6214fedce6ea03803c712b89/02-fused-softmax.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">02-fused-softmax.ipynb</span></code></a></p>
|
||
</div>
|
||
</div>
|
||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="03-matrix-multiplication.html" class="btn btn-neutral float-right" title="Matrix Multiplication" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||
<span class="fa fa-book"> Other Versions</span>
|
||
v: master
|
||
<span class="fa fa-caret-down"></span>
|
||
</span>
|
||
<div class="rst-other-versions">
|
||
<dl>
|
||
<dt>Tags</dt>
|
||
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
|
||
</dl>
|
||
<dl>
|
||
<dt>Branches</dt>
|
||
<dd><a href="02-fused-softmax.html">master</a></dd>
|
||
</dl>
|
||
</div>
|
||
</div>
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |