# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
# Instead, we want to write a custom "fused" pytorch operators that only reads X once and does all the necessary computations on-chip.
# This would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of 5x.
# In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.
# Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:
#
# .. code-block:: C
#
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
# // row index
# int m = get_program_id(0);
# // column indices
# int n [BLOCK] = 0 ... BLOCK;
# // the memory address of all the elements
# // that we want to load can be computed as follows
# float* px [BLOCK] = X + m*stride_xm + n;
# // because BLOCK has to be a power of two
# // (per Triton-C specs), it is important
# // to guard each memory operation with predicates
# // or we will read out of bounds
# bool check[BLOCK] = n < N;
# float x [BLOCK] = check ? *px : -F32_INFINITY;
# // syntax for reduction in Triton is:
# // x[..., OPERATOR, ...]
# // ^
# // index
# // The operators currently supported are {min, max, +}
# float z [BLOCK] = x - x[max];
# // The exponential in Triton is fast but approximate
# // (i.e., like __expf in CUDA)
# float num [BLOCK] = exp(z);
# float denom = num[+];
# // The result of the reduction is now stored in y