{ "cells": [ { "cell_type": "markdown", "id": "induced-zoning", "metadata": {}, "source": [ "# Getting Started" ] }, { "cell_type": "markdown", "id": "median-malaysia", "metadata": {}, "source": [ "In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n", "* The basic syntax of the Triton programming language\n", "* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n", "* The best practices for validating and benchmarking custom ops against native reference implementations" ] }, { "cell_type": "markdown", "id": "identical-conditions", "metadata": {}, "source": [ "# Writing the Compute Kernel" ] }, { "cell_type": "markdown", "id": "collectible-belle", "metadata": {}, "source": [ "Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n", "\n", "\n", "```c\n", "__global__ void add(float* z, float* x, float* y, int N){\n", " // The `get_program_id(i)` returns the i-th coordinate\n", " // of the program in the overaching SPMD context\n", " // (a.k.a launch grid). This is what allows us to process\n", " // different chunks of data in parallel.\n", " // For those similar with CUDA, `get_program_id({0,1,2})`\n", " // is similar to blockIdx.{x,y,z}\n", " int pid = get_program_id(0);\n", " // In Triton, arrays are first-class citizen. In other words,\n", " // they are primitives data-types and are -- contrary to C and\n", " // CUDA -- not implemented as pointers to contiguous chunks of\n", " // memory.\n", " // In the few lines below, we create an array of `BLOCK` pointers\n", " // whose memory values are, e.g.:\n", " // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n", " // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n", " int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n", " float* pz [BLOCK] = z + offset;\n", " float* px [BLOCK] = x + offset;\n", " float* py [BLOCK] = y + offset;\n", " // Simple element-wise control-flow for load/store operations can\n", " // be achieved using the the ternary operator `cond ? val_true : val_false`\n", " // or the conditional dereferencing operator `*?(cond)ptr\n", " // Here, we make sure that we do not access memory out-of-bounds when we\n", " // write-back `z`\n", " bool check[BLOCK] = offset < N;\n", " *?(check)pz = *?(check)px + *?(check)py;\n", "}\n", "```\n", "\n", "The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)." ] }, { "cell_type": "markdown", "id": "forbidden-wednesday", "metadata": {}, "source": [ "# Writing the Torch bindings" ] }, { "cell_type": "markdown", "id": "numerical-agency", "metadata": {}, "source": [ "The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n", "\n", "To create a `triton.kernel`, you only need three things:\n", "* `source: string`: the source-code of the kernel you want to create\n", "* `device: torch.device`: the device you want to compile this code for\n", "* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n", "\n", "Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs." ] }, { "cell_type": "code", "execution_count": 10, "id": "sporting-keyboard", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import triton\n", "\n", "# source-code for Triton compute kernel\n", "# here we just copy-paste the above code without the extensive comments.\n", "# you may prefer to store it in a .c file and load it from there instead.\n", "_src = \"\"\"\n", "__global__ void add(float* z, float* x, float* y, int N){\n", " // program id\n", " int pid = get_program_id(0);\n", " // create arrays of pointers\n", " int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n", " float* pz[BLOCK] = z + offset;\n", " float* px[BLOCK] = x + offset;\n", " float* py[BLOCK] = y + offset;\n", " // bounds checking\n", " bool check[BLOCK] = offset < N;\n", " // write-back\n", " *?(check)pz = *?(check)px + *?(check)py;\n", "}\n", " \"\"\"\n", "# This function returns a callable `triton.kernel` object\n", "# created from the above source code.\n", "# For portability, we maintain a cache of kernels for different `torch.device`\n", "# We compile the kernel with -DBLOCK=1024\n", "_kernels = dict()\n", "def make_add_kernel(device):\n", " if device not in _kernels:\n", " defines = {'BLOCK': 1024}\n", " _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n", " return _kernels[device]\n", "\n", "# This is a standard torch custom autograd Function\n", "# The only difference is that we can now use the above kernel\n", "# in the `forward` and `backward` functions.`\n", "class _add(torch.autograd.Function):\n", " \n", " @staticmethod\n", " def forward(ctx, x, y):\n", " # constraints of the op\n", " assert x.dtype == torch.float32\n", " # *allocate output*\n", " z = torch.empty_like(x)\n", " # *create launch grid*:\n", " # this is a function which takes compilation parameters `opt`\n", " # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n", " # triton.cdiv is a shortcut for ceil division:\n", " # triton.cdiv(a, b) = (a + b - 1) // b\n", " N = z.shape[0]\n", " grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n", " # *launch kernel*:\n", " # pointer to the data of torch tensors can be retrieved with\n", " # the `.data_ptr()` method\n", " kernel = make_add_kernel(z.device)\n", " kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n", " return z\n", "# Just like we standard PyTorch ops\n", "# We use the `.apply` method to create a \n", "# callable object for our function\n", "add = _add.apply" ] }, { "cell_type": "markdown", "id": "separated-polyester", "metadata": {}, "source": [ "At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!" ] }, { "cell_type": "markdown", "id": "exclusive-salvation", "metadata": {}, "source": [ "# Writing a Unit Test" ] }, { "cell_type": "code", "execution_count": 9, "id": "supported-ribbon", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n", "tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n", "The maximum difference between torch and triton is 0.0\n" ] } ], "source": [ "torch.manual_seed(0)\n", "x = torch.rand(98432, device='cuda')\n", "y = torch.rand(98432, device='cuda')\n", "za = x + y\n", "zb = add(x, y)\n", "print(za)\n", "print(zb)\n", "print(f'The maximum difference between torch and triton is '\n", " f'{torch.max(torch.abs(za - zb))}')" ] }, { "cell_type": "markdown", "id": "otherwise-canadian", "metadata": {}, "source": [ "Seems to work!" ] }, { "cell_type": "markdown", "id": "polished-australia", "metadata": {}, "source": [ "# Writing a Benchmark" ] }, { "cell_type": "markdown", "id": "historic-glass", "metadata": {}, "source": [ "The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs." ] }, { "cell_type": "code", "execution_count": 11, "id": "strange-luxembourg", "metadata": {}, "outputs": [], "source": [ "# We now want to benchmark the performance of `add`\n", "# Against that of PyTorch for increasing vector sizes\n", "def do_bench(fn, warmup = 10, rep = 50):\n", " start_event = torch.cuda.Event(enable_timing=True)\n", " end_event = torch.cuda.Event(enable_timing=True)\n", " ret = fn()\n", " for i in range(warmup):\n", " fn()\n", " torch.cuda.synchronize()\n", " start_event.record()\n", " for i in range(rep):\n", " fn()\n", " end_event.record()\n", " torch.cuda.synchronize()\n", " time_ms = start_event.elapsed_time(end_event) / rep\n", " return time_ms" ] }, { "cell_type": "markdown", "id": "hairy-claim", "metadata": {}, "source": [ "We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does" ] }, { "cell_type": "code", "execution_count": 15, "id": "pleasant-valley", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "131072 0.020 0.003\n", "262144 0.019 0.004\n", "524288 0.016 0.016\n", "1048576 0.033 0.033\n", "2097152 0.071 0.070\n", "4194304 0.142 0.144\n", "8388608 0.287 0.286\n", "16777216 0.572 0.568\n", "33554432 1.139 1.110\n" ] } ], "source": [ "for N in [2**i for i in range(17, 26, 1)]:\n", " x = torch.rand(N, device='cuda')\n", " y = torch.rand(N, device='cuda')\n", " triton_ms = do_bench(lambda: add(x, y))\n", " torch_ms = do_bench(lambda: x + y)\n", " # print the performance of triton and torch as well as the achieved bandwidth\n", " print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')" ] }, { "cell_type": "markdown", "id": "juvenile-supplement", "metadata": {}, "source": [ "Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce." ] }, { "cell_type": "code", "execution_count": null, "id": "agreed-backing", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 5 }