Files
triton/_sources/getting-started/tutorials/01-vector-add.rst.txt
2021-04-21 01:40:29 -04:00

244 lines
6.6 KiB
ReStructuredText

.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/01-vector-add.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_01-vector-add.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_01-vector-add.py:
Vector Addition
=================
In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic programming model used by Triton
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
- The best practices for validating and benchmarking custom ops against native reference implementations
.. GENERATED FROM PYTHON SOURCE LINES 12-14
Compute Kernel
--------------------------
.. GENERATED FROM PYTHON SOURCE LINES 14-42
.. code-block:: default
import torch
import triton
@triton.jit
def _add(
X, # *Pointer* to first input vector
Y, # *Pointer* to second input vector
Z, # *Pointer* to output vector
N, # Size of the vector
**meta # Optional meta-parameters for the kernel
):
pid = triton.program_id(0)
# Create an offset for the blocks of pointers to be
# processed by this program instance
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
# Create a mask to guard memory operations against
# out-of-bounds accesses
mask = offsets < N
# Load x
x = triton.load(X + offsets, mask=mask)
y = triton.load(Y + offsets, mask=mask)
# Write back x + y
z = x + y
triton.store(Z + offsets, z)
.. GENERATED FROM PYTHON SOURCE LINES 43-45
We can also declara a helper function that handles allocating the output vector
and enqueueing the kernel.
.. GENERATED FROM PYTHON SOURCE LINES 45-63
.. code-block:: default
def add(x, y):
z = torch.empty_like(x)
N = z.shape[0]
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
# NOTE:
# - torch.tensor objects are implicitly converted to pointers to their first element.
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
_add[grid](x, y, z, N, BLOCK=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously.
return z
.. GENERATED FROM PYTHON SOURCE LINES 64-65
We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
.. GENERATED FROM PYTHON SOURCE LINES 65-76
.. code-block:: default
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
za = x + y
zb = add(x, y)
print(za)
print(zb)
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0
.. GENERATED FROM PYTHON SOURCE LINES 77-78
Seems like we're good to go!
.. GENERATED FROM PYTHON SOURCE LINES 80-85
Benchmark
-----------
We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
for different problem sizes.
.. GENERATED FROM PYTHON SOURCE LINES 85-111
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton'], # possible keys for `y_name`
y_lines=["Torch", "Triton"], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)
.. GENERATED FROM PYTHON SOURCE LINES 112-114
We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or
`save_path='/path/to/results/' to save them to disk along with raw CSV data
.. GENERATED FROM PYTHON SOURCE LINES 114-114
.. code-block:: default
benchmark.run(show_plots=True)
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
:alt: 01 vector add
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 5.812 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 01-vector-add.py <01-vector-add.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 01-vector-add.ipynb <01-vector-add.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_