Let us start with something simple, and see how Triton can be used to create a custom vector addition for PyTorch. The Triton compute kernel for this operation is the following:
..code-block:: C
// Triton
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float* z, float* x, float* y, int N){
As you can see, arrays are first-class citizen in Triton. This has a number of important advantages that will be highlighted in the next tutorial. For now, let's keep it simple and see how to execute the above operation in PyTorch.
---------------
PyTorch Wrapper
---------------
As you will see, a wrapper for the above Triton function can be created in just a few lines of pure python code.
..code-block:: python
import torch
import triton
class _add(triton.function):
# source-code for Triton compute kernel
src = """
__global__ void add(float* z, float* x, float* y, int N){
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.