271 lines
7.5 KiB
ReStructuredText
271 lines
7.5 KiB
ReStructuredText
|
|
.. DO NOT EDIT.
|
|
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
|
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
|
.. "getting-started/tutorials/01-vector-add.py"
|
|
.. LINE NUMBERS ARE GIVEN BELOW.
|
|
|
|
.. only:: html
|
|
|
|
.. note::
|
|
:class: sphx-glr-download-link-note
|
|
|
|
Click :ref:`here <sphx_glr_download_getting-started_tutorials_01-vector-add.py>`
|
|
to download the full example code
|
|
|
|
.. rst-class:: sphx-glr-example-title
|
|
|
|
.. _sphx_glr_getting-started_tutorials_01-vector-add.py:
|
|
|
|
|
|
Vector Addition
|
|
=================
|
|
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
|
|
|
- The basic programming model of Triton
|
|
- The `triton.jit` decorator, which is used to define Triton kernels.
|
|
- The best practices for validating and benchmarking your custom ops against native reference implementations
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 12-14
|
|
|
|
Compute Kernel
|
|
--------------------------
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 14-43
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
import torch
|
|
import triton.language as tl
|
|
import triton
|
|
|
|
|
|
@triton.jit
|
|
def _add(
|
|
X, # *Pointer* to first input vector
|
|
Y, # *Pointer* to second input vector
|
|
Z, # *Pointer* to output vector
|
|
N, # Size of the vector
|
|
**meta # Optional meta-parameters for the kernel
|
|
):
|
|
pid = tl.program_id(0)
|
|
# Create an offset for the blocks of pointers to be
|
|
# processed by this program instance
|
|
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
|
# Create a mask to guard memory operations against
|
|
# out-of-bounds accesses
|
|
mask = offsets < N
|
|
# Load x
|
|
x = tl.load(X + offsets, mask=mask)
|
|
y = tl.load(Y + offsets, mask=mask)
|
|
# Write back x + y
|
|
z = x + y
|
|
tl.store(Z + offsets, z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 44-46
|
|
|
|
Let's also declare a helper function to (1) allocate the `z` tensor
|
|
and (2) enqueue the above kernel with appropriate grid/block sizes.
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 46-64
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
def add(x, y):
|
|
z = torch.empty_like(x)
|
|
N = z.shape[0]
|
|
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
|
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
|
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
|
# NOTE:
|
|
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
|
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
|
# - don't forget to pass meta-parameters as keywords arguments
|
|
_add[grid](x, y, z, N, BLOCK=1024)
|
|
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
|
# running asynchronously at this point.
|
|
return z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 65-66
|
|
|
|
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 66-77
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
torch.manual_seed(0)
|
|
size = 98432
|
|
x = torch.rand(size, device='cuda')
|
|
y = torch.rand(size, device='cuda')
|
|
za = x + y
|
|
zb = add(x, y)
|
|
print(za)
|
|
print(zb)
|
|
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
|
|
|
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-script-out
|
|
|
|
Out:
|
|
|
|
.. code-block:: none
|
|
|
|
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
|
|
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
|
|
The maximum difference between torch and triton is 0.0
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 78-79
|
|
|
|
Seems like we're good to go!
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 81-86
|
|
|
|
Benchmark
|
|
-----------
|
|
We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
|
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
|
for different problem sizes.
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 86-113
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=['size'], # argument names to use as an x-axis for the plot
|
|
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
|
x_log=True, # x axis is logarithmic
|
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
|
line_vals=['triton', 'torch'], # possible values for `line_arg`
|
|
line_names=["Triton", "Torch"], # label name for the lines
|
|
styles=[('blue', '-'), ('green', '-')], # line styles
|
|
ylabel="GB/s", # label name for the y-axis
|
|
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
|
args={} # values for function arguments not in `x_names` and `y_name`
|
|
)
|
|
)
|
|
def benchmark(size, provider):
|
|
x = torch.rand(size, device='cuda', dtype=torch.float32)
|
|
y = torch.rand(size, device='cuda', dtype=torch.float32)
|
|
if provider == 'torch':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
|
|
if provider == 'triton':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
|
|
gbps = lambda ms: 12 * size / ms * 1e-6
|
|
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 114-116
|
|
|
|
We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
|
|
`save_path='/path/to/results/' to save them to disk along with raw CSV data
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 116-116
|
|
|
|
.. code-block:: default
|
|
|
|
benchmark.run(print_data=True, show_plots=True)
|
|
|
|
|
|
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
|
|
:alt: 01 vector add
|
|
:class: sphx-glr-single-img
|
|
|
|
|
|
.. rst-class:: sphx-glr-script-out
|
|
|
|
Out:
|
|
|
|
.. code-block:: none
|
|
|
|
vector-add-performance:
|
|
size Triton Torch
|
|
0 4096.0 9.600000 9.600000
|
|
1 8192.0 19.200000 19.200000
|
|
2 16384.0 38.400001 38.400001
|
|
3 32768.0 76.800002 76.800002
|
|
4 65536.0 127.999995 127.999995
|
|
5 131072.0 219.428568 219.428568
|
|
6 262144.0 384.000001 384.000001
|
|
7 524288.0 472.615390 472.615390
|
|
8 1048576.0 614.400016 614.400016
|
|
9 2097152.0 722.823517 722.823517
|
|
10 4194304.0 780.190482 780.190482
|
|
11 8388608.0 819.200021 812.429770
|
|
12 16777216.0 833.084721 833.084721
|
|
13 33554432.0 843.811163 843.811163
|
|
14 67108864.0 849.278610 848.362445
|
|
15 134217728.0 851.577704 850.656574
|
|
|
|
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-timing
|
|
|
|
**Total running time of the script:** ( 0 minutes 10.996 seconds)
|
|
|
|
|
|
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
|
|
|
|
|
.. only :: html
|
|
|
|
.. container:: sphx-glr-footer
|
|
:class: sphx-glr-footer-example
|
|
|
|
|
|
|
|
.. container:: sphx-glr-download sphx-glr-download-python
|
|
|
|
:download:`Download Python source code: 01-vector-add.py <01-vector-add.py>`
|
|
|
|
|
|
|
|
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
|
|
|
:download:`Download Jupyter notebook: 01-vector-add.ipynb <01-vector-add.ipynb>`
|
|
|
|
|
|
.. only:: html
|
|
|
|
.. rst-class:: sphx-glr-signature
|
|
|
|
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|