Files
triton/_sources/getting-started/tutorials/02-fused-softmax.rst.txt
2021-04-21 01:40:29 -04:00

296 lines
9.0 KiB
ReStructuredText

.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/02-fused-softmax.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_02-fused-softmax.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_02-fused-softmax.py:
Fused Softmax
=================
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- The reduction operators in Triton.
.. GENERATED FROM PYTHON SOURCE LINES 11-15
Motivations
------------
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
.. GENERATED FROM PYTHON SOURCE LINES 15-35
.. code-block:: default
import torch
# Compute the row-wise softmax of x
def naive_softmax(x):
# read MN elements ; write M elements
x_max = torch.max(x, axis=1)[0]
# read 2MN elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(x)
# read MN elements ; write M elements
denominator = torch.sum(numerator, axis=1)
# read 2MN elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements
return ret
.. GENERATED FROM PYTHON SOURCE LINES 36-40
When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
.. GENERATED FROM PYTHON SOURCE LINES 42-47
Compute Kernel
----------------
Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
.. GENERATED FROM PYTHON SOURCE LINES 47-73
.. code-block:: default
import triton
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = triton.program_id(0)
# col indices
n = triton.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
x = triton.load(X, mask=n < N, other=-float('inf'))
# Substract maximum for numerical stability
z = x - triton.max(x, axis=0)
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
num = triton.exp(z)
denom = triton.sum(num, axis=0)
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
triton.store(Y, y, mask=n < N)
.. GENERATED FROM PYTHON SOURCE LINES 74-75
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
.. GENERATED FROM PYTHON SOURCE LINES 75-107
.. code-block:: default
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
def softmax(x):
M, N = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK = next_power_of_2(N)
# Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
# that are longer
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
return y
.. GENERATED FROM PYTHON SOURCE LINES 108-110
Unit Test
----------
.. GENERATED FROM PYTHON SOURCE LINES 112-114
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 114-121
.. code-block:: default
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_tri = softmax(x)
y_ref = torch.softmax(x, axis=1)
print(torch.allclose(y_tri, y_ref))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
True
.. GENERATED FROM PYTHON SOURCE LINES 122-123
As expected, the results are identical.
.. GENERATED FROM PYTHON SOURCE LINES 125-129
Benchmark
-------------
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
.. GENERATED FROM PYTHON SOURCE LINES 129-157
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(show_plots=True)
.. image:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png
:alt: 02 fused softmax
:class: sphx-glr-single-img
.. GENERATED FROM PYTHON SOURCE LINES 158-163
In the above plot, we can see that:
- Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions.
- Triton is significantly faster than :code:`torch.softmax` for very large input matrices. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of data necessary.
Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 20.767 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_