# Matrix multiplications are a key building block of most modern high-performance computing systems.
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
#
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
#
# .. code-block:: python
#
# # do in parallel
# for m in range(0, M, MB):
# # do in parallel
# for n in range(0, N, NB):
# acc = zeros((MB, NB), dtype=float32)
# for k in range(0, K, KB):
# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
# C[m : m+MB, n : n+NB] = acc;
#
# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
# %%
# Compute Kernel
# ----------------
#
# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
#
# Pointer Arithmetics
# ~~~~~~~~~~~~~~~~~~~~
#
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
#
# Final Result
# ~~~~~~~~~~~~~~
#
# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
#
# .. code-block:: C
# :force:
#
# #define MAX_GROUP_SIZE 8
#
# __global__ void dot(TYPE* A, TYPE* B, TYPE* C,
# int M, int N, int K,
# int stride_a_0, int stride_b_0, int stride_c_0) {
# // prologue
# int pid = get_program_id(0);
# int grid_m = (M + MB - 1) / MB;
# int grid_n = (N + NB - 1) / NB;
# // re-order program ID for better L2 performance
# int width = MAX_GROUP_SIZE * grid_n;
# int group_id = pid / width;
# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
# // note the parentheses here; they force the offset
# // computation to happen in typeof(stride_a_0) = int32 rather than
# // typeof(A) = int64
# int rm[MB] = pid_m * MB + 0 ... MB;
# int rn[NB] = pid_n * NB + 0 ... NB;
# int rk[KB] = 0 ... KB;
# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
# // reduction loop
# float acc[MB, NB] = 0;
# for (int k = K; k > 0; k -= KB) {
# acc += (*pa) @ (*pb);
# pa += KB * 1;
# pb += KB * stride_b_0;
# }
# // pointers to output
# // here we rematerialize `rm` and `rn` so that they are not live through
# // the above reduction loop. In the future, the compiler should be able to
# // do this automatically.
# rm = pid_m * MB + 0 ... MB;
# rn = pid_n * NB + 0 ... NB;
# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
# // we write back using *?() operator. `acc` gets casted to `float32` implicitly.
# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
# }
#
# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
#
# %%
# Torch Bindings
# ----------------
#
# Auto-Tuning
# ~~~~~~~~~~~~~~
#
# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
# Here, we want to re-tune our kernel only when the shape of input matrices changes.
autotune_key=["M","N","K"]
# %%
# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
# For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
# To install CUTLASS, you need a recent version of cmake:
# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
# To re-install Triton with the updated CUTLASS bindings, run the following command:
# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
#
# Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M','N','K'],# argument names to use as an x-axis for the plot
x_vals=[256*iforiinrange(2,33)],# different possible values for `x_name`
y_name='provider',# argument name whose value corresponds to a different line in the plot
# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.