[DOCS] Improved documentation and integration in CI (#139)

This commit is contained in:
Philippe Tillet
2021-07-22 22:45:19 -07:00
committed by Philippe Tillet
parent 76c6f24fb6
commit b253b77c71
5 changed files with 124 additions and 84 deletions

View File

@@ -1,10 +1,11 @@
"""
Fused Softmax
=================
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
You will learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- The reduction operators in Triton.
- Reduction operators in Triton.
"""
# %%
@@ -17,15 +18,16 @@ import torch
# Compute the row-wise softmax of x
@torch.jit.script
def naive_softmax(x):
# read MN elements ; write M elements
x_max = torch.max(x, axis=1)[0]
x_max = x.max(dim=1)[0]
# read 2MN elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(x)
# read MN elements ; write M elements
denominator = torch.sum(numerator, axis=1)
denominator = numerator.sum(dim=1)
# read 2MN elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements
@@ -35,15 +37,15 @@ def naive_softmax(x):
# %%
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
# Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
# %%
# Compute Kernel
# ----------------
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
# so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@@ -54,6 +56,7 @@ def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = tl.program_id(0)
# col indices
# here BLOCK is the smallest power of two greater than `N`
n = tl.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
@@ -90,11 +93,10 @@ def softmax(x):
M, N = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK = next_power_of_2(N)
# Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
# that are longer
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
@@ -132,10 +134,11 @@ print(torch.allclose(y_tri, y_ref))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
line_names=["Torch", "Triton", 'Naive'], # label name for the lines
line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
@@ -143,11 +146,11 @@ print(torch.allclose(y_tri, y_ref))
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
if provider == 'torch':
if provider == 'torch-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive':
if provider == 'torch-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
@@ -158,7 +161,7 @@ benchmark.run(show_plots=True, print_data=True)
# %%
# In the above plot, we can see that:
#
# - Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions.
# - Triton is significantly faster than :code:`torch.softmax` for very large input matrices. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of data necessary.
# - Triton is 2-3x faster than the Torch JIT.
# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.