[PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
335
python/tutorials/01-vector-add.ipynb
Normal file
335
python/tutorials/01-vector-add.ipynb
Normal file
@@ -0,0 +1,335 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "induced-zoning",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Getting Started"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "median-malaysia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
|
||||
"* The basic syntax of the Triton programming language\n",
|
||||
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
|
||||
"* The best practices for validating and benchmarking custom ops against native reference implementations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "identical-conditions",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing the Compute Kernel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "collectible-belle",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```c\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // The `get_program_id(i)` returns the i-th coordinate\n",
|
||||
" // of the program in the overaching SPMD context\n",
|
||||
" // (a.k.a launch grid). This is what allows us to process\n",
|
||||
" // different chunks of data in parallel.\n",
|
||||
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
|
||||
" // is similar to blockIdx.{x,y,z}\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // In Triton, arrays are first-class citizen. In other words,\n",
|
||||
" // they are primitives data-types and are -- contrary to C and\n",
|
||||
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
|
||||
" // memory.\n",
|
||||
" // In the few lines below, we create an array of `BLOCK` pointers\n",
|
||||
" // whose memory values are, e.g.:\n",
|
||||
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
|
||||
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz [BLOCK] = z + offset;\n",
|
||||
" float* px [BLOCK] = x + offset;\n",
|
||||
" float* py [BLOCK] = y + offset;\n",
|
||||
" // Simple element-wise control-flow for load/store operations can\n",
|
||||
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
|
||||
" // or the conditional dereferencing operator `*?(cond)ptr\n",
|
||||
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
|
||||
" // write-back `z`\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "forbidden-wednesday",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing the Torch bindings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "numerical-agency",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
|
||||
"\n",
|
||||
"To create a `triton.kernel`, you only need three things:\n",
|
||||
"* `source: string`: the source-code of the kernel you want to create\n",
|
||||
"* `device: torch.device`: the device you want to compile this code for\n",
|
||||
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
|
||||
"\n",
|
||||
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "sporting-keyboard",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import triton\n",
|
||||
"\n",
|
||||
"# source-code for Triton compute kernel\n",
|
||||
"# here we just copy-paste the above code without the extensive comments.\n",
|
||||
"# you may prefer to store it in a .c file and load it from there instead.\n",
|
||||
"_src = \"\"\"\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // program id\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // create arrays of pointers\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz[BLOCK] = z + offset;\n",
|
||||
" float* px[BLOCK] = x + offset;\n",
|
||||
" float* py[BLOCK] = y + offset;\n",
|
||||
" // bounds checking\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" // write-back\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
" \"\"\"\n",
|
||||
"# This function returns a callable `triton.kernel` object\n",
|
||||
"# created from the above source code.\n",
|
||||
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
|
||||
"# We compile the kernel with -DBLOCK=1024\n",
|
||||
"_kernels = dict()\n",
|
||||
"def make_add_kernel(device):\n",
|
||||
" if device not in _kernels:\n",
|
||||
" defines = {'BLOCK': 1024}\n",
|
||||
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
|
||||
" return _kernels[device]\n",
|
||||
"\n",
|
||||
"# This is a standard torch custom autograd Function\n",
|
||||
"# The only difference is that we can now use the above kernel\n",
|
||||
"# in the `forward` and `backward` functions.`\n",
|
||||
"class _add(torch.autograd.Function):\n",
|
||||
" \n",
|
||||
" @staticmethod\n",
|
||||
" def forward(ctx, x, y):\n",
|
||||
" # constraints of the op\n",
|
||||
" assert x.dtype == torch.float32\n",
|
||||
" # *allocate output*\n",
|
||||
" z = torch.empty_like(x)\n",
|
||||
" # *create launch grid*:\n",
|
||||
" # this is a function which takes compilation parameters `opt`\n",
|
||||
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
|
||||
" # triton.cdiv is a shortcut for ceil division:\n",
|
||||
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
|
||||
" N = z.shape[0]\n",
|
||||
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
|
||||
" # *launch kernel*:\n",
|
||||
" # pointer to the data of torch tensors can be retrieved with\n",
|
||||
" # the `.data_ptr()` method\n",
|
||||
" kernel = make_add_kernel(z.device)\n",
|
||||
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
|
||||
" return z\n",
|
||||
"# Just like we standard PyTorch ops\n",
|
||||
"# We use the `.apply` method to create a \n",
|
||||
"# callable object for our function\n",
|
||||
"add = _add.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "separated-polyester",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exclusive-salvation",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing a Unit Test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "supported-ribbon",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"The maximum difference between torch and triton is 0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\n",
|
||||
"x = torch.rand(98432, device='cuda')\n",
|
||||
"y = torch.rand(98432, device='cuda')\n",
|
||||
"za = x + y\n",
|
||||
"zb = add(x, y)\n",
|
||||
"print(za)\n",
|
||||
"print(zb)\n",
|
||||
"print(f'The maximum difference between torch and triton is '\n",
|
||||
" f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "otherwise-canadian",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems to work!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "polished-australia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing a Benchmark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "historic-glass",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "strange-luxembourg",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We now want to benchmark the performance of `add`\n",
|
||||
"# Against that of PyTorch for increasing vector sizes\n",
|
||||
"def do_bench(fn, warmup = 10, rep = 50):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" ret = fn()\n",
|
||||
" for i in range(warmup):\n",
|
||||
" fn()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" start_event.record()\n",
|
||||
" for i in range(rep):\n",
|
||||
" fn()\n",
|
||||
" end_event.record()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" time_ms = start_event.elapsed_time(end_event) / rep\n",
|
||||
" return time_ms"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "hairy-claim",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "pleasant-valley",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"131072 0.020 0.003\n",
|
||||
"262144 0.019 0.004\n",
|
||||
"524288 0.016 0.016\n",
|
||||
"1048576 0.033 0.033\n",
|
||||
"2097152 0.071 0.070\n",
|
||||
"4194304 0.142 0.144\n",
|
||||
"8388608 0.287 0.286\n",
|
||||
"16777216 0.572 0.568\n",
|
||||
"33554432 1.139 1.110\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for N in [2**i for i in range(17, 26, 1)]:\n",
|
||||
" x = torch.rand(N, device='cuda')\n",
|
||||
" y = torch.rand(N, device='cuda')\n",
|
||||
" triton_ms = do_bench(lambda: add(x, y))\n",
|
||||
" torch_ms = do_bench(lambda: x + y)\n",
|
||||
" # print the performance of triton and torch as well as the achieved bandwidth\n",
|
||||
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "juvenile-supplement",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "agreed-backing",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Reference in New Issue
Block a user