[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2021-03-15 13:58:20 -04:00
parent b4495e0ddc
commit 746b15ee0a
39 changed files with 3933 additions and 1113 deletions

View File

@@ -70,7 +70,7 @@ The existence of arrays as a primitive data-type for Triton comes with a number
.. GENERATED FROM PYTHON SOURCE LINES 53-60
Torch bindings
Torch Bindings
--------------------------
The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
@@ -161,7 +161,7 @@ We can now use the above function to compute the sum of two `torch.tensor` objec
.. GENERATED FROM PYTHON SOURCE LINES 129-133
Unit Test
--------------------------
-----------
Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
@@ -189,8 +189,8 @@ Of course, the first thing that we should check is that whether kernel is correc
.. code-block:: none
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0
@@ -202,8 +202,8 @@ Seems like we're good to go!
.. GENERATED FROM PYTHON SOURCE LINES 147-152
Benchmarking
--------------------------
Benchmark
-----------
We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
for different problem sizes.
@@ -268,7 +268,7 @@ We can now run the decorated function above. Pass `show_plots=True` to see the p
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 5.901 seconds)
**Total running time of the script:** ( 0 minutes 7.521 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:

View File

@@ -121,7 +121,7 @@ Here our torch bindings is quite similar to that of the vector addition mentione
We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
This means that different values of BLOCK will result in different kernels
.. GENERATED FROM PYTHON SOURCE LINES 89-156
.. GENERATED FROM PYTHON SOURCE LINES 89-165
.. code-block:: default
@@ -165,10 +165,19 @@ This means that different values of BLOCK will result in different kernels
# Now are kernels are indexed not only by the provided device but also
# by the rounded number of columns in the input matrix
BLOCK = next_power_of_2(N)
key = (BLOCK, device)
# Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
# that are longer
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Each (BLOCK, num_warps, device) results in a different kernel
key = (BLOCK, num_warps, device)
if key not in cache:
defines = {'BLOCK': BLOCK}
cache[key] = triton.kernel(_src, device=device, defines=defines)
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
return cache[key]
@@ -199,21 +208,21 @@ This means that different values of BLOCK will result in different kernels
.. GENERATED FROM PYTHON SOURCE LINES 157-158
.. GENERATED FROM PYTHON SOURCE LINES 166-167
We can use the above softmax function to compute the row-wise softmax of a given matrix.
.. GENERATED FROM PYTHON SOURCE LINES 160-162
.. GENERATED FROM PYTHON SOURCE LINES 169-171
Unit Test
----------
.. GENERATED FROM PYTHON SOURCE LINES 164-166
.. GENERATED FROM PYTHON SOURCE LINES 173-175
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 166-173
.. GENERATED FROM PYTHON SOURCE LINES 175-182
.. code-block:: default
@@ -239,18 +248,18 @@ This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 174-175
.. GENERATED FROM PYTHON SOURCE LINES 183-184
As expected, the results are identical.
.. GENERATED FROM PYTHON SOURCE LINES 177-181
.. GENERATED FROM PYTHON SOURCE LINES 186-190
Benchmarking
Benchmark
-------------
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
.. GENERATED FROM PYTHON SOURCE LINES 181-209
.. GENERATED FROM PYTHON SOURCE LINES 190-218
.. code-block:: default
@@ -293,7 +302,7 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
.. GENERATED FROM PYTHON SOURCE LINES 210-215
.. GENERATED FROM PYTHON SOURCE LINES 219-224
In the above plot, we can see that:
@@ -305,7 +314,7 @@ In the above plot, we can see that:
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 21.805 seconds)
**Total running time of the script:** ( 0 minutes 19.896 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:

View File

@@ -0,0 +1,558 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/03-matrix-multiplication.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_03-matrix-multiplication.py:
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
You will specifically learn about:
- The block-level matrix multiplication operator `@`
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
.. GENERATED FROM PYTHON SOURCE LINES 14-35
Motivations
-------------
Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
.. code-block:: python
# do in parallel
for m in range(0, M, MB):
# do in parallel
for n in range(0, N, NB):
acc = zeros((MB, NB), dtype=float32)
for k in range(0, K, KB):
acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
C[m : m+MB, n : n+NB] = acc;
where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
.. GENERATED FROM PYTHON SOURCE LINES 37-161
Compute Kernel
----------------
The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
Pointer Arithmetics
~~~~~~~~~~~~~~~~~~~~
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
.. code-block:: python
&A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
&B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
.. code-block:: C
:force:
int rm[MB] = program_id_m * MB + 0 ... MB;
int rn[NB] = program_id_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
These pointers can then be updated in the inner loop as:
.. code-block:: C
pa += KB * 1;
pb += KB * ldb;
L2 Cache Optimizations
~~~~~~~~~~~~~~~~~~~~~~~~
As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
This means that a naive row-major ordering:
.. code-block:: C
int program_id = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
int program_id_m = program_id / grid_n;
int program_id_n = program_id % grid_n;
is unlikely to result in optimal performance.
One possible solution is to launch blocks in an order that promotes data reuse.
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
.. code-block:: C
int program_id = get_program_id(0);
int width = GROUP_SIZE * grid_n;
int group_id = pid / width;
// we need to handle the case where M % (GROUP_SIZE*BM) != 0
int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
int pid_m = group_id * GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
Final Result
~~~~~~~~~~~~~~
We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
.. code-block:: C
:force:
#define MAX_GROUP_SIZE 8
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
int M, int N, int K,
int stride_a_0, int stride_b_0, int stride_c_0) {
// prologue
int pid = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
// re-order program ID for better L2 performance
int width = MAX_GROUP_SIZE * grid_n;
int group_id = pid / width;
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
// pointers to operands
// note the parentheses here; they force the offset
// computation to happen in typeof(stride_a_0) = int32 rather than
// typeof(A) = int64
int rm[MB] = pid_m * MB + 0 ... MB;
int rn[NB] = pid_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
// reduction loop
float acc[MB, NB] = 0;
for (int k = K; k > 0; k -= KB) {
acc += (*pa) @ (*pb);
pa += KB * 1;
pb += KB * stride_b_0;
}
// pointers to output
// here we rematerialize `rm` and `rn` so that they are not live through
// the above reduction loop. In the future, the compiler should be able to
// do this automatically.
rm = pid_m * MB + 0 ... MB;
rn = pid_n * NB + 0 ... NB;
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
// we write back using *?() operator. `acc` gets casted to `float32` implicitly.
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
}
Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
.. GENERATED FROM PYTHON SOURCE LINES 163-170
Torch Bindings
----------------
Auto-Tuning
~~~~~~~~~~~~~~
In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
.. GENERATED FROM PYTHON SOURCE LINES 170-185
.. code-block:: default
import torch
import triton
autotune_configs = [
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
]
.. GENERATED FROM PYTHON SOURCE LINES 186-188
we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
Here, we want to re-tune our kernel only when the shape of input matrices changes.
.. GENERATED FROM PYTHON SOURCE LINES 188-191
.. code-block:: default
autotune_key = ["M", "N", "K"]
.. GENERATED FROM PYTHON SOURCE LINES 192-193
We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
.. GENERATED FROM PYTHON SOURCE LINES 193-238
.. code-block:: default
src = """
#define MAX_GROUP_SIZE 8
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
int M, int N, int K,
int lda, int ldb, int ldc) {
int pid = get_program_id(0);
int grid_m = (M + MB - 1) / MB;
int grid_n = (N + NB - 1) / NB;
int width = MAX_GROUP_SIZE * grid_n;
int group_id = pid / width;
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
int pid_n = (pid % width) / (group_size);
int rm[MB] = pid_m * MB + 0 ... MB;
int rn[NB] = pid_n * NB + 0 ... NB;
int rk[KB] = 0 ... KB;
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
float acc[MB, NB] = 0;
for (int k = K; k > 0; k -= KB) {
acc += (*pa) @ (*pb);
pa += KB * 1;
pb += KB * ldb;
}
rm = pid_m * MB + 0 ... MB;
rn = pid_n * NB + 0 ... NB;
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
}
"""
def make_kernel(device, dtype):
key = (device, dtype)
cache = make_kernel.cache
if key not in cache:
defines = {'TYPE': dtype}
cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
return cache[key]
make_kernel.cache = dict()
.. GENERATED FROM PYTHON SOURCE LINES 239-244
Autograd Function
~~~~~~~~~~~~~~~~~~
Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
.. GENERATED FROM PYTHON SOURCE LINES 244-265
.. code-block:: default
class _dot(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
M, Ka = a.shape
Kb, N = b.shape
assert Ka == Kb, "incompatible dimensions"
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
kernel = make_kernel(a.device, a.dtype)
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka, \
a.stride(0), b.stride(0), c.stride(0), \
grid=grid)
return c
dot = _dot.apply
.. GENERATED FROM PYTHON SOURCE LINES 266-271
Unit Test
-----------
We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
.. GENERATED FROM PYTHON SOURCE LINES 271-280
.. code-block:: default
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
c_0 = dot(a, b)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device='cuda:0', dtype=torch.float16)
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device='cuda:0', dtype=torch.float16)
True
.. GENERATED FROM PYTHON SOURCE LINES 281-327
Benchmark
--------------
Installing The CUTLASS Bindings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
To install CUTLASS, you need a recent version of cmake:
.. code-block:: bash
cd /path/to/cutlass/
git clone https://github.com/NVIDIA/cutlass.git
cd cutlass
mkdir build
cd build
wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
tar xzvf *.tar.gz
You can then install CUTLASS as follows for V100
.. code-block:: bash
./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
make -j8 install
Or as follows for A100:
.. code-block:: bash
./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
make -j8 install
Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
To re-install Triton with the updated CUTLASS bindings, run the following command:
.. code-block:: bash
export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
pip uninstall -y triton
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
Which we can test as follows:
.. GENERATED FROM PYTHON SOURCE LINES 327-333
.. code-block:: default
import triton
c_2 = triton.testing.cutlass_matmul(a, b)
print(c_2)
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([[186.7500, 195.3750, 196.1250, ..., 197.0000, 199.1250, 200.1250],
[181.8750, 181.1250, 187.2500, ..., 191.5000, 192.3750, 185.1250],
[183.0000, 192.7500, 194.3750, ..., 200.3750, 195.1250, 193.5000],
...,
[176.1250, 183.0000, 182.1250, ..., 184.7500, 190.8750, 187.5000],
[182.0000, 181.8750, 183.2500, ..., 187.8750, 190.5000, 186.2500],
[173.0000, 182.3750, 187.2500, ..., 191.2500, 187.6250, 184.5000]],
device='cuda:0', dtype=torch.float16)
True
.. GENERATED FROM PYTHON SOURCE LINES 334-339
Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
Square Matrix Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~
We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
.. GENERATED FROM PYTHON SOURCE LINES 339-368
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'CUTLASS'], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
if provider == 'cutlass':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True)
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
:alt: matmul-performance
:class: sphx-glr-single-img
.. GENERATED FROM PYTHON SOURCE LINES 369-369
As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 10.181 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 03-matrix-multiplication.py <03-matrix-multiplication.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 03-matrix-multiplication.ipynb <03-matrix-multiplication.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -51,6 +51,27 @@ Below is a gallery of tutorials for writing various basic operations with Triton
:hidden:
/getting-started/tutorials/02-fused-softmax
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- The block-level matrix multiplication operator @ - Multi-dimensional pointer arithmetic - Pro...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_03-matrix-multiplication_thumb.png
:alt: Matrix Multiplication
:ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/03-matrix-multiplication
.. raw:: html
<div class="sphx-glr-clear"></div>

View File

@@ -5,10 +5,12 @@
Computation times
=================
**00:27.706** total execution time for **getting-started_tutorials** files:
**01:10.181** total execution time for **getting-started_tutorials** files:
+-----------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:21.805 | 0.0 MB |
+-----------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:05.901 | 0.0 MB |
+-----------------------------------------------------------------------------------------+-----------+--------+
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 01:10.181 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:00.000 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:00.000 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+