2021-03-06 14:03:01 -05:00
|
|
|
"""
|
|
|
|
Vector Addition
|
|
|
|
=================
|
2021-03-06 22:04:00 -05:00
|
|
|
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
2021-03-06 17:26:49 -05:00
|
|
|
|
2021-07-22 22:45:19 -07:00
|
|
|
- The basic programming model of Triton
|
|
|
|
- The `triton.jit` decorator, which is used to define Triton kernels.
|
|
|
|
- The best practices for validating and benchmarking your custom ops against native reference implementations
|
2021-03-06 14:03:01 -05:00
|
|
|
"""
|
|
|
|
|
|
|
|
# %%
|
2021-03-06 17:26:49 -05:00
|
|
|
# Compute Kernel
|
2021-03-06 14:03:01 -05:00
|
|
|
# --------------------------
|
|
|
|
|
2021-04-22 10:27:02 -04:00
|
|
|
import torch
|
2021-04-23 17:18:14 -04:00
|
|
|
import triton.language as tl
|
2021-04-22 10:27:02 -04:00
|
|
|
import triton
|
|
|
|
|
2021-03-06 14:03:01 -05:00
|
|
|
|
2021-04-20 22:29:40 -04:00
|
|
|
@triton.jit
|
|
|
|
def _add(
|
|
|
|
X, # *Pointer* to first input vector
|
|
|
|
Y, # *Pointer* to second input vector
|
|
|
|
Z, # *Pointer* to output vector
|
|
|
|
N, # Size of the vector
|
|
|
|
**meta # Optional meta-parameters for the kernel
|
|
|
|
):
|
2021-04-23 17:18:14 -04:00
|
|
|
pid = tl.program_id(0)
|
2021-04-20 22:29:40 -04:00
|
|
|
# Create an offset for the blocks of pointers to be
|
|
|
|
# processed by this program instance
|
2021-04-23 17:18:14 -04:00
|
|
|
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
2021-04-20 22:29:40 -04:00
|
|
|
# Create a mask to guard memory operations against
|
|
|
|
# out-of-bounds accesses
|
|
|
|
mask = offsets < N
|
|
|
|
# Load x
|
2021-04-23 17:18:14 -04:00
|
|
|
x = tl.load(X + offsets, mask=mask)
|
|
|
|
y = tl.load(Y + offsets, mask=mask)
|
2021-04-20 22:29:40 -04:00
|
|
|
# Write back x + y
|
|
|
|
z = x + y
|
2021-04-23 17:18:14 -04:00
|
|
|
tl.store(Z + offsets, z)
|
2021-03-06 17:26:49 -05:00
|
|
|
|
2021-03-06 14:03:01 -05:00
|
|
|
|
2021-03-06 22:04:00 -05:00
|
|
|
# %%
|
2021-07-22 22:45:19 -07:00
|
|
|
# Let's also declare a helper function that to (1) allocate the output vector
|
|
|
|
# and (2) enqueueing the above kernel.
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
|
|
|
|
def add(x, y):
|
|
|
|
z = torch.empty_like(x)
|
|
|
|
N = z.shape[0]
|
2021-07-22 22:45:19 -07:00
|
|
|
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
2021-04-20 22:29:40 -04:00
|
|
|
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
|
|
|
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
|
|
|
# NOTE:
|
2021-07-22 22:45:19 -07:00
|
|
|
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
|
|
|
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
2021-04-20 22:29:40 -04:00
|
|
|
# - don't forget to pass meta-parameters as keywords arguments
|
|
|
|
_add[grid](x, y, z, N, BLOCK=1024)
|
|
|
|
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
2021-07-22 22:45:19 -07:00
|
|
|
# running asynchronously at this point.
|
2021-04-20 22:29:40 -04:00
|
|
|
return z
|
|
|
|
|
2021-03-06 22:04:00 -05:00
|
|
|
|
2021-03-06 14:03:01 -05:00
|
|
|
# %%
|
2021-07-22 22:45:19 -07:00
|
|
|
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
|
2021-03-06 22:04:00 -05:00
|
|
|
|
2021-03-06 14:03:01 -05:00
|
|
|
torch.manual_seed(0)
|
2021-04-20 22:29:40 -04:00
|
|
|
size = 98432
|
|
|
|
x = torch.rand(size, device='cuda')
|
|
|
|
y = torch.rand(size, device='cuda')
|
2021-03-06 14:03:01 -05:00
|
|
|
za = x + y
|
|
|
|
zb = add(x, y)
|
|
|
|
print(za)
|
|
|
|
print(zb)
|
|
|
|
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
|
|
|
|
2021-03-06 22:04:00 -05:00
|
|
|
# %%
|
|
|
|
# Seems like we're good to go!
|
|
|
|
|
2021-03-06 14:03:01 -05:00
|
|
|
# %%
|
2021-03-14 18:49:59 -04:00
|
|
|
# Benchmark
|
|
|
|
# -----------
|
2021-03-06 22:04:00 -05:00
|
|
|
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
2021-07-22 22:45:19 -07:00
|
|
|
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
2021-03-11 00:29:16 -05:00
|
|
|
# for different problem sizes.
|
|
|
|
|
|
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
|
|
triton.testing.Benchmark(
|
|
|
|
x_names=['size'], # argument names to use as an x-axis for the plot
|
|
|
|
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
|
|
|
x_log=True, # x axis is logarithmic
|
2021-04-23 17:18:14 -04:00
|
|
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
2021-07-22 22:45:19 -07:00
|
|
|
line_vals=['triton', 'torch'], # possible values for `line_arg`
|
|
|
|
line_names=["Triton", "Torch"], # label name for the lines
|
|
|
|
styles=[('blue', '-'), ('green', '-')], # line styles
|
2021-03-11 00:29:16 -05:00
|
|
|
ylabel="GB/s", # label name for the y-axis
|
|
|
|
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
|
|
|
args={} # values for function arguments not in `x_names` and `y_name`
|
|
|
|
)
|
|
|
|
)
|
|
|
|
def benchmark(size, provider):
|
|
|
|
x = torch.rand(size, device='cuda', dtype=torch.float32)
|
|
|
|
y = torch.rand(size, device='cuda', dtype=torch.float32)
|
|
|
|
if provider == 'torch':
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
|
|
|
|
if provider == 'triton':
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
|
|
|
|
gbps = lambda ms: 12 * size / ms * 1e-6
|
|
|
|
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
2021-03-06 22:04:00 -05:00
|
|
|
|
|
|
|
|
|
|
|
# %%
|
2021-03-11 00:29:16 -05:00
|
|
|
# We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or
|
|
|
|
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
|
2021-07-22 22:45:19 -07:00
|
|
|
benchmark.run(print_data=True, show_plots=True)
|