[DOCS] softmax tutorial fixup (#198)

This commit is contained in:
Philippe Tillet
2021-08-11 17:35:00 -07:00
committed by GitHub
parent 83da7065da
commit 398d4b4aeb

View File

@@ -28,25 +28,25 @@ def naive_softmax(x):
""" """
# read MN elements ; write M elements # read MN elements ; write M elements
x_max = x.max(dim=1)[0] x_max = x.max(dim=1)[0]
# read 2MN elements ; write MN elements # read MN + M elements ; write MN elements
z = x - x_max[:, None] z = x - x_max[:, None]
# read MN elements ; write MN elements # read MN elements ; write MN elements
numerator = torch.exp(z) numerator = torch.exp(z)
# read MN elements ; write M elements # read MN elements ; write M elements
denominator = numerator.sum(dim=1) denominator = numerator.sum(dim=1)
# read 2MN elements ; write MN elements # read MN + M elements ; write MN elements
ret = numerator / denominator[:, None] ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret return ret
# %% # %%
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` # When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
# requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements. # requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads # This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
# X once and does all the necessary computations on-chip. # X once and does all the necessary computations on-chip.
# Doing so would require reading and writing back only :math:`MN` bytes, so we could # Doing so would require reading and writing back only :math:`MN` bytes, so we could
# expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`). # expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically # The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
# but, as we will see later, it is still far from ideal. # but, as we will see later, it is still far from ideal.
@@ -200,7 +200,6 @@ benchmark.run(show_plots=True, print_data=True)
# %% # %%
# In the above plot, we can see that: # In the above plot, we can see that:
# #
# - Triton is 2-3x faster than the Torch JIT. # - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax. # - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary. # Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.