[DOCS] Improve tutorial readability (#185)
This commit is contained in:
@@ -13,31 +13,37 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
# --------------------------
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(
|
||||
X, # *Pointer* to first input vector
|
||||
Y, # *Pointer* to second input vector
|
||||
Z, # *Pointer* to output vector
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
**meta, # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = tl.load(X + offsets, mask=mask)
|
||||
y = tl.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
tl.store(Z + offsets, z)
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extar elements in case the input is not a
|
||||
# multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output)
|
||||
|
||||
|
||||
# %%
|
||||
@@ -45,20 +51,23 @@ def _add(
|
||||
# and (2) enqueue the above kernel with appropriate grid/block sizes.
|
||||
|
||||
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# We need to preallocate the output
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda and y.is_cuda and output.is_cuda
|
||||
n_elements = output.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# In this case, we use a 1D grid where the size is the number of blocks
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# NOTE:
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
return z
|
||||
return output
|
||||
|
||||
|
||||
# %%
|
||||
@@ -68,11 +77,14 @@ torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
print(zb)
|
||||
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||
output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
# %%
|
||||
# Seems like we're good to go!
|
||||
@@ -88,15 +100,17 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
||||
x_vals=[
|
||||
2 ** i for i in range(12, 28, 1)
|
||||
], # different possible values for `x_name`
|
||||
x_log=True, # x axis is logarithmic
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['triton', 'torch'], # possible values for `line_arg`
|
||||
line_names=["Triton", "Torch"], # label name for the lines
|
||||
line_names=['Triton', 'Torch'], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={} # values for function arguments not in `x_names` and `y_name`
|
||||
ylabel='GB/s', # label name for the y-axis
|
||||
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(size, provider):
|
||||
@@ -113,4 +127,4 @@ def benchmark(size, provider):
|
||||
# %%
|
||||
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
|
||||
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
|
Reference in New Issue
Block a user