Vector Addition¶
In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn: * The basic syntax of the Triton programming language * The best practices for creating PyTorch custom operators using the triton.kernel
Python API * The best practices for validating and benchmarking custom ops against native reference implementations
Writing the Compute Kernel¶
Each compute kernel is declared using the __global__
attribute, and executed many times in parallel on different chunks of data (See the Single Program, Multiple Data programming model for more details).
__global__ void add(float* z, float* x, float* y, int N){
// The `get_program_id(i)` returns the i-th coordinate
// of the program in the overaching SPMD context
// (a.k.a launch grid). This is what allows us to process
// different chunks of data in parallel.
// For those similar with CUDA, `get_program_id({0,1,2})`
// is similar to blockIdx.{x,y,z}
int pid = get_program_id(0);
// In Triton, arrays are first-class citizen. In other words,
// they are primitives data-types and are -- contrary to C and
// CUDA -- not implemented as pointers to contiguous chunks of
// memory.
// In the few lines below, we create an array of `BLOCK` pointers
// whose memory values are, e.g.:
// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float* pz [BLOCK] = z + offset;
float* px [BLOCK] = x + offset;
float* py [BLOCK] = y + offset;
// Simple element-wise control-flow for load/store operations can
// be achieved using the the ternary operator `cond ? val_true : val_false`
// or the conditional dereferencing operator `*?(cond)ptr
// Here, we make sure that we do not access memory out-of-bounds when we
// write-back `z`
bool check[BLOCK] = offset < N;
*?(check)pz = *?(check)px + *?(check)py;
}
The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the MAPL’2019 Triton paper.
Writing the Torch bindings¶
The only thing that matters when it comes to Triton and Torch is the triton.kernel
class. This allows you to transform the above C-like function into a callable python object that can be used to modify torch.tensor
objects.
To create a triton.kernel
, you only need three things: * source: string
: the source-code of the kernel you want to create * device: torch.device
: the device you want to compile this code for * defines: dict
: the set of macros that you want the pre-processor to #define
for you
Note: The constructor of triton.kernel
does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see _kernels
variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator’s inputs.
[10]:
import torch
import triton
# source-code for Triton compute kernel
# here we just copy-paste the above code without the extensive comments.
# you may prefer to store it in a .c file and load it from there instead.
_src = """
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float* pz[BLOCK] = z + offset;
float* px[BLOCK] = x + offset;
float* py[BLOCK] = y + offset;
// bounds checking
bool check[BLOCK] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
"""
# This function returns a callable `triton.kernel` object
# created from the above source code.
# For portability, we maintain a cache of kernels for different `torch.device`
# We compile the kernel with -DBLOCK=1024
_kernels = dict()
def make_add_kernel(device):
if device not in _kernels:
defines = {'BLOCK': 1024}
_kernels[device] = triton.kernel(_src, device=device, defines=defines)
return _kernels[device]
# This is a standard torch custom autograd Function
# The only difference is that we can now use the above kernel
# in the `forward` and `backward` functions.`
class _add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
# constraints of the op
assert x.dtype == torch.float32
# *allocate output*
z = torch.empty_like(x)
# *create launch grid*:
# this is a function which takes compilation parameters `opt`
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
# triton.cdiv is a shortcut for ceil division:
# triton.cdiv(a, b) = (a + b - 1) // b
N = z.shape[0]
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
# *launch kernel*:
# pointer to the data of torch tensors can be retrieved with
# the `.data_ptr()` method
kernel = make_add_kernel(z.device)
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)
return z
# Just like we standard PyTorch ops
# We use the `.apply` method to create a
# callable object for our function
add = _add.apply
At this point add(x, y)
is equivalent to x + y
for contiguous tensors. Now let’s test and benchmark it!
Writing a Unit Test¶
[9]:
torch.manual_seed(0)
x = torch.rand(98432, device='cuda')
y = torch.rand(98432, device='cuda')
za = x + y
zb = add(x, y)
print(za)
print(zb)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(za - zb))}')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
The maximum difference between torch and triton is 0.0
Seems to work!
Writing a Benchmark¶
The performance of our GPU code can be benchmark using the torch.cuda.Event(enable_timing=True)
wrapper. Below is a simple function that benchmarks rep
runs of our kernels after warmup
“cold” runs.
[11]:
# We now want to benchmark the performance of `add`
# Against that of PyTorch for increasing vector sizes
def do_bench(fn, warmup = 10, rep = 50):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
torch.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
torch.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does
[15]:
for N in [2**i for i in range(17, 26, 1)]:
x = torch.rand(N, device='cuda')
y = torch.rand(N, device='cuda')
triton_ms = do_bench(lambda: add(x, y))
torch_ms = do_bench(lambda: x + y)
# print the performance of triton and torch as well as the achieved bandwidth
print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')
131072 0.020 0.003
262144 0.019 0.004
524288 0.016 0.016
1048576 0.033 0.033
2097152 0.071 0.070
4194304 0.142 0.144
8388608 0.287 0.286
16777216 0.572 0.568
33554432 1.139 1.110
Our op is on-par with Torch’s vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce.