[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2021-08-12 00:39:35 +00:00
parent 3f6d8e2afa
commit 7d91e06e08
19 changed files with 101 additions and 104 deletions

View File

@@ -33,14 +33,14 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n\n\n@torch.jit.script\ndef naive_softmax(x):\n \"\"\"Compute row-wise softmax of X using native pytorch\n\n We subtract the maximum element in order to avoid overflows. Softmax is invariant to\n this shift.\n \"\"\"\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read 2MN elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read 2MN elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 7MN elements ; wrote 3MN + 2M elements\n return ret" "import torch\n\n\n@torch.jit.script\ndef naive_softmax(x):\n \"\"\"Compute row-wise softmax of X using native pytorch\n\n We subtract the maximum element in order to avoid overflows. Softmax is invariant to\n this shift.\n \"\"\"\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read MN + M elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read MN + M elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements\n return ret"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$\nrequires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads\nX once and does all the necessary computations on-chip.\nDoing so would require reading and writing back only $MN$ bytes, so we could\nexpect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nThe `torch.jit.script` flags aims to perform this kind of \"kernel fusion\" automatically\nbut, as we will see later, it is still far from ideal.\n\n" "When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$\nrequires reading $5MN + 2M$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads\nX once and does all the necessary computations on-chip.\nDoing so would require reading and writing back only $MN$ bytes, so we could\nexpect a theoretical speed-up of ~4x (i.e., $(8MN + 4M) / 2MN$).\nThe `torch.jit.script` flags aims to perform this kind of \"kernel fusion\" automatically\nbut, as we will see later, it is still far from ideal.\n\n"
] ]
}, },
{ {
@@ -133,7 +133,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"In the above plot, we can see that:\n\n - Triton is 2-3x faster than the Torch JIT.\n - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.\n This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.\n Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.\n\n" "In the above plot, we can see that:\n\n - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.\n - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. \n Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.\n\n"
] ]
} }
], ],

View File

@@ -28,25 +28,25 @@ def naive_softmax(x):
""" """
# read MN elements ; write M elements # read MN elements ; write M elements
x_max = x.max(dim=1)[0] x_max = x.max(dim=1)[0]
# read 2MN elements ; write MN elements # read MN + M elements ; write MN elements
z = x - x_max[:, None] z = x - x_max[:, None]
# read MN elements ; write MN elements # read MN elements ; write MN elements
numerator = torch.exp(z) numerator = torch.exp(z)
# read MN elements ; write M elements # read MN elements ; write M elements
denominator = numerator.sum(dim=1) denominator = numerator.sum(dim=1)
# read 2MN elements ; write MN elements # read MN + M elements ; write MN elements
ret = numerator / denominator[:, None] ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret return ret
# %% # %%
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` # When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
# requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements. # requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads # This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
# X once and does all the necessary computations on-chip. # X once and does all the necessary computations on-chip.
# Doing so would require reading and writing back only :math:`MN` bytes, so we could # Doing so would require reading and writing back only :math:`MN` bytes, so we could
# expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`). # expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically # The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
# but, as we will see later, it is still far from ideal. # but, as we will see later, it is still far from ideal.
@@ -200,7 +200,6 @@ benchmark.run(show_plots=True, print_data=True)
# %% # %%
# In the above plot, we can see that: # In the above plot, we can see that:
# #
# - Triton is 2-3x faster than the Torch JIT. # - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax. # - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary. # Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

After

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

After

Width:  |  Height:  |  Size: 31 KiB

View File

