[DOC] Basic Sphynx skeleton

This commit is contained in:
Philippe Tillet
2020-02-10 00:33:03 -05:00
committed by Philippe Tillet
parent d22cf4f717
commit 12a621bb2f
11 changed files with 640 additions and 436 deletions

9
docs/tutorials/index.rst Normal file
View File

@@ -0,0 +1,9 @@
Tutorials
==========
.. toctree::
:maxdepth: 1
triton-vs-cuda
matrix-transposition
matrix-multiplication

View File

@@ -0,0 +1,184 @@
*********************
Matrix Multiplication
*********************
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see `CUTLASS <https://github.com/NVIDIA/cutlass>`_). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
*Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops.*
==============
Compute Kernel
==============
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
.. code-block:: C
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0); //(1)
int pn = get_program_id(1); //(2)
int rm[TM] = pm * TM + 0 ... TM; //(3)
int rn[TN] = pn * TN + 0 ... TN; //(4)
int rk[TK] = 0 ... TK; //(5)
// initialize accumulator
float c[TM, TN] = 0; //(6)
// pointers to operands
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; //(8)
// reduction loop
for(int k = K; k > 0; k-= TK){
// fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
// matrix-multiply accumulate
c += a @ b; //(11)
// increment pointers
pa = pa + TK * 1; //(12)
pb = pb + TK * ldb; //(13)
}
// epilogue
TYPE* pc[TM, TN] = C + rn[newaxis, :] + rm[:, newaxis] * ldc; //(14)
*pc = c; //(15)
}
Here, each kernel instance produces a :code:`TM x TN` tile of the output matrix C as follows:
- Statements (1) - (2) fetch the id of the current program instance.
- Statements (3) - (4) construct ranges of indices to process for the vertical and horizontal axes of the output matrix :code:`C`
- Statement (5) constructs a range of indices along the reduction axis: :code:`rk = [0, 1, ..., TK - 1]`
- Statement (6) initialize a :code:`TM x TN` array of accumulators to hold the result of :code:`A[rm, :] x B[:, rn]`
- Statements (7) - (8) initializes arrays of pointers :code:`pa` and :code:`pb` to the operands :code:`A` and :code:`B` using logic similar to that of the above transposition kernel
- Statements (9) - (10) load tiles of operands by dereferencing :code:`pa` and :code:`pb`
- Statement (11) performs updates the accumulator array using Triton-C's matrix multiplication operator :code:'@'
- Statements (12) - (13) updates :code:`pa` and :code:`pb`
- Statement (14) creates an array of pointers `pc` to the result matrix :code:`C`
- Statement (15) writes back the accumulator to :code:`C`
Internally, the Triton compiler will perform quite a few optimizations that will ensure good performance for this kernel:
- Automatic coalescing of load/store operations
- Automatic vectorization of load/store operations
- Stashing `a` and `b` to shared memory
- Automatic allocation of shared memory
- Automatic synchronization of shared memory
- Automatic padding of shared memory to avoid bank conflicts
- Automatic usage of tensor cores when TYPE = half and TK % 4 = 0
==============
Optimizations
==============
Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
-------------
Pre-Fetching
-------------
The purpose of pre-fetching is to overlap the update of the accumulator `c` with the memory loads for the next tiles that will need to be multiplied. This can be done by modifying the above reduction loop as follows:
.. code-block:: C
// pre-fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
for(int k = K; k > 0; k-= TK){
c += a @ b;
pa = pa + TK * 1;
pb = pb + TK * ldb;
// don't prefetch last iteration
bool check = k > TK;
// pre-fetch operands
a = check ? *pa : 0;
b = check ? *pb : 0;
}
Note that the Triton-C compiler will now also be able to use double-buffering techniques to make sure that the array `a` can be used and updated at the same time without any memory hazard.
-----------------
Rematerialization
-----------------
`Rematerialization <https://en.wikipedia.org/wiki/Rematerialization>`_ is a compiler optimization which consists in recomputing some values instead of storing and reloading them from (register) memory, so as to decrease register pressure in the compute kernel. Although LLVM does this automatically to some extent, it fails to find good heuristics for the above kernel -- thereby requiring some source code modification to achieve optimal performance. Fortunately, only :code:`rm` and :code:`rn` need to be rematerialized, leading to the following epilogue:
.. code-block:: C
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
------------------------------------
Fused Transpositions and Auto-Tuning
------------------------------------
It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
.. code-block:: C
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0);
int pn = get_program_id(1);
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
}
All matrix multiplications variants can then be retrieved using the following compilation option:
.. code-block:: C
// A is not transposed
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
// A is transposed
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
// B is not transpose
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
// B is transpose
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
Auto-tuning can also be handled using pre-processor macros:
.. code-block:: C
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]

View File

@@ -0,0 +1,113 @@
*********************
Matrix Transpositions
*********************
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
==============
Compute Kernel
==============
In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory. Therefore, an optimal Triton kernel for this operation would look like:
.. code-block:: C
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
__global__ void transpose(TYPE * X, TYPE * Y,
int M, int N, int ldx, int ldy) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6)
// write back using the transposition operator '^'
*py = ^(*px); //(7)
}
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
.. code-block:: C
pidm = 2
pidn = 1
- Statements (3) and (4) construct the ranges of indices:
.. code-block:: C
rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)]
rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
which will be used in statements (5) and (6) to construct tiles of pointers
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
.. code-block:: C
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
│ ⋮ ⋮ │
│ ⋮ ⋮ │
│ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │
- Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics:
.. code-block:: C
Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy)
Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy)
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
==========================
The __multipleof attribute
==========================
The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible.
For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows:
.. code-block:: C
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
int ldx __multipleof(8),
int ldy __multipleof(8)) {
// ...
}
==========================
Bounds Checking
==========================
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
.. code-block:: C
// launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
// ...
// create bounds-checking mask
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
bool checky[TN, TM] = (rm[newaxis, :] < M) && (rn[:, newaxis] < N); //(7b)
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.

243
docs/tutorials/pytriton.md Normal file
View File

@@ -0,0 +1,243 @@
# The PyTriton API
## <span style="color:darkred"> Table of Contents </span>
1. [Motivations](#motivations)
2. [Triton Functions](#pytriton-function)
1. [Creation of Triton Kernels](#creation-triton-kernels)
2. [Usage of Triton Kernels](#usage-triton-kernels)
3. [Integration with Automatic Differentiation](#autodiff)
1. [Basics](#autodiff-basics)
2. [Convenience](#autodiff-convenience)
## <span style="color:darkred"> Motivations </span> <a name="motivations"></a>
The purpose of PyTriton is to provide an API for easily executing Triton-C kernels from PyTorch and Tensorflow. One of the main advantages of PyTriton is that it is framework agnostic: any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required, as will be shown in this tutorial.
Consider for example the following piece of code:
```python
import numpy as np
import triton
def run_tf():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
da, db = tf.gradients(c, [a, b])
# Run
ha = np.random.rand(M, K).astype(np.float32)
hb = np.random.rand(K, N).astype(np.float32)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([da], feed_dict = {a: ha, b: hb})
def run_torch():
M, N, K = 128, 128, 128
a = torch.randn(M, K).cuda()
b = torch.randn(K, N).cuda()
a.requires_grad_(True)
b.requires_grad_(True)
c = triton.ops.dot(a, b, False, True)
c.backward()
da = a.grad.clone()
db = b.grad.clone()
## Run on tensorflow
# import tensorflow as tf
# run_tf()
## Run on pytorch
# import torch
# run_torch()
```
PyTriton works by detecting which frameworks are imported and automatically generating and just-in-time compiling C++ binding code for them. Specifically, the following chain of events is triggered when a Triton operation is executed:
1. The imported frameworks are detected
2. C++ binding code for Tensorflow or PyTorch is generated, compiled and cached.
3. The corresponding custom-op is automatically loaded from the generated .so file, and a framework-agnostic wrapper is created.
4. The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
The remainder of this tutorial will show you how to re-implement the above `triton.ops.dot` operation from scratch.
## <span style="color:darkred"> PyTriton Functions </span> <a name="pytriton-function"></a>
The PyTriton API provides a `triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
```python
import triton
# Entry point
class _dot(triton.function):
@staticmethod
# Forward Pass
def forward(ctx, *args):
#...
@staticmethod
# Backward Pass
def backward(ctx, dy):
#...
```
### <span style="color:darkblue">Creation of Triton Kernels </span> <a name="creation-triton-kernel"></a>
PyTriton also provides a `triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton-C code derived at the end of the [previous tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
```
src = """
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0);
int pn = get_program_id(1);
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
}
}
"""
kernel = triton.kernel(src, ['C'])
```
Note that the second argument to `triton.kernel` constructors indicates which of the operands our kernel function should return. Here, we only return `C`.
At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors:
```
[tensor, tensor, tensor, int, int, int, int, int, int]
```
### <span style="color:darkblue">Usage of Triton Kernels </span> <a name="usage-triton-kernels"></a>
However, in practice only A, B are provided by the user, and all the other `int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the `A` and `B` tensors, and then returns the results of a call to `kernel`:
```python
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if transpose_a:
M, Ka = Ka, M
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# compute
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# macros -- not necessary but makes kernel source-code simpler
macros = {# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [32, 64, 128], TN = [32, 64, 128], TK = [8], **macros)
```
While this code should be mostly self-explanatory, there are a few of noteworthy things worth pointing out
- `triton.shape` provides a framework-agnostic way to retrieve the shape of a tensor
- `triton.empty` creates an empty tensor of the specified dimensions
- `grid` corresponds to the grid with which our Triton kernel will be launched. Because in our case this grid depends on parametric tile variables, it is supplied as a function of compilation options `opt`, whose compile-time definition can be retrieved using `opt.d(name)`. Here, `opt.d('TM')` and `opt.d('TN')` retrieve the first and second tile dimension our kernel was compiled with. We also provide a helper `triton.cdiv` for ceil divisions.
- `macros` provides a list of preprocessor definitions to compile the kernel with. Alternatively, these can also be supplied as named argument to the `_dot.kernel`. We recall that lists can be supplied to the preprocessor, in which case an auto-tuning procedure will be triggered. Here, the value of `TM` and `TN` are both tuned between 32, 64 and 128.
## <span style="color:darkred"> Compatibility with Automatic Differentiation</span> <a name="autodiff"></a>
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
### <span style="color:darkblue"> Basics </span> <a name="autodiff-basics"></a>
PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework using a single, common API inspired by PyTorch. It consists of two static methods `forward` and `backward` that take a context as their first input:
```
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
return _dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b
if not t_a and not t_b:
da = _dot._call(dy, b, False, True)
db = _dot._call(a, dy, True, False)
elif not t_a and t_b:
da = _dot._call(dy, b, False, False)
db = _dot._call(dy, a, True, False)
elif t_a and not t_b:
da = _dot._call(b, dy, False, True)
db = _dot._call(a, dy, False, False)
elif t_a and t_b:
da = _dot._call(b, dy, True, True)
db = _dot._call(dy, a, True, True)
else:
assert False
return da, db, None, None, None, None, None, None, None
```
### <span style="color:darkblue">Convenience </span> <a name="autodiff-convenience"></a>
Still like for PyTorch, a callable operation can be created using the `apply` method of our `triton.function` class. We wrap it as a module variable for convenience:
```python
dot = _dot.apply
```
And that's it! Our custom op is now created and ready to be used with both PyTorch and Tensorflow.

View File

@@ -0,0 +1,111 @@
***************
Triton vs. CUDA
***************
The purpose of this tutorial is to explore in more depth the major differences between Triton and CUDA. To keep things simple, we will still be focusing on the following vector addition code:
.. code-block:: C
// Triton
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float* z, float* x, float* y, int N){
int offset[TILE] = get_program_id(0) * TILE + 0 ... TILE;
bool check[TILE] = offset < N;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
*?(check)pz = *?(check)*px + *?(check)py;
}
And its CUDA equivalent:
.. code-block:: C
// CUDA
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float *z, float *x, float *y, int N) {
int off = blockIdx.x * TILE + threadIdx.x;
if(off < N){
float *pz = z + off;
float *px = x + off;
float *py = y + off;
*pz = *px + *py
}
}
==========================
Automatic parallelization
==========================
While the two above pieces of code may look at first sight similar, a closer look reveals one *fundamental* difference: While CUDA kernels are launched on a cooperative array of threads, **Triton kernel are single-threaded and automatically parallelized**.
This is a major difference in programming model, which not only makes your life much easier as a programmer, but also allows the Triton compiler to automatically do all sorts of nice optimizations:
- *Automatic shared memory allocation and synchronization*
That's right; programmers don't need to worry about shared memory allocation, usage and synchronization. Instead, the Triton compiler will use complex program analysis techniques to determine when shared memory should be used, where it should be synchronized and how threads should access it to avoid memory bank conflicts.
- *Automatic memory coalescing*
When you write Triton code, you also don't need to worry about memory coalescing. The compiler will arrange threads so that global memory accesses are coalesced when possible.
- *Automatic tensor core utilization*
Using tensor cores on Volta and Turing is notoriously difficult. Code is hard to write and even harder to optimize. Fortunately, the Triton compiler can also generate very efficient tensor core instructions (e.g., :code:`mma.sync.m8n8k4`) when low-precision matrices are multiplied together:
.. code-block:: C
half A[16, 8] = ... // initialize A
half B[8, 16] = ... // initialize B
float C[16, 16] = dot(A, B); // uses Tensor Cores!
- *Automatic instruction predication*
Contrary to CUDA, Triton directly exposes predicated instruction through masked load/store instructions. This enables the Triton compiler to generate predicated instructions in PTX directly, resulting in sometimes better performance than I/O operations wrapped inside conditionals.
===========================
Vector Addition - Revisited
===========================
In light of these optimizations, it turns out that the GPU code generated by our Triton-C vector addition code is actually more analogous to the following:
.. code-block:: C
// CUDA
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float *z, float *x, float *y, int N) {
int off[4];
#pragma unroll
for(int k = 0; k < 4; k++)
off[k] = blockIdx.x * TILE + threadIdx.x + k * blockSize.x;
#pragma unroll
for(int k = 0; k < 4; k++)
z[off[0]] = x[off[0]] + y[off[0]]
}
This code is actually more complicated when x, y and z have :code:`half` type, because then the Triton compiler automatically vectorizes data accesses using :code:`half2` to ensure memory transactions of 32-bits per thread.
============================
Auto-Tuning
============================
Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example:
.. code-block:: python
_vector_add.kernel(y, x, N, grid=grid,
defines={'TILE': [256, 512, 1024]},
num_warps = [2, 4, 8])
would benchmark our above triton-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
=============================
Going Further
=============================
The benefits of Triton become more and more pronounced as compute kernels get more and more complex. In the next few tutorials, you will see how to implement transposition and tensor-core-compatible matrix multiplication routine on par with cuBLAS and CUTLASS without having to know anything about GPU micro-architecture!