2021-03-06 14:03:01 -05:00
"""
Fused Softmax
== == == == == == == == =
2021-08-05 12:39:07 -04:00
In this tutorial , you will write a fused softmax operation that is significantly faster
than PyTorch ' s native op for a particular class of matrices: those whose rows can fit in
the GPU ' s SRAM.
2021-07-22 22:45:19 -07:00
You will learn about :
2021-03-06 17:26:49 -05:00
- The benefits of kernel fusion for bandwidth - bound operations .
2021-07-22 22:45:19 -07:00
- Reduction operators in Triton .
2021-03-06 14:03:01 -05:00
"""
# %%
2021-03-06 17:26:49 -05:00
# Motivations
# ------------
2021-03-06 14:03:01 -05:00
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
2021-07-22 22:45:19 -07:00
@torch.jit.script
2021-03-06 14:03:01 -05:00
def naive_softmax ( x ) :
2021-08-05 12:39:07 -04:00
""" Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows . Softmax is invariant to
this shift .
"""
2021-03-06 14:03:01 -05:00
# read MN elements ; write M elements
2021-07-22 22:45:19 -07:00
x_max = x . max ( dim = 1 ) [ 0 ]
2021-03-06 14:03:01 -05:00
# read 2MN elements ; write MN elements
z = x - x_max [ : , None ]
# read MN elements ; write MN elements
2021-08-02 09:37:31 -07:00
numerator = torch . exp ( z )
2021-03-06 14:03:01 -05:00
# read MN elements ; write M elements
2021-07-22 22:45:19 -07:00
denominator = numerator . sum ( dim = 1 )
2021-03-06 14:03:01 -05:00
# read 2MN elements ; write MN elements
ret = numerator / denominator [ : , None ]
# in total: read 7MN elements ; wrote 3MN + 2M elements
return ret
# %%
2021-08-05 12:39:07 -04:00
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
# requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
# X once and does all the necessary computations on-chip.
# Doing so would require reading and writing back only :math:`MN` bytes, so we could
# expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
# but, as we will see later, it is still far from ideal.
2021-03-06 14:03:01 -05:00
# %%
2021-03-06 17:26:49 -05:00
# Compute Kernel
2021-03-06 22:04:00 -05:00
# ----------------
2021-08-05 12:39:07 -04:00
# Our softmax kernel works as follows: each program loads a row of the input matrix X,
# normalizes it and writes back the result to the output Y.
# Note that one important limitation of Triton is that each block must have a
# power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes:
2021-03-06 14:03:01 -05:00
import triton
2021-04-23 17:18:14 -04:00
import triton . language as tl
2021-03-06 14:03:01 -05:00
2021-04-20 22:29:40 -04:00
@triton.jit
2021-08-05 12:39:07 -04:00
def softmax_kernel (
output_ptr , input_ptr , input_row_stride , output_row_stride , n_cols , * * meta
) :
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl . program_id ( 0 )
BLOCK_SIZE = meta [ ' BLOCK_SIZE ' ]
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl . arange ( 0 , BLOCK_SIZE )
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl . load ( input_ptrs , mask = col_offsets < n_cols , other = - float ( ' inf ' ) )
2021-04-20 22:29:40 -04:00
# Substract maximum for numerical stability
2021-08-05 12:39:07 -04:00
row_minus_max = row - tl . max ( row , axis = 0 )
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
numerator = tl . exp ( row_minus_max )
denominator = tl . sum ( numerator , axis = 0 )
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl . store ( output_ptrs , softmax_output , mask = col_offsets < n_cols )
2021-04-20 22:29:40 -04:00
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
2021-03-06 14:03:01 -05:00
def next_power_of_2 ( n ) :
2021-08-05 12:39:07 -04:00
""" Return the smallest power of 2 greater than or equal to n """
2021-03-06 14:03:01 -05:00
n - = 1
n | = n >> 1
n | = n >> 2
n | = n >> 4
n | = n >> 8
n | = n >> 16
n + = 1
return n
2021-04-20 22:29:40 -04:00
def softmax ( x ) :
2021-08-05 12:39:07 -04:00
n_rows , n_cols = x . shape
2021-04-20 22:29:40 -04:00
# The block size is the smallest power of two greater than the number of columns in `x`
2021-08-05 12:39:07 -04:00
BLOCK_SIZE = next_power_of_2 ( n_cols )
2021-07-22 22:45:19 -07:00
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
2021-03-14 18:49:59 -04:00
# You will see in the next tutorial how to auto-tune this value in a more natural
2021-07-22 22:45:19 -07:00
# way so you don't have to come up with manual heuristics yourself.
2021-03-14 18:49:59 -04:00
num_warps = 4
2021-08-05 12:39:07 -04:00
if BLOCK_SIZE > = 2048 :
num_warps = 8
if BLOCK_SIZE > = 4096 :
num_warps = 16
2021-04-20 22:29:40 -04:00
# Allocate output
y = torch . empty_like ( x )
2021-08-05 12:39:07 -04:00
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel [ ( n_rows , ) ] (
y ,
x ,
x . stride ( 0 ) ,
y . stride ( 0 ) ,
n_cols ,
num_warps = num_warps ,
BLOCK_SIZE = BLOCK_SIZE ,
)
2021-04-20 22:29:40 -04:00
return y
2021-03-06 22:04:00 -05:00
2021-03-06 14:03:01 -05:00
# %%
# Unit Test
# ----------
2021-03-06 22:04:00 -05:00
# %%
# We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
# This will allow us to verify that our padding mechanism works.
torch . manual_seed ( 0 )
2021-03-06 14:03:01 -05:00
x = torch . randn ( 1823 , 781 , device = ' cuda ' )
2021-08-05 12:39:07 -04:00
y_triton = softmax ( x )
y_torch = torch . softmax ( x , axis = 1 )
print ( torch . allclose ( y_triton , y_torch ) )
2021-03-06 14:03:01 -05:00
2021-03-06 22:04:00 -05:00
#%%
# As expected, the results are identical.
2021-03-06 14:03:01 -05:00
# %%
2021-03-14 18:49:59 -04:00
# Benchmark
2021-03-06 22:04:00 -05:00
# -------------
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
2021-03-06 14:03:01 -05:00
2021-03-11 00:29:16 -05:00
@triton.testing.perf_report (
triton . testing . Benchmark (
x_names = [ ' N ' ] , # argument names to use as an x-axis for the plot
2021-08-05 12:39:07 -04:00
x_vals = [
128 * i for i in range ( 2 , 100 )
] , # different possible values for `x_name`
2021-04-23 17:18:14 -04:00
line_arg = ' provider ' , # argument name whose value corresponds to a different line in the plot
2021-08-05 12:39:07 -04:00
line_vals = [
' triton ' ,
' torch-native ' ,
' torch-jit ' ,
] , # possible values for `line_arg``
line_names = [
" Triton " ,
" Torch (native) " ,
" Torch (jit) " ,
] , # label name for the lines
2021-07-22 22:45:19 -07:00
styles = [ ( ' blue ' , ' - ' ) , ( ' green ' , ' - ' ) , ( ' green ' , ' -- ' ) ] , # line styles
2021-03-11 00:29:16 -05:00
ylabel = " GB/s " , # label name for the y-axis
plot_name = " softmax-performance " , # name for the plot. Used also as a file name for saving the plot.
2021-08-05 12:39:07 -04:00
args = { ' M ' : 4096 } , # values for function arguments not in `x_names` and `y_name`
2021-03-11 00:29:16 -05:00
)
)
def benchmark ( M , N , provider ) :
2021-03-06 14:03:01 -05:00
x = torch . randn ( M , N , device = ' cuda ' , dtype = torch . float32 )
2021-07-22 22:45:19 -07:00
if provider == ' torch-native ' :
2021-03-11 00:29:16 -05:00
ms , min_ms , max_ms = triton . testing . do_bench ( lambda : torch . softmax ( x , axis = - 1 ) )
if provider == ' triton ' :
ms , min_ms , max_ms = triton . testing . do_bench ( lambda : softmax ( x ) )
2021-07-22 22:45:19 -07:00
if provider == ' torch-jit ' :
2021-03-11 00:29:16 -05:00
ms , min_ms , max_ms = triton . testing . do_bench ( lambda : naive_softmax ( x ) )
gbps = lambda ms : 2 * x . nelement ( ) * x . element_size ( ) * 1e-9 / ( ms * 1e-3 )
return gbps ( ms ) , gbps ( max_ms ) , gbps ( min_ms )
2021-06-11 13:48:11 -04:00
benchmark . run ( show_plots = True , print_data = True )
2021-03-06 22:04:00 -05:00
# %%
# In the above plot, we can see that:
#
2021-07-22 22:45:19 -07:00
# - Triton is 2-3x faster than the Torch JIT.
# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
2021-08-05 12:39:07 -04:00
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.