Files
triton/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt

441 lines
18 KiB
Plaintext
Raw Normal View History

2021-03-15 13:58:20 -04:00
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/03-matrix-multiplication.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_03-matrix-multiplication.py:
Matrix Multiplication
======================
2021-07-23 04:39:46 +00:00
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
2021-03-15 13:58:20 -04:00
You will specifically learn about:
2021-04-21 01:40:29 -04:00
- Block-level matrix multiplications
2021-03-15 13:58:20 -04:00
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
2021-04-21 01:40:29 -04:00
.. GENERATED FROM PYTHON SOURCE LINES 14-37
2021-03-15 13:58:20 -04:00
Motivations
-------------
Matrix multiplications are a key building block of most modern high-performance computing systems.
2021-07-23 04:39:46 +00:00
They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
2021-03-15 13:58:20 -04:00
Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
.. code-block:: python
# do in parallel
2021-04-21 01:40:29 -04:00
for m in range(0, M, BLOCK_M):
2021-03-15 13:58:20 -04:00
# do in parallel
2021-04-21 01:40:29 -04:00
for n in range(0, N, BLOCK_N):
acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
for k in range(0, K, BLOCK_K):
a = A[m : m+BLOCK_M, k : k+BLOCK_K]
b = B[k : k+BLOCK_K, n : n+BLOCK_N]
acc += dot(a, b)
C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
.. GENERATED FROM PYTHON SOURCE LINES 39-110
2021-03-15 13:58:20 -04:00
Compute Kernel
----------------
2021-07-23 04:39:46 +00:00
The above algorithm is, actually, fairly straightforward to implement in Triton.
The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
2021-03-15 13:58:20 -04:00
Pointer Arithmetics
~~~~~~~~~~~~~~~~~~~~
2021-04-21 01:40:29 -04:00
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
2021-03-15 13:58:20 -04:00
.. code-block:: python
2021-07-23 04:39:46 +00:00
&A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
.. code-block:: python
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
pid_m = triton.program_id(0)
pid_n = triton.program_id(1)
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
// pointer for A operand
pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
// pointer for B operand
pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
And then updated in the inner loop as follows:
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
.. code-block:: python
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
pa += BLOCK_K * stride_a_1;
pb += BLOCK_K * stride_b_0;
2021-03-15 13:58:20 -04:00
L2 Cache Optimizations
~~~~~~~~~~~~~~~~~~~~~~~~
2021-04-21 01:40:29 -04:00
As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
2021-07-23 04:39:46 +00:00
It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
And unfortunately, a simple row-major ordering
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
.. code-block:: Python
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
pid = triton.program_id(0);
grid_m = (M + BLOCK_M - 1) // BLOCK_M;
grid_n = (N + BLOCK_N - 1) // BLOCK_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
is just not going to cut it.
2021-03-15 13:58:20 -04:00
One possible solution is to launch blocks in an order that promotes data reuse.
2021-04-21 01:40:29 -04:00
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
.. code-block:: python
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
pid = triton.program_id(0);
width = GROUP_M * grid_n;
group_id = pid // width;
# we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
pid_m = group_id * GROUP_M + (pid % group_size);
pid_n = (pid % width) // (group_size);
2021-03-15 13:58:20 -04:00
In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
2021-04-21 01:40:29 -04:00
.. GENERATED FROM PYTHON SOURCE LINES 112-115
Final Result
-------------
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 115-190
2021-03-15 13:58:20 -04:00
.. code-block:: default
import torch
import triton
2021-04-23 16:42:55 -04:00
import triton.language as tl
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
@triton.autotune(
configs=[
2021-07-23 04:39:46 +00:00
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
2021-04-21 01:40:29 -04:00
],
key=['M', 'N', 'K'],
)
# %
# We can now define our kernel as normal, using all the techniques presented above
@triton.jit
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
# extract meta-parameters
BLOCK_M = META['BLOCK_M']
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = 8
# matrix multiplication
2021-04-23 16:42:55 -04:00
pid = tl.program_id(0)
2021-04-21 01:40:29 -04:00
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
2021-04-23 16:42:55 -04:00
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
2021-04-21 01:40:29 -04:00
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
2021-04-23 16:42:55 -04:00
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
2021-04-21 01:40:29 -04:00
for k in range(K, 0, -BLOCK_K):
2021-04-23 16:42:55 -04:00
a = tl.load(A)
b = tl.load(B)
acc += tl.dot(a, b)
2021-04-21 01:40:29 -04:00
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# triton can accept arbitrary activation function
# via metaparameters!
if META['ACTIVATION']:
acc = META['ACTIVATION'](acc)
# rematerialize rm and rn to save registers
2021-04-23 16:42:55 -04:00
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
2021-04-21 01:40:29 -04:00
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm[:, None] < M) & (rn[None, :] < N)
2021-04-23 16:42:55 -04:00
tl.store(C, acc, mask=mask)
2021-04-21 01:40:29 -04:00
2021-07-23 04:39:46 +00:00
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01*x)
2021-04-21 01:40:29 -04:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 191-193
2021-04-21 01:40:29 -04:00
2021-07-23 04:39:46 +00:00
We can now create a convenience wrapper function that only takes two input tensors
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 193-214
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
.. code-block:: default
2021-03-15 13:58:20 -04:00
2021-04-21 01:40:29 -04:00
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
2021-07-23 04:39:46 +00:00
pgm = _matmul[grid](
2021-04-21 01:40:29 -04:00
a, b, c, M, N, K, \
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation
)
2021-07-23 04:39:46 +00:00
# done; return the output tensor
2021-04-21 01:40:29 -04:00
return c
2021-03-15 13:58:20 -04:00
2021-04-23 16:42:55 -04:00
.. GENERATED FROM PYTHON SOURCE LINES 215-219
2021-03-15 13:58:20 -04:00
Unit Test
-----------
2021-07-23 04:39:46 +00:00
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
2021-03-15 13:58:20 -04:00
2021-04-23 16:42:55 -04:00
.. GENERATED FROM PYTHON SOURCE LINES 219-229
2021-03-15 13:58:20 -04:00
.. code-block:: default
2021-07-23 04:39:46 +00:00
torch.manual_seed(0)
2021-04-21 01:40:29 -04:00
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
2021-07-23 04:39:46 +00:00
c_0 = matmul(a, b, activation=None)
c_1 = torch.matmul(a, b)
2021-03-15 13:58:20 -04:00
print(c_0)
print(c_1)
2021-04-21 01:40:29 -04:00
print(triton.testing.allclose(c_0, c_1))
2021-03-15 13:58:20 -04:00
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
2021-07-23 04:39:46 +00:00
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
2021-03-15 13:58:20 -04:00
...,
2021-07-23 04:39:46 +00:00
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
2021-03-15 13:58:20 -04:00
...,
2021-07-23 04:39:46 +00:00
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
2021-04-21 01:40:29 -04:00
tensor(True, device='cuda:0')
2021-03-15 13:58:20 -04:00
2021-04-23 16:42:55 -04:00
.. GENERATED FROM PYTHON SOURCE LINES 230-236
2021-03-15 13:58:20 -04:00
Benchmark
--------------
Square Matrix Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~
2021-07-23 04:39:46 +00:00
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 236-268
2021-03-15 13:58:20 -04:00
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
2021-07-23 04:39:46 +00:00
x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
2021-04-23 16:42:55 -04:00
line_arg='provider', # argument name whose value corresponds to a different line in the plot
2021-07-23 04:39:46 +00:00
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
2021-03-15 13:58:20 -04:00
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
2021-03-29 11:59:18 -04:00
if provider == 'cublas':
2021-03-15 13:58:20 -04:00
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
2021-04-21 01:40:29 -04:00
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
2021-07-23 04:39:46 +00:00
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
2021-03-15 13:58:20 -04:00
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
2021-04-23 16:42:55 -04:00
benchmark.run(show_plots=True, print_data=True)
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
2021-03-15 13:58:20 -04:00
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
2021-03-29 11:59:18 -04:00
:alt: 03 matrix multiplication
2021-03-15 13:58:20 -04:00
:class: sphx-glr-single-img
2021-04-21 01:40:29 -04:00
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
2021-03-15 13:58:20 -04:00
2021-07-23 04:39:46 +00:00
matmul-performance:
M cuBLAS cuBLAS (+ torch.nn.LeakyReLU) Triton Triton (+ LeakyReLU)
0 128.0 0.455111 0.372364 0.512000 0.512000
1 256.0 2.978909 2.340571 3.276800 2.978909
2 384.0 7.372800 6.144000 8.507077 8.507077
3 512.0 14.563555 11.915636 16.384000 16.384000
4 640.0 22.260869 18.285714 23.272727 23.272727
5 768.0 32.768000 26.810182 34.028308 34.028308
6 896.0 39.025776 32.672744 39.025776 39.025776
7 1024.0 49.932191 41.943041 52.428801 52.428801
8 1152.0 44.566925 38.779015 46.656000 46.656000
9 1280.0 51.200001 44.521738 56.109587 56.109587
10 1408.0 64.138541 55.068446 65.684049 59.258433
11 1536.0 79.526831 67.408458 75.296679 75.296679
12 1664.0 63.372618 55.893862 61.636381 61.636381
13 1792.0 72.983276 63.860363 68.953520 68.953520
14 1920.0 66.782607 61.168141 68.776119 68.776119
15 2048.0 73.262953 65.793006 75.234154 75.234154
16 2176.0 82.473969 73.712993 79.540109 79.855747
17 2304.0 68.251065 62.207998 73.051599 73.051599
18 2432.0 71.305746 65.033481 80.963875 80.963875
19 2560.0 77.649287 70.773218 76.560748 75.851852
20 2688.0 82.463163 75.413632 82.106182 80.880718
21 2816.0 82.602666 73.424595 78.442822 77.330158
22 2944.0 82.784108 72.966370 80.122235 80.122235
23 3072.0 79.638683 74.997490 79.082550 82.903517
24 3200.0 84.099871 78.335374 89.385477 85.333333
25 3328.0 83.226931 77.828428 81.346098 81.530349
26 3456.0 79.351933 75.276907 82.858753 81.435930
27 3584.0 87.466332 81.518940 95.858629 91.470385
28 3712.0 84.230479 79.283603 81.682211 85.455380
29 3840.0 84.421376 79.562590 87.355452 87.562949
30 3968.0 93.006050 86.296981 84.038524 84.504108
31 4096.0 93.662059 87.381330 83.729089 92.119235
2021-03-15 13:58:20 -04:00
.. rst-class:: sphx-glr-timing
2021-07-23 04:39:46 +00:00
**Total running time of the script:** ( 2 minutes 12.630 seconds)
2021-03-15 13:58:20 -04:00
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 03-matrix-multiplication.py <03-matrix-multiplication.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 03-matrix-multiplication.ipynb <03-matrix-multiplication.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_