[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2021-08-12 00:39:35 +00:00
parent 3f6d8e2afa
commit 7d91e06e08
19 changed files with 101 additions and 104 deletions

View File

@@ -219,24 +219,24 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># read MN elements ; write M elements</span>
<span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># read 2MN elements ; write MN elements</span>
<span class="c1"># read MN + M elements ; write MN elements</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># read MN elements ; write MN elements</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="c1"># read MN elements ; write M elements</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">numerator</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># read 2MN elements ; write MN elements</span>
<span class="c1"># read MN + M elements ; write MN elements</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span>
<span class="c1"># in total: read 5MN + 2M elements ; wrote 3MN + 2M elements</span>
<span class="k">return</span> <span class="n">ret</span>
</pre></div>
</div>
<p>When implemented naively in PyTorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span>
requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
requires reading <span class="math notranslate nohighlight">\(5MN + 2M\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
This is obviously wasteful; wed prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could
expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
expect a theoretical speed-up of ~4x (i.e., <span class="math notranslate nohighlight">\((8MN + 4M) / 2MN\)</span>).
The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.</p>
</div>
@@ -391,11 +391,11 @@ We will then compare its performance against (1) <code class="code docutils lite
3 640.0 682.666684 640.000002 160.000000
4 768.0 702.171410 664.216187 163.839992
.. ... ... ... ...
93 12160.0 812.359066 405.755985 199.038365
94 12288.0 812.429770 415.661740 199.298541
95 12416.0 810.840807 412.149375 198.954424
93 12160.0 812.359066 406.179533 198.936606
94 12288.0 812.429770 416.101597 199.298541
95 12416.0 810.840807 412.149375 198.854847
96 12544.0 810.925276 412.971190 199.209928
97 12672.0 809.389265 412.097543 199.167004
97 12672.0 811.007961 412.097543 199.167004
[98 rows x 4 columns]
</pre></div>
@@ -403,13 +403,12 @@ We will then compare its performance against (1) <code class="code docutils lite
<p>In the above plot, we can see that:</p>
<blockquote>
<div><ul class="simple">
<li><p>Triton is 2-3x faster than the Torch JIT.</p></li>
<li><p>Triton is even faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code>. My guess from looking at the source-code of the <a class="reference external" href="https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240">PyTorch kernel</a> is that PyTorch only partially fuses the computation of the softmax.
This means that when temporary data is too large to fit entirely in the GPUs cache it transfers almost twice the amount of memory necessary.
Note that our Triton kernel is not only faster than PyTorchs CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
<li><p>Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.</p></li>
<li><p>Triton is noticeably faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> in addition to being <strong>easier to read, understand and maintain</strong>.
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
</ul>
</div></blockquote>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.602 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.617 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>