Files
triton/docs/tutorials/01-vector-add.ipynb
2021-07-27 12:38:49 -07:00

158 lines
8.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Vector Addition\nIn this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n* The basic syntax of the Triton programming language\n* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n* The best practices for validating and benchmarking custom ops against native reference implementations\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Writing the Compute Kernel\n\nEach compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel\non different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)\nprogramming model for more details).\n\n .. code-block:: C\n\n __global__ void add(float* z, float* x, float* y, int N){\n // The `get_program_id(i)` returns the i-th coordinate\n // of the program in the overaching SPMD context\n // (a.k.a launch grid). This is what allows us to process\n // different chunks of data in parallel.\n // For those similar with CUDA, `get_program_id({0,1,2})`\n // is similar to blockIdx.{x,y,z}\n int pid = get_program_id(0);\n // In Triton, arrays are first-class citizen. In other words,\n // they are primitives data-types and are -- contrary to C and\n // CUDA -- not implemented as pointers to contiguous chunks of\n // memory.\n // In the few lines below, we create an array of `BLOCK` pointers\n // whose memory values are, e.g.:\n // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz [BLOCK] = z + offset;\n float* px [BLOCK] = x + offset;\n float* py [BLOCK] = y + offset;\n // Simple element-wise control-flow for load/store operations can\n // be achieved using the the ternary operator `cond ? val_true : val_false`\n // or the conditional dereferencing operator `*?(cond)ptr\n // Here, we make sure that we do not access memory out-of-bounds when we\n // write-back `z`\n bool check[BLOCK] = offset < N;\n *?(check)pz = *?(check)px + *?(check)py;\n }\n\nThe existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Writing the Torch bindings\nThe only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n\nTo create a `triton.kernel`, you only need three things:\n- `source: string`: the source-code of the kernel you want to create\n- `device: torch.device`: the device you want to compile this code for\n- `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\nimport triton"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"source-code for Triton compute kernel\nhere we just copy-paste the above code without the extensive comments.\nyou may prefer to store it in a .c file and load it from there instead.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"_src = \"\"\"\n__global__ void add(float* z, float* x, float* y, int N){\n // program id\n int pid = get_program_id(0);\n // create arrays of pointers\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz[BLOCK] = z + offset;\n float* px[BLOCK] = x + offset;\n float* py[BLOCK] = y + offset;\n // bounds checking\n bool check[BLOCK] = offset < N;\n // write-back\n *?(check)pz = *?(check)px + *?(check)py;\n}\n \"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function returns a callable `triton.kernel` object\ncreated from the above source code.\nFor portability, we maintain a cache of kernels for different `torch.device`\nWe compile the kernel with -DBLOCK=1024\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def make_add_kernel(device):\n cache = make_add_kernel.cache\n if device not in cache:\n defines = {'BLOCK': 1024}\n cache[device] = triton.kernel(_src, device=device, defines=defines)\n return cache[device]\n\n\nmake_add_kernel.cache = dict()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a standard torch custom autograd Function\nThe only difference is that we can now use the above kernel\nin the `forward` and `backward` functions.`\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class _add(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, y):\n # constraints of the op\n assert x.dtype == torch.float32\n # *allocate output*\n z = torch.empty_like(x)\n # *create launch grid*:\n # this is a function which takes compilation parameters `opt`\n # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n # triton.cdiv is a shortcut for ceil division:\n # triton.cdiv(a, b) = (a + b - 1) // b\n N = z.shape[0]\n grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n # *launch kernel*:\n # pointer to the data of torch tensors can be retrieved with\n # the `.data_ptr()` method\n kernel = make_add_kernel(z.device)\n kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)\n return z\n\n\n# Just like we standard PyTorch ops We use the `.apply` method to create a callable object for our function\nadd = _add.apply"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Writing a Unit Test\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"torch.manual_seed(0)\nx = torch.rand(98432, device='cuda')\ny = torch.rand(98432, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Writing a Benchmark\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"warmup = 10\nrep = 200\nfor N in [2**i for i in range(17, 26, 1)]:\n x = torch.rand(N, device='cuda')\n y = torch.rand(N, device='cuda')\n triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep)\n torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep)\n # print the performance of triton and torch as well as the achieved bandwidth\n print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}