[PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
@@ -11,7 +11,6 @@ Welcome to Triton's documentation!
|
||||
:caption: Contents:
|
||||
|
||||
installation/index
|
||||
tutorials/index
|
||||
|
||||
|
||||
Indices and tables
|
||||
|
@@ -1,102 +0,0 @@
|
||||
===========================
|
||||
Writing a Custom Operation
|
||||
===========================
|
||||
|
||||
--------------
|
||||
Compute Kernel
|
||||
--------------
|
||||
|
||||
Let us start with something simple, and see how Triton can be used to create a custom vector addition for PyTorch. The Triton compute kernel for this operation is the following:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
// bounds checking
|
||||
bool check[TILE] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
|
||||
As you can see, arrays are first-class citizen in Triton. This has a number of important advantages that will be highlighted in the next tutorial. For now, let's keep it simple and see how to execute the above operation in PyTorch.
|
||||
|
||||
---------------
|
||||
PyTorch Wrapper
|
||||
---------------
|
||||
|
||||
As you will see, a wrapper for the above Triton function can be created in just a few lines of pure python code.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(torch.autograd.Function):
|
||||
# source-code for Triton compute kernel
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
// bounds checking
|
||||
bool check[TILE] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
# create callable kernel for the source-code
|
||||
# options: 4 warps and a -DTILE=1024
|
||||
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
|
||||
|
||||
# Forward pass
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# type checking
|
||||
assert x.dtype == torch.float32
|
||||
# allocate output
|
||||
z = torch.empty_like(x).cuda()
|
||||
# create launch grid
|
||||
# this is a function of the launch parameters
|
||||
# triton.cdiv indicates ceil division
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )
|
||||
# launch kernel
|
||||
_add.kernel(z, x, y, N, grid = grid)
|
||||
# return output
|
||||
return z
|
||||
|
||||
# get callable from Triton function
|
||||
add = _add.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432).cuda()
|
||||
y = torch.rand(98432).cuda()
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
diff = (za - zb).abs().max()
|
||||
print(diff)
|
||||
print(torch.allclose(za,zb))
|
||||
|
||||
Executing the above code will:
|
||||
|
||||
- Generate a .cpp file containing PyTorch bindings for the Triton function
|
||||
- Compile this .cpp file using distutils
|
||||
- Cache the resulting custom op
|
||||
- Call the resulting custom op
|
||||
|
||||
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.
|
@@ -1,10 +0,0 @@
|
||||
Tutorials
|
||||
==========
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
custom-operation
|
||||
triton-vs-cuda
|
||||
matrix-transposition
|
||||
matrix-multiplication
|
@@ -1,186 +0,0 @@
|
||||
*********************
|
||||
Matrix Multiplication
|
||||
*********************
|
||||
|
||||
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see `CUTLASS <https://github.com/NVIDIA/cutlass>`_). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
|
||||
|
||||
*Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops.*
|
||||
|
||||
==============
|
||||
Compute Kernel
|
||||
==============
|
||||
|
||||
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton-C
|
||||
// launched on a grid of (M / TM) x (N / TN) programs
|
||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pm = get_program_id(0); //(1)
|
||||
int pn = get_program_id(1); //(2)
|
||||
int rm[TM] = pm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pn * TN + 0 ... TN; //(4)
|
||||
int rk[TK] = 0 ... TK; //(5)
|
||||
// initialize accumulator
|
||||
float c[TM, TN] = 0; //(6)
|
||||
// pointers to operands
|
||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
|
||||
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; //(8)
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
// fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
// matrix-multiply accumulate
|
||||
c += a @ b; //(11)
|
||||
// increment pointers
|
||||
pa = pa + TK * 1; //(12)
|
||||
pb = pb + TK * ldb; //(13)
|
||||
}
|
||||
// epilogue
|
||||
TYPE* pc[TM, TN] = C + rn[newaxis, :] + rm[:, newaxis] * ldc; //(14)
|
||||
*pc = c; //(15)
|
||||
}
|
||||
|
||||
Here, each kernel instance produces a :code:`TM x TN` tile of the output matrix C as follows:
|
||||
|
||||
- Statements (1) - (2) fetch the id of the current program instance.
|
||||
- Statements (3) - (4) construct ranges of indices to process for the vertical and horizontal axes of the output matrix :code:`C`
|
||||
- Statement (5) constructs a range of indices along the reduction axis: :code:`rk = [0, 1, ..., TK - 1]`
|
||||
- Statement (6) initialize a :code:`TM x TN` array of accumulators to hold the result of :code:`A[rm, :] x B[:, rn]`
|
||||
- Statements (7) - (8) initializes arrays of pointers :code:`pa` and :code:`pb` to the operands :code:`A` and :code:`B` using logic similar to that of the above transposition kernel
|
||||
- Statements (9) - (10) load tiles of operands by dereferencing :code:`pa` and :code:`pb`
|
||||
- Statement (11) performs updates the accumulator array using Triton-C's matrix multiplication operator :code:'@'
|
||||
- Statements (12) - (13) updates :code:`pa` and :code:`pb`
|
||||
- Statement (14) creates an array of pointers `pc` to the result matrix :code:`C`
|
||||
- Statement (15) writes back the accumulator to :code:`C`
|
||||
|
||||
Internally, the Triton compiler will perform quite a few optimizations that will ensure good performance for this kernel:
|
||||
|
||||
- Automatic coalescing of load/store operations
|
||||
- Automatic vectorization of load/store operations
|
||||
- Stashing `a` and `b` to shared memory
|
||||
- Automatic allocation of shared memory
|
||||
- Automatic synchronization of shared memory
|
||||
- Automatic padding of shared memory to avoid bank conflicts
|
||||
- Automatic usage of tensor cores when TYPE = half and TK % 4 = 0
|
||||
|
||||
|
||||
==============
|
||||
Optimizations
|
||||
==============
|
||||
|
||||
Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
|
||||
|
||||
-------------
|
||||
Pre-Fetching
|
||||
-------------
|
||||
|
||||
The purpose of pre-fetching is to overlap the update of the accumulator `c` with the memory loads for the next tiles that will need to be multiplied. This can be done by modifying the above reduction loop as follows:
|
||||
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// pre-fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += a @ b;
|
||||
pa = pa + TK * 1;
|
||||
pb = pb + TK * ldb;
|
||||
// don't prefetch last iteration
|
||||
bool check = k > TK;
|
||||
// pre-fetch operands
|
||||
a = check ? *pa : 0;
|
||||
b = check ? *pb : 0;
|
||||
}
|
||||
|
||||
|
||||
Note that the Triton-C compiler will now also be able to use double-buffering techniques to make sure that the array `a` can be used and updated at the same time without any memory hazard.
|
||||
|
||||
-----------------
|
||||
Rematerialization
|
||||
-----------------
|
||||
|
||||
`Rematerialization <https://en.wikipedia.org/wiki/Rematerialization>`_ is a compiler optimization which consists in recomputing some values instead of storing and reloading them from (register) memory, so as to decrease register pressure in the compute kernel. Although LLVM does this automatically to some extent, it fails to find good heuristics for the above kernel -- thereby requiring some source code modification to achieve optimal performance. Fortunately, only :code:`rm` and :code:`rn` need to be rematerialized, leading to the following epilogue:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// epilogue
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
|
||||
|
||||
------------------------------------
|
||||
Fused Transpositions and Auto-Tuning
|
||||
------------------------------------
|
||||
|
||||
It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton-C
|
||||
// launched on a grid of (M / TM) x (N / TN) programs
|
||||
void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pm = get_program_id(0);
|
||||
int pn = get_program_id(1);
|
||||
int rm[TM] = pm * TM + 0 ... TM;
|
||||
int rn[TN] = pn * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
// pointers to operands
|
||||
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
// prefetches operands
|
||||
TYPE a[SHAPE_A] = (*pa);
|
||||
TYPE b[SHAPE_B] = (*pb);
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USE_A @ USE_B;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
// epilogue
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
}
|
||||
|
||||
|
||||
All matrix multiplications variants can then be retrieved using the following compilation option:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// A is not transposed
|
||||
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
|
||||
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
|
||||
// A is transposed
|
||||
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
|
||||
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
|
||||
// B is not transpose
|
||||
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
|
||||
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
|
||||
// B is transpose
|
||||
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
|
||||
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
|
||||
|
||||
|
||||
Auto-tuning can also be handled using pre-processor macros:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
|
||||
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.
|
@@ -1,174 +0,0 @@
|
||||
*********************
|
||||
Matrix Transpositions
|
||||
*********************
|
||||
|
||||
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
|
||||
|
||||
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
|
||||
|
||||
==============
|
||||
Compute Kernel
|
||||
==============
|
||||
|
||||
In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory. Therefore, an optimal Triton kernel for this operation would look like:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
int M, int N, int ldx, int ldy) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
// write back using the transposition operator '^'
|
||||
*py = ^(*px); //(7)
|
||||
}
|
||||
|
||||
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
|
||||
|
||||
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
pidm = 2
|
||||
pidn = 1
|
||||
|
||||
|
||||
- Statements (3) and (4) construct the ranges of indices:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)]
|
||||
rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
|
||||
|
||||
|
||||
which will be used in statements (5) and (6) to construct tiles of pointers
|
||||
|
||||
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
|
||||
|
||||
::
|
||||
|
||||
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
|
||||
│ ⋮ ⋮ │
|
||||
│ ⋮ ⋮ │
|
||||
│ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │
|
||||
|
||||
|
||||
- Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics:
|
||||
|
||||
::
|
||||
|
||||
│ Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy) │
|
||||
│ ⋮ ⋮ │
|
||||
│ ⋮ ⋮ │
|
||||
│ Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy) │
|
||||
|
||||
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
|
||||
|
||||
|
||||
==================================
|
||||
A Note on Numpy-style Broadcasting
|
||||
==================================
|
||||
|
||||
The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example.
|
||||
|
||||
First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
vec = np.linspace(0,11,12)
|
||||
mat = vec.reshape((4,3))
|
||||
|
||||
Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom).
|
||||
|
||||
::
|
||||
|
||||
[[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8],
|
||||
[ 9, 10, 11]]
|
||||
|
||||
Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values:
|
||||
|
||||
::
|
||||
|
||||
[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
|
||||
We construct ranges `rm` and `rn` as:
|
||||
::
|
||||
|
||||
rm = [0, 1]
|
||||
rn = [0, 1, 2]
|
||||
|
||||
Using numpy-style broadcasting, we can add these together to create a matrix:
|
||||
|
||||
::
|
||||
|
||||
rm[:, np.newaxis] + rn[np.newaxis, :]
|
||||
|
||||
rn -> [0, 1, 2]
|
||||
rm -> [0., [[0, 1, 2],
|
||||
1.] [1, 2, 3]]
|
||||
|
||||
The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index
|
||||
of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need
|
||||
to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension).
|
||||
In this case this is 3, so what we really need is:
|
||||
|
||||
::
|
||||
|
||||
ldx = 3
|
||||
px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:]
|
||||
|
||||
`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example).
|
||||
|
||||
==========================
|
||||
The __multipleof attribute
|
||||
==========================
|
||||
|
||||
The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible.
|
||||
|
||||
For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
|
||||
int ldx __multipleof(8),
|
||||
int ldy __multipleof(8)) {
|
||||
// ...
|
||||
}
|
||||
|
||||
|
||||
==========================
|
||||
Bounds Checking
|
||||
==========================
|
||||
|
||||
|
||||
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
|
||||
// ...
|
||||
// create bounds-checking mask
|
||||
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||
bool checky[TN, TM] = (rm[newaxis, :] < M) && (rn[:, newaxis] < N); //(7b)
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
|
||||
|
||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_transpose.py>`_.
|
@@ -1,180 +0,0 @@
|
||||
====================================================
|
||||
Putting It All Together
|
||||
====================================================
|
||||
|
||||
In the previous tutorial, we saw how to write tensor-core-friendly matrix multiplication code competitive with cuBLAS in 20 lines of Triton code. Here, we will see how to wrap it into an automatically differentiable PyTorch functions for easy integration in your Deep Learning pipeline.
|
||||
|
||||
-----------------
|
||||
PyTriton Function
|
||||
-----------------
|
||||
|
||||
The PyTriton API provides a :code:`triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import triton
|
||||
|
||||
# Entry point
|
||||
class _dot(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
# Forward Pass
|
||||
def forward(ctx, *args):
|
||||
#...
|
||||
|
||||
@staticmethod
|
||||
# Backward Pass
|
||||
def backward(ctx, dy):
|
||||
#...
|
||||
|
||||
-----------------
|
||||
PyTriton Kernels
|
||||
-----------------
|
||||
|
||||
|
||||
PyTriton also provides a :code:`triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton code shown at the end of the previous tutorial.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
src = """
|
||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pm = get_program_id(0);
|
||||
int pn = get_program_id(1);
|
||||
int rm[TM] = pm * TM + 0 ... TM;
|
||||
int rn[TN] = pn * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
// pointers to operands
|
||||
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
// prefetches operands
|
||||
TYPE a[SHAPE_A] = (*pa);
|
||||
TYPE b[SHAPE_B] = (*pb);
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USE_A @ USE_B;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
// epilogue
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src)
|
||||
|
||||
|
||||
At this point, `kernel` is a callable object which takes the same signature as the :code:`dot` function in our source code, except that pointers are treated as tensors: :code:`[tensor, tensor, tensor, int, int, int, int, int, int]`.
|
||||
|
||||
-----------------------
|
||||
Using PyTriton Kernels
|
||||
-----------------------
|
||||
|
||||
|
||||
However, in practice only A, B are provided by the user, and all the other :code:`int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the :code:`A` and :code:`B` tensors, and then returns the results of a call to :code:`kernel`:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, transpose_a, transpose_b):
|
||||
# extract shapes
|
||||
shape_a = a.shape
|
||||
shape_b = b.shape
|
||||
M, Ka = shape_a[0], shape_a[1]
|
||||
Kb, N = shape_b[0], shape_b[1]
|
||||
# transpose shapes
|
||||
if transpose_a:
|
||||
M, Ka = Ka, M
|
||||
if transpose_b:
|
||||
Kb, N = N, Kb
|
||||
# contiguous dimensions
|
||||
lda = M if transpose_a else Ka
|
||||
ldb = Kb if transpose_b else N
|
||||
ldc = N
|
||||
# data-type
|
||||
dtype = a.dtype
|
||||
# allocate output
|
||||
c = triton.empty([M, N], dtype = dtype)
|
||||
# launch grid
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
# pre-processor definitions
|
||||
defines = {# tile sizes
|
||||
'TYPE' : dtype,
|
||||
'AT' : transpose_a,
|
||||
'BT' : transpose_b,
|
||||
'TM' : [32, 64, 128],
|
||||
'TN' : [32, 64, 128],
|
||||
'TK' : [8],
|
||||
# handle A transposition
|
||||
'USE_A' : '^a' if transpose_a else 'a',
|
||||
'STRIDE_AK' : 'lda' if transpose_a else '1',
|
||||
'STRIDE_AM' : '1' if transpose_a else 'lda',
|
||||
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
|
||||
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
|
||||
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
|
||||
# handle B transposition
|
||||
'USE_B' : '^b' if transpose_b else 'b',
|
||||
'STRIDE_BK' : '1' if transpose_b else 'ldb',
|
||||
'STRIDE_BN' : 'ldb' if transpose_b else '1',
|
||||
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
||||
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
||||
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
||||
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||
grid=grid, num_warps=4, defines=defines)
|
||||
|
||||
|
||||
--------------------------------------------
|
||||
Automatic Differentiation
|
||||
--------------------------------------------
|
||||
|
||||
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
|
||||
|
||||
|
||||
Creating custom operations for Triton and PyTorch is very similar; programmers have to provide two static methods :code:`forward` and :code:`backward` that take a context as their first input:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.t_a = transpose_a
|
||||
ctx.t_b = transpose_b
|
||||
return _dot._call(a, b, transpose_a, transpose_b)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
a, b = ctx.saved_tensors
|
||||
t_a, t_b = ctx.t_a, ctx.t_b
|
||||
if not t_a and not t_b:
|
||||
da = _dot._call(dy, b, False, True)
|
||||
db = _dot._call(a, dy, True, False)
|
||||
elif not t_a and t_b:
|
||||
da = _dot._call(dy, b, False, False)
|
||||
db = _dot._call(dy, a, True, False)
|
||||
elif t_a and not t_b:
|
||||
da = _dot._call(b, dy, False, True)
|
||||
db = _dot._call(a, dy, False, False)
|
||||
elif t_a and t_b:
|
||||
da = _dot._call(b, dy, True, True)
|
||||
db = _dot._call(dy, a, True, True)
|
||||
else:
|
||||
assert False
|
||||
return da, db, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class.
|
||||
|
||||
.. code:: python
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
|
||||
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
|
@@ -1,109 +0,0 @@
|
||||
***************
|
||||
Triton vs. CUDA
|
||||
***************
|
||||
|
||||
|
||||
|
||||
The purpose of this tutorial is to explore in more depth the major differences between Triton and CUDA. To keep things simple, we will still be focusing on the following vector addition code:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
int offset[TILE] = get_program_id(0) * TILE + 0 ... TILE;
|
||||
bool check[TILE] = offset < N;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
*?(check)pz = *?(check)*px + *?(check)py;
|
||||
}
|
||||
|
||||
And its CUDA equivalent:
|
||||
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// CUDA
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float *z, float *x, float *y, int N) {
|
||||
int off = blockIdx.x * TILE + threadIdx.x;
|
||||
if(off < N){
|
||||
float *pz = z + off;
|
||||
float *px = x + off;
|
||||
float *py = y + off;
|
||||
*pz = *px + *py
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
==========================
|
||||
Automatic parallelization
|
||||
==========================
|
||||
|
||||
While the two above pieces of code may look at first sight similar, a closer look reveals one *fundamental* difference: While CUDA kernels are launched on a cooperative array of threads, **Triton kernel are single-threaded and automatically parallelized**.
|
||||
|
||||
This is a major difference in programming model, which not only makes your life much easier as a programmer, but also allows the Triton compiler to automatically do all sorts of nice optimizations:
|
||||
|
||||
- *Automatic shared memory allocation and synchronization*
|
||||
|
||||
That's right; programmers don't need to worry about shared memory allocation, usage and synchronization. Instead, the Triton compiler will use complex program analysis techniques to determine when shared memory should be used, where it should be synchronized and how threads should access it to avoid memory bank conflicts.
|
||||
|
||||
- *Automatic memory coalescing*
|
||||
|
||||
When you write Triton code, you also don't need to worry about memory coalescing. The compiler will arrange threads so that global memory accesses are coalesced when possible.
|
||||
|
||||
- *Automatic tensor core utilization*
|
||||
|
||||
Using tensor cores on Volta and Turing is notoriously difficult. Code is hard to write and even harder to optimize. Fortunately, the Triton compiler can also generate very efficient tensor core instructions (e.g., :code:`mma.sync.m8n8k4`) when low-precision matrices are multiplied together:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
half A[16, 8] = ... // initialize A
|
||||
half B[8, 16] = ... // initialize B
|
||||
float C[16, 16] = dot(A, B); // uses Tensor Cores!
|
||||
|
||||
|
||||
- *Automatic instruction predication*
|
||||
|
||||
Contrary to CUDA, Triton directly exposes predicated instruction through masked load/store instructions. This enables the Triton compiler to generate predicated instructions in PTX directly, resulting in sometimes better performance than I/O operations wrapped inside conditionals.
|
||||
|
||||
===========================
|
||||
Vector Addition - Revisited
|
||||
===========================
|
||||
|
||||
In light of these optimizations, it turns out that the GPU code generated by our Triton-C vector addition code is actually more analogous to the following:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// CUDA
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float *z, float *x, float *y, int N) {
|
||||
int off[4];
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++)
|
||||
off[k] = blockIdx.x * TILE + threadIdx.x + k * blockSize.x;
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++)
|
||||
z[off[0]] = x[off[0]] + y[off[0]]
|
||||
}
|
||||
|
||||
This code is actually more complicated when x, y and z have :code:`half` type, because then the Triton compiler automatically vectorizes data accesses using :code:`half2` to ensure memory transactions of 32-bits per thread.
|
||||
|
||||
============================
|
||||
Auto-Tuning
|
||||
============================
|
||||
|
||||
Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
kernel = triton.kernel(src, defines = {'TILE': [256, 512, 1024]}, num_warps = [2, 4, 8])
|
||||
|
||||
would benchmark our above triton source-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
|
||||
|
||||
=============================
|
||||
Going Further
|
||||
=============================
|
||||
|
||||
The benefits of Triton become more and more pronounced as compute kernels get more and more complex. In the next few tutorials, you will see how to implement transposition and tensor-core-compatible matrix multiplication routine on par with cuBLAS and CUTLASS without having to know anything about GPU micro-architecture!
|
Reference in New Issue
Block a user