[documentation][triton-c] clearer motivations; now starting each snippet
with the language it's written in
This commit is contained in:
112
docs/triton-c.md
112
docs/triton-c.md
@@ -2,8 +2,6 @@
|
||||
|
||||
## <span style="color:darkred"> Table of Contents </span>
|
||||
1. [Motivations](#motivations)
|
||||
1. [Issues of C/C++ for Linear Algebra](#issues-c-c++)
|
||||
2. [Design Philosophy of Triton-C](#design-philosophy)
|
||||
2. [Vector Addition](#vector-addition)
|
||||
1. [Differences with CUDA](#differences-with-cuda)
|
||||
2. [Advantages over CUDA](#advantages-over-cuda)
|
||||
@@ -18,50 +16,47 @@
|
||||
2. [Optimizations](#optimizations)
|
||||
1. [Pre-Fetching](#pre-fetching)
|
||||
1. [Rematerialization](#rematerialization)
|
||||
3. [Fused Transpositions](#fused-trans)
|
||||
3. [Fused Transpositions and Auto-Tuning](#fused-trans-autotuning)
|
||||
|
||||
## <span style="color:darkred"> Motivations </span> <a name="motivations"></a>
|
||||
|
||||
## <span style="color:darkblue"> Issues of C/C++ for Linear Algebra </span> <a name="issues-c-c++"></a>
|
||||
|
||||
In C and C++, arrays and pointers have similar semantics. Indeed, there is no way to manipulate statically shaped multi-dimensional arrays (beyond initialization) as a whole without resorting to third-party libraries:
|
||||
In C and C++, arrays and pointers have similar semantics. Indeed, there is no native way to manipulate statically shaped multi-dimensional arrays (beyond initialization) as a whole:
|
||||
|
||||
```c
|
||||
// C99
|
||||
float x[16][8] = {3.14};
|
||||
float y[16][8] = {5.17};
|
||||
// z = x + y
|
||||
float z[16][8];
|
||||
for(int i = 0; i < 16; i ++)
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 16; i++)
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++)
|
||||
z[i][j] = x[i][j] + y[i][j];
|
||||
```
|
||||
|
||||
This issue can be somewhat mitigated using templates metaprogramming in C++:
|
||||
While it does not seem like a big deal at first sight, there are two issues with this:
|
||||
|
||||
- **Ergonomics**: Of course, it is possible to simplify the above code using functions in C
|
||||
```
|
||||
float z[16][8];
|
||||
add(z, x, y, 16, 8);
|
||||
```
|
||||
but this would be semantically different as the loops can no longer be unrolled due to their bounds being now dynamic arguments of the add function. This can be mitigated using templates metaprogramming (and operator overloads) in C++:
|
||||
|
||||
```c
|
||||
// C++
|
||||
template<typename T, int M, int N>
|
||||
class matrix;
|
||||
|
||||
matrix<float, 16, 8> x = {3.14};
|
||||
matrix<float, 16, 8> y = {5.17};
|
||||
matrix<float, 16, 8> z = x + y;
|
||||
```
|
||||
|
||||
This is better, but there are still some important issues with this approach:
|
||||
|
||||
- The syntax could be better, especially when it comes to broadcasting and reshaping.
|
||||
|
||||
- Data-flow information for array operations does not propagate beyond the program's AST, thereby making it difficult for compilers to optimize moderately complicated array programs (i.e., Matrix-Multiplication). This can be worked around using heavy metaprogramming techniques (see [CUTLASS](https://github.com/NVIDIA/cutlass)), but even then programmers still have to allocate and synchronize shared memory manually and endure prohibitively long compilation procedures not easily amenable to auto-tuning.
|
||||
|
||||
For these reasons, most Deep-Learning frameworks still rely heavily on highly optimized subroutines (e.g., BLAS), which makes the development of novel custom primitives time-consuming for experts and challenging for others. This is where Triton comes into play.
|
||||
|
||||
## <span style="color:darkblue"> Design Philosophy of Triton-C </span> <a name="design-philosophy"></a>
|
||||
|
||||
The purpose of Triton is to bring native support for efficient numerical multi-dimensional array operations into a standard procedural languages. We achieve this through:
|
||||
|
||||
* **Triton-C**: Syntactic and semantical extensions to the C language. In particular, native support for reshaping, broadcasting, matrix-multiplication, transposition, etc. This is the object of this tutorial.
|
||||
|
||||
* **Triton-IR**: An LLVM-like IR for array operations, as well as various (automatic memory coalescing, automatic vectorization, shared memory allocation/synchronization, tensor core instruction selection, etc.). Although our system generates Triton-IR programs from Triton-C source-code, this is beyond the scope of this tutorial. More information can be found [here](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf).
|
||||
|
||||
Anyway, the Triton-C code corresponding to the above matrix addition operation can be written and extended as follows:
|
||||
While this is better and now equivalent to our initial code snippet, the syntax is not quite as ergonomically satisfying as what native syntactic support could provide:
|
||||
```c
|
||||
// Triton-C
|
||||
float x[16, 8] = 3.14;
|
||||
float y[16, 8] = 5.17;
|
||||
// float z[8, 8] = x + y; // doesn't compile -- incompatible shapes!
|
||||
@@ -69,10 +64,30 @@ float z[16, 8] = x + y;
|
||||
float u[16] = z[:, +]; // sum along the second axis
|
||||
float v[16, 32] = u[:, newaxis]; // broadcasting along the second axis
|
||||
```
|
||||
which is valid _Triton-C_.
|
||||
|
||||
_Note: Triton-C is single-threaded._
|
||||
|
||||
- **Portability**: One other issue with our initial C99 program is that it is not portable. While it will run well on a single CPU thread, the operation `z = x + y` would perform poorly on a GPU Streaming Processor as it would execute on a single thread only. For this readon, it would have to be rewritten for GPUs as follows:
|
||||
```
|
||||
// CUDA
|
||||
// Launch on a block of 16 x 8 threads
|
||||
float x = 3.14;
|
||||
float y = 5.17;
|
||||
float z = x + y
|
||||
```
|
||||
In Triton-C, the same code can be used on many different platforms (only CPUs and GPUs are supported at the moment).
|
||||
|
||||
- **Performance**: Another issue with our initial C99 code snippet is its performance. Although the loops are unrolled, the program does not carry the data-flow information of array operations. This issue gets more and more problematic as programs get increasingly complex, eventually culminating in matrix multiplication being remarkably hard to optimize.
|
||||
|
||||
This can be worked around using heavy metaprogramming techniques (see [CUTLASS](https://github.com/NVIDIA/cutlass)), but even then programmers still have to allocate and synchronize shared memory manually and endure prohibitively long compilation procedures not easily amenable to auto-tuning. For these reasons, most Deep-Learning frameworks still rely heavily on highly optimized subroutines (e.g., BLAS), which makes the development of novel custom primitives time-consuming for experts and challenging for others.
|
||||
|
||||
Triton addresses this issue by relying on **Triton-IR**, an LLVM-like IR for array operations, and **Triton-JIT**, an optimizing compiler for Triton-IR. These two systems are, however, beyond the scope of this tutorial. More information can be found [here](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf).
|
||||
|
||||
|
||||
_Note: You might be thinking that this is exactly what [MLIR](https://github.com/tensorflow/mlir) was made for... and you're right! You can think of Triton-IR as a dialect for MLIR, and Triton-C as a frontend for it. If you're interested in making this a thing, let me know._
|
||||
|
||||
Of course, we can do much more than additions, reduction and broadcasting. The purpose of this tutorial is to walk you through all the features of Triton-C, and eventually show you how it can be used to build auto-tuned matrix-multiplication kernels on par with state-of-the-art CUDA-C implementation in less than an afternoon.
|
||||
|
||||
_Note: You might be thinking that this is exactly what [MLIR](https://github.com/tensorflow/mlir) was made for... and you're right! You can think of Triton-IR conceptually as a dialect for MLIR, and Triton-C as a frontend for it -- although not yet implemented as such. If you're interested in making this a thing, let me know._
|
||||
|
||||
## <span style="color:darkred"> Vector Addition </span> <a name="vector-addition"></a>
|
||||
|
||||
@@ -81,16 +96,18 @@ _Note: You might be thinking that this is exactly what [MLIR](https://github.com
|
||||
Let's start it off by looking at a simple example. Vector addition, in its most trivial Triton-C implementation, can be written as follows:
|
||||
|
||||
```c
|
||||
// Triton-C
|
||||
// launched on a grid of (N / 32) programs of 1 thread each
|
||||
__global__ void add(int N, float *a, float *b, float* c) {
|
||||
int id = get_program_id(0);
|
||||
int off[32] = id * 32 + (0 ... 32)
|
||||
*(c + off) = *(a + off) + *(b + off)
|
||||
*(c + off) = *(a + off) + *(b + off);
|
||||
}
|
||||
```
|
||||
For reference, here is an equivalent CUDA kernel (nvcc will generate the same PTX code as triton-jit on the above code):
|
||||
For reference, here is an equivalent CUDA kernel (NVCC will generate the same PTX code as Triton-JIT on the above code):
|
||||
|
||||
```c
|
||||
// CUDA
|
||||
// launched on a grid of (N / 32) programs of 32 threads each
|
||||
__global__ void add(int N, float *a, float *b, float *c) {
|
||||
int off = blockIdx.x * 32 + threadIdx.x;
|
||||
@@ -98,10 +115,10 @@ __global__ void add(int N, float *a, float *b, float *c) {
|
||||
}
|
||||
```
|
||||
|
||||
As you can see, there are three main differences between our Triton-C kernel and the equivalent CUDA-C:
|
||||
As you can see, there are three main differences between our Triton-C kernel and the equivalent CUDA:
|
||||
|
||||
- **The programming model is different**.
|
||||
While Triton-C and CUDA-C both use a Single-Program, Multiple-Data (SPMD) programming model, each Triton-C kernel is single-threaded.
|
||||
While Triton-C and CUDA both use a Single-Program, Multiple-Data (SPMD) programming model, each Triton-C kernel is single-threaded.
|
||||
Therefore, `get_program_id({0, 1, 2})` is equivalent to `blockIdx.{x, y, z}`, but there is no such thing as `blockDim` and `threadIdx`.
|
||||
|
||||
- **The semantics of arrays is different**
|
||||
@@ -113,7 +130,7 @@ float* c_ptr[32] = c_broadcast + off; // c_ptr = c + off
|
||||
```
|
||||
|
||||
- **The semantics of the subscript operator is different**.
|
||||
n C/CUDA-C, subscripting can be used to offset and dereference a pointer, but in Triton-C it can only be used to index and broadcast an array (think NumPy).
|
||||
n C/CUDA, subscripting can be used to offset and dereference a pointer, but in Triton-C it can only be used to index and broadcast an array (think NumPy).
|
||||
|
||||
### <span style="color:darkblue"> Advantages over CUDA </span> <a name="advantages-over-cuda"></a>
|
||||
|
||||
@@ -123,6 +140,7 @@ At this point, the advantages of Triton-C over CUDA may not be obvious. But they
|
||||
|
||||
On some hardware architectures, vectorizing load/store operations can lead to better memory utilization and, in turn, noticeable performance gains. In general, 128-bit memory transactions are favored, leading to the following CUDA kernel:
|
||||
```c
|
||||
// CUDA
|
||||
// launched on a grid of (N / 128) programs of 32 threads each
|
||||
__global__ void add(int N, float4 *a, float4 *b, float4 *c) {
|
||||
int off = blockIdx.x * 32 + threadIdx.x;
|
||||
@@ -131,6 +149,7 @@ __global__ void add(int N, float4 *a, float4 *b, float4 *c) {
|
||||
```
|
||||
Or, for half-precision inputs:
|
||||
```c
|
||||
// CUDA
|
||||
// launched on a grid of (N / 256) programs of 32 threads each
|
||||
__global__ void add(int N, half8 *a, half8 *b, half8 *c) {
|
||||
int off = blockIdx.x * 32 + threadIdx.x;
|
||||
@@ -146,11 +165,12 @@ In Triton-C, this is not a problem as the compiler will figure out automatically
|
||||
|
||||
Specifically, the Triton compiler would refuse to 4-way vectorize our above compute kernel because it would require the array `int off[32]` to be distributed over 8 threads, which is less than a warp. Fortunately, it turns out that this problem can be easily solved using preprocessor directrives to _parameterize_ our kernel:
|
||||
```c
|
||||
// Triton-C
|
||||
// launched on a grid of (N / SIZE) programs of 1 thread each
|
||||
__global__ void add(int N, TYPE* a, TYPE* b, TYPE* c) {
|
||||
int id = get_program_id(0);
|
||||
int off[SIZE] = id * SIZE + (0 ... SIZE)
|
||||
*(c + off) = *(a + off) + *(b + off)
|
||||
int off[SIZE] = id * SIZE + (0 ... SIZE);
|
||||
*(c + off) = *(a + off) + *(b + off);
|
||||
}
|
||||
// Not vectorized when compiled with -DSIZE=32 -DTYPE=float
|
||||
// 4-Vectorized when compiled with -DSIZE=128 -DTYPE=float
|
||||
@@ -169,7 +189,7 @@ _Note: Tuning our reference CUDA kernel would be much more cumbersome, as templa
|
||||
|
||||
## <span style="color:darkred"> Matrix Transposition </span> <a name="matrix-transposition"></a>
|
||||
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA as naive implementations typically suffer from _uncoalesced_ memory operations when writing back the transposed matrix to DRAM. Of course, this can be fixed by using shared memory as shown [here](https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/), but this comes at the cost of simplicity and -- more importantly -- interferes with auto-tuning.
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from _uncoalesced_ memory operations when writing back the transposed matrix to DRAM. Of course, this can be fixed by using shared memory as shown [here](https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/), but this comes at the cost of simplicity and -- more importantly -- interferes with auto-tuning.
|
||||
|
||||
### <span style="color:darkblue"> Compute Kernel </span> <a name="trans-compute-kernel"></a>
|
||||
|
||||
@@ -192,7 +212,7 @@ __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
|
||||
}
|
||||
```
|
||||
|
||||
At a high level, this kernel loads a `TM x TN` tile from the input matrix `X`, transposes it and writes the resulting `TN x TM` tile to the output matrix `Y`. As a result, transposition of the full input matrix is achieved by launching a grid of `(M / TM) x (N / TN)` programs decomposed as follows:
|
||||
At a high level, this kernel loads a `TM x TN` tile from the input matrix `X`, transposes it and writes the resulting `TN x TM` tile to the output matrix `Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of `(M / TM) x (N / TN)` programs decomposed as follows:
|
||||
|
||||
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
|
||||
```
|
||||
@@ -239,21 +259,22 @@ __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
|
||||
}
|
||||
```
|
||||
|
||||
Here, statements (7a) creates an array of booleans `checkx[TM, TN]` such that `checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Then, both `px` and `py` can be conditionally dereferenced using Triton-C's conditional dereferencing operator `*?(predicate) pointer`.
|
||||
Here, statements (7a) creates an array of booleans `checkx[TM, TN]` such that `checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator `*?(predicate) pointer`.
|
||||
|
||||
|
||||
## <span style="color:darkred"> Matrix Multiplication </span> <a name="matrix-multiplication"></a>
|
||||
|
||||
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA-C kernels (see [CUTLASS](https://github.com/NVIDIA/cutlass)). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
|
||||
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see [CUTLASS](https://github.com/NVIDIA/cutlass)). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
|
||||
|
||||
_Note: Bounds-checking is ommitted for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops._
|
||||
_Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops._
|
||||
|
||||
### <span style="color:darkblue"> Compute Kernel </span> <a name="matmul-compute-kernel"></a>
|
||||
|
||||
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
|
||||
|
||||
```c
|
||||
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
|
||||
// Triton-C
|
||||
// launched on a grid of (M / TM) x (N / TN) programs
|
||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
@@ -308,7 +329,7 @@ Internally, the Triton compiler will perform quite a few optimizations that will
|
||||
|
||||
### <span style="color:darkblue"> Optimizations </span> <a name="optimizations"></a>
|
||||
|
||||
Nonetheless, there are two important optimizations that the Triton compiler does not do at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
|
||||
Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
|
||||
|
||||
#### <span style="color:purple"> Pre-Fetching </span> <a name="pre-fetching"></a>
|
||||
|
||||
@@ -344,11 +365,13 @@ TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
```
|
||||
|
||||
### <span style="color:darkblue"> Fused Transpositions </span> <a name="fused-trans"></a>
|
||||
### <span style="color:darkblue"> Fused Transpositions and Auto-Tuning </span> <a name="fused-trans-autotuning"></a>
|
||||
|
||||
It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. This is also what is done in the [PyTriton](https://github.com/ptillet/triton/blob/master/python/triton/ops/dot.py) implementation of matrix-multiplication. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
|
||||
|
||||
```c
|
||||
// Triton-C
|
||||
// launched on a grid of (M / TM) x (N / TN) programs
|
||||
void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
@@ -393,4 +416,9 @@ All matrix multiplications variants can then be retrieved using the following co
|
||||
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb -DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
|
||||
```
|
||||
|
||||
Auto-tuning can also be handled using pre-processor macros:
|
||||
```c
|
||||
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
|
||||
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
|
||||
```
|
||||
|
||||
|
Reference in New Issue
Block a user