"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
"\n",
"\n",
"```c\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // The `get_program_id(i)` returns the i-th coordinate\n",
" // of the program in the overaching SPMD context\n",
" // (a.k.a launch grid). This is what allows us to process\n",
" // different chunks of data in parallel.\n",
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
" // is similar to blockIdx.{x,y,z}\n",
" int pid = get_program_id(0);\n",
" // In Triton, arrays are first-class citizen. In other words,\n",
" // they are primitives data-types and are -- contrary to C and\n",
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
" // memory.\n",
" // In the few lines below, we create an array of `BLOCK` pointers\n",
" // whose memory values are, e.g.:\n",
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
" // Simple element-wise control-flow for load/store operations can\n",
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
" // or the conditional dereferencing operator `*?(cond)ptr\n",
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
" // write-back `z`\n",
" bool check[BLOCK] = offset < N;\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
"```\n",
"\n",
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
"\n",
"To create a `triton.kernel`, you only need three things:\n",
"* `source: string`: the source-code of the kernel you want to create\n",
"* `device: torch.device`: the device you want to compile this code for\n",
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
"\n",
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "sporting-keyboard",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import triton\n",
"\n",
"# source-code for Triton compute kernel\n",
"# here we just copy-paste the above code without the extensive comments.\n",
"# you may prefer to store it in a .c file and load it from there instead.\n",
"_src = \"\"\"\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "strange-luxembourg",
"metadata": {},
"outputs": [],
"source": [
"# We now want to benchmark the performance of `add`\n",
"# Against that of PyTorch for increasing vector sizes\n",
" # print the performance of triton and torch as well as the achieved bandwidth\n",
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
]
},
{
"cell_type": "markdown",
"id": "juvenile-supplement",
"metadata": {},
"source": [
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."