From 23c71538fcf3055a30ec5aea40e1b6874c6f0e6e Mon Sep 17 00:00:00 2001 From: Nicholas Joseph Date: Thu, 5 Aug 2021 12:27:06 -0400 Subject: [PATCH] [DOCS] Improve tutorial readability (#185) --- python/tutorials/01-vector-add.py | 88 ++++++++++++++++++------------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index ad6303858..5ac5a1225 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -13,31 +13,37 @@ In this tutorial, you will write a simple vector addition using Triton and learn # -------------------------- import torch -import triton.language as tl import triton +import triton.language as tl @triton.jit -def _add( - X, # *Pointer* to first input vector - Y, # *Pointer* to second input vector - Z, # *Pointer* to output vector - N, # Size of the vector - **meta # Optional meta-parameters for the kernel +def add_kernel( + x_ptr, # *Pointer* to first input vector + y_ptr, # *Pointer* to second input vector + output_ptr, # *Pointer* to output vector + n_elements, # Size of the vector + **meta, # Optional meta-parameters for the kernel ): - pid = tl.program_id(0) - # Create an offset for the blocks of pointers to be - # processed by this program instance - offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) - # Create a mask to guard memory operations against - # out-of-bounds accesses - mask = offsets < N - # Load x - x = tl.load(X + offsets, mask=mask) - y = tl.load(Y + offsets, mask=mask) - # Write back x + y - z = x + y - tl.store(Z + offsets, z) + BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process + # There are multiple 'program's processing different data. We identify which program + # we are here + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 + # This program will process inputs that are offset from the initial data. + # for instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extar elements in case the input is not a + # multiple of the block size + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM + tl.store(output_ptr + offsets, output) # %% @@ -45,20 +51,23 @@ def _add( # and (2) enqueue the above kernel with appropriate grid/block sizes. -def add(x, y): - z = torch.empty_like(x) - N = z.shape[0] +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output + output = torch.empty_like(x) + assert x.is_cuda and y.is_cuda and output.is_cuda + n_elements = output.shape[0] # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int] - grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), ) + # In this case, we use a 1D grid where the size is the number of blocks + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) # NOTE: # - each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - don't forget to pass meta-parameters as keywords arguments - _add[grid](x, y, z, N, BLOCK=1024) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. - return z + return output # %% @@ -68,11 +77,14 @@ torch.manual_seed(0) size = 98432 x = torch.rand(size, device='cuda') y = torch.rand(size, device='cuda') -za = x + y -zb = add(x, y) -print(za) -print(zb) -print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}') +output_torch = x + y +output_triton = add(x, y) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +) # %% # Seems like we're good to go! @@ -88,15 +100,17 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch. @triton.testing.perf_report( triton.testing.Benchmark( x_names=['size'], # argument names to use as an x-axis for the plot - x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name` + x_vals=[ + 2 ** i for i in range(12, 28, 1) + ], # different possible values for `x_name` x_log=True, # x axis is logarithmic line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch'], # possible values for `line_arg` - line_names=["Triton", "Torch"], # label name for the lines + line_names=['Triton', 'Torch'], # label name for the lines styles=[('blue', '-'), ('green', '-')], # line styles - ylabel="GB/s", # label name for the y-axis - plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot. - args={} # values for function arguments not in `x_names` and `y_name` + ylabel='GB/s', # label name for the y-axis + plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot. + args={}, # values for function arguments not in `x_names` and `y_name` ) ) def benchmark(size, provider): @@ -113,4 +127,4 @@ def benchmark(size, provider): # %% # We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or # `save_path='/path/to/results/' to save them to disk along with raw CSV data -benchmark.run(print_data=True, show_plots=True) \ No newline at end of file +benchmark.run(print_data=True, show_plots=True)