[DOCS] softmax tutorial fixup (#198)
This commit is contained in:
@@ -28,25 +28,25 @@ def naive_softmax(x):
|
||||
"""
|
||||
# read MN elements ; write M elements
|
||||
x_max = x.max(dim=1)[0]
|
||||
# read 2MN elements ; write MN elements
|
||||
# read MN + M elements ; write MN elements
|
||||
z = x - x_max[:, None]
|
||||
# read MN elements ; write MN elements
|
||||
numerator = torch.exp(z)
|
||||
# read MN elements ; write M elements
|
||||
denominator = numerator.sum(dim=1)
|
||||
# read 2MN elements ; write MN elements
|
||||
# read MN + M elements ; write MN elements
|
||||
ret = numerator / denominator[:, None]
|
||||
# in total: read 7MN elements ; wrote 3MN + 2M elements
|
||||
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
|
||||
return ret
|
||||
|
||||
|
||||
# %%
|
||||
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
|
||||
# requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
|
||||
# X once and does all the necessary computations on-chip.
|
||||
# Doing so would require reading and writing back only :math:`MN` bytes, so we could
|
||||
# expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
|
||||
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
|
||||
# but, as we will see later, it is still far from ideal.
|
||||
|
||||
@@ -200,7 +200,6 @@ benchmark.run(show_plots=True, print_data=True)
|
||||
# %%
|
||||
# In the above plot, we can see that:
|
||||
#
|
||||
# - Triton is 2-3x faster than the Torch JIT.
|
||||
# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
|
||||
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
|
||||
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.
|
||||
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
||||
|
Reference in New Issue
Block a user