{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Vector Addition\nIn this tutorial, you will write a simple vector addition using Triton and learn about:\n\n- The basic programming model used by Triton\n- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.\n- The best practices for validating and benchmarking custom ops against native reference implementations\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute Kernel\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport triton\n\n\n@triton.jit\ndef _add(\n X, # *Pointer* to first input vector\n Y, # *Pointer* to second input vector\n Z, # *Pointer* to output vector\n N, # Size of the vector\n **meta # Optional meta-parameters for the kernel\n):\n pid = triton.program_id(0)\n # Create an offset for the blocks of pointers to be\n # processed by this program instance\n offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])\n # Create a mask to guard memory operations against\n # out-of-bounds accesses\n mask = offsets < N\n # Load x\n x = triton.load(X + offsets, mask=mask)\n y = triton.load(Y + offsets, mask=mask)\n # Write back x + y\n z = x + y\n triton.store(Z + offsets, z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also declara a helper function that handles allocating the output vector\nand enqueueing the kernel.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def add(x, y):\n z = torch.empty_like(x)\n N = z.shape[0]\n # The SPMD launch grid denotes the number of kernel instances that should execute in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )\n # NOTE:\n # - torch.tensor objects are implicitly converted to pointers to their first element.\n # - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n _add[grid](x, y, z, N, BLOCK=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously.\n return z" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Seems like we're good to go!\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmark\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.\nTo make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.\nfor different problem sizes.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['size'], # argument names to use as an x-axis for the plot\n x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`\n x_log=True, # x axis is logarithmic\n y_name='provider', # argument name whose value corresponds to a different line in the plot\n y_vals=['torch', 'triton'], # possible keys for `y_name`\n y_lines=[\"Torch\", \"Triton\"], # label name for the lines\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"vector-add-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={} # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(size, provider):\n x = torch.rand(size, device='cuda', dtype=torch.float32)\n y = torch.rand(size, device='cuda', dtype=torch.float32)\n if provider == 'torch':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))\n gbps = lambda ms: 12 * size / ms * 1e-6\n return gbps(ms), gbps(max_ms), gbps(min_ms)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or\n`save_path='/path/to/results/' to save them to disk along with raw CSV data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "benchmark.run(show_plots=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 0 }