[DOCS] softmax tutorial fixup (#198)

This commit is contained in:
Philippe Tillet
2021-08-11 17:35:00 -07:00
committed by GitHub
parent 83da7065da
commit 398d4b4aeb

View File

@@ -28,25 +28,25 @@ def naive_softmax(x):
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read 2MN elements ; write MN elements
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read 2MN elements ; write MN elements
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
# %%
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
# requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
# X once and does all the necessary computations on-chip.
# Doing so would require reading and writing back only :math:`MN` bytes, so we could
# expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
# but, as we will see later, it is still far from ideal.
@@ -200,7 +200,6 @@ benchmark.run(show_plots=True, print_data=True)
# %%
# In the above plot, we can see that:
#
# - Triton is 2-3x faster than the Torch JIT.
# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.