@@ -254,7 +254,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 10.981 seconds) **Total running time of the script:** ( 0 minutes 10.994 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py: .. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:

View File

@@ -52,15 +52,15 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
""" """
# read MN elements ; write M elements # read MN elements ; write M elements
x_max = x.max(dim=1)[0] x_max = x.max(dim=1)[0]
# read 2MN elements ; write MN elements # read MN + M elements ; write MN elements
z = x - x_max[:, None] z = x - x_max[:, None]
# read MN elements ; write MN elements # read MN elements ; write MN elements
numerator = torch.exp(z) numerator = torch.exp(z)
# read MN elements ; write M elements # read MN elements ; write M elements
denominator = numerator.sum(dim=1) denominator = numerator.sum(dim=1)
# read 2MN elements ; write MN elements # read MN + M elements ; write MN elements
ret = numerator / denominator[:, None] ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret return ret
@@ -74,11 +74,11 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
.. GENERATED FROM PYTHON SOURCE LINES 44-52 .. GENERATED FROM PYTHON SOURCE LINES 44-52
When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements. requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
X once and does all the necessary computations on-chip. X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only :math:`MN` bytes, so we could Doing so would require reading and writing back only :math:`MN` bytes, so we could
expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`). expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
but, as we will see later, it is still far from ideal. but, as we will see later, it is still far from ideal.
@@ -306,30 +306,29 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
3 640.0 682.666684 640.000002 160.000000 3 640.0 682.666684 640.000002 160.000000
4 768.0 702.171410 664.216187 163.839992 4 768.0 702.171410 664.216187 163.839992
.. ... ... ... ... .. ... ... ... ...
93 12160.0 812.359066 405.755985 199.038365 93 12160.0 812.359066 406.179533 198.936606
94 12288.0 812.429770 415.661740 199.298541 94 12288.0 812.429770 416.101597 199.298541
95 12416.0 810.840807 412.149375 198.954424 95 12416.0 810.840807 412.149375 198.854847
96 12544.0 810.925276 412.971190 199.209928 96 12544.0 810.925276 412.971190 199.209928
97 12672.0 809.389265 412.097543 199.167004 97 12672.0 811.007961 412.097543 199.167004
[98 rows x 4 columns] [98 rows x 4 columns]
.. GENERATED FROM PYTHON SOURCE LINES 201-207 .. GENERATED FROM PYTHON SOURCE LINES 201-206
In the above plot, we can see that: In the above plot, we can see that:
- Triton is 2-3x faster than the Torch JIT. - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
- Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax. - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary. Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 12.602 seconds) **Total running time of the script:** ( 1 minutes 12.617 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:

View File

@@ -471,37 +471,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
matmul-performance: matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU) M cuBLAS ... Triton Triton (+ LeakyReLU)
0 128.0 0.455111 ... 0.512000 0.512000 0 128.0 0.455111 ... 0.512000 0.512000
1 256.0 2.978909 ... 3.276800 2.978909 1 256.0 2.730667 ... 2.978909 2.978909
2 384.0 7.372800 ... 8.507077 8.507077 2 384.0 7.372800 ... 8.507077 8.507077
3 512.0 14.563555 ... 16.384000 15.420235 3 512.0 14.563555 ... 15.420235 16.384000
4 640.0 22.260869 ... 24.380953 24.380953 4 640.0 22.260869 ... 24.380953 23.272727
5 768.0 32.768000 ... 34.028308 34.028308 5 768.0 32.768000 ... 34.028308 34.028308
6 896.0 39.025776 ... 40.140799 39.025776 6 896.0 39.025776 ... 40.140799 39.025776
7 1024.0 49.932191 ... 53.773130 52.428801 7 1024.0 49.932191 ... 53.773130 52.428801
8 1152.0 45.242181 ... 46.656000 46.656000 8 1152.0 45.242181 ... 46.656000 46.656000
9 1280.0 51.200001 ... 56.888887 56.109587 9 1280.0 51.200001 ... 56.888887 56.888887
10 1408.0 64.138541 ... 64.902096 64.902096 10 1408.0 64.138541 ... 64.902096 64.902096
11 1536.0 80.430545 ... 76.106321 76.106321 11 1536.0 78.643199 ... 76.106321 75.296679
12 1664.0 63.372618 ... 62.061463 62.061463 12 1664.0 62.929456 ... 62.061463 62.061463
13 1792.0 72.983276 ... 69.810085 69.810085 13 1792.0 72.983276 ... 69.810085 69.379162
14 1920.0 69.120002 ... 70.172588 69.120002 14 1920.0 67.434145 ... 70.892307 70.530615
15 2048.0 73.584279 ... 74.898285 73.584279 15 2048.0 73.908442 ... 74.898285 74.565406
16 2176.0 82.813365 ... 78.916269 79.540109 16 2176.0 83.500614 ... 78.916269 79.855747
17 2304.0 68.056616 ... 73.275679 73.275679 17 2304.0 68.251065 ... 73.275679 72.828879
18 2432.0 71.125224 ... 81.197876 81.908060 18 2432.0 71.125224 ... 80.731218 80.731218
19 2560.0 77.649287 ... 76.560748 75.676673 19 2560.0 77.649287 ... 76.560748 76.382283
20 2688.0 84.108772 ... 81.053536 84.108772 20 2688.0 81.928846 ... 80.366642 82.823267
21 2816.0 80.469019 ... 79.298560 78.726003 21 2816.0 77.743683 ... 78.868366 78.301990
22 2944.0 81.832567 ... 79.737653 79.104810 22 2944.0 81.832567 ... 79.610276 78.605729
23 3072.0 82.540970 ... 80.890151 83.146995 23 3072.0 81.005868 ... 81.005868 82.420822
24 3200.0 84.210524 ... 89.260810 87.791493 24 3200.0 84.321474 ... 89.635851 85.106381
25 3328.0 83.130825 ... 86.736504 82.275764 25 3328.0 83.226931 ... 87.156532 86.113988
26 3456.0 78.578525 ... 83.459178 85.767626 26 3456.0 81.932484 ... 83.632331 85.313831
27 3584.0 87.808000 ... 92.410473 95.148565 27 3584.0 87.211821 ... 87.211821 91.563533
28 3712.0 85.601834 ... 85.601834 88.876645 28 3712.0 85.896254 ... 82.491612 84.874549
29 3840.0 84.615146 ... 86.875096 87.148936 29 3840.0 85.070769 ... 87.493673 87.701820
30 3968.0 92.512459 ... 83.807647 83.520835 30 3968.0 92.935215 ... 83.865247 83.578035
31 4096.0 93.596744 ... 90.748973 90.321484 31 4096.0 93.662059 ... 85.926841 84.840533
[32 rows x 5 columns] [32 rows x 5 columns]
@@ -511,7 +511,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 2 minutes 16.405 seconds) **Total running time of the script:** ( 2 minutes 9.226 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py: .. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:

View File

@@ -5,12 +5,12 @@
Computation times Computation times
================= =================
**03:39.988** total execution time for **getting-started_tutorials** files: **03:32.837** total execution time for **getting-started_tutorials** files:
+---------------------------------------------------------------------------------------------------------+-----------+--------+ +---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:16.405 | 0.0 MB | | :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:09.226 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+ +---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.602 | 0.0 MB | | :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.617 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+ +---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.981 | 0.0 MB | | :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.994 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+ +---------------------------------------------------------------------------------------------------------+-----------+--------+

View File

@@ -337,7 +337,7 @@ for different problem sizes.</p>
15 134217728.0 851.577704 850.656574 15 134217728.0 851.577704 850.656574
</pre></div> </pre></div>
</div> </div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 10.981 seconds)</p> <p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 10.994 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py"> <div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container"> <div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p> <p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>

View File

@@ -219,24 +219,24 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
<span class="sd"> &quot;&quot;&quot;</span> <span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># read MN elements ; write M elements</span> <span class="c1"># read MN elements ; write M elements</span>
<span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># read 2MN elements ; write MN elements</span> <span class="c1"># read MN + M elements ; write MN elements</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># read MN elements ; write MN elements</span> <span class="c1"># read MN elements ; write MN elements</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="c1"># read MN elements ; write M elements</span> <span class="c1"># read MN elements ; write M elements</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">numerator</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">denominator</span> <span class="o">=</span> <span class="n">numerator</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># read 2MN elements ; write MN elements</span> <span class="c1"># read MN + M elements ; write MN elements</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span> <span class="c1"># in total: read 5MN + 2M elements ; wrote 3MN + 2M elements</span>
<span class="k">return</span> <span class="n">ret</span> <span class="k">return</span> <span class="n">ret</span>
</pre></div> </pre></div>
</div> </div>
<p>When implemented naively in PyTorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span> <p>When implemented naively in PyTorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span>
requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements. requires reading <span class="math notranslate nohighlight">\(5MN + 2M\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
This is obviously wasteful; wed prefer to have a custom “fused” kernel that only reads This is obviously wasteful; wed prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip. X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could
expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>). expect a theoretical speed-up of ~4x (i.e., <span class="math notranslate nohighlight">\((8MN + 4M) / 2MN\)</span>).
The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.</p> but, as we will see later, it is still far from ideal.</p>
</div> </div>
@@ -391,11 +391,11 @@ We will then compare its performance against (1) <code class="code docutils lite
3 640.0 682.666684 640.000002 160.000000 3 640.0 682.666684 640.000002 160.000000
4 768.0 702.171410 664.216187 163.839992 4 768.0 702.171410 664.216187 163.839992
.. ... ... ... ... .. ... ... ... ...
93 12160.0 812.359066 405.755985 199.038365 93 12160.0 812.359066 406.179533 198.936606
94 12288.0 812.429770 415.661740 199.298541 94 12288.0 812.429770 416.101597 199.298541
95 12416.0 810.840807 412.149375 198.954424 95 12416.0 810.840807 412.149375 198.854847
96 12544.0 810.925276 412.971190 199.209928 96 12544.0 810.925276 412.971190 199.209928
97 12672.0 809.389265 412.097543 199.167004 97 12672.0 811.007961 412.097543 199.167004
[98 rows x 4 columns] [98 rows x 4 columns]
</pre></div> </pre></div>
@@ -403,13 +403,12 @@ We will then compare its performance against (1) <code class="code docutils lite
<p>In the above plot, we can see that:</p> <p>In the above plot, we can see that:</p>
<blockquote> <blockquote>
<div><ul class="simple"> <div><ul class="simple">
<li><p>Triton is 2-3x faster than the Torch JIT.</p></li> <li><p>Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.</p></li>
<li><p>Triton is even faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code>. My guess from looking at the source-code of the <a class="reference external" href="https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240">PyTorch kernel</a> is that PyTorch only partially fuses the computation of the softmax. <li><p>Triton is noticeably faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> in addition to being <strong>easier to read, understand and maintain</strong>.
This means that when temporary data is too large to fit entirely in the GPUs cache it transfers almost twice the amount of memory necessary. Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
Note that our Triton kernel is not only faster than PyTorchs CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
</ul> </ul>
</div></blockquote> </div></blockquote>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.602 seconds)</p> <p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.617 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py"> <div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container"> <div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p> <p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>

View File

@@ -575,42 +575,42 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance: <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU) M cuBLAS ... Triton Triton (+ LeakyReLU)
0 128.0 0.455111 ... 0.512000 0.512000 0 128.0 0.455111 ... 0.512000 0.512000
1 256.0 2.978909 ... 3.276800 2.978909 1 256.0 2.730667 ... 2.978909 2.978909
2 384.0 7.372800 ... 8.507077 8.507077 2 384.0 7.372800 ... 8.507077 8.507077
3 512.0 14.563555 ... 16.384000 15.420235 3 512.0 14.563555 ... 15.420235 16.384000
4 640.0 22.260869 ... 24.380953 24.380953 4 640.0 22.260869 ... 24.380953 23.272727
5 768.0 32.768000 ... 34.028308 34.028308 5 768.0 32.768000 ... 34.028308 34.028308
6 896.0 39.025776 ... 40.140799 39.025776 6 896.0 39.025776 ... 40.140799 39.025776
7 1024.0 49.932191 ... 53.773130 52.428801 7 1024.0 49.932191 ... 53.773130 52.428801
8 1152.0 45.242181 ... 46.656000 46.656000 8 1152.0 45.242181 ... 46.656000 46.656000
9 1280.0 51.200001 ... 56.888887 56.109587 9 1280.0 51.200001 ... 56.888887 56.888887
10 1408.0 64.138541 ... 64.902096 64.902096 10 1408.0 64.138541 ... 64.902096 64.902096
11 1536.0 80.430545 ... 76.106321 76.106321 11 1536.0 78.643199 ... 76.106321 75.296679
12 1664.0 63.372618 ... 62.061463 62.061463 12 1664.0 62.929456 ... 62.061463 62.061463
13 1792.0 72.983276 ... 69.810085 69.810085 13 1792.0 72.983276 ... 69.810085 69.379162
14 1920.0 69.120002 ... 70.172588 69.120002 14 1920.0 67.434145 ... 70.892307 70.530615
15 2048.0 73.584279 ... 74.898285 73.584279 15 2048.0 73.908442 ... 74.898285 74.565406
16 2176.0 82.813365 ... 78.916269 79.540109 16 2176.0 83.500614 ... 78.916269 79.855747
17 2304.0 68.056616 ... 73.275679 73.275679 17 2304.0 68.251065 ... 73.275679 72.828879
18 2432.0 71.125224 ... 81.197876 81.908060 18 2432.0 71.125224 ... 80.731218 80.731218
19 2560.0 77.649287 ... 76.560748 75.676673 19 2560.0 77.649287 ... 76.560748 76.382283
20 2688.0 84.108772 ... 81.053536 84.108772 20 2688.0 81.928846 ... 80.366642 82.823267
21 2816.0 80.469019 ... 79.298560 78.726003 21 2816.0 77.743683 ... 78.868366 78.301990
22 2944.0 81.832567 ... 79.737653 79.104810 22 2944.0 81.832567 ... 79.610276 78.605729
23 3072.0 82.540970 ... 80.890151 83.146995 23 3072.0 81.005868 ... 81.005868 82.420822
24 3200.0 84.210524 ... 89.260810 87.791493 24 3200.0 84.321474 ... 89.635851 85.106381
25 3328.0 83.130825 ... 86.736504 82.275764 25 3328.0 83.226931 ... 87.156532 86.113988
26 3456.0 78.578525 ... 83.459178 85.767626 26 3456.0 81.932484 ... 83.632331 85.313831
27 3584.0 87.808000 ... 92.410473 95.148565 27 3584.0 87.211821 ... 87.211821 91.563533
28 3712.0 85.601834 ... 85.601834 88.876645 28 3712.0 85.896254 ... 82.491612 84.874549
29 3840.0 84.615146 ... 86.875096 87.148936 29 3840.0 85.070769 ... 87.493673 87.701820
30 3968.0 92.512459 ... 83.807647 83.520835 30 3968.0 92.935215 ... 83.865247 83.578035
31 4096.0 93.596744 ... 90.748973 90.321484 31 4096.0 93.662059 ... 85.926841 84.840533
[32 rows x 5 columns] [32 rows x 5 columns]
</pre></div> </pre></div>
</div> </div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 16.405 seconds)</p> <p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 9.226 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py"> <div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container"> <div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p> <p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>

View File

@@ -174,7 +174,7 @@
<div class="section" id="computation-times"> <div class="section" id="computation-times">
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline"></a></h1> <span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline"></a></h1>
<p><strong>03:39.988</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p> <p><strong>03:32.837</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
<table class="docutils align-default"> <table class="docutils align-default">
<colgroup> <colgroup>
<col style="width: 85%" /> <col style="width: 85%" />
@@ -183,15 +183,15 @@
</colgroup> </colgroup>
<tbody> <tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td> <tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
<td><p>02:16.405</p></td> <td><p>02:09.226</p></td>
<td><p>0.0 MB</p></td> <td><p>0.0 MB</p></td>
</tr> </tr>
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td> <tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
<td><p>01:12.602</p></td> <td><p>01:12.617</p></td>
<td><p>0.0 MB</p></td> <td><p>0.0 MB</p></td>
</tr> </tr>
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td> <tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
<td><p>00:10.981</p></td> <td><p>00:10.994</p></td>
<td><p>0.0 MB</p></td> <td><p>0.0 MB</p></td>
</tr> </tr>
</tbody> </tbody>

File diff suppressed because one or more lines are too long