[DOCS] Improve tutorial readability (#185)

This commit is contained in:
Nicholas Joseph
2021-08-05 12:27:06 -04:00
committed by GitHub
parent 3cb77aa126
commit 23c71538fc

View File

@@ -13,31 +13,37 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# -------------------------- # --------------------------
import torch import torch
import triton.language as tl
import triton import triton
import triton.language as tl
@triton.jit @triton.jit
def _add( def add_kernel(
X, # *Pointer* to first input vector x_ptr, # *Pointer* to first input vector
Y, # *Pointer* to second input vector y_ptr, # *Pointer* to second input vector
Z, # *Pointer* to output vector output_ptr, # *Pointer* to output vector
N, # Size of the vector n_elements, # Size of the vector
**meta # Optional meta-parameters for the kernel **meta, # Optional meta-parameters for the kernel
): ):
pid = tl.program_id(0) BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# Create an offset for the blocks of pointers to be # There are multiple 'program's processing different data. We identify which program
# processed by this program instance # we are here
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# Create a mask to guard memory operations against # This program will process inputs that are offset from the initial data.
# out-of-bounds accesses # for instance, if you had a vector of length 256 and block_size of 64, the programs
mask = offsets < N # would each access the elements [0:64, 64:128, 128:192, 192:256].
# Load x # Note that offsets is a list of pointers
x = tl.load(X + offsets, mask=mask) block_start = pid * BLOCK_SIZE
y = tl.load(Y + offsets, mask=mask) offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Write back x + y # Create a mask to guard memory operations against out-of-bounds accesses
z = x + y mask = offsets < n_elements
tl.store(Z + offsets, z) # Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output)
# %% # %%
@@ -45,20 +51,23 @@ def _add(
# and (2) enqueue the above kernel with appropriate grid/block sizes. # and (2) enqueue the above kernel with appropriate grid/block sizes.
def add(x, y): def add(x: torch.Tensor, y: torch.Tensor):
z = torch.empty_like(x) # We need to preallocate the output
N = z.shape[0] output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.shape[0]
# The SPMD launch grid denotes the number of kernel instances that run in parallel. # The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int] # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), ) # In this case, we use a 1D grid where the size is the number of blocks
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE: # NOTE:
# - each torch.tensor object is implicitly converted into a pointer to its first element. # - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments # - don't forget to pass meta-parameters as keywords arguments
_add[grid](x, y, z, N, BLOCK=1024) add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point. # running asynchronously at this point.
return z return output
# %% # %%
@@ -68,11 +77,14 @@ torch.manual_seed(0)
size = 98432 size = 98432
x = torch.rand(size, device='cuda') x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda') y = torch.rand(size, device='cuda')
za = x + y output_torch = x + y
zb = add(x, y) output_triton = add(x, y)
print(za) print(output_torch)
print(zb) print(output_triton)
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}') print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
# %% # %%
# Seems like we're good to go! # Seems like we're good to go!
@@ -88,15 +100,17 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['size'], # argument names to use as an x-axis for the plot x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name` x_vals=[
2 ** i for i in range(12, 28, 1)
], # different possible values for `x_name`
x_log=True, # x axis is logarithmic x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg` line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=["Triton", "Torch"], # label name for the lines line_names=['Triton', 'Torch'], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis ylabel='GB/s', # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot. plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name` args={}, # values for function arguments not in `x_names` and `y_name`
) )
) )
def benchmark(size, provider): def benchmark(size, provider):