2021-03-06 17:35:11 -05:00
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/02-fused-softmax.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_02-fused-softmax.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_02-fused-softmax.py:
Fused Softmax
=================
2021-07-23 04:39:46 +00:00
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
You will learn about:
2021-03-06 17:35:11 -05:00
- The benefits of kernel fusion for bandwidth-bound operations.
2021-07-23 04:39:46 +00:00
- Reduction operators in Triton.
2021-03-06 17:35:11 -05:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 12-16
2021-03-06 17:35:11 -05:00
Motivations
------------
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 16-37
2021-03-06 17:35:11 -05:00
.. code-block:: default
import torch
# Compute the row-wise softmax of x
2021-07-23 04:39:46 +00:00
@torch.jit.script
2021-03-06 17:35:11 -05:00
def naive_softmax(x):
# read MN elements ; write M elements
2021-07-23 04:39:46 +00:00
x_max = x.max(dim=1)[0]
2021-03-06 17:35:11 -05:00
# read 2MN elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(x)
# read MN elements ; write M elements
2021-07-23 04:39:46 +00:00
denominator = numerator.sum(dim=1)
2021-03-06 17:35:11 -05:00
# read 2MN elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements
return ret
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 38-42
2021-03-06 17:35:11 -05:00
When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
2021-03-06 22:06:32 -05:00
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
2021-07-23 04:39:46 +00:00
Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
2021-03-06 17:35:11 -05:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 44-49
2021-03-06 17:35:11 -05:00
Compute Kernel
2021-03-06 22:06:32 -05:00
----------------
2021-04-21 01:40:29 -04:00
Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
2021-03-06 22:06:32 -05:00
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
2021-07-23 04:39:46 +00:00
so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
2021-03-06 17:35:11 -05:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 49-77
2021-03-06 17:35:11 -05:00
.. code-block:: default
import triton
2021-04-23 16:42:55 -04:00
import triton.language as tl
2021-03-06 17:35:11 -05:00
2021-04-21 01:40:29 -04:00
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
2021-04-23 16:42:55 -04:00
m = tl.program_id(0)
2021-04-21 01:40:29 -04:00
# col indices
2021-07-23 04:39:46 +00:00
# here BLOCK is the smallest power of two greater than `N`
2021-04-23 16:42:55 -04:00
n = tl.arange(0, meta['BLOCK'])
2021-04-21 01:40:29 -04:00
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
2021-04-23 16:42:55 -04:00
x = tl.load(X, mask=n < N, other=-float('inf'))
2021-04-21 01:40:29 -04:00
# Substract maximum for numerical stability
2021-04-23 16:42:55 -04:00
z = x - tl.max(x, axis=0)
2021-04-21 01:40:29 -04:00
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
2021-04-23 16:42:55 -04:00
num = tl.exp(z)
denom = tl.sum(num, axis=0)
2021-04-21 01:40:29 -04:00
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
2021-04-23 16:42:55 -04:00
tl.store(Y, y, mask=n < N)
2021-04-21 01:40:29 -04:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 78-79
2021-04-21 01:40:29 -04:00
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 79-110
2021-04-21 01:40:29 -04:00
.. code-block:: default
2021-03-06 17:35:11 -05:00
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
2021-04-21 01:40:29 -04:00
def softmax(x):
M, N = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
2021-03-06 17:35:11 -05:00
BLOCK = next_power_of_2(N)
2021-07-23 04:39:46 +00:00
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
2021-03-15 13:58:20 -04:00
# You will see in the next tutorial how to auto-tune this value in a more natural
2021-07-23 04:39:46 +00:00
# way so you don't have to come up with manual heuristics yourself.
2021-03-15 13:58:20 -04:00
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
2021-04-21 01:40:29 -04:00
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
2021-04-21 01:58:48 -04:00
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)
2021-04-21 01:40:29 -04:00
return y
2021-03-06 17:35:11 -05:00
2021-03-06 22:06:32 -05:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 111-113
2021-03-06 17:35:11 -05:00
Unit Test
----------
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 115-117
2021-03-06 22:06:32 -05:00
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 117-124
2021-03-06 17:35:11 -05:00
.. code-block:: default
2021-03-06 22:06:32 -05:00
torch.manual_seed(0)
2021-03-06 17:35:11 -05:00
x = torch.randn(1823, 781, device='cuda')
y_tri = softmax(x)
y_ref = torch.softmax(x, axis=1)
print(torch.allclose(y_tri, y_ref))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
True
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 125-126
2021-03-06 17:35:11 -05:00
2021-03-06 22:06:32 -05:00
As expected, the results are identical.
2021-03-06 17:35:11 -05:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 128-132
2021-03-06 17:35:11 -05:00
2021-03-15 13:58:20 -04:00
Benchmark
2021-03-06 22:06:32 -05:00
-------------
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
2021-03-06 17:35:11 -05:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 132-161
2021-03-06 17:35:11 -05:00
.. code-block:: default
2021-03-11 11:58:42 -05:00
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
2021-07-23 04:39:46 +00:00
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
2021-04-23 16:42:55 -04:00
line_arg='provider', # argument name whose value corresponds to a different line in the plot
2021-07-23 04:39:46 +00:00
line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
2021-03-11 11:58:42 -05:00
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
2021-03-06 17:35:11 -05:00
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
2021-07-23 04:39:46 +00:00
if provider == 'torch-native':
2021-03-11 11:58:42 -05:00
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
2021-07-23 04:39:46 +00:00
if provider == 'torch-jit':
2021-03-11 11:58:42 -05:00
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
2021-07-23 04:39:46 +00:00
benchmark.run(show_plots=True, print_data=True)
2021-03-06 17:35:11 -05:00
2021-03-06 22:06:32 -05:00
2021-03-06 17:35:11 -05:00
.. image:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png
2021-03-29 11:59:18 -04:00
:alt: 02 fused softmax
2021-03-06 17:35:11 -05:00
:class: sphx-glr-single-img
2021-07-23 04:39:46 +00:00
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
softmax-performance:
N Triton Torch (native) Torch (jit)
2021-07-31 05:27:59 +00:00
0 256.0 512.000001 512.000001 273.066674
2021-07-28 11:39:54 +00:00
1 384.0 585.142862 585.142862 261.446801
2021-07-31 05:27:59 +00:00
2 512.0 630.153853 606.814814 264.258068
3 640.0 682.666684 640.000002 269.473696
2021-07-24 06:04:37 +00:00
4 768.0 702.171410 664.216187 273.066663
2021-07-23 04:39:46 +00:00
.. ... ... ... ...
2021-07-31 05:27:59 +00:00
93 12160.0 812.359066 406.179533 329.483481
2021-07-28 07:36:00 +00:00
94 12288.0 812.429770 415.661740 329.602681
2021-07-28 11:39:54 +00:00
95 12416.0 810.840807 411.722274 329.173158
2021-07-28 10:15:45 +00:00
96 12544.0 810.925276 412.971190 329.292871
2021-07-31 05:27:59 +00:00
97 12672.0 811.007961 412.516771 329.142870
2021-07-23 04:39:46 +00:00
[98 rows x 4 columns]
2021-03-06 17:35:11 -05:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 162-167
2021-03-06 22:06:32 -05:00
In the above plot, we can see that:
2021-07-23 04:39:46 +00:00
- Triton is 2-3x faster than the Torch JIT.
- Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
2021-03-06 22:06:32 -05:00
Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.
2021-03-06 17:35:11 -05:00
.. rst-class:: sphx-glr-timing
2021-07-31 05:27:59 +00:00
**Total running time of the script:** ( 1 minutes 8.252 seconds)
2021-03-06 17:35:11 -05:00
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_