111 lines
4.9 KiB
ReStructuredText
111 lines
4.9 KiB
ReStructuredText
***************
|
|
Triton vs. CUDA
|
|
***************
|
|
|
|
|
|
|
|
The purpose of this tutorial is to explore in more depth the major differences between Triton and CUDA. To keep things simple, we will still be focusing on the following vector addition code:
|
|
|
|
.. code-block:: C
|
|
|
|
// Triton
|
|
// launch on a grid of (N + TILE - 1) / TILE programs
|
|
__global__ void add(float* z, float* x, float* y, int N){
|
|
int offset[TILE] = get_program_id(0) * TILE + 0 ... TILE;
|
|
bool check[TILE] = offset < N;
|
|
float* pz[TILE] = z + offset;
|
|
float* px[TILE] = x + offset;
|
|
float* py[TILE] = y + offset;
|
|
*?(check)pz = *?(check)*px + *?(check)py;
|
|
}
|
|
|
|
And its CUDA equivalent:
|
|
|
|
|
|
.. code-block:: C
|
|
|
|
// CUDA
|
|
// launch on a grid of (N + TILE - 1) / TILE programs
|
|
__global__ void add(float *z, float *x, float *y, int N) {
|
|
int off = blockIdx.x * TILE + threadIdx.x;
|
|
if(off < N){
|
|
float *pz = z + off;
|
|
float *px = x + off;
|
|
float *py = y + off;
|
|
*pz = *px + *py
|
|
}
|
|
}
|
|
|
|
|
|
==========================
|
|
Automatic parallelization
|
|
==========================
|
|
|
|
While the two above pieces of code may look at first sight similar, a closer look reveals one *fundamental* difference: While CUDA kernels are launched on a cooperative array of threads, **Triton kernel are single-threaded and automatically parallelized**.
|
|
|
|
This is a major difference in programming model, which not only makes your life much easier as a programmer, but also allows the Triton compiler to automatically do all sorts of nice optimizations:
|
|
|
|
- *Automatic shared memory allocation and synchronization*
|
|
|
|
That's right; programmers don't need to worry about shared memory allocation, usage and synchronization. Instead, the Triton compiler will use complex program analysis techniques to determine when shared memory should be used, where it should be synchronized and how threads should access it to avoid memory bank conflicts.
|
|
|
|
- *Automatic memory coalescing*
|
|
|
|
When you write Triton code, you also don't need to worry about memory coalescing. The compiler will arrange threads so that global memory accesses are coalesced when possible.
|
|
|
|
- *Automatic tensor core utilization*
|
|
|
|
Using tensor cores on Volta and Turing is notoriously difficult. Code is hard to write and even harder to optimize. Fortunately, the Triton compiler can also generate very efficient tensor core instructions (e.g., :code:`mma.sync.m8n8k4`) when low-precision matrices are multiplied together:
|
|
|
|
.. code-block:: C
|
|
|
|
half A[16, 8] = ... // initialize A
|
|
half B[8, 16] = ... // initialize B
|
|
float C[16, 16] = dot(A, B); // uses Tensor Cores!
|
|
|
|
|
|
- *Automatic instruction predication*
|
|
|
|
Contrary to CUDA, Triton directly exposes predicated instruction through masked load/store instructions. This enables the Triton compiler to generate predicated instructions in PTX directly, resulting in sometimes better performance than I/O operations wrapped inside conditionals.
|
|
|
|
===========================
|
|
Vector Addition - Revisited
|
|
===========================
|
|
|
|
In light of these optimizations, it turns out that the GPU code generated by our Triton-C vector addition code is actually more analogous to the following:
|
|
|
|
.. code-block:: C
|
|
|
|
// CUDA
|
|
// launch on a grid of (N + TILE - 1) / TILE programs
|
|
__global__ void add(float *z, float *x, float *y, int N) {
|
|
int off[4];
|
|
#pragma unroll
|
|
for(int k = 0; k < 4; k++)
|
|
off[k] = blockIdx.x * TILE + threadIdx.x + k * blockSize.x;
|
|
#pragma unroll
|
|
for(int k = 0; k < 4; k++)
|
|
z[off[0]] = x[off[0]] + y[off[0]]
|
|
}
|
|
|
|
This code is actually more complicated when x, y and z have :code:`half` type, because then the Triton compiler automatically vectorizes data accesses using :code:`half2` to ensure memory transactions of 32-bits per thread.
|
|
|
|
============================
|
|
Auto-Tuning
|
|
============================
|
|
|
|
Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example:
|
|
|
|
.. code-block:: python
|
|
|
|
_vector_add.kernel(y, x, N, grid=grid,
|
|
defines={'TILE': [256, 512, 1024]},
|
|
num_warps = [2, 4, 8])
|
|
|
|
would benchmark our above triton-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
|
|
|
|
=============================
|
|
Going Further
|
|
=============================
|
|
|
|
The benefits of Triton become more and more pronounced as compute kernels get more and more complex. In the next few tutorials, you will see how to implement transposition and tensor-core-compatible matrix multiplication routine on par with cuBLAS and CUTLASS without having to know anything about GPU micro-architecture! |