[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2021-08-05 23:10:57 +00:00
parent 08f1b16893
commit c0858c5dac
23 changed files with 1185 additions and 697 deletions

View File

@@ -31,37 +31,43 @@ In this tutorial, you will write a simple vector addition using Triton and learn
Compute Kernel
--------------------------
.. GENERATED FROM PYTHON SOURCE LINES 14-43
.. GENERATED FROM PYTHON SOURCE LINES 14-49
.. code-block:: default
import torch
import triton.language as tl
import triton
import triton.language as tl
@triton.jit
def _add(
X, # *Pointer* to first input vector
Y, # *Pointer* to second input vector
Z, # *Pointer* to output vector
N, # Size of the vector
**meta # Optional meta-parameters for the kernel
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
):
pid = tl.program_id(0)
# Create an offset for the blocks of pointers to be
# processed by this program instance
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against
# out-of-bounds accesses
mask = offsets < N
# Load x
x = tl.load(X + offsets, mask=mask)
y = tl.load(Y + offsets, mask=mask)
# Write back x + y
z = x + y
tl.store(Z + offsets, z)
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output)
@@ -71,31 +77,34 @@ Compute Kernel
.. GENERATED FROM PYTHON SOURCE LINES 44-46
.. GENERATED FROM PYTHON SOURCE LINES 50-52
Let's also declare a helper function to (1) allocate the `z` tensor
and (2) enqueue the above kernel with appropriate grid/block sizes.
.. GENERATED FROM PYTHON SOURCE LINES 46-64
.. GENERATED FROM PYTHON SOURCE LINES 52-73
.. code-block:: default
def add(x, y):
z = torch.empty_like(x)
N = z.shape[0]
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.shape[0]
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
# In this case, we use a 1D grid where the size is the number of blocks
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
_add[grid](x, y, z, N, BLOCK=1024)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return z
return output
@@ -105,11 +114,11 @@ and (2) enqueue the above kernel with appropriate grid/block sizes.
.. GENERATED FROM PYTHON SOURCE LINES 65-66
.. GENERATED FROM PYTHON SOURCE LINES 74-75
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
.. GENERATED FROM PYTHON SOURCE LINES 66-77
.. GENERATED FROM PYTHON SOURCE LINES 75-89
.. code-block:: default
@@ -118,11 +127,14 @@ We can now use the above function to compute the element-wise sum of two `torch.
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
za = x + y
zb = add(x, y)
print(za)
print(zb)
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
@@ -141,11 +153,11 @@ We can now use the above function to compute the element-wise sum of two `torch.
.. GENERATED FROM PYTHON SOURCE LINES 78-79
.. GENERATED FROM PYTHON SOURCE LINES 90-91
Seems like we're good to go!
.. GENERATED FROM PYTHON SOURCE LINES 81-86
.. GENERATED FROM PYTHON SOURCE LINES 93-98
Benchmark
-----------
@@ -153,7 +165,7 @@ We can now benchmark our custom op on vectors of increasing sizes to get a sense
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
for different problem sizes.
.. GENERATED FROM PYTHON SOURCE LINES 86-113
.. GENERATED FROM PYTHON SOURCE LINES 98-127
.. code-block:: default
@@ -162,15 +174,17 @@ for different problem sizes.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
x_vals=[
2 ** i for i in range(12, 28, 1)
], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=["Triton", "Torch"], # label name for the lines
line_names=['Triton', 'Torch'], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name`
ylabel='GB/s', # label name for the y-axis
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(size, provider):
@@ -191,18 +205,19 @@ for different problem sizes.
.. GENERATED FROM PYTHON SOURCE LINES 114-116
.. GENERATED FROM PYTHON SOURCE LINES 128-130
We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
`save_path='/path/to/results/' to save them to disk along with raw CSV data
.. GENERATED FROM PYTHON SOURCE LINES 116-116
.. GENERATED FROM PYTHON SOURCE LINES 130-131
.. code-block:: default
benchmark.run(print_data=True, show_plots=True)
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
:alt: 01 vector add
:class: sphx-glr-single-img
@@ -218,16 +233,16 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
size Triton Torch
0 4096.0 9.600000 9.600000
1 8192.0 19.200000 19.200000
2 16384.0 38.400001 38.400001
3 32768.0 76.800002 76.800002
2 16384.0 31.999999 31.999999
3 32768.0 63.999998 76.800002
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 384.000001 384.000001
6 262144.0 341.333321 384.000001
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
10 4194304.0 780.190482 780.190482
11 8388608.0 819.200021 812.429770
11 8388608.0 812.429770 812.429770
12 16777216.0 833.084721 833.084721
13 33554432.0 843.811163 843.811163
14 67108864.0 849.278610 848.362445
@@ -239,7 +254,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 10.996 seconds)
**Total running time of the script:** ( 0 minutes 11.055 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:

View File

@@ -20,20 +20,22 @@
Fused Softmax
=================
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
In this tutorial, you will write a fused softmax operation that is significantly faster
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
the GPU's SRAM.
You will learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- Reduction operators in Triton.
.. GENERATED FROM PYTHON SOURCE LINES 12-16
.. GENERATED FROM PYTHON SOURCE LINES 14-18
Motivations
------------
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
.. GENERATED FROM PYTHON SOURCE LINES 16-37
.. GENERATED FROM PYTHON SOURCE LINES 18-43
.. code-block:: default
@@ -41,9 +43,13 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
import torch
# Compute the row-wise softmax of x
@torch.jit.script
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read 2MN elements ; write MN elements
@@ -65,22 +71,28 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
.. GENERATED FROM PYTHON SOURCE LINES 38-42
.. GENERATED FROM PYTHON SOURCE LINES 44-52
When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only :math:`MN` bytes, so we could
expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
but, as we will see later, it is still far from ideal.
.. GENERATED FROM PYTHON SOURCE LINES 44-49
.. GENERATED FROM PYTHON SOURCE LINES 54-61
Compute Kernel
----------------
Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
Our softmax kernel works as follows: each program loads a row of the input matrix X,
normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a
power-of-two number of elements, so we need to internally "pad" each row and guard the
memory operations properly if we want to handle any possible input shapes:
.. GENERATED FROM PYTHON SOURCE LINES 49-77
.. GENERATED FROM PYTHON SOURCE LINES 61-94
.. code-block:: default
@@ -90,26 +102,31 @@ so we need to internally "pad" each row and guard the memory operations properly
@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index
m = tl.program_id(0)
# col indices
# here BLOCK is the smallest power of two greater than `N`
n = tl.arange(0, meta['BLOCK'])
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n
x = tl.load(X, mask=n < N, other=-float('inf'))
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Substract maximum for numerical stability
z = x - tl.max(x, axis=0)
# Note that exponentials in Triton are fast
# but approximate (i.e., think __expf in CUDA)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# Write back to Y
Y = Y + m * stride_ym + n
tl.store(Y, y, mask=n < N)
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
@@ -119,17 +136,18 @@ so we need to internally "pad" each row and guard the memory operations properly
.. GENERATED FROM PYTHON SOURCE LINES 78-79
.. GENERATED FROM PYTHON SOURCE LINES 95-96
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
.. GENERATED FROM PYTHON SOURCE LINES 79-110
.. GENERATED FROM PYTHON SOURCE LINES 96-139
.. code-block:: default
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
@@ -141,20 +159,31 @@ We can create a helper function that enqueues the kernel and its (meta-)argument
def softmax(x):
M, N = x.shape
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK = next_power_of_2(N)
BLOCK_SIZE = next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
@@ -165,26 +194,26 @@ We can create a helper function that enqueues the kernel and its (meta-)argument
.. GENERATED FROM PYTHON SOURCE LINES 111-113
.. GENERATED FROM PYTHON SOURCE LINES 140-142
Unit Test
----------
.. GENERATED FROM PYTHON SOURCE LINES 115-117
.. GENERATED FROM PYTHON SOURCE LINES 144-146
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 117-124
.. GENERATED FROM PYTHON SOURCE LINES 146-153
.. code-block:: default
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_tri = softmax(x)
y_ref = torch.softmax(x, axis=1)
print(torch.allclose(y_tri, y_ref))
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
print(torch.allclose(y_triton, y_torch))
@@ -201,18 +230,18 @@ This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 125-126
.. GENERATED FROM PYTHON SOURCE LINES 154-155
As expected, the results are identical.
.. GENERATED FROM PYTHON SOURCE LINES 128-132
.. GENERATED FROM PYTHON SOURCE LINES 157-161
Benchmark
-------------
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
.. GENERATED FROM PYTHON SOURCE LINES 132-161
.. GENERATED FROM PYTHON SOURCE LINES 161-200
.. code-block:: default
@@ -221,14 +250,24 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
x_vals=[
128 * i for i in range(2, 100)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
line_vals=[
'triton',
'torch-native',
'torch-jit',
], # possible values for `line_arg``
line_names=[
"Triton",
"Torch (native)",
"Torch (jit)",
], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
@@ -263,22 +302,22 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
N Triton Torch (native) Torch (jit)
0 256.0 512.000001 546.133347 186.181817
1 384.0 585.142862 585.142862 153.600004
2 512.0 630.153853 585.142849 154.566038
3 640.0 682.666684 640.000002 160.000000
2 512.0 630.153853 606.814814 154.566038
3 640.0 660.645170 640.000002 160.000000
4 768.0 702.171410 664.216187 163.839992
.. ... ... ... ...
93 12160.0 812.359066 406.179533 199.140227
94 12288.0 812.429770 415.661740 199.399583
95 12416.0 810.840807 412.149375 199.054102
96 12544.0 810.925276 412.971190 199.308841
97 12672.0 811.007961 412.097543 199.264875
93 12160.0 812.359066 406.179533 199.038365
94 12288.0 812.429770 415.222812 199.298541
95 12416.0 810.840807 412.149375 198.854847
96 12544.0 810.925276 412.971190 199.209928
97 12672.0 809.389265 412.097543 199.167004
[98 rows x 4 columns]
.. GENERATED FROM PYTHON SOURCE LINES 162-167
.. GENERATED FROM PYTHON SOURCE LINES 201-207
In the above plot, we can see that:
@@ -290,7 +329,7 @@ In the above plot, we can see that:
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 12.626 seconds)
**Total running time of the script:** ( 1 minutes 13.186 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:

View File

@@ -20,58 +20,67 @@
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
kernel that achieves performance on par with cuBLAS.
You will specifically learn about:
- Block-level matrix multiplications
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
.. GENERATED FROM PYTHON SOURCE LINES 14-37
.. GENERATED FROM PYTHON SOURCE LINES 15-42
Motivations
-------------
Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
They are notoriously hard to optimize, hence their implementation is generally done by
hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be easily customized
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
In this tutorial, you will learn how to implement efficient matrix multiplications by
yourself with Triton, in a way that is easy to customize and extend.
Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
Roughly speaking, the kernel that we will write will implement the following blocked
algorithm to multiply a (MxK) by a (KxN) matrix:
.. code-block:: python
# do in parallel
for m in range(0, M, BLOCK_M):
for m in range(0, M, BLOCK_SIZE_M):
# do in parallel
for n in range(0, N, BLOCK_N):
acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
for k in range(0, K, BLOCK_K):
a = A[m : m+BLOCK_M, k : k+BLOCK_K]
b = B[k : k+BLOCK_K, n : n+BLOCK_N]
for n in range(0, N, BLOCK_SIZE_N):
acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
for k in range(0, K, BLOCK_SIZE_K):
a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
acc += dot(a, b)
C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
.. GENERATED FROM PYTHON SOURCE LINES 39-110
.. GENERATED FROM PYTHON SOURCE LINES 44-119
Compute Kernel
----------------
The above algorithm is, actually, fairly straightforward to implement in Triton.
The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
The main difficulty comes from the computation of the memory locations at which blocks
of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
multi-dimensional pointer arithmetics.
Pointer Arithmetics
~~~~~~~~~~~~~~~~~~~~
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
y :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
.. code-block:: python
&A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
@@ -79,9 +88,9 @@ Which means that pointers for blocks of A and B can be initialized (i.e., :code:
pid_m = triton.program_id(0)
pid_n = triton.program_id(1)
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)
rk = triton.arange(0, BLOCK_SIZE_K)
// pointer for A operand
pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
// pointer for B operand
@@ -91,50 +100,72 @@ And then updated in the inner loop as follows:
.. code-block:: python
pa += BLOCK_K * stride_a_1;
pb += BLOCK_K * stride_b_0;
pa += BLOCK_SIZE_K * stride_a_1;
pb += BLOCK_SIZE_K * stride_b_0;
L2 Cache Optimizations
~~~~~~~~~~~~~~~~~~~~~~~~
As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
And unfortunately, a simple row-major ordering
As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
block of :code:`C`.
It is important to remember that the order in which these blocks are computed does
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
a simple row-major ordering
.. code-block:: Python
pid = triton.program_id(0);
grid_m = (M + BLOCK_M - 1) // BLOCK_M;
grid_n = (N + BLOCK_N - 1) // BLOCK_N;
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;
is just not going to cut it.
One possible solution is to launch blocks in an order that promotes data reuse.
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
switching to the next column:
.. code-block:: python
pid = triton.program_id(0);
width = GROUP_M * grid_n;
group_id = pid // width;
# we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
# we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0
group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
pid_m = group_id * GROUP_M + (pid % group_size);
pid_n = (pid % width) // (group_size);
In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
.. GENERATED FROM PYTHON SOURCE LINES 119-130
.. code-block:: default
.. GENERATED FROM PYTHON SOURCE LINES 112-115
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
# we can see that if we compute the output in row-major ordering, we need to load 90
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
# ordering, we only need to load 54 blocks.
# .. image:: grouped_vs_row_major_ordering.png
#
# In practice, this can improve the performance of our matrix multiplication kernel by
# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
#
.. GENERATED FROM PYTHON SOURCE LINES 131-134
Final Result
-------------
.. GENERATED FROM PYTHON SOURCE LINES 115-190
.. GENERATED FROM PYTHON SOURCE LINES 134-263
.. code-block:: default
@@ -144,74 +175,127 @@ Final Result
import triton.language as tl
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
# decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
# %
# We can now define our kernel as normal, using all the techniques presented above
@triton.jit
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
def matmul_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
**meta,
):
"""Kernel for computing the matmul AB = C
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# extract meta-parameters
BLOCK_M = META['BLOCK_M']
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = 8
# matrix multiplication
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
a = tl.load(A)
b = tl.load(B)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# triton can accept arbitrary activation function
# via metaparameters!
if META['ACTIVATION']:
acc = META['ACTIVATION'](acc)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm[:, None] < M) & (rn[None, :] < N)
tl.store(C, acc, mask=mask)
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
GROUP_SIZE_M = 8
pid = tl.program_id(axis=0)
# the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block
# Note that this will lead to some quantization in performance where time-taken jumps
# when you need to add a new block
n_blocks_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
n_blocks_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
# Map PIDs to the block they should compute. This is done in a grouped ordering
# to promote L2 cache reuse.
n_output_blocks_in_group = GROUP_SIZE_M * n_blocks_n
group_id = pid // n_output_blocks_in_group
first_m_block_in_group = group_id * GROUP_SIZE_M
# If the number of blocks is not divisible by the group size, the last group is smaller
group_size_m = min(n_blocks_m - first_m_block_in_group, GROUP_SIZE_M)
# Within a group, we compute in col-major ordering, block_m and block_n are the
# output row and col that this program is computing in terms of blocks
block_m = first_m_block_in_group + (pid % group_size_m)
block_n = (pid % n_output_blocks_in_group) // group_size_m
# Convert from block indices back to element indices
m_start = block_m * BLOCK_SIZE_M
n_start = block_n * BLOCK_SIZE_N
# Expand out to all the offsets for each of the elements in this block.
m_offsets_a = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
n_offsets_b = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
k_offsets = tl.arange(0, BLOCK_SIZE_K)
# Get the pointers for the first block of each. We will advance this pointer
# as we move in the K direction and accumulate.
# a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers
a_ptrs = a_ptr + (stride_am * m_offsets_a + stride_ak * k_offsets[None, :])
# b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers
b_ptrs = b_ptr + (stride_bk * k_offsets[:, None] + stride_bn * n_offsets_b)
# We accumulate internally in fp32, but the output is written out in the dtype
# of the tensor when it is stored
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here. This means that if K is
# not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and
# accumulate it incorrectly.
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# triton can accept arbitrary activation function via metaparameters!
if meta['ACTIVATION']:
accumulator = meta['ACTIVATION'](accumulator)
m_offsets_c = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
n_offsets_c = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
c_ptrs = c_ptr + stride_cm * m_offsets_c + stride_cn * n_offsets_c
mask = (m_offsets_c < M) & (n_offsets_c < N)
tl.store(c_ptrs, accumulator, mask=mask)
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01*x)
return tl.where(x >= 0, x, 0.01 * x)
@@ -220,33 +304,49 @@ Final Result
.. GENERATED FROM PYTHON SOURCE LINES 191-193
.. GENERATED FROM PYTHON SOURCE LINES 264-266
We can now create a convenience wrapper function that only takes two input tensors
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
.. GENERATED FROM PYTHON SOURCE LINES 193-214
.. GENERATED FROM PYTHON SOURCE LINES 266-302
.. code-block:: default
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
M, K = a.shape
_, N = b.shape
K, N = b.shape
assert (
K % 32 == 0
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
# allocates output
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = _matmul[grid](
a, b, c, M, N, K, \
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
ACTIVATION=activation,
)
# done; return the output tensor
return c
@@ -257,14 +357,14 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
.. GENERATED FROM PYTHON SOURCE LINES 215-219
.. GENERATED FROM PYTHON SOURCE LINES 303-307
Unit Test
-----------
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
.. GENERATED FROM PYTHON SOURCE LINES 219-229
.. GENERATED FROM PYTHON SOURCE LINES 307-320
.. code-block:: default
@@ -272,11 +372,14 @@ We can test our custom matrix multiplication operation against a native torch im
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c_0 = matmul(a, b, activation=None)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(triton.testing.allclose(c_0, c_1))
triton_output = matmul(a, b, activation=None)
torch_output = torch.matmul(a, b)
print(f"{triton_output=}")
print(f"{torch_output=}")
if triton.testing.allclose(triton_output, torch_output):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
@@ -288,7 +391,7 @@ We can test our custom matrix multiplication operation against a native torch im
.. code-block:: none
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
...,
@@ -296,7 +399,7 @@ We can test our custom matrix multiplication operation against a native torch im
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
...,
@@ -304,12 +407,12 @@ We can test our custom matrix multiplication operation against a native torch im
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
tensor(True, device='cuda:0')
✅ Triton and Torch match
.. GENERATED FROM PYTHON SOURCE LINES 230-236
.. GENERATED FROM PYTHON SOURCE LINES 321-327
Benchmark
--------------
@@ -318,7 +421,7 @@ Square Matrix Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
.. GENERATED FROM PYTHON SOURCE LINES 236-268
.. GENERATED FROM PYTHON SOURCE LINES 327-368
.. code-block:: default
@@ -327,14 +430,19 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
x_vals=[
128 * i for i in range(1, 33)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
# possible values for `line_arg``
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
# label name for the lines
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
args={},
)
)
def benchmark(M, N, K, provider):
@@ -346,9 +454,13 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_relu(torch.matmul(a, b))
)
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, activation=leaky_relu)
)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
@@ -371,37 +483,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 128.0 0.455111 ... 0.512000 0.512000
1 256.0 2.730667 ... 2.978909 2.978909
2 384.0 7.372800 ... 7.899428 8.507077
3 512.0 14.563555 ... 16.384000 16.384000
1 256.0 2.978909 ... 2.978909 2.978909
2 384.0 7.372800 ... 8.507077 7.899428
3 512.0 14.563555 ... 16.384000 15.420235
4 640.0 22.260869 ... 24.380953 24.380953
5 768.0 32.768000 ... 34.028308 34.028308
6 896.0 39.025776 ... 39.025776 39.025776
6 896.0 39.025776 ... 39.025776 35.123201
7 1024.0 49.932191 ... 52.428801 52.428801
8 1152.0 45.242181 ... 46.656000 45.938215
9 1280.0 51.200001 ... 56.109587 56.109587
10 1408.0 64.138541 ... 65.684049 58.640951
11 1536.0 79.526831 ... 75.296679 75.296679
12 1664.0 63.372618 ... 61.636381 62.061463
13 1792.0 72.983276 ... 69.379162 68.953520
14 1920.0 69.467336 ... 67.434145 70.172588
15 2048.0 73.584279 ... 75.573044 74.898285
16 2176.0 83.155572 ... 80.817862 77.398646
17 2304.0 68.251065 ... 72.828879 73.051599
18 2432.0 71.305746 ... 80.963875 80.963875
19 2560.0 77.649287 ... 75.676673 74.983980
20 2688.0 83.186525 ... 84.671999 82.823267
21 2816.0 82.916747 ... 76.115547 79.733474
22 2944.0 82.237674 ... 80.771529 78.358539
23 3072.0 82.062468 ... 84.892208 82.782312
24 3200.0 84.544253 ... 88.397792 89.385477
25 3328.0 79.812967 ... 80.617354 81.071278
26 3456.0 81.518272 ... 86.970406 81.600781
27 3584.0 87.042978 ... 96.372338 90.640517
28 3712.0 84.230479 ... 82.764991 82.423549
29 3840.0 80.255442 ... 81.377484 80.783056
30 3968.0 89.329379 ... 85.932350 87.347124
31 4096.0 93.531519 ... 85.816960 91.056800
8 1152.0 44.566925 ... 46.656000 46.656000
9 1280.0 51.200001 ... 56.888887 56.109587
10 1408.0 64.138541 ... 64.902096 64.902096
11 1536.0 78.643199 ... 76.106321 76.106321
12 1664.0 62.929456 ... 62.061463 62.061463
13 1792.0 72.983276 ... 69.810085 69.379162
14 1920.0 67.764707 ... 70.530615 70.530615
15 2048.0 73.908442 ... 75.234154 74.898285
16 2176.0 83.500614 ... 81.143743 81.143743
17 2304.0 68.446623 ... 73.501144 73.501144
18 2432.0 71.305746 ... 82.147552 82.147552
19 2560.0 77.833728 ... 77.283019 77.101175
20 2688.0 81.053536 ... 81.928846 83.922689
21 2816.0 81.981598 ... 79.443003 80.320825
22 2944.0 82.373605 ... 77.385141 78.112900
23 3072.0 81.472093 ... 83.761985 79.638683
24 3200.0 84.768213 ... 88.888888 85.561498
25 3328.0 83.905938 ... 87.794262 87.156532
26 3456.0 80.220468 ... 85.676480 84.068369
27 3584.0 86.707226 ... 95.553020 94.847460
28 3712.0 83.247783 ... 84.303780 85.309435
29 3840.0 80.255442 ... 83.339866 85.005380
30 3968.0 88.938731 ... 87.409694 87.159957
31 4096.0 91.616198 ... 89.597949 89.538177
[32 rows x 5 columns]
@@ -411,7 +523,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 2 minutes 14.738 seconds)
**Total running time of the script:** ( 2 minutes 30.425 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:

View File

@@ -5,12 +5,12 @@
Computation times
=================
**03:38.360** total execution time for **getting-started_tutorials** files:
**03:54.665** total execution time for **getting-started_tutorials** files:
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:14.738 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:30.425 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.626 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:13.186 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.996 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:11.055 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+