[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2021-03-15 13:58:20 -04:00
parent b4495e0ddc
commit 746b15ee0a
39 changed files with 3933 additions and 1113 deletions

View File

@@ -65,7 +65,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\nimport triton\n\n# Source code for the Triton kernel\n_src = \"\"\"\n__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n int m = get_program_id(0);\n int n [BLOCK] = 0 ... BLOCK;\n float* px [BLOCK] = X + m*stride_xm + n;\n bool check[BLOCK] = n < N;\n float x [BLOCK] = check ? *px : -F32_INFINITY;\n float z [BLOCK] = x - x[max];\n float num [BLOCK] = exp(z);\n float denom = num[+];\n float y [BLOCK] = num / denom;\n float* py [BLOCK] = Y + m*stride_ym + n;\n *?(check)py = y; \n}\n\"\"\"\n\n\n# helper function to get the smaller power-of-two larger than a given number\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\n# kernel caching mechanism\ndef make_kernel(N, device):\n cache = make_kernel.cache\n # Now are kernels are indexed not only by the provided device but also\n # by the rounded number of columns in the input matrix\n BLOCK = next_power_of_2(N)\n key = (BLOCK, device)\n if key not in cache:\n defines = {'BLOCK': BLOCK}\n cache[key] = triton.kernel(_src, device=device, defines=defines)\n return cache[key]\n\n\nmake_kernel.cache = dict()\n\n\nclass _softmax(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x):\n # constraints of the op\n assert x.dtype == torch.float32\n y = torch.empty_like(x)\n # The launch grid is simple: we have one kernel instance per row of the input matrix\n M, N = y.shape\n grid = lambda opt: (M, )\n # Launch kernel\n kernel = make_kernel(N, y.device)\n kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)\n return y\n\n\nsoftmax = _softmax.apply" "import torch\nimport triton\n\n# Source code for the Triton kernel\n_src = \"\"\"\n__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n int m = get_program_id(0);\n int n [BLOCK] = 0 ... BLOCK;\n float* px [BLOCK] = X + m*stride_xm + n;\n bool check[BLOCK] = n < N;\n float x [BLOCK] = check ? *px : -F32_INFINITY;\n float z [BLOCK] = x - x[max];\n float num [BLOCK] = exp(z);\n float denom = num[+];\n float y [BLOCK] = num / denom;\n float* py [BLOCK] = Y + m*stride_ym + n;\n *?(check)py = y; \n}\n\"\"\"\n\n\n# helper function to get the smaller power-of-two larger than a given number\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\n# kernel caching mechanism\ndef make_kernel(N, device):\n cache = make_kernel.cache\n # Now are kernels are indexed not only by the provided device but also\n # by the rounded number of columns in the input matrix\n BLOCK = next_power_of_2(N)\n # Another trick we can use is to ask the compiler to parallelize each\n # row-normalization more aggressively -- i.e., with more warps -- vectors\n # that are longer\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself\n num_warps = 4\n if BLOCK >= 2048: num_warps = 8\n if BLOCK >= 4096: num_warps = 16\n # Each (BLOCK, num_warps, device) results in a different kernel\n key = (BLOCK, num_warps, device)\n if key not in cache:\n defines = {'BLOCK': BLOCK}\n cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)\n return cache[key]\n\n\nmake_kernel.cache = dict()\n\n\nclass _softmax(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x):\n # constraints of the op\n assert x.dtype == torch.float32\n y = torch.empty_like(x)\n # The launch grid is simple: we have one kernel instance per row of the input matrix\n M, N = y.shape\n grid = lambda opt: (M, )\n # Launch kernel\n kernel = make_kernel(N, y.device)\n kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)\n return y\n\n\nsoftmax = _softmax.apply"
] ]
}, },
{ {
@@ -111,7 +111,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Benchmarking\nHere we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.\nWe will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.\n\n" "## Benchmark\nHere we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.\nWe will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.\n\n"
] ]
}, },
{ {
@@ -149,7 +149,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.9" "version": "3.7.3"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -50,7 +50,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_. # The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
# %% # %%
# Torch bindings # Torch Bindings
# -------------------------- # --------------------------
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things: # The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
# #
@@ -127,7 +127,7 @@ add = _add.apply
# %% # %%
# Unit Test # Unit Test
# -------------------------- # -----------
# #
# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below: # Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
@@ -144,8 +144,8 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
# Seems like we're good to go! # Seems like we're good to go!
# %% # %%
# Benchmarking # Benchmark
# -------------------------- # -----------
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch. # We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op. # To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
# for different problem sizes. # for different problem sizes.

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,369 @@
"""
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
You will specifically learn about:
- The block-level matrix multiplication operator `@`
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
"""
# %%
# Motivations
# -------------
# Matrix multiplications are a key building block of most modern high-performance computing systems.
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
#
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
#
# .. code-block:: python
#
# # do in parallel
# for m in range(0, M, MB):
# # do in parallel
# for n in range(0, N, NB):
# acc = zeros((MB, NB), dtype=float32)
# for k in range(0, K, KB):
# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
# C[m : m+MB, n : n+NB] = acc;
#
# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
# %%
# Compute Kernel
# ----------------
#
# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
#
# Pointer Arithmetics
# ~~~~~~~~~~~~~~~~~~~~
#
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
#
# .. code-block:: python
#
# &A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
# &B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
#
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
#
# .. code-block:: C
# :force:
#
# int rm[MB] = program_id_m * MB + 0 ... MB;
# int rn[NB] = program_id_n * NB + 0 ... NB;
# int rk[KB] = 0 ... KB;
# TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
#
# These pointers can then be updated in the inner loop as:
#
# .. code-block:: C
#
# pa += KB * 1;
# pb += KB * ldb;
#
#
# L2 Cache Optimizations
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
# This means that a naive row-major ordering:
#
# .. code-block:: C
#
# int program_id = get_program_id(0);
# int grid_m = (M + MB - 1) / MB;
# int grid_n = (N + NB - 1) / NB;
# int program_id_m = program_id / grid_n;
# int program_id_n = program_id % grid_n;
#
# is unlikely to result in optimal performance.
#
# One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
#
# .. code-block:: C
#
# int program_id = get_program_id(0);
# int width = GROUP_SIZE * grid_n;
# int group_id = pid / width;
# // we need to handle the case where M % (GROUP_SIZE*BM) != 0
# int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
# int pid_m = group_id * GROUP_SIZE + (pid % group_size);
# int pid_n = (pid % width) / (group_size);
#
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
#
# Final Result
# ~~~~~~~~~~~~~~
#
# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
#
# .. code-block:: C
# :force:
#
# #define MAX_GROUP_SIZE 8
#
# __global__ void dot(TYPE* A, TYPE* B, TYPE* C,
# int M, int N, int K,
# int stride_a_0, int stride_b_0, int stride_c_0) {
# // prologue
# int pid = get_program_id(0);
# int grid_m = (M + MB - 1) / MB;
# int grid_n = (N + NB - 1) / NB;
# // re-order program ID for better L2 performance
# int width = MAX_GROUP_SIZE * grid_n;
# int group_id = pid / width;
# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
# int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
# int pid_n = (pid % width) / (group_size);
# // pointers to operands
# // note the parentheses here; they force the offset
# // computation to happen in typeof(stride_a_0) = int32 rather than
# // typeof(A) = int64
# int rm[MB] = pid_m * MB + 0 ... MB;
# int rn[NB] = pid_n * NB + 0 ... NB;
# int rk[KB] = 0 ... KB;
# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
# // reduction loop
# float acc[MB, NB] = 0;
# for (int k = K; k > 0; k -= KB) {
# acc += (*pa) @ (*pb);
# pa += KB * 1;
# pb += KB * stride_b_0;
# }
# // pointers to output
# // here we rematerialize `rm` and `rn` so that they are not live through
# // the above reduction loop. In the future, the compiler should be able to
# // do this automatically.
# rm = pid_m * MB + 0 ... MB;
# rn = pid_n * NB + 0 ... NB;
# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
# // we write back using *?() operator. `acc` gets casted to `float32` implicitly.
# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
# }
#
# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
#
# %%
# Torch Bindings
# ----------------
#
# Auto-Tuning
# ~~~~~~~~~~~~~~
#
# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
import torch
import triton
autotune_configs = [
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
]
# %%
# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
# Here, we want to re-tune our kernel only when the shape of input matrices changes.
autotune_key = ["M", "N", "K"]
# %%
# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
src = """
#define MAX_GROUP_SIZE 8
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
int M, int N, int K,
int lda, int ldb, int ldc) {
int pid = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
int width = MAX_GROUP_SIZE * grid_n;
int group_id = pid / width;
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
int rm[MB] = pid_m * MB + 0 ... MB;
int rn[NB] = pid_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
float acc[MB, NB] = 0;
for (int k = K; k > 0; k -= KB) {
acc += (*pa) @ (*pb);
pa += KB * 1;
pb += KB * ldb;
}
rm = pid_m * MB + 0 ... MB;
rn = pid_n * NB + 0 ... NB;
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
}
"""
def make_kernel(device, dtype):
key = (device, dtype)
cache = make_kernel.cache
if key not in cache:
defines = {'TYPE': dtype}
cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
return cache[key]
make_kernel.cache = dict()
# %%
# Autograd Function
# ~~~~~~~~~~~~~~~~~~
#
# Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
# To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
class _dot(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
M, Ka = a.shape
Kb, N = b.shape
assert Ka == Kb, "incompatible dimensions"
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
kernel = make_kernel(a.device, a.dtype)
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka, \
a.stride(0), b.stride(0), c.stride(0), \
grid=grid)
return c
dot = _dot.apply
# %%
# Unit Test
# -----------
#
# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
c_0 = dot(a, b)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
# %%
# Benchmark
# --------------
#
# Installing The CUTLASS Bindings
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
# For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
# To install CUTLASS, you need a recent version of cmake:
#
# .. code-block:: bash
#
# cd /path/to/cutlass/
# git clone https://github.com/NVIDIA/cutlass.git
# cd cutlass
# mkdir build
# cd build
# wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
# tar xzvf *.tar.gz
#
# You can then install CUTLASS as follows for V100
#
# .. code-block:: bash
#
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
# make -j8 install
#
# Or as follows for A100:
#
# .. code-block:: bash
#
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
# make -j8 install
#
# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
# To re-install Triton with the updated CUTLASS bindings, run the following command:
#
# .. code-block:: bash
#
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
# pip uninstall -y triton
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
#
# Which we can test as follows:
import triton
c_2 = triton.testing.cutlass_matmul(a, b)
print(c_2)
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
# %%
# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
#
# Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'CUTLASS'], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
if provider == 'cutlass':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True)
# %%
# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.

View File

@@ -126,10 +126,19 @@ def make_kernel(N, device):
# Now are kernels are indexed not only by the provided device but also # Now are kernels are indexed not only by the provided device but also
# by the rounded number of columns in the input matrix # by the rounded number of columns in the input matrix
BLOCK = next_power_of_2(N) BLOCK = next_power_of_2(N)
key = (BLOCK, device) # Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
# that are longer
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Each (BLOCK, num_warps, device) results in a different kernel
key = (BLOCK, num_warps, device)
if key not in cache: if key not in cache:
defines = {'BLOCK': BLOCK} defines = {'BLOCK': BLOCK}
cache[key] = triton.kernel(_src, device=device, defines=defines) cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
return cache[key] return cache[key]
@@ -174,7 +183,7 @@ print(torch.allclose(y_tri, y_ref))
# As expected, the results are identical. # As expected, the results are identical.
# %% # %%
# Benchmarking # Benchmark
# ------------- # -------------
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. # Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. # We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.

View File

@@ -29,7 +29,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Torch bindings\nThe only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:\n\n- :code:`source: string`: the source-code of the kernel you want to create\n- :code:`device: torch.device`: the device you want to compile this code for\n- :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n\n" "## Torch Bindings\nThe only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:\n\n- :code:`source: string`: the source-code of the kernel you want to create\n- :code:`device: torch.device`: the device you want to compile this code for\n- :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n\n"
] ]
}, },
{ {
@@ -79,7 +79,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Benchmarking\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.\nTo make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.\nfor different problem sizes.\n\n" "## Benchmark\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.\nTo make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.\nfor different problem sizes.\n\n"
] ]
}, },
{ {
@@ -128,7 +128,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.9" "version": "3.7.3"
} }
}, },
"nbformat": 4, "nbformat": 4,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@@ -70,7 +70,7 @@ The existence of arrays as a primitive data-type for Triton comes with a number
.. GENERATED FROM PYTHON SOURCE LINES 53-60 .. GENERATED FROM PYTHON SOURCE LINES 53-60
Torch bindings Torch Bindings
-------------------------- --------------------------
The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things: The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
@@ -161,7 +161,7 @@ We can now use the above function to compute the sum of two `torch.tensor` objec
.. GENERATED FROM PYTHON SOURCE LINES 129-133 .. GENERATED FROM PYTHON SOURCE LINES 129-133
Unit Test Unit Test
-------------------------- -----------
Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below: Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
@@ -189,8 +189,8 @@ Of course, the first thing that we should check is that whether kernel is correc
.. code-block:: none .. code-block:: none
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0') tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0') tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0 The maximum difference between torch and triton is 0.0
@@ -202,8 +202,8 @@ Seems like we're good to go!
.. GENERATED FROM PYTHON SOURCE LINES 147-152 .. GENERATED FROM PYTHON SOURCE LINES 147-152
Benchmarking Benchmark
-------------------------- -----------
We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch. We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op. To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
for different problem sizes. for different problem sizes.
@@ -268,7 +268,7 @@ We can now run the decorated function above. Pass `show_plots=True` to see the p
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 5.901 seconds) **Total running time of the script:** ( 0 minutes 7.521 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py: .. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:

View File

@@ -121,7 +121,7 @@ Here our torch bindings is quite similar to that of the vector addition mentione
We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix. We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
This means that different values of BLOCK will result in different kernels This means that different values of BLOCK will result in different kernels
.. GENERATED FROM PYTHON SOURCE LINES 89-156 .. GENERATED FROM PYTHON SOURCE LINES 89-165
.. code-block:: default .. code-block:: default
@@ -165,10 +165,19 @@ This means that different values of BLOCK will result in different kernels
# Now are kernels are indexed not only by the provided device but also # Now are kernels are indexed not only by the provided device but also
# by the rounded number of columns in the input matrix # by the rounded number of columns in the input matrix
BLOCK = next_power_of_2(N) BLOCK = next_power_of_2(N)
key = (BLOCK, device) # Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
# that are longer
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Each (BLOCK, num_warps, device) results in a different kernel
key = (BLOCK, num_warps, device)
if key not in cache: if key not in cache:
defines = {'BLOCK': BLOCK} defines = {'BLOCK': BLOCK}
cache[key] = triton.kernel(_src, device=device, defines=defines) cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
return cache[key] return cache[key]
@@ -199,21 +208,21 @@ This means that different values of BLOCK will result in different kernels
.. GENERATED FROM PYTHON SOURCE LINES 157-158 .. GENERATED FROM PYTHON SOURCE LINES 166-167
We can use the above softmax function to compute the row-wise softmax of a given matrix. We can use the above softmax function to compute the row-wise softmax of a given matrix.
.. GENERATED FROM PYTHON SOURCE LINES 160-162 .. GENERATED FROM PYTHON SOURCE LINES 169-171
Unit Test Unit Test
---------- ----------
.. GENERATED FROM PYTHON SOURCE LINES 164-166 .. GENERATED FROM PYTHON SOURCE LINES 173-175
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works. This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 166-173 .. GENERATED FROM PYTHON SOURCE LINES 175-182
.. code-block:: default .. code-block:: default
@@ -239,18 +248,18 @@ This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 174-175 .. GENERATED FROM PYTHON SOURCE LINES 183-184
As expected, the results are identical. As expected, the results are identical.
.. GENERATED FROM PYTHON SOURCE LINES 177-181 .. GENERATED FROM PYTHON SOURCE LINES 186-190
Benchmarking Benchmark
------------- -------------
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
.. GENERATED FROM PYTHON SOURCE LINES 181-209 .. GENERATED FROM PYTHON SOURCE LINES 190-218
.. code-block:: default .. code-block:: default
@@ -293,7 +302,7 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
.. GENERATED FROM PYTHON SOURCE LINES 210-215 .. GENERATED FROM PYTHON SOURCE LINES 219-224
In the above plot, we can see that: In the above plot, we can see that:
@@ -305,7 +314,7 @@ In the above plot, we can see that:
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 21.805 seconds) **Total running time of the script:** ( 0 minutes 19.896 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:

View File

@@ -0,0 +1,558 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/03-matrix-multiplication.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_03-matrix-multiplication.py:
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
You will specifically learn about:
- The block-level matrix multiplication operator `@`
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
.. GENERATED FROM PYTHON SOURCE LINES 14-35
Motivations
-------------
Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
.. code-block:: python
# do in parallel
for m in range(0, M, MB):
# do in parallel
for n in range(0, N, NB):
acc = zeros((MB, NB), dtype=float32)
for k in range(0, K, KB):
acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
C[m : m+MB, n : n+NB] = acc;
where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
.. GENERATED FROM PYTHON SOURCE LINES 37-161
Compute Kernel
----------------
The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
Pointer Arithmetics
~~~~~~~~~~~~~~~~~~~~
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
.. code-block:: python
&A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
&B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
.. code-block:: C
:force:
int rm[MB] = program_id_m * MB + 0 ... MB;
int rn[NB] = program_id_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
These pointers can then be updated in the inner loop as:
.. code-block:: C
pa += KB * 1;
pb += KB * ldb;
L2 Cache Optimizations
~~~~~~~~~~~~~~~~~~~~~~~~
As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
This means that a naive row-major ordering:
.. code-block:: C
int program_id = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
int program_id_m = program_id / grid_n;
int program_id_n = program_id % grid_n;
is unlikely to result in optimal performance.
One possible solution is to launch blocks in an order that promotes data reuse.
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
.. code-block:: C
int program_id = get_program_id(0);
int width = GROUP_SIZE * grid_n;
int group_id = pid / width;
// we need to handle the case where M % (GROUP_SIZE*BM) != 0
int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
int pid_m = group_id * GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
Final Result
~~~~~~~~~~~~~~
We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
.. code-block:: C
:force:
#define MAX_GROUP_SIZE 8
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
int M, int N, int K,
int stride_a_0, int stride_b_0, int stride_c_0) {
// prologue
int pid = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
// re-order program ID for better L2 performance
int width = MAX_GROUP_SIZE * grid_n;
int group_id = pid / width;
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
// pointers to operands
// note the parentheses here; they force the offset
// computation to happen in typeof(stride_a_0) = int32 rather than
// typeof(A) = int64
int rm[MB] = pid_m * MB + 0 ... MB;
int rn[NB] = pid_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
// reduction loop
float acc[MB, NB] = 0;
for (int k = K; k > 0; k -= KB) {
acc += (*pa) @ (*pb);
pa += KB * 1;
pb += KB * stride_b_0;
}
// pointers to output
// here we rematerialize `rm` and `rn` so that they are not live through
// the above reduction loop. In the future, the compiler should be able to
// do this automatically.
rm = pid_m * MB + 0 ... MB;
rn = pid_n * NB + 0 ... NB;
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
// we write back using *?() operator. `acc` gets casted to `float32` implicitly.
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
}
Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
.. GENERATED FROM PYTHON SOURCE LINES 163-170
Torch Bindings
----------------
Auto-Tuning
~~~~~~~~~~~~~~
In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
.. GENERATED FROM PYTHON SOURCE LINES 170-185
.. code-block:: default
import torch
import triton
autotune_configs = [
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
]
.. GENERATED FROM PYTHON SOURCE LINES 186-188
we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
Here, we want to re-tune our kernel only when the shape of input matrices changes.
.. GENERATED FROM PYTHON SOURCE LINES 188-191
.. code-block:: default
autotune_key = ["M", "N", "K"]
.. GENERATED FROM PYTHON SOURCE LINES 192-193
We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
.. GENERATED FROM PYTHON SOURCE LINES 193-238
.. code-block:: default
src = """
#define MAX_GROUP_SIZE 8
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
int M, int N, int K,
int lda, int ldb, int ldc) {
int pid = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
int width = MAX_GROUP_SIZE * grid_n;
int group_id = pid / width;
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
int rm[MB] = pid_m * MB + 0 ... MB;
int rn[NB] = pid_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
float acc[MB, NB] = 0;
for (int k = K; k > 0; k -= KB) {
acc += (*pa) @ (*pb);
pa += KB * 1;
pb += KB * ldb;
}
rm = pid_m * MB + 0 ... MB;
rn = pid_n * NB + 0 ... NB;
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
}
"""
def make_kernel(device, dtype):
key = (device, dtype)
cache = make_kernel.cache
if key not in cache:
defines = {'TYPE': dtype}
cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
return cache[key]
make_kernel.cache = dict()
.. GENERATED FROM PYTHON SOURCE LINES 239-244
Autograd Function
~~~~~~~~~~~~~~~~~~
Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
.. GENERATED FROM PYTHON SOURCE LINES 244-265
.. code-block:: default
class _dot(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
M, Ka = a.shape
Kb, N = b.shape
assert Ka == Kb, "incompatible dimensions"
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
kernel = make_kernel(a.device, a.dtype)
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka, \
a.stride(0), b.stride(0), c.stride(0), \
grid=grid)
return c
dot = _dot.apply
.. GENERATED FROM PYTHON SOURCE LINES 266-271
Unit Test
-----------
We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
.. GENERATED FROM PYTHON SOURCE LINES 271-280
.. code-block:: default
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
c_0 = dot(a, b)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device='cuda:0', dtype=torch.float16)
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device='cuda:0', dtype=torch.float16)
True
.. GENERATED FROM PYTHON SOURCE LINES 281-327
Benchmark
--------------
Installing The CUTLASS Bindings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
To install CUTLASS, you need a recent version of cmake:
.. code-block:: bash
cd /path/to/cutlass/
git clone https://github.com/NVIDIA/cutlass.git
cd cutlass
mkdir build
cd build
wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
tar xzvf *.tar.gz
You can then install CUTLASS as follows for V100
.. code-block:: bash
./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
make -j8 install
Or as follows for A100:
.. code-block:: bash
./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
make -j8 install
Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
To re-install Triton with the updated CUTLASS bindings, run the following command:
.. code-block:: bash
export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
pip uninstall -y triton
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
Which we can test as follows:
.. GENERATED FROM PYTHON SOURCE LINES 327-333
.. code-block:: default
import triton
c_2 = triton.testing.cutlass_matmul(a, b)
print(c_2)
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device='cuda:0', dtype=torch.float16)
True
.. GENERATED FROM PYTHON SOURCE LINES 334-339
Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
Square Matrix Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~
We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
.. GENERATED FROM PYTHON SOURCE LINES 339-368
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'CUTLASS'], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
if provider == 'cutlass':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True)
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
:alt: matmul-performance
:class: sphx-glr-single-img
.. GENERATED FROM PYTHON SOURCE LINES 369-369
As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 10.181 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 03-matrix-multiplication.py <03-matrix-multiplication.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 03-matrix-multiplication.ipynb <03-matrix-multiplication.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -51,6 +51,27 @@ Below is a gallery of tutorials for writing various basic operations with Triton
:hidden: :hidden:
/getting-started/tutorials/02-fused-softmax /getting-started/tutorials/02-fused-softmax
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- The block-level matrix multiplication operator @ - Multi-dimensional pointer arithmetic - Pro...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_03-matrix-multiplication_thumb.png
:alt: Matrix Multiplication
:ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/03-matrix-multiplication
.. raw:: html .. raw:: html
<div class="sphx-glr-clear"></div> <div class="sphx-glr-clear"></div>

View File

@@ -5,10 +5,12 @@
Computation times Computation times
================= =================
**00:27.706** total execution time for **getting-started_tutorials** files: **01:10.181** total execution time for **getting-started_tutorials** files:
+-----------------------------------------------------------------------------------------+-----------+--------+ +---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:21.805 | 0.0 MB | | :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 01:10.181 | 0.0 MB |
+-----------------------------------------------------------------------------------------+-----------+--------+ +---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:05.901 | 0.0 MB | | :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------+-----------+--------+ +---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:00.000 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+

Binary file not shown.

View File

@@ -29,9 +29,14 @@ if (!window.console || !console.firebug) {
/** /**
* small helper function to urldecode strings * small helper function to urldecode strings
*
* See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL
*/ */
jQuery.urldecode = function(x) { jQuery.urldecode = function(x) {
return decodeURIComponent(x).replace(/\+/g, ' '); if (!x) {
return x
}
return decodeURIComponent(x.replace(/\+/g, ' '));
}; };
/** /**

View File

@@ -13,7 +13,8 @@
var stopwords = ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"]; var stopwords = ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"];
/* Non-minified version JS is _stemmer.js if file is provided */ /* Non-minified version is copied as a separate JS file, is available */
/** /**
* Porter Stemmer * Porter Stemmer
*/ */
@@ -199,7 +200,6 @@ var Stemmer = function() {
var splitChars = (function() { var splitChars = (function() {
var result = {}; var result = {};
var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648, var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648,

View File

@@ -1,7 +1,7 @@
pre { line-height: 125%; } pre { line-height: 125%; }
td.linenos pre { color: #000000; background-color: #f0f0f0; padding-left: 5px; padding-right: 5px; } td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
span.linenos { color: #000000; background-color: #f0f0f0; padding-left: 5px; padding-right: 5px; } span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
td.linenos pre.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
.highlight .hll { background-color: #ffffcc } .highlight .hll { background-color: #ffffcc }
.highlight { background: #eeffcc; } .highlight { background: #eeffcc; }

View File

@@ -248,7 +248,7 @@ var Search = {
// results left, load the summary and display it // results left, load the summary and display it
if (results.length) { if (results.length) {
var item = results.pop(); var item = results.pop();
var listItem = $('<li style="display:none"></li>'); var listItem = $('<li></li>');
var requestUrl = ""; var requestUrl = "";
var linkUrl = ""; var linkUrl = "";
if (DOCUMENTATION_OPTIONS.BUILDER === 'dirhtml') { if (DOCUMENTATION_OPTIONS.BUILDER === 'dirhtml') {
@@ -273,9 +273,9 @@ var Search = {
if (item[3]) { if (item[3]) {
listItem.append($('<span> (' + item[3] + ')</span>')); listItem.append($('<span> (' + item[3] + ')</span>'));
Search.output.append(listItem); Search.output.append(listItem);
listItem.slideDown(5, function() { setTimeout(function() {
displayNextItem(); displayNextItem();
}); }, 5);
} else if (DOCUMENTATION_OPTIONS.HAS_SOURCE) { } else if (DOCUMENTATION_OPTIONS.HAS_SOURCE) {
$.ajax({url: requestUrl, $.ajax({url: requestUrl,
dataType: "text", dataType: "text",
@@ -285,16 +285,16 @@ var Search = {
listItem.append(Search.makeSearchSummary(data, searchterms, hlterms)); listItem.append(Search.makeSearchSummary(data, searchterms, hlterms));
} }
Search.output.append(listItem); Search.output.append(listItem);
listItem.slideDown(5, function() { setTimeout(function() {
displayNextItem(); displayNextItem();
}); }, 5);
}}); }});
} else { } else {
// no source available, just display title // no source available, just display title
Search.output.append(listItem); Search.output.append(listItem);
listItem.slideDown(5, function() { setTimeout(function() {
displayNextItem(); displayNextItem();
}); }, 5);
} }
} }
// search finished, update title and status message // search finished, update title and status message
@@ -379,6 +379,13 @@ var Search = {
return results; return results;
}, },
/**
* See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Regular_Expressions
*/
escapeRegExp : function(string) {
return string.replace(/[.*+\-?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
},
/** /**
* search for full-text terms in the index * search for full-text terms in the index
*/ */
@@ -402,13 +409,14 @@ var Search = {
]; ];
// add support for partial matches // add support for partial matches
if (word.length > 2) { if (word.length > 2) {
var word_regex = this.escapeRegExp(word);
for (var w in terms) { for (var w in terms) {
if (w.match(word) && !terms[word]) { if (w.match(word_regex) && !terms[word]) {
_o.push({files: terms[w], score: Scorer.partialTerm}) _o.push({files: terms[w], score: Scorer.partialTerm})
} }
} }
for (var w in titleterms) { for (var w in titleterms) {
if (w.match(word) && !titleterms[word]) { if (w.match(word_regex) && !titleterms[word]) {
_o.push({files: titleterms[w], score: Scorer.partialTitle}) _o.push({files: titleterms[w], score: Scorer.partialTitle})
} }
} }

2027
_static/underscore-1.12.0.js Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,999 +0,0 @@
// Underscore.js 1.3.1
// (c) 2009-2012 Jeremy Ashkenas, DocumentCloud Inc.
// Underscore is freely distributable under the MIT license.
// Portions of Underscore are inspired or borrowed from Prototype,
// Oliver Steele's Functional, and John Resig's Micro-Templating.
// For all details and documentation:
// http://documentcloud.github.com/underscore
(function() {
// Baseline setup
// --------------
// Establish the root object, `window` in the browser, or `global` on the server.
var root = this;
// Save the previous value of the `_` variable.
var previousUnderscore = root._;
// Establish the object that gets returned to break out of a loop iteration.
var breaker = {};
// Save bytes in the minified (but not gzipped) version:
var ArrayProto = Array.prototype, ObjProto = Object.prototype, FuncProto = Function.prototype;
// Create quick reference variables for speed access to core prototypes.
var slice = ArrayProto.slice,
unshift = ArrayProto.unshift,
toString = ObjProto.toString,
hasOwnProperty = ObjProto.hasOwnProperty;
// All **ECMAScript 5** native function implementations that we hope to use
// are declared here.
var
nativeForEach = ArrayProto.forEach,
nativeMap = ArrayProto.map,
nativeReduce = ArrayProto.reduce,
nativeReduceRight = ArrayProto.reduceRight,
nativeFilter = ArrayProto.filter,
nativeEvery = ArrayProto.every,
nativeSome = ArrayProto.some,
nativeIndexOf = ArrayProto.indexOf,
nativeLastIndexOf = ArrayProto.lastIndexOf,
nativeIsArray = Array.isArray,
nativeKeys = Object.keys,
nativeBind = FuncProto.bind;
// Create a safe reference to the Underscore object for use below.
var _ = function(obj) { return new wrapper(obj); };
// Export the Underscore object for **Node.js**, with
// backwards-compatibility for the old `require()` API. If we're in
// the browser, add `_` as a global object via a string identifier,
// for Closure Compiler "advanced" mode.
if (typeof exports !== 'undefined') {
if (typeof module !== 'undefined' && module.exports) {
exports = module.exports = _;
}
exports._ = _;
} else {
root['_'] = _;
}
// Current version.
_.VERSION = '1.3.1';
// Collection Functions
// --------------------
// The cornerstone, an `each` implementation, aka `forEach`.
// Handles objects with the built-in `forEach`, arrays, and raw objects.
// Delegates to **ECMAScript 5**'s native `forEach` if available.
var each = _.each = _.forEach = function(obj, iterator, context) {
if (obj == null) return;
if (nativeForEach && obj.forEach === nativeForEach) {
obj.forEach(iterator, context);
} else if (obj.length === +obj.length) {
for (var i = 0, l = obj.length; i < l; i++) {
if (i in obj && iterator.call(context, obj[i], i, obj) === breaker) return;
}
} else {
for (var key in obj) {
if (_.has(obj, key)) {
if (iterator.call(context, obj[key], key, obj) === breaker) return;
}
}
}
};
// Return the results of applying the iterator to each element.
// Delegates to **ECMAScript 5**'s native `map` if available.
_.map = _.collect = function(obj, iterator, context) {
var results = [];
if (obj == null) return results;
if (nativeMap && obj.map === nativeMap) return obj.map(iterator, context);
each(obj, function(value, index, list) {
results[results.length] = iterator.call(context, value, index, list);
});
if (obj.length === +obj.length) results.length = obj.length;
return results;
};
// **Reduce** builds up a single result from a list of values, aka `inject`,
// or `foldl`. Delegates to **ECMAScript 5**'s native `reduce` if available.
_.reduce = _.foldl = _.inject = function(obj, iterator, memo, context) {
var initial = arguments.length > 2;
if (obj == null) obj = [];
if (nativeReduce && obj.reduce === nativeReduce) {
if (context) iterator = _.bind(iterator, context);
return initial ? obj.reduce(iterator, memo) : obj.reduce(iterator);
}
each(obj, function(value, index, list) {
if (!initial) {
memo = value;
initial = true;
} else {
memo = iterator.call(context, memo, value, index, list);
}
});
if (!initial) throw new TypeError('Reduce of empty array with no initial value');
return memo;
};
// The right-associative version of reduce, also known as `foldr`.
// Delegates to **ECMAScript 5**'s native `reduceRight` if available.
_.reduceRight = _.foldr = function(obj, iterator, memo, context) {
var initial = arguments.length > 2;
if (obj == null) obj = [];
if (nativeReduceRight && obj.reduceRight === nativeReduceRight) {
if (context) iterator = _.bind(iterator, context);
return initial ? obj.reduceRight(iterator, memo) : obj.reduceRight(iterator);
}
var reversed = _.toArray(obj).reverse();
if (context && !initial) iterator = _.bind(iterator, context);
return initial ? _.reduce(reversed, iterator, memo, context) : _.reduce(reversed, iterator);
};
// Return the first value which passes a truth test. Aliased as `detect`.
_.find = _.detect = function(obj, iterator, context) {
var result;
any(obj, function(value, index, list) {
if (iterator.call(context, value, index, list)) {
result = value;
return true;
}
});
return result;
};
// Return all the elements that pass a truth test.
// Delegates to **ECMAScript 5**'s native `filter` if available.
// Aliased as `select`.
_.filter = _.select = function(obj, iterator, context) {
var results = [];
if (obj == null) return results;
if (nativeFilter && obj.filter === nativeFilter) return obj.filter(iterator, context);
each(obj, function(value, index, list) {
if (iterator.call(context, value, index, list)) results[results.length] = value;
});
return results;
};
// Return all the elements for which a truth test fails.
_.reject = function(obj, iterator, context) {
var results = [];
if (obj == null) return results;
each(obj, function(value, index, list) {
if (!iterator.call(context, value, index, list)) results[results.length] = value;
});
return results;
};
// Determine whether all of the elements match a truth test.
// Delegates to **ECMAScript 5**'s native `every` if available.
// Aliased as `all`.
_.every = _.all = function(obj, iterator, context) {
var result = true;
if (obj == null) return result;
if (nativeEvery && obj.every === nativeEvery) return obj.every(iterator, context);
each(obj, function(value, index, list) {
if (!(result = result && iterator.call(context, value, index, list))) return breaker;
});
return result;
};
// Determine if at least one element in the object matches a truth test.
// Delegates to **ECMAScript 5**'s native `some` if available.
// Aliased as `any`.
var any = _.some = _.any = function(obj, iterator, context) {
iterator || (iterator = _.identity);
var result = false;
if (obj == null) return result;
if (nativeSome && obj.some === nativeSome) return obj.some(iterator, context);
each(obj, function(value, index, list) {
if (result || (result = iterator.call(context, value, index, list))) return breaker;
});
return !!result;
};
// Determine if a given value is included in the array or object using `===`.
// Aliased as `contains`.
_.include = _.contains = function(obj, target) {
var found = false;
if (obj == null) return found;
if (nativeIndexOf && obj.indexOf === nativeIndexOf) return obj.indexOf(target) != -1;
found = any(obj, function(value) {
return value === target;
});
return found;
};
// Invoke a method (with arguments) on every item in a collection.
_.invoke = function(obj, method) {
var args = slice.call(arguments, 2);
return _.map(obj, function(value) {
return (_.isFunction(method) ? method || value : value[method]).apply(value, args);
});
};
// Convenience version of a common use case of `map`: fetching a property.
_.pluck = function(obj, key) {
return _.map(obj, function(value){ return value[key]; });
};
// Return the maximum element or (element-based computation).
_.max = function(obj, iterator, context) {
if (!iterator && _.isArray(obj)) return Math.max.apply(Math, obj);
if (!iterator && _.isEmpty(obj)) return -Infinity;
var result = {computed : -Infinity};
each(obj, function(value, index, list) {
var computed = iterator ? iterator.call(context, value, index, list) : value;
computed >= result.computed && (result = {value : value, computed : computed});
});
return result.value;
};
// Return the minimum element (or element-based computation).
_.min = function(obj, iterator, context) {
if (!iterator && _.isArray(obj)) return Math.min.apply(Math, obj);
if (!iterator && _.isEmpty(obj)) return Infinity;
var result = {computed : Infinity};
each(obj, function(value, index, list) {
var computed = iterator ? iterator.call(context, value, index, list) : value;
computed < result.computed && (result = {value : value, computed : computed});
});
return result.value;
};
// Shuffle an array.
_.shuffle = function(obj) {
var shuffled = [], rand;
each(obj, function(value, index, list) {
if (index == 0) {
shuffled[0] = value;
} else {
rand = Math.floor(Math.random() * (index + 1));
shuffled[index] = shuffled[rand];
shuffled[rand] = value;
}
});
return shuffled;
};
// Sort the object's values by a criterion produced by an iterator.
_.sortBy = function(obj, iterator, context) {
return _.pluck(_.map(obj, function(value, index, list) {
return {
value : value,
criteria : iterator.call(context, value, index, list)
};
}).sort(function(left, right) {
var a = left.criteria, b = right.criteria;
return a < b ? -1 : a > b ? 1 : 0;
}), 'value');
};
// Groups the object's values by a criterion. Pass either a string attribute
// to group by, or a function that returns the criterion.
_.groupBy = function(obj, val) {
var result = {};
var iterator = _.isFunction(val) ? val : function(obj) { return obj[val]; };
each(obj, function(value, index) {
var key = iterator(value, index);
(result[key] || (result[key] = [])).push(value);
});
return result;
};
// Use a comparator function to figure out at what index an object should
// be inserted so as to maintain order. Uses binary search.
_.sortedIndex = function(array, obj, iterator) {
iterator || (iterator = _.identity);
var low = 0, high = array.length;
while (low < high) {
var mid = (low + high) >> 1;
iterator(array[mid]) < iterator(obj) ? low = mid + 1 : high = mid;
}
return low;
};
// Safely convert anything iterable into a real, live array.
_.toArray = function(iterable) {
if (!iterable) return [];
if (iterable.toArray) return iterable.toArray();
if (_.isArray(iterable)) return slice.call(iterable);
if (_.isArguments(iterable)) return slice.call(iterable);
return _.values(iterable);
};
// Return the number of elements in an object.
_.size = function(obj) {
return _.toArray(obj).length;
};
// Array Functions
// ---------------
// Get the first element of an array. Passing **n** will return the first N
// values in the array. Aliased as `head`. The **guard** check allows it to work
// with `_.map`.
_.first = _.head = function(array, n, guard) {
return (n != null) && !guard ? slice.call(array, 0, n) : array[0];
};
// Returns everything but the last entry of the array. Especcialy useful on
// the arguments object. Passing **n** will return all the values in
// the array, excluding the last N. The **guard** check allows it to work with
// `_.map`.
_.initial = function(array, n, guard) {
return slice.call(array, 0, array.length - ((n == null) || guard ? 1 : n));
};
// Get the last element of an array. Passing **n** will return the last N
// values in the array. The **guard** check allows it to work with `_.map`.
_.last = function(array, n, guard) {
if ((n != null) && !guard) {
return slice.call(array, Math.max(array.length - n, 0));
} else {
return array[array.length - 1];
}
};
// Returns everything but the first entry of the array. Aliased as `tail`.
// Especially useful on the arguments object. Passing an **index** will return
// the rest of the values in the array from that index onward. The **guard**
// check allows it to work with `_.map`.
_.rest = _.tail = function(array, index, guard) {
return slice.call(array, (index == null) || guard ? 1 : index);
};
// Trim out all falsy values from an array.
_.compact = function(array) {
return _.filter(array, function(value){ return !!value; });
};
// Return a completely flattened version of an array.
_.flatten = function(array, shallow) {
return _.reduce(array, function(memo, value) {
if (_.isArray(value)) return memo.concat(shallow ? value : _.flatten(value));
memo[memo.length] = value;
return memo;
}, []);
};
// Return a version of the array that does not contain the specified value(s).
_.without = function(array) {
return _.difference(array, slice.call(arguments, 1));
};
// Produce a duplicate-free version of the array. If the array has already
// been sorted, you have the option of using a faster algorithm.
// Aliased as `unique`.
_.uniq = _.unique = function(array, isSorted, iterator) {
var initial = iterator ? _.map(array, iterator) : array;
var result = [];
_.reduce(initial, function(memo, el, i) {
if (0 == i || (isSorted === true ? _.last(memo) != el : !_.include(memo, el))) {
memo[memo.length] = el;
result[result.length] = array[i];
}
return memo;
}, []);
return result;
};
// Produce an array that contains the union: each distinct element from all of
// the passed-in arrays.
_.union = function() {
return _.uniq(_.flatten(arguments, true));
};
// Produce an array that contains every item shared between all the
// passed-in arrays. (Aliased as "intersect" for back-compat.)
_.intersection = _.intersect = function(array) {
var rest = slice.call(arguments, 1);
return _.filter(_.uniq(array), function(item) {
return _.every(rest, function(other) {
return _.indexOf(other, item) >= 0;
});
});
};
// Take the difference between one array and a number of other arrays.
// Only the elements present in just the first array will remain.
_.difference = function(array) {
var rest = _.flatten(slice.call(arguments, 1));
return _.filter(array, function(value){ return !_.include(rest, value); });
};
// Zip together multiple lists into a single array -- elements that share
// an index go together.
_.zip = function() {
var args = slice.call(arguments);
var length = _.max(_.pluck(args, 'length'));
var results = new Array(length);
for (var i = 0; i < length; i++) results[i] = _.pluck(args, "" + i);
return results;
};
// If the browser doesn't supply us with indexOf (I'm looking at you, **MSIE**),
// we need this function. Return the position of the first occurrence of an
// item in an array, or -1 if the item is not included in the array.
// Delegates to **ECMAScript 5**'s native `indexOf` if available.
// If the array is large and already in sort order, pass `true`
// for **isSorted** to use binary search.
_.indexOf = function(array, item, isSorted) {
if (array == null) return -1;
var i, l;
if (isSorted) {
i = _.sortedIndex(array, item);
return array[i] === item ? i : -1;
}
if (nativeIndexOf && array.indexOf === nativeIndexOf) return array.indexOf(item);
for (i = 0, l = array.length; i < l; i++) if (i in array && array[i] === item) return i;
return -1;
};
// Delegates to **ECMAScript 5**'s native `lastIndexOf` if available.
_.lastIndexOf = function(array, item) {
if (array == null) return -1;
if (nativeLastIndexOf && array.lastIndexOf === nativeLastIndexOf) return array.lastIndexOf(item);
var i = array.length;
while (i--) if (i in array && array[i] === item) return i;
return -1;
};
// Generate an integer Array containing an arithmetic progression. A port of
// the native Python `range()` function. See
// [the Python documentation](http://docs.python.org/library/functions.html#range).
_.range = function(start, stop, step) {
if (arguments.length <= 1) {
stop = start || 0;
start = 0;
}
step = arguments[2] || 1;
var len = Math.max(Math.ceil((stop - start) / step), 0);
var idx = 0;
var range = new Array(len);
while(idx < len) {
range[idx++] = start;
start += step;
}
return range;
};
// Function (ahem) Functions
// ------------------
// Reusable constructor function for prototype setting.
var ctor = function(){};
// Create a function bound to a given object (assigning `this`, and arguments,
// optionally). Binding with arguments is also known as `curry`.
// Delegates to **ECMAScript 5**'s native `Function.bind` if available.
// We check for `func.bind` first, to fail fast when `func` is undefined.
_.bind = function bind(func, context) {
var bound, args;
if (func.bind === nativeBind && nativeBind) return nativeBind.apply(func, slice.call(arguments, 1));
if (!_.isFunction(func)) throw new TypeError;
args = slice.call(arguments, 2);
return bound = function() {
if (!(this instanceof bound)) return func.apply(context, args.concat(slice.call(arguments)));
ctor.prototype = func.prototype;
var self = new ctor;
var result = func.apply(self, args.concat(slice.call(arguments)));
if (Object(result) === result) return result;
return self;
};
};
// Bind all of an object's methods to that object. Useful for ensuring that
// all callbacks defined on an object belong to it.
_.bindAll = function(obj) {
var funcs = slice.call(arguments, 1);
if (funcs.length == 0) funcs = _.functions(obj);
each(funcs, function(f) { obj[f] = _.bind(obj[f], obj); });
return obj;
};
// Memoize an expensive function by storing its results.
_.memoize = function(func, hasher) {
var memo = {};
hasher || (hasher = _.identity);
return function() {
var key = hasher.apply(this, arguments);
return _.has(memo, key) ? memo[key] : (memo[key] = func.apply(this, arguments));
};
};
// Delays a function for the given number of milliseconds, and then calls
// it with the arguments supplied.
_.delay = function(func, wait) {
var args = slice.call(arguments, 2);
return setTimeout(function(){ return func.apply(func, args); }, wait);
};
// Defers a function, scheduling it to run after the current call stack has
// cleared.
_.defer = function(func) {
return _.delay.apply(_, [func, 1].concat(slice.call(arguments, 1)));
};
// Returns a function, that, when invoked, will only be triggered at most once
// during a given window of time.
_.throttle = function(func, wait) {
var context, args, timeout, throttling, more;
var whenDone = _.debounce(function(){ more = throttling = false; }, wait);
return function() {
context = this; args = arguments;
var later = function() {
timeout = null;
if (more) func.apply(context, args);
whenDone();
};
if (!timeout) timeout = setTimeout(later, wait);
if (throttling) {
more = true;
} else {
func.apply(context, args);
}
whenDone();
throttling = true;
};
};
// Returns a function, that, as long as it continues to be invoked, will not
// be triggered. The function will be called after it stops being called for
// N milliseconds.
_.debounce = function(func, wait) {
var timeout;
return function() {
var context = this, args = arguments;
var later = function() {
timeout = null;
func.apply(context, args);
};
clearTimeout(timeout);
timeout = setTimeout(later, wait);
};
};
// Returns a function that will be executed at most one time, no matter how
// often you call it. Useful for lazy initialization.
_.once = function(func) {
var ran = false, memo;
return function() {
if (ran) return memo;
ran = true;
return memo = func.apply(this, arguments);
};
};
// Returns the first function passed as an argument to the second,
// allowing you to adjust arguments, run code before and after, and
// conditionally execute the original function.
_.wrap = function(func, wrapper) {
return function() {
var args = [func].concat(slice.call(arguments, 0));
return wrapper.apply(this, args);
};
};
// Returns a function that is the composition of a list of functions, each
// consuming the return value of the function that follows.
_.compose = function() {
var funcs = arguments;
return function() {
var args = arguments;
for (var i = funcs.length - 1; i >= 0; i--) {
args = [funcs[i].apply(this, args)];
}
return args[0];
};
};
// Returns a function that will only be executed after being called N times.
_.after = function(times, func) {
if (times <= 0) return func();
return function() {
if (--times < 1) { return func.apply(this, arguments); }
};
};
// Object Functions
// ----------------
// Retrieve the names of an object's properties.
// Delegates to **ECMAScript 5**'s native `Object.keys`
_.keys = nativeKeys || function(obj) {
if (obj !== Object(obj)) throw new TypeError('Invalid object');
var keys = [];
for (var key in obj) if (_.has(obj, key)) keys[keys.length] = key;
return keys;
};
// Retrieve the values of an object's properties.
_.values = function(obj) {
return _.map(obj, _.identity);
};
// Return a sorted list of the function names available on the object.
// Aliased as `methods`
_.functions = _.methods = function(obj) {
var names = [];
for (var key in obj) {
if (_.isFunction(obj[key])) names.push(key);
}
return names.sort();
};
// Extend a given object with all the properties in passed-in object(s).
_.extend = function(obj) {
each(slice.call(arguments, 1), function(source) {
for (var prop in source) {
obj[prop] = source[prop];
}
});
return obj;
};
// Fill in a given object with default properties.
_.defaults = function(obj) {
each(slice.call(arguments, 1), function(source) {
for (var prop in source) {
if (obj[prop] == null) obj[prop] = source[prop];
}
});
return obj;
};
// Create a (shallow-cloned) duplicate of an object.
_.clone = function(obj) {
if (!_.isObject(obj)) return obj;
return _.isArray(obj) ? obj.slice() : _.extend({}, obj);
};
// Invokes interceptor with the obj, and then returns obj.
// The primary purpose of this method is to "tap into" a method chain, in
// order to perform operations on intermediate results within the chain.
_.tap = function(obj, interceptor) {
interceptor(obj);
return obj;
};
// Internal recursive comparison function.
function eq(a, b, stack) {
// Identical objects are equal. `0 === -0`, but they aren't identical.
// See the Harmony `egal` proposal: http://wiki.ecmascript.org/doku.php?id=harmony:egal.
if (a === b) return a !== 0 || 1 / a == 1 / b;
// A strict comparison is necessary because `null == undefined`.
if (a == null || b == null) return a === b;
// Unwrap any wrapped objects.
if (a._chain) a = a._wrapped;
if (b._chain) b = b._wrapped;
// Invoke a custom `isEqual` method if one is provided.
if (a.isEqual && _.isFunction(a.isEqual)) return a.isEqual(b);
if (b.isEqual && _.isFunction(b.isEqual)) return b.isEqual(a);
// Compare `[[Class]]` names.
var className = toString.call(a);
if (className != toString.call(b)) return false;
switch (className) {
// Strings, numbers, dates, and booleans are compared by value.
case '[object String]':
// Primitives and their corresponding object wrappers are equivalent; thus, `"5"` is
// equivalent to `new String("5")`.
return a == String(b);
case '[object Number]':
// `NaN`s are equivalent, but non-reflexive. An `egal` comparison is performed for
// other numeric values.
return a != +a ? b != +b : (a == 0 ? 1 / a == 1 / b : a == +b);
case '[object Date]':
case '[object Boolean]':
// Coerce dates and booleans to numeric primitive values. Dates are compared by their
// millisecond representations. Note that invalid dates with millisecond representations
// of `NaN` are not equivalent.
return +a == +b;
// RegExps are compared by their source patterns and flags.
case '[object RegExp]':
return a.source == b.source &&
a.global == b.global &&
a.multiline == b.multiline &&
a.ignoreCase == b.ignoreCase;
}
if (typeof a != 'object' || typeof b != 'object') return false;
// Assume equality for cyclic structures. The algorithm for detecting cyclic
// structures is adapted from ES 5.1 section 15.12.3, abstract operation `JO`.
var length = stack.length;
while (length--) {
// Linear search. Performance is inversely proportional to the number of
// unique nested structures.
if (stack[length] == a) return true;
}
// Add the first object to the stack of traversed objects.
stack.push(a);
var size = 0, result = true;
// Recursively compare objects and arrays.
if (className == '[object Array]') {
// Compare array lengths to determine if a deep comparison is necessary.
size = a.length;
result = size == b.length;
if (result) {
// Deep compare the contents, ignoring non-numeric properties.
while (size--) {
// Ensure commutative equality for sparse arrays.
if (!(result = size in a == size in b && eq(a[size], b[size], stack))) break;
}
}
} else {
// Objects with different constructors are not equivalent.
if ('constructor' in a != 'constructor' in b || a.constructor != b.constructor) return false;
// Deep compare objects.
for (var key in a) {
if (_.has(a, key)) {
// Count the expected number of properties.
size++;
// Deep compare each member.
if (!(result = _.has(b, key) && eq(a[key], b[key], stack))) break;
}
}
// Ensure that both objects contain the same number of properties.
if (result) {
for (key in b) {
if (_.has(b, key) && !(size--)) break;
}
result = !size;
}
}
// Remove the first object from the stack of traversed objects.
stack.pop();
return result;
}
// Perform a deep comparison to check if two objects are equal.
_.isEqual = function(a, b) {
return eq(a, b, []);
};
// Is a given array, string, or object empty?
// An "empty" object has no enumerable own-properties.
_.isEmpty = function(obj) {
if (_.isArray(obj) || _.isString(obj)) return obj.length === 0;
for (var key in obj) if (_.has(obj, key)) return false;
return true;
};
// Is a given value a DOM element?
_.isElement = function(obj) {
return !!(obj && obj.nodeType == 1);
};
// Is a given value an array?
// Delegates to ECMA5's native Array.isArray
_.isArray = nativeIsArray || function(obj) {
return toString.call(obj) == '[object Array]';
};
// Is a given variable an object?
_.isObject = function(obj) {
return obj === Object(obj);
};
// Is a given variable an arguments object?
_.isArguments = function(obj) {
return toString.call(obj) == '[object Arguments]';
};
if (!_.isArguments(arguments)) {
_.isArguments = function(obj) {
return !!(obj && _.has(obj, 'callee'));
};
}
// Is a given value a function?
_.isFunction = function(obj) {
return toString.call(obj) == '[object Function]';
};
// Is a given value a string?
_.isString = function(obj) {
return toString.call(obj) == '[object String]';
};
// Is a given value a number?
_.isNumber = function(obj) {
return toString.call(obj) == '[object Number]';
};
// Is the given value `NaN`?
_.isNaN = function(obj) {
// `NaN` is the only value for which `===` is not reflexive.
return obj !== obj;
};
// Is a given value a boolean?
_.isBoolean = function(obj) {
return obj === true || obj === false || toString.call(obj) == '[object Boolean]';
};
// Is a given value a date?
_.isDate = function(obj) {
return toString.call(obj) == '[object Date]';
};
// Is the given value a regular expression?
_.isRegExp = function(obj) {
return toString.call(obj) == '[object RegExp]';
};
// Is a given value equal to null?
_.isNull = function(obj) {
return obj === null;
};
// Is a given variable undefined?
_.isUndefined = function(obj) {
return obj === void 0;
};
// Has own property?
_.has = function(obj, key) {
return hasOwnProperty.call(obj, key);
};
// Utility Functions
// -----------------
// Run Underscore.js in *noConflict* mode, returning the `_` variable to its
// previous owner. Returns a reference to the Underscore object.
_.noConflict = function() {
root._ = previousUnderscore;
return this;
};
// Keep the identity function around for default iterators.
_.identity = function(value) {
return value;
};
// Run a function **n** times.
_.times = function (n, iterator, context) {
for (var i = 0; i < n; i++) iterator.call(context, i);
};
// Escape a string for HTML interpolation.
_.escape = function(string) {
return (''+string).replace(/&/g, '&amp;').replace(/</g, '&lt;').replace(/>/g, '&gt;').replace(/"/g, '&quot;').replace(/'/g, '&#x27;').replace(/\//g,'&#x2F;');
};
// Add your own custom functions to the Underscore object, ensuring that
// they're correctly added to the OOP wrapper as well.
_.mixin = function(obj) {
each(_.functions(obj), function(name){
addToWrapper(name, _[name] = obj[name]);
});
};
// Generate a unique integer id (unique within the entire client session).
// Useful for temporary DOM ids.
var idCounter = 0;
_.uniqueId = function(prefix) {
var id = idCounter++;
return prefix ? prefix + id : id;
};
// By default, Underscore uses ERB-style template delimiters, change the
// following template settings to use alternative delimiters.
_.templateSettings = {
evaluate : /<%([\s\S]+?)%>/g,
interpolate : /<%=([\s\S]+?)%>/g,
escape : /<%-([\s\S]+?)%>/g
};
// When customizing `templateSettings`, if you don't want to define an
// interpolation, evaluation or escaping regex, we need one that is
// guaranteed not to match.
var noMatch = /.^/;
// Within an interpolation, evaluation, or escaping, remove HTML escaping
// that had been previously added.
var unescape = function(code) {
return code.replace(/\\\\/g, '\\').replace(/\\'/g, "'");
};
// JavaScript micro-templating, similar to John Resig's implementation.
// Underscore templating handles arbitrary delimiters, preserves whitespace,
// and correctly escapes quotes within interpolated code.
_.template = function(str, data) {
var c = _.templateSettings;
var tmpl = 'var __p=[],print=function(){__p.push.apply(__p,arguments);};' +
'with(obj||{}){__p.push(\'' +
str.replace(/\\/g, '\\\\')
.replace(/'/g, "\\'")
.replace(c.escape || noMatch, function(match, code) {
return "',_.escape(" + unescape(code) + "),'";
})
.replace(c.interpolate || noMatch, function(match, code) {
return "'," + unescape(code) + ",'";
})
.replace(c.evaluate || noMatch, function(match, code) {
return "');" + unescape(code).replace(/[\r\n\t]/g, ' ') + ";__p.push('";
})
.replace(/\r/g, '\\r')
.replace(/\n/g, '\\n')
.replace(/\t/g, '\\t')
+ "');}return __p.join('');";
var func = new Function('obj', '_', tmpl);
if (data) return func(data, _);
return function(data) {
return func.call(this, data, _);
};
};
// Add a "chain" function, which will delegate to the wrapper.
_.chain = function(obj) {
return _(obj).chain();
};
// The OOP Wrapper
// ---------------
// If Underscore is called as a function, it returns a wrapped object that
// can be used OO-style. This wrapper holds altered versions of all the
// underscore functions. Wrapped objects may be chained.
var wrapper = function(obj) { this._wrapped = obj; };
// Expose `wrapper.prototype` as `_.prototype`
_.prototype = wrapper.prototype;
// Helper function to continue chaining intermediate results.
var result = function(obj, chain) {
return chain ? _(obj).chain() : obj;
};
// A method to easily add functions to the OOP wrapper.
var addToWrapper = function(name, func) {
wrapper.prototype[name] = function() {
var args = slice.call(arguments);
unshift.call(args, this._wrapped);
return result(func.apply(_, args), this._chain);
};
};
// Add all of the Underscore functions to the wrapper object.
_.mixin(_);
// Add all mutator Array functions to the wrapper.
each(['pop', 'push', 'reverse', 'shift', 'sort', 'splice', 'unshift'], function(name) {
var method = ArrayProto[name];
wrapper.prototype[name] = function() {
var wrapped = this._wrapped;
method.apply(wrapped, arguments);
var length = wrapped.length;
if ((name == 'shift' || name == 'splice') && length === 0) delete wrapped[0];
return result(wrapped, this._chain);
};
});
// Add all accessor Array functions to the wrapper.
each(['concat', 'join', 'slice'], function(name) {
var method = ArrayProto[name];
wrapper.prototype[name] = function() {
return result(method.apply(this._wrapped, arguments), this._chain);
};
});
// Start chaining a wrapped Underscore object.
wrapper.prototype.chain = function() {
this._chain = true;
return this;
};
// Extracts the result from a wrapped and chained object.
wrapper.prototype.value = function() {
return this._wrapped;
};
}).call(this);

File diff suppressed because one or more lines are too long

View File

@@ -36,7 +36,6 @@
<script src="_static/jquery.js"></script> <script src="_static/jquery.js"></script>
<script src="_static/underscore.js"></script> <script src="_static/underscore.js"></script>
<script src="_static/doctools.js"></script> <script src="_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="_static/js/theme.js"></script> <script type="text/javascript" src="_static/js/theme.js"></script>

View File

@@ -36,7 +36,6 @@
<script src="../_static/jquery.js"></script> <script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script> <script src="../_static/underscore.js"></script>
<script src="../_static/doctools.js"></script> <script src="../_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../_static/js/theme.js"></script> <script type="text/javascript" src="../_static/js/theme.js"></script>

View File

@@ -36,7 +36,6 @@
<script src="../../_static/jquery.js"></script> <script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script> <script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script> <script src="../../_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script> <script type="text/javascript" src="../../_static/js/theme.js"></script>
@@ -95,12 +94,13 @@
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current"> <li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul> <li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li> <li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch bindings</a></li> <li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li>
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li> <li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmarking">Benchmarking</a></li> <li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
</ul> </ul>
</li> </li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li> <li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
</ul> </ul>
</li> </li>
</ul> </ul>
@@ -225,7 +225,7 @@ programming model for more details).</p>
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL2019 Triton paper</a>.</p> <p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL2019 Triton paper</a>.</p>
</div> </div>
<div class="section" id="torch-bindings"> <div class="section" id="torch-bindings">
<h2>Torch bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline"></a></h2> <h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline"></a></h2>
<p>The only thing that matters when it comes to Triton and Torch is the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="code docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects. To create a <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things:</p> <p>The only thing that matters when it comes to Triton and Torch is the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="code docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects. To create a <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things:</p>
<ul class="simple"> <ul class="simple">
<li><p><code class="code docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create</p></li> <li><p><code class="code docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create</p></li>
@@ -313,15 +313,15 @@ programming model for more details).</p>
</pre></div> </pre></div>
</div> </div>
<p class="sphx-glr-script-out">Out:</p> <p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;) <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;) tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 0.0 The maximum difference between torch and triton is 0.0
</pre></div> </pre></div>
</div> </div>
<p>Seems like were good to go!</p> <p>Seems like were good to go!</p>
</div> </div>
<div class="section" id="benchmarking"> <div class="section" id="benchmark">
<h2>Benchmarking<a class="headerlink" href="#benchmarking" title="Permalink to this headline"></a></h2> <h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch. <p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op. To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
for different problem sizes.</p> for different problem sizes.</p>
@@ -355,7 +355,7 @@ for different problem sizes.</p>
</pre></div> </pre></div>
</div> </div>
<img alt="vector-add-performance" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" /> <img alt="vector-add-performance" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 5.901 seconds)</p> <p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 7.521 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py"> <div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container"> <div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p> <p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>

View File

@@ -43,6 +43,7 @@
<link rel="index" title="Index" href="../../genindex.html" /> <link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" /> <link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Matrix Multiplication" href="03-matrix-multiplication.html" />
<link rel="prev" title="Vector Addition" href="01-vector-add.html" /> <link rel="prev" title="Vector Addition" href="01-vector-add.html" />
</head> </head>
@@ -98,9 +99,10 @@
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li> <li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li> <li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li>
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li> <li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmarking">Benchmarking</a></li> <li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
</ul> </ul>
</li> </li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
</ul> </ul>
</li> </li>
</ul> </ul>
@@ -298,10 +300,19 @@ This means that different values of BLOCK will result in different kernels</p>
<span class="c1"># Now are kernels are indexed not only by the provided device but also</span> <span class="c1"># Now are kernels are indexed not only by the provided device but also</span>
<span class="c1"># by the rounded number of columns in the input matrix</span> <span class="c1"># by the rounded number of columns in the input matrix</span>
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span> <span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span> <span class="c1"># Another trick we can use is to ask the compiler to parallelize each</span>
<span class="c1"># row-normalization more aggressively -- i.e., with more warps -- vectors</span>
<span class="c1"># that are longer</span>
<span class="c1"># You will see in the next tutorial how to auto-tune this value in a more natural</span>
<span class="c1"># way so you don&#39;t have to come up with manual heuristics yourself</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">&gt;=</span> <span class="mi">2048</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">&gt;=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
<span class="c1"># Each (BLOCK, num_warps, device) results in a different kernel</span>
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">num_warps</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span> <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span> <span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
<span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">)</span>
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
@@ -345,8 +356,8 @@ This will allow us to verify that our padding mechanism works.</p>
</div> </div>
<p>As expected, the results are identical.</p> <p>As expected, the results are identical.</p>
</div> </div>
<div class="section" id="benchmarking"> <div class="section" id="benchmark">
<h2>Benchmarking<a class="headerlink" href="#benchmarking" title="Permalink to this headline"></a></h2> <h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
<p>Here we will benchmark our operation as a function of the number of columns in the input matrix assuming 4096 rows. <p>Here we will benchmark our operation as a function of the number of columns in the input matrix assuming 4096 rows.
We will then compare its performance against (1) <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> and (2) the <code class="code docutils literal notranslate"><span class="pre">naive_softmax</span></code> defined above.</p> We will then compare its performance against (1) <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> and (2) the <code class="code docutils literal notranslate"><span class="pre">naive_softmax</span></code> defined above.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
@@ -386,7 +397,7 @@ This means that when temporary data is too large to fit entirely in the GPU
Note that our Triton kernel is not only faster than PyTorchs CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li> Note that our Triton kernel is not only faster than PyTorchs CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
</ul> </ul>
</div></blockquote> </div></blockquote>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 21.805 seconds)</p> <p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 19.896 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py"> <div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container"> <div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p> <p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
@@ -405,6 +416,7 @@ Note that our Triton kernel is not only faster than PyTorchs CUDA kernel, it
</div> </div>
<footer> <footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation"> <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="03-matrix-multiplication.html" class="btn btn-neutral float-right" title="Matrix Multiplication" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a> <a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div> </div>

View File

@@ -0,0 +1,629 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Matrix Multiplication &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="prev" title="Fused Softmax" href="02-fused-softmax.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Matrix Multiplication</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#pointer-arithmetics">Pointer Arithmetics</a></li>
<li class="toctree-l4"><a class="reference internal" href="#l2-cache-optimizations">L2 Cache Optimizations</a></li>
<li class="toctree-l4"><a class="reference internal" href="#final-result">Final Result</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#auto-tuning">Auto-Tuning</a></li>
<li class="toctree-l4"><a class="reference internal" href="#autograd-function">Autograd Function</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#installing-the-cutlass-bindings">Installing The CUTLASS Bindings</a></li>
<li class="toctree-l4"><a class="reference internal" href="#square-matrix-performance">Square Matrix Performance</a></li>
</ul>
</li>
</ul>
</li>
</ul>
</li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Matrix Multiplication</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="matrix-multiplication">
<span id="sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"></span><h1>Matrix Multiplication<a class="headerlink" href="#matrix-multiplication" title="Permalink to this headline"></a></h1>
<p>In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLASs performance.
You will specifically learn about:</p>
<ul class="simple">
<li><p>The block-level matrix multiplication operator <cite>&#64;</cite></p></li>
<li><p>Multi-dimensional pointer arithmetic</p></li>
<li><p>Program re-ordering for improved L2 cache hit rate</p></li>
<li><p>Automatic performance tuning</p></li>
</ul>
<div class="section" id="motivations">
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline"></a></h2>
<p>Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.</p>
<p>Roughly speaking, the kernel that we will write will implement the following blocked algorithm:</p>
<blockquote>
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">MB</span><span class="p">):</span>
<span class="c1"># do in parallel</span>
<span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">NB</span><span class="p">):</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="n">MB</span><span class="p">,</span> <span class="n">NB</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">KB</span><span class="p">):</span>
<span class="n">acc</span> <span class="o">+=</span> <span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">,</span> <span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">]</span> <span class="o">@</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">]</span>
<span class="n">C</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
</pre></div>
</div>
</div></blockquote>
<p>where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.</p>
</div>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<p>The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the <code class="code docutils literal notranslate"><span class="pre">&#64;</span></code> operator for block-level matrix multiplication.
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> that we need to read in the inner loop.</p>
<div class="section" id="pointer-arithmetics">
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline"></a></h3>
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given by <code class="code docutils literal notranslate"><span class="pre">&amp;X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">i</span> <span class="pre">+</span> <span class="pre">X.stride(0)</span> <span class="pre">+</span> <span class="pre">j</span></code>.
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+MB,</span> <span class="pre">k:k+KB]</span></code> and <code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+KB,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+NB]</span></code> can be defined in pseudo-code as:</p>
<blockquote>
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">MB</span><span class="p">)[:,</span> <span class="n">newaxis</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">)[</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:];</span>
<span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">KB</span><span class="p">)[:,</span> <span class="n">newaxis</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">NB</span><span class="p">)[</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:];</span>
</pre></div>
</div>
</div></blockquote>
<p>Which means that, at initialization (i.e., <code class="code docutils literal notranslate"><span class="pre">k</span> <span class="pre">=</span> <span class="pre">0</span></code>), pointers for blocks of A and B can be initialized in Triton as:</p>
<blockquote>
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span> <span class="n">rm</span><span class="p">[</span><span class="n">MB</span><span class="p">]</span> <span class="o">=</span> <span class="n">program_id_m</span> <span class="o">*</span> <span class="n">MB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">MB</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">rn</span><span class="p">[</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">program_id_n</span> <span class="o">*</span> <span class="n">NB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">NB</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">rk</span><span class="p">[</span><span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">KB</span><span class="p">;</span>
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pa</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_a_0</span> <span class="o">+</span> <span class="n">rk</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span><span class="p">);</span>
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pb</span><span class="p">[</span><span class="n">KB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_b_0</span> <span class="o">+</span> <span class="n">rn</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span><span class="p">);</span>
</pre></div>
</div>
</div></blockquote>
<p>These pointers can then be updated in the inner loop as:</p>
<blockquote>
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="mi">1</span><span class="p">;</span>
<span class="n">pb</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="n">ldb</span><span class="p">;</span>
</pre></div>
</div>
</div></blockquote>
</div>
<div class="section" id="l2-cache-optimizations">
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline"></a></h3>
<p>As mentioned above, each program instance computes an <code class="code docutils literal notranslate"><span class="pre">[MB,</span> <span class="pre">NB]</span></code> block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
This means that a naive row-major ordering:</p>
<blockquote>
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span> <span class="n">program_id</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">MB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">MB</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">NB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">NB</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">program_id_m</span> <span class="o">=</span> <span class="n">program_id</span> <span class="o">/</span> <span class="n">grid_n</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">program_id_n</span> <span class="o">=</span> <span class="n">program_id</span> <span class="o">%</span> <span class="n">grid_n</span><span class="p">;</span>
</pre></div>
</div>
</div></blockquote>
<p>is unlikely to result in optimal performance.</p>
<p>One possible solution is to launch blocks in an order that promotes data reuse.
This can be done by super-grouping blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_SIZE</span></code> before switching to the next column:</p>
<blockquote>
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="kt">int</span> <span class="n">program_id</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">width</span> <span class="o">=</span> <span class="n">GROUP_SIZE</span> <span class="o">*</span> <span class="n">grid_n</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">width</span><span class="p">;</span>
<span class="c1">// we need to handle the case where M % (GROUP_SIZE*BM) != 0</span>
<span class="kt">int</span> <span class="n">group_size</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE</span><span class="p">,</span> <span class="n">GROUP_SIZE</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">group_size</span><span class="p">);</span>
</pre></div>
</div>
</div></blockquote>
<p>In practice, this can improve the performance of our matrix multiplication kernel by &gt;10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
</div>
<div class="section" id="final-result">
<h3>Final Result<a class="headerlink" href="#final-result" title="Permalink to this headline"></a></h3>
<p>We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
Note that we rematerialize <code class="code docutils literal notranslate"><span class="pre">rm</span></code> and <code class="code docutils literal notranslate"><span class="pre">rn:</span></code> after the inner loop to decrease register pressure.
This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.</p>
<blockquote>
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="cp">#define MAX_GROUP_SIZE 8</span>
<span class="n">__global__</span> <span class="kt">void</span> <span class="n">dot</span><span class="p">(</span><span class="n">TYPE</span><span class="o">*</span> <span class="n">A</span><span class="p">,</span> <span class="n">TYPE</span><span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="n">TYPE</span><span class="o">*</span> <span class="n">C</span><span class="p">,</span>
<span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">,</span> <span class="kt">int</span> <span class="n">K</span><span class="p">,</span>
<span class="kt">int</span> <span class="n">stride_a_0</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_b_0</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_c_0</span><span class="p">)</span> <span class="p">{</span>
<span class="c1">// prologue</span>
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">MB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">MB</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">NB</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">NB</span><span class="p">;</span>
<span class="c1">// re-order program ID for better L2 performance</span>
<span class="kt">int</span> <span class="n">width</span> <span class="o">=</span> <span class="n">MAX_GROUP_SIZE</span> <span class="o">*</span> <span class="n">grid_n</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">width</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">group_size</span> <span class="o">=</span> <span class="n">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">MAX_GROUP_SIZE</span><span class="p">,</span> <span class="n">MAX_GROUP_SIZE</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">MAX_GROUP_SIZE</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">group_size</span><span class="p">);</span>
<span class="c1">// pointers to operands</span>
<span class="c1">// note the parentheses here; they force the offset</span>
<span class="c1">// computation to happen in typeof(stride_a_0) = int32 rather than</span>
<span class="c1">// typeof(A) = int64</span>
<span class="kt">int</span> <span class="n">rm</span><span class="p">[</span><span class="n">MB</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">MB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">MB</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">rn</span><span class="p">[</span><span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">NB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">NB</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">rk</span><span class="p">[</span><span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">KB</span><span class="p">;</span>
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pa</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">KB</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_a_0</span><span class="p">);</span>
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pb</span><span class="p">[</span><span class="n">KB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_b_0</span> <span class="o">+</span> <span class="n">rn</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">*</span> <span class="mi">1</span><span class="p">);</span>
<span class="c1">// reduction loop</span>
<span class="kt">float</span> <span class="n">acc</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="n">K</span><span class="p">;</span> <span class="n">k</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">-=</span> <span class="n">KB</span><span class="p">)</span> <span class="p">{</span>
<span class="n">acc</span> <span class="o">+=</span> <span class="p">(</span><span class="o">*</span><span class="n">pa</span><span class="p">)</span> <span class="err">@</span> <span class="p">(</span><span class="o">*</span><span class="n">pb</span><span class="p">);</span>
<span class="n">pa</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="mi">1</span><span class="p">;</span>
<span class="n">pb</span> <span class="o">+=</span> <span class="n">KB</span> <span class="o">*</span> <span class="n">stride_b_0</span><span class="p">;</span>
<span class="p">}</span>
<span class="c1">// pointers to output</span>
<span class="c1">// here we rematerialize `rm` and `rn` so that they are not live through</span>
<span class="c1">// the above reduction loop. In the future, the compiler should be able to</span>
<span class="c1">// do this automatically.</span>
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">MB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">MB</span><span class="p">;</span>
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">NB</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">NB</span><span class="p">;</span>
<span class="n">TYPE</span> <span class="o">*</span><span class="n">pc</span><span class="p">[</span><span class="n">MB</span><span class="p">,</span> <span class="n">NB</span><span class="p">]</span> <span class="o">=</span> <span class="n">C</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_c_0</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]);</span>
<span class="c1">// we write back using *?() operator. `acc` gets casted to `float32` implicitly.</span>
<span class="o">*?</span> <span class="p">(</span><span class="n">rm</span><span class="p">[</span><span class="o">:</span><span class="p">,</span> <span class="n">newaxis</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">M</span> <span class="o">&amp;&amp;</span> <span class="n">rn</span> <span class="p">[</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">:</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="n">pc</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
</div></blockquote>
<p>Where <code class="code docutils literal notranslate"><span class="pre">TYPE</span></code> is the data-type of the input matrices and <code class="code docutils literal notranslate"><span class="pre">MB</span></code>, <code class="code docutils literal notranslate"><span class="pre">NB</span></code>, <code class="code docutils literal notranslate"><span class="pre">KB</span></code> are the block sizes defined in the above pseudo-code.
Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
If <code class="code docutils literal notranslate"><span class="pre">TYPE</span></code> is <code class="code docutils literal notranslate"><span class="pre">half</span></code>, then tensor cores will be used automatically provided that <code class="code docutils literal notranslate"><span class="pre">MB</span></code>, <code class="code docutils literal notranslate"><span class="pre">NB</span></code> and <code class="code docutils literal notranslate"><span class="pre">KB</span></code> are multiples of 16.</p>
</div>
</div>
<div class="section" id="torch-bindings">
<h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline"></a></h2>
<div class="section" id="auto-tuning">
<h3>Auto-Tuning<a class="headerlink" href="#auto-tuning" title="Permalink to this headline"></a></h3>
<p>In order to use Tritons built-in auto-tuner in the above kernel, we need to define a list of <code class="code docutils literal notranslate"><span class="pre">triton.config</span></code> objects. that can be constructed as follows:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="n">autotune_configs</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;MB&quot;</span><span class="p">:</span> <span class="s2">&quot;128&quot;</span><span class="p">,</span> <span class="s2">&quot;NB&quot;</span><span class="p">:</span> <span class="s2">&quot;128&quot;</span><span class="p">,</span> <span class="s2">&quot;KB&quot;</span><span class="p">:</span> <span class="s2">&quot;32&quot;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;MB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">,</span> <span class="s1">&#39;NB&#39;</span><span class="p">:</span> <span class="s1">&#39;128&#39;</span><span class="p">,</span> <span class="s1">&#39;KB&#39;</span><span class="p">:</span> <span class="s1">&#39;32&#39;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;MB&#39;</span><span class="p">:</span> <span class="s1">&#39;128&#39;</span><span class="p">,</span> <span class="s1">&#39;NB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">,</span> <span class="s1">&#39;KB&#39;</span><span class="p">:</span> <span class="s1">&#39;32&#39;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;MB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">,</span> <span class="s1">&#39;NB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">,</span> <span class="s1">&#39;KB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;MB&#39;</span><span class="p">:</span> <span class="s1">&#39;32&#39;</span><span class="p">,</span> <span class="s1">&#39;NB&#39;</span><span class="p">:</span> <span class="s1">&#39;128&#39;</span><span class="p">,</span> <span class="s1">&#39;KB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;MB&#39;</span><span class="p">:</span> <span class="s1">&#39;128&#39;</span><span class="p">,</span> <span class="s1">&#39;NB&#39;</span><span class="p">:</span> <span class="s1">&#39;32&#39;</span><span class="p">,</span> <span class="s1">&#39;KB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;MB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">,</span> <span class="s1">&#39;NB&#39;</span><span class="p">:</span> <span class="s1">&#39;32&#39;</span><span class="p">,</span> <span class="s1">&#39;KB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;MB&#39;</span><span class="p">:</span> <span class="s1">&#39;32&#39;</span><span class="p">,</span> <span class="s1">&#39;NB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">,</span> <span class="s1">&#39;KB&#39;</span><span class="p">:</span> <span class="s1">&#39;64&#39;</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="p">]</span>
</pre></div>
</div>
<p>we also need to define a list of <code class="code docutils literal notranslate"><span class="pre">string</span></code> (i.e., “autotuning key”) that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
Here, we want to re-tune our kernel only when the shape of input matrices changes.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">autotune_key</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;M&quot;</span><span class="p">,</span> <span class="s2">&quot;N&quot;</span><span class="p">,</span> <span class="s2">&quot;K&quot;</span><span class="p">]</span>
</pre></div>
</div>
<p>We can now create an auto-tuned kernel by passing the <cite>autotune_configs</cite> and <cite>autotune_key</cite> lists to the constructor of the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">src</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2">#define MAX_GROUP_SIZE 8</span>
<span class="s2">__global__ void dot(TYPE* A, TYPE* B, TYPE* C,</span>
<span class="s2"> int M, int N, int K,</span>
<span class="s2"> int lda, int ldb, int ldc) {</span>
<span class="s2"> int pid = get_program_id(0);</span>
<span class="s2"> int grid_m = (M + MB - 1) / MB;</span>
<span class="s2"> int grid_n = (N + NB - 1) / NB;</span>
<span class="s2"> int width = MAX_GROUP_SIZE * grid_n;</span>
<span class="s2"> int group_id = pid / width;</span>
<span class="s2"> int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);</span>
<span class="s2"> int pid_m = group_id * MAX_GROUP_SIZE + (pid </span><span class="si">% g</span><span class="s2">roup_size);</span>
<span class="s2"> int pid_n = (pid % width) / (group_size);</span>
<span class="s2"> int rm[MB] = pid_m * MB + 0 ... MB;</span>
<span class="s2"> int rn[NB] = pid_n * NB + 0 ... NB;</span>
<span class="s2"> int rk[KB] = 0 ... KB;</span>
<span class="s2"> TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);</span>
<span class="s2"> TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);</span>
<span class="s2"> float acc[MB, NB] = 0;</span>
<span class="s2"> for (int k = K; k &gt; 0; k -= KB) {</span>
<span class="s2"> acc += (*pa) @ (*pb);</span>
<span class="s2"> pa += KB * 1;</span>
<span class="s2"> pb += KB * ldb;</span>
<span class="s2"> }</span>
<span class="s2"> rm = pid_m * MB + 0 ... MB;</span>
<span class="s2"> rn = pid_n * NB + 0 ... NB;</span>
<span class="s2"> TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);</span>
<span class="s2"> *? (rm[:, newaxis] &lt; M &amp;&amp; rn [newaxis, :] &lt; N) pc = acc;</span>
<span class="s2">}</span>
<span class="s2">&quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span>
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;TYPE&#39;</span><span class="p">:</span> <span class="n">dtype</span><span class="p">}</span>
<span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">,</span> <span class="n">autotune_vals</span><span class="o">=</span><span class="n">autotune_configs</span><span class="p">,</span> <span class="n">autotune_key</span><span class="o">=</span><span class="n">autotune_key</span><span class="p">)</span>
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
<span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
</pre></div>
</div>
</div>
<div class="section" id="autograd-function">
<h3>Autograd Function<a class="headerlink" href="#autograd-function" title="Permalink to this headline"></a></h3>
<p>Now we are ready to expose our auto-tuned kernel as a <cite>torch.autograd.Function</cite>.
To do so, we just need to define a <cite>forward</cite> function that takes a two tensors as input and returns a tensor as output.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">_dot</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="n">M</span><span class="p">,</span> <span class="n">Ka</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
<span class="n">Kb</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span>
<span class="k">assert</span> <span class="n">Ka</span> <span class="o">==</span> <span class="n">Kb</span><span class="p">,</span> <span class="s2">&quot;incompatible dimensions&quot;</span>
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span> <span class="ow">and</span> <span class="n">b</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">&quot;inputs must be contiguous&quot;</span>
<span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">MB</span><span class="p">)</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">NB</span><span class="p">),</span> <span class="p">)</span>
<span class="n">kernel</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">b</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">c</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> \
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">Ka</span><span class="p">,</span> \
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> \
<span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
<span class="k">return</span> <span class="n">c</span>
<span class="n">dot</span> <span class="o">=</span> <span class="n">_dot</span><span class="o">.</span><span class="n">apply</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="unit-test">
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline"></a></h2>
<p>We can test our custom matrix multiplication operation against cuBLAS (i.e., <code class="code docutils literal notranslate"><span class="pre">torch.matmul</span></code>).
Note that we need to modify the :code`atol` and <code class="code docutils literal notranslate"><span class="pre">rtol</span></code> parameters of <cite>torch.allclose</cite> to account for the fact that we are comparing FP16 tensors.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">768</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">((</span><span class="mi">768</span><span class="p">,</span> <span class="mi">896</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">c_0</span> <span class="o">=</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="n">c_1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">c_0</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">c_1</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">c_0</span><span class="p">,</span> <span class="n">c_1</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device=&#39;cuda:0&#39;, dtype=torch.float16)
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device=&#39;cuda:0&#39;, dtype=torch.float16)
True
</pre></div>
</div>
</div>
<div class="section" id="benchmark">
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
<div class="section" id="installing-the-cutlass-bindings">
<h3>Installing The CUTLASS Bindings<a class="headerlink" href="#installing-the-cutlass-bindings" title="Permalink to this headline"></a></h3>
<p>The cuBLAS library (used by <code class="code docutils literal notranslate"><span class="pre">torch.matmul</span></code>) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
For this reason, we will instead compare the performance of our kernel against <a class="reference external" href="https://github.com/NVIDIA/cutlass/">CUTLASS</a> , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
To install CUTLASS, you need a recent version of cmake:</p>
<blockquote>
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span> /path/to/cutlass/
git clone https://github.com/NVIDIA/cutlass.git
<span class="nb">cd</span> cutlass
mkdir build
<span class="nb">cd</span> build
wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
tar xzvf *.tar.gz
</pre></div>
</div>
</div></blockquote>
<p>You can then install CUTLASS as follows for V100</p>
<blockquote>
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED<span class="o">=</span><span class="m">70</span> -DCUTLASS_LIBRARY_KERNELS<span class="o">=</span>cutlass_tensorop_f16_s884gemm_f16_*_align8
make -j8 install
</pre></div>
</div>
</div></blockquote>
<p>Or as follows for A100:</p>
<blockquote>
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED<span class="o">=</span><span class="m">80</span> -DCUTLASS_LIBRARY_KERNELS<span class="o">=</span>cutlass_tensorop_f16_s16816gemm_*align8
make -j8 install
</pre></div>
</div>
</div></blockquote>
<p>Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables <code class="code docutils literal notranslate"><span class="pre">CUTLASS_INCLUDE_DIR</span></code> and <code class="code docutils literal notranslate"><span class="pre">CUTLASS_LIBRARY_DIR</span></code> are set during the installation process.
To re-install Triton with the updated CUTLASS bindings, run the following command:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">export</span> <span class="nv">CUTLASS_INCLUDE_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/include/
<span class="nb">export</span> <span class="nv">CUTLASS_LIBRARY_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/lib/a
pip uninstall -y triton
pip install -e <span class="s2">&quot;git+https://github.com/ptillet/triton.git#egg=triton&amp;subdirectory=python&quot;</span>
</pre></div>
</div>
<p>Which we can test as follows:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
<span class="n">c_2</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">cutlass_matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">c_2</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">c_0</span><span class="p">,</span> <span class="n">c_2</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device=&#39;cuda:0&#39;, dtype=torch.float16)
True
</pre></div>
</div>
<p>Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.</p>
</div>
<div class="section" id="square-matrix-performance">
<h3>Square Matrix Performance<a class="headerlink" href="#square-matrix-performance" title="Permalink to this headline"></a></h3>
<p>We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="s1">&#39;K&#39;</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">256</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
<span class="n">y_name</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;torch&#39;</span><span class="p">,</span> <span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="s1">&#39;cutlass&#39;</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;Torch&quot;</span><span class="p">,</span> <span class="s2">&quot;Triton&quot;</span><span class="p">,</span> <span class="s1">&#39;CUTLASS&#39;</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s2">&quot;TFLOPS&quot;</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s2">&quot;matmul-performance&quot;</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
<span class="n">args</span><span class="o">=</span><span class="p">{}</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;cutlass&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">cutlass_matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<img alt="matmul-performance" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
<p>As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 10.181 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/b51b68bc1c6b1a5e509f67800b6235af/03-matrix-multiplication.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">03-matrix-multiplication.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="02-fused-softmax.html" class="btn btn-neutral float-left" title="Fused Softmax" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -36,7 +36,6 @@
<script src="../../_static/jquery.js"></script> <script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script> <script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script> <script src="../../_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script> <script type="text/javascript" src="../../_static/js/theme.js"></script>
@@ -95,6 +94,7 @@
<li class="toctree-l1 current"><a class="current reference internal" href="#">Tutorials</a><ul> <li class="toctree-l1 current"><a class="current reference internal" href="#">Tutorials</a><ul>
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li> <li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li> <li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
</ul> </ul>
</li> </li>
</ul> </ul>
@@ -179,6 +179,12 @@
</div> </div>
</div><div class="toctree-wrapper compound"> </div><div class="toctree-wrapper compound">
</div> </div>
<div class="sphx-glr-thumbcontainer" tooltip="- The block-level matrix multiplication operator @ - Multi-dimensional pointer arithmetic - Pro..."><div class="figure align-default" id="id3">
<img alt="Matrix Multiplication" src="../../_images/sphx_glr_03-matrix-multiplication_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a></span><a class="headerlink" href="#id3" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-clear"></div><div class="sphx-glr-footer class sphx-glr-footer-gallery docutils container"> <div class="sphx-glr-clear"></div><div class="sphx-glr-footer class sphx-glr-footer-gallery docutils container">
<div class="sphx-glr-download sphx-glr-download-python docutils container"> <div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tutorials_python.zip</span></code></a></p> <p><a class="reference download internal" download="" href="../../_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tutorials_python.zip</span></code></a></p>

View File

@@ -36,7 +36,6 @@
<script src="../../_static/jquery.js"></script> <script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script> <script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script> <script src="../../_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script> <script type="text/javascript" src="../../_static/js/theme.js"></script>
@@ -160,20 +159,24 @@
<div class="section" id="computation-times"> <div class="section" id="computation-times">
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline"></a></h1> <span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline"></a></h1>
<p><strong>00:27.706</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p> <p><strong>01:10.181</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
<table class="docutils align-default"> <table class="docutils align-default">
<colgroup> <colgroup>
<col style="width: 82%" /> <col style="width: 85%" />
<col style="width: 10%" /> <col style="width: 9%" />
<col style="width: 7%" /> <col style="width: 6%" />
</colgroup> </colgroup>
<tbody> <tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td> <tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
<td><p>00:21.805</p></td> <td><p>01:10.181</p></td>
<td><p>0.0 MB</p></td> <td><p>0.0 MB</p></td>
</tr> </tr>
<tr class="row-even"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td> <tr class="row-even"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
<td><p>00:05.901</p></td> <td><p>00:00.000</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
<td><p>00:00.000</p></td>
<td><p>0.0 MB</p></td> <td><p>0.0 MB</p></td>
</tr> </tr>
</tbody> </tbody>

View File

@@ -36,7 +36,6 @@
<script src="_static/jquery.js"></script> <script src="_static/jquery.js"></script>
<script src="_static/underscore.js"></script> <script src="_static/underscore.js"></script>
<script src="_static/doctools.js"></script> <script src="_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="_static/js/theme.js"></script> <script type="text/javascript" src="_static/js/theme.js"></script>

Binary file not shown.

View File

@@ -37,7 +37,6 @@
<script src="_static/jquery.js"></script> <script src="_static/jquery.js"></script>
<script src="_static/underscore.js"></script> <script src="_static/underscore.js"></script>
<script src="_static/doctools.js"></script> <script src="_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="_static/js/theme.js"></script> <script type="text/javascript" src="_static/js/theme.js"></script>

File diff suppressed because one or more lines are too long