[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2022-06-05 21:05:02 +00:00
parent a598db498f
commit fd3a9985ea
351 changed files with 43281 additions and 140 deletions

View File

@@ -0,0 +1,286 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/01-vector-add.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_01-vector-add.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_01-vector-add.py:
Vector Addition
=================
In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic programming model of Triton
- The `triton.jit` decorator, which is used to define Triton kernels.
- The best practices for validating and benchmarking your custom ops against native reference implementations
.. GENERATED FROM PYTHON SOURCE LINES 12-14
Compute Kernel
--------------------------
.. GENERATED FROM PYTHON SOURCE LINES 14-50
.. code-block:: default
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
.. GENERATED FROM PYTHON SOURCE LINES 51-53
Let's also declare a helper function to (1) allocate the `z` tensor
and (2) enqueue the above kernel with appropriate grid/block sizes.
.. GENERATED FROM PYTHON SOURCE LINES 53-74
.. code-block:: default
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
# In this case, we use a 1D grid where the size is the number of blocks
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
.. GENERATED FROM PYTHON SOURCE LINES 75-76
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
.. GENERATED FROM PYTHON SOURCE LINES 76-90
.. code-block:: default
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0
.. GENERATED FROM PYTHON SOURCE LINES 91-92
Seems like we're good to go!
.. GENERATED FROM PYTHON SOURCE LINES 94-99
Benchmark
-----------
We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
for different problem sizes.
.. GENERATED FROM PYTHON SOURCE LINES 99-128
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[
2 ** i for i in range(12, 28, 1)
], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=['Triton', 'Torch'], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel='GB/s', # label name for the y-axis
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)
.. GENERATED FROM PYTHON SOURCE LINES 129-131
We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
`save_path='/path/to/results/' to save them to disk along with raw CSV data
.. GENERATED FROM PYTHON SOURCE LINES 131-132
.. code-block:: default
benchmark.run(print_data=True, show_plots=True)
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
:alt: 01 vector add
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
vector-add-performance:
size Triton Torch
0 4096.0 9.600000 9.600000
1 8192.0 19.200000 19.200000
2 16384.0 38.400001 38.400001
3 32768.0 63.999998 63.999998
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 341.333321 341.333321
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 702.171410
10 4194304.0 780.190482 780.190482
11 8388608.0 812.429770 812.429770
12 16777216.0 833.084721 833.084721
13 33554432.0 842.004273 842.004273
14 67108864.0 847.448255 848.362445
15 134217728.0 849.737435 850.656574
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 39.514 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 01-vector-add.py <01-vector-add.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 01-vector-add.ipynb <01-vector-add.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,337 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/02-fused-softmax.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_02-fused-softmax.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_02-fused-softmax.py:
Fused Softmax
=================
In this tutorial, you will write a fused softmax operation that is significantly faster
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
the GPU's SRAM.
You will learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- Reduction operators in Triton.
.. GENERATED FROM PYTHON SOURCE LINES 14-18
Motivations
------------
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
.. GENERATED FROM PYTHON SOURCE LINES 18-46
.. code-block:: default
import torch
import triton
import triton.language as tl
@torch.jit.script
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
.. GENERATED FROM PYTHON SOURCE LINES 47-55
When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only :math:`MN` bytes, so we could
expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
but, as we will see later, it is still far from ideal.
.. GENERATED FROM PYTHON SOURCE LINES 57-64
Compute Kernel
----------------
Our softmax kernel works as follows: each program loads a row of the input matrix X,
normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a
power-of-two number of elements, so we need to internally "pad" each row and guard the
memory operations properly if we want to handle any possible input shapes:
.. GENERATED FROM PYTHON SOURCE LINES 64-93
.. code-block:: default
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Substract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
.. GENERATED FROM PYTHON SOURCE LINES 94-95
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
.. GENERATED FROM PYTHON SOURCE LINES 95-125
.. code-block:: default
def softmax(x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
.. GENERATED FROM PYTHON SOURCE LINES 126-128
Unit Test
----------
.. GENERATED FROM PYTHON SOURCE LINES 130-132
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 132-139
.. code-block:: default
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
.. GENERATED FROM PYTHON SOURCE LINES 140-141
As expected, the results are identical.
.. GENERATED FROM PYTHON SOURCE LINES 143-147
Benchmark
-------------
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
.. GENERATED FROM PYTHON SOURCE LINES 147-186
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 100)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=[
'triton',
'torch-native',
'torch-jit',
], # possible values for `line_arg``
line_names=[
"Triton",
"Torch (native)",
"Torch (jit)",
], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
if provider == 'torch-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'torch-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(show_plots=True, print_data=True)
.. image:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png
:alt: 02 fused softmax
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 512.000001 512.000001 190.511628
1 384.0 614.400016 585.142862 153.600004
2 512.0 655.360017 585.142849 154.566038
3 640.0 706.206879 640.000002 158.759699
4 768.0 722.823517 664.216187 162.754967
.. ... ... ... ...
93 12160.0 812.359066 406.179533 198.631953
94 12288.0 812.429770 415.881552 198.995960
95 12416.0 812.498981 412.149375 198.556711
96 12544.0 812.566838 412.546756 198.815254
97 12672.0 811.007961 412.097543 198.971549
[98 rows x 4 columns]
.. GENERATED FROM PYTHON SOURCE LINES 187-192
In the above plot, we can see that:
- Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
- Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 3 minutes 22.699 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,530 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/03-matrix-multiplication.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_03-matrix-multiplication.py:
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
kernel that achieves performance on par with cuBLAS.
You will specifically learn about:
- Block-level matrix multiplications
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
.. GENERATED FROM PYTHON SOURCE LINES 15-42
Motivations
-------------
Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is generally done by
hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be easily customized
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
In this tutorial, you will learn how to implement efficient matrix multiplications by
yourself with Triton, in a way that is easy to customize and extend.
Roughly speaking, the kernel that we will write will implement the following blocked
algorithm to multiply a (M, K) by a (K, N) matrix:
.. code-block:: python
# do in parallel
for m in range(0, M, BLOCK_SIZE_M):
# do in parallel
for n in range(0, N, BLOCK_SIZE_N):
acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
for k in range(0, K, BLOCK_SIZE_K):
a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
acc += dot(a, b)
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
.. GENERATED FROM PYTHON SOURCE LINES 44-137
Compute Kernel
----------------
The above algorithm is, actually, fairly straightforward to implement in Triton.
The main difficulty comes from the computation of the memory locations at which blocks
of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
multi-dimensional pointer arithmetics.
Pointer Arithmetics
~~~~~~~~~~~~~~~~~~~~
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
.. code-block:: python
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
.. code-block:: python
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
And then updated in the inner loop as follows:
.. code-block:: python
pa += BLOCK_SIZE_K * stride_ak;
pb += BLOCK_SIZE_K * stride_bk;
L2 Cache Optimizations
~~~~~~~~~~~~~~~~~~~~~~~~
As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
block of :code:`C`.
It is important to remember that the order in which these blocks are computed does
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
a simple row-major ordering
.. code-block:: Python
pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;
is just not going to cut it.
One possible solution is to launch blocks in an order that promotes data reuse.
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
switching to the next column:
.. code-block:: python
# program ID
pid = tl.program_id(axis=0)
# number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# id of the group this program is in
group_id = pid // num_pid_in_group
# row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *within groups*, programs are ordered in a column-major order
# row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m
For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
we can see that if we compute the output in row-major ordering, we need to load 90
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
ordering, we only need to load 54 blocks.
.. image:: grouped_vs_row_major_ordering.png
In practice, this can improve the performance of our matrix multiplication kernel by
more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
.. GENERATED FROM PYTHON SOURCE LINES 139-142
Final Result
-------------
.. GENERATED FROM PYTHON SOURCE LINES 142-259
.. code-block:: default
import torch
import triton
import triton.language as tl
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
# decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
# See above `L2 Cache Optimizations` section for details
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
# see above `Pointer Arithmetics` section for details
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K,
# this will access out-of-bounds memory and produce an
# error or (worse!) incorrect results.
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION:
accumulator = ACTIVATION(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
x = x + 1
return tl.where(x >= 0, x, 0.01 * x)
.. GENERATED FROM PYTHON SOURCE LINES 260-262
We can now create a convenience wrapper function that only takes two input tensors
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
.. GENERATED FROM PYTHON SOURCE LINES 262-291
.. code-block:: default
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
assert (
K % 32 == 0
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
# allocates output
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=activation,
)
return c
.. GENERATED FROM PYTHON SOURCE LINES 292-296
Unit Test
-----------
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
.. GENERATED FROM PYTHON SOURCE LINES 296-309
.. code-block:: default
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
if triton.testing.allclose(triton_output, torch_output):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
...,
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
...,
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
.. GENERATED FROM PYTHON SOURCE LINES 310-316
Benchmark
--------------
Square Matrix Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
.. GENERATED FROM PYTHON SOURCE LINES 316-357
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 33)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
# label name for the lines
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={},
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_relu(torch.matmul(a, b))
)
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, activation=leaky_relu)
)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True)
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
:alt: 03 matrix multiplication
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
1 384.0 7.372800 ... 8.507077 8.507077
2 512.0 14.563555 ... 15.420235 15.420235
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 35.389441 34.028308
5 896.0 39.025776 ... 41.321411 39.025776
6 1024.0 49.932191 ... 53.773130 52.428801
7 1152.0 45.242181 ... 48.161033 47.396572
8 1280.0 51.200001 ... 57.690139 57.690139
9 1408.0 64.138541 ... 68.147202 67.305878
10 1536.0 80.430545 ... 80.430545 78.643199
11 1664.0 63.372618 ... 63.372618 62.492442
12 1792.0 72.983276 ... 63.499573 63.142831
13 1920.0 69.120002 ... 71.626943 71.257735
14 2048.0 73.908442 ... 78.033565 76.959706
15 2176.0 83.155572 ... 87.115360 85.632545
16 2304.0 68.251065 ... 78.064941 76.809875
17 2432.0 71.125224 ... 75.522751 74.521127
18 2560.0 77.833728 ... 82.331658 80.908642
19 2688.0 84.108772 ... 90.966561 89.254248
20 2816.0 83.552120 ... 83.712490 83.552120
21 2944.0 82.237674 ... 84.324925 83.899046
22 3072.0 81.825298 ... 89.735509 89.170242
23 3200.0 84.432717 ... 96.240602 94.674553
24 3328.0 82.939284 ... 86.736504 86.113988
25 3456.0 82.688790 ... 88.014813 81.269178
26 3584.0 86.457107 ... 99.684470 99.025764
27 3712.0 83.247783 ... 89.594031 85.675250
28 3840.0 85.070769 ... 93.090912 87.980905
29 3968.0 93.648452 ... 86.114283 87.284643
30 4096.0 89.627865 ... 89.240508 93.792965
[31 rows x 5 columns]
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 6 minutes 16.582 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 03-matrix-multiplication.py <03-matrix-multiplication.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 03-matrix-multiplication.ipynb <03-matrix-multiplication.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,271 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/04-low-memory-dropout.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_04-low-memory-dropout.py:
Low-Memory Dropout
=================
In this tutorial, you will write a memory-efficient implementation of dropout whose state
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:
- The limitations of naive implementations of Dropout with PyTorch
- Parallel pseudo-random number generation in Triton
.. GENERATED FROM PYTHON SOURCE LINES 14-29
Baseline
-------------
The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
of deep neural networks in low-data regime (i.e. regularization).
It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input.
This forces the network to perform well even when only :math:`1 - p` scalars from the input are available.
At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which
keeps the norm consistent regardless of the dropout probability.
Let's first take a look at the baseline implementation.
.. GENERATED FROM PYTHON SOURCE LINES 29-82
.. code-block:: default
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
# The line below is the crucial part, described in the paragraph above!
output = tl.where(x_keep, x / (1 - p), 0.0)
# Write-back output
tl.store(output_ptr + offsets, output, mask=mask)
def dropout(x, x_keep, p):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
["input"] + x.tolist(),
["keep mask"] + x_keep.tolist(),
["output"] + output.tolist()
]))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344
keep mask 1 1 0 1 0 1 1 0 0 0
output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
.. GENERATED FROM PYTHON SOURCE LINES 83-101
Seeded dropout
-------------
Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in
https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
of persisting randomness across multiple invocations of the kernel.
Pseudorandom number generation in Triton is simple! In this tutorial we will use the
:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
other :ref:`random number generation strategies <Random Number Generation>`.
.. note::
Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).
Let's put it all together.
.. GENERATED FROM PYTHON SOURCE LINES 101-149
.. code-block:: default
@triton.jit
def _seeded_dropout(
x_ptr,
output_ptr,
n_elements,
p,
seed,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# load data from x
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# randomly prune it
random = tl.rand(seed, offsets)
x_keep = random > p
# write-back
output = tl.where(x_keep, x / (1 - p), 0.0)
tl.store(output_ptr + offsets, output, mask=mask)
def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
return output
x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)
print(tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist()
]))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
.. GENERATED FROM PYTHON SOURCE LINES 150-153
Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
to explore the `triton/language/random` folder!
.. GENERATED FROM PYTHON SOURCE LINES 155-160
Exercises
-------------
1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.
2. Add support for striding.
3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
.. GENERATED FROM PYTHON SOURCE LINES 162-167
References
--------------
.. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
.. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.484 seconds)
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 04-low-memory-dropout.py <04-low-memory-dropout.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 04-low-memory-dropout.ipynb <04-low-memory-dropout.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,420 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/05-layer-norm.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_05-layer-norm.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_05-layer-norm.py:
Layer Normalization
====================
.. GENERATED FROM PYTHON SOURCE LINES 5-312
.. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png
:alt: 05 layer norm
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
layer-norm:
N Triton Torch Apex
0 1024.0 585.142849 277.694907 468.114273
1 1536.0 630.153868 323.368435 511.999982
2 2048.0 682.666643 334.367358 520.126988
3 2560.0 694.237267 362.477870 512.000013
4 3072.0 712.347810 375.206126 501.551037
5 3584.0 725.873439 384.859062 458.751978
6 4096.0 728.177767 381.023256 458.293714
7 4608.0 670.254540 394.267384 426.173427
8 5120.0 688.403381 397.669909 426.666652
9 5632.0 704.000002 395.228063 413.357796
10 6144.0 702.171410 402.885254 409.600010
11 6656.0 705.271522 398.861429 400.360920
12 7168.0 690.891575 396.844306 387.459443
13 7680.0 686.480466 392.587863 387.634072
14 8192.0 636.271854 393.609605 371.308771
15 8704.0 630.153861 389.005597 380.502740
16 9216.0 609.322328 407.337026 383.999986
17 9728.0 589.575753 409.599987 383.369452
18 10240.0 568.888869 408.578556 382.803739
19 10752.0 551.384634 411.559798 381.445676
20 11264.0 536.380957 406.826188 373.134567
21 11776.0 523.377770 409.599991 377.587162
22 12288.0 517.389457 413.911572 383.251457
23 12800.0 505.679014 410.420828 376.470582
24 13312.0 494.180982 405.699062 376.310952
25 13824.0 482.934503 411.888257 379.389355
26 14336.0 471.967074 406.695045 374.185964
27 14848.0 461.297068 408.192434 374.712936
28 15360.0 454.269882 406.214870 378.092307
29 15872.0 447.887117 406.974373 376.225175
|
.. code-block:: default
import torch
import triton
import triton.language as tl
try:
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
# should not be added to extras_require in setup.py.
import apex
HAS_APEX = True
except ModuleNotFoundError:
HAS_APEX = False
@triton.jit
def _layer_norm_fwd_fused(
Out,
A,
Weight,
Bias,
Mean, Rstd,
stride, N, eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# write-back mean/rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(
_DA,
_DOut,
_A,
Weight,
Mean, Rstd,
stride, NumRows, NumCols, eps,
BLOCK_SIZE_N: tl.constexpr,
):
# position of elements processed by this program
pid = tl.program_id(0)
row = pid
A = _A + row * stride
DOut = _DOut + row * stride
DA = _DA + row * stride
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# load data to SRAM
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
_mean1 += a_hat * wdout
_mean2 += wdout
mean1 = tl.sum(_mean1, axis=0) / NumCols
mean2 = 0.
mean2 = tl.sum(_mean2, axis=0) / NumCols
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DA + cols, da, mask=mask)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(
A, DOut,
Mean, Var,
DW,
DB,
M, N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
tl.store(DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, a, normalized_shape, weight, bias, eps):
# allocate output
out = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.reshape(-1, a.shape[-1])
M, N = a_arg.shape
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_layer_norm_fwd_fused[(M,)](
out,
a_arg,
weight,
bias,
mean, rstd,
a_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.save_for_backward(
a, weight, bias, mean, rstd,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
if hasattr(bias, "config"):
assert bias.config.grad_scale_name == weight.config.grad_scale_name
grad_scale_name = bias.config.grad_scale_name
else:
grad_scale_name = None
ctx.grad_scale_gain_bias_name = grad_scale_name
return out
@staticmethod
def backward(ctx, dout):
assert dout.is_contiguous()
a, weight, bias, mean, var = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = weight.shape[0]
# allocate output
da = torch.empty_like(dout)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = a.reshape(-1, a.shape[-1])
M, N = x_arg.shape
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
_layer_norm_bwd_dx_fused[(M,)](
da,
dout,
a,
weight,
mean, var,
x_arg.stride(0), M, N,
ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
# accumulate partial sums in separate kernel
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
mean, var,
dweight,
dbias,
M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128,
)
return (da, None, dweight, dbias, None, None,
None, None, None, None,
None,
None, None, None,
None,
None, None, None,
None, None, None,
None, None, None)
def layer_norm(a, normalized_shape, weight, bias, eps):
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
torch.manual_seed(0)
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[512 * i for i in range(2, 32)],
line_arg='provider',
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
# test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 5 minutes 26.747 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,152 @@
:orphan:
.. _sphx_glr_getting-started_tutorials:
Tutorials
==================
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
To install the dependencies for the tutorials:
.. code-block:: bash
cd triton
pip install -e './python[tutorials]'
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model of Triton - The triton.jit decorator, which is used to define Tri...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_01-vector-add_thumb.png
:alt: Vector Addition
:ref:`sphx_glr_getting-started_tutorials_01-vector-add.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/01-vector-add
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - Reduction operators in Triton...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_02-fused-softmax_thumb.png
:alt: Fused Softmax
:ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/02-fused-softmax
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- Block-level matrix multiplications - Multi-dimensional pointer arithmetic - Program re-orderi...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_03-matrix-multiplication_thumb.png
:alt: Matrix Multiplication
:ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/03-matrix-multiplication
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, you will write a memory-efficient implementation of dropout whose state will ...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_04-low-memory-dropout_thumb.png
:alt: Low-Memory Dropout
:ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/04-low-memory-dropout
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Layer Normalization">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_05-layer-norm_thumb.png
:alt: Layer Normalization
:ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/05-layer-norm
.. raw:: html
<div class="sphx-glr-clear"></div>
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-gallery
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download all examples in Python source code: tutorials_python.zip </getting-started/tutorials/tutorials_python.zip>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download all examples in Jupyter notebooks: tutorials_jupyter.zip </getting-started/tutorials/tutorials_jupyter.zip>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,20 @@
:orphan:
.. _sphx_glr_getting-started_tutorials_sg_execution_times:
Computation times
=================
**16:46.026** total execution time for **getting-started_tutorials** files:
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 06:16.582 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:26.747 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:22.699 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:39.514 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.484 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+