diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index e908cca4e..07b405a92 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -28,25 +28,25 @@ def naive_softmax(x): """ # read MN elements ; write M elements x_max = x.max(dim=1)[0] - # read 2MN elements ; write MN elements + # read MN + M elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) - # read 2MN elements ; write MN elements + # read MN + M elements ; write MN elements ret = numerator / denominator[:, None] - # in total: read 7MN elements ; wrote 3MN + 2M elements + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements return ret # %% # When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` -# requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements. +# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. # This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads # X once and does all the necessary computations on-chip. # Doing so would require reading and writing back only :math:`MN` bytes, so we could -# expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`). +# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). # The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically # but, as we will see later, it is still far from ideal. @@ -200,7 +200,6 @@ benchmark.run(show_plots=True, print_data=True) # %% # In the above plot, we can see that: # -# - Triton is 2-3x faster than the Torch JIT. -# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel `_ is that PyTorch only partially fuses the computation of the softmax. -# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary. -# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**. +# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.