Merge branch 'master' into auto-coalesce

This commit is contained in:
Philippe Tillet
2019-10-18 16:21:28 -04:00
22 changed files with 685 additions and 139 deletions

View File

@@ -29,7 +29,7 @@ python dot.py
## Tutorials
- The PyTriton API (coming soon...)
- [The PyTriton API](https://github.com/ptillet/triton/blob/master/docs/pytriton.md)
- [The Triton-C language](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
- The Triton-IR representation (coming soon...)
- The Triton-JIT compiler (coming soon...)

View File

@@ -1,13 +1,23 @@
#The PyTriton API
# The PyTriton API
## <span style="color:darkred"> Table of Contents </span>
1. [Motivations](#motivations)
2. [Triton Functions](#pytriton-function)
1. [Creation of Triton Kernels](#creation-triton-kernels)
2. [Usage of Triton Kernels](#usage-triton-kernels)
3. [Integration with Automatic Differentiation](#autodiff)
1. [Basics](#autodiff-basics)
2. [Convenience](#autodiff-convenience)
## <span style="color:darkred"> Motivations </span> <a name="motivations"></a>
In this tutorial we assume some basic knowledge of Triton-C, so check out the corresponding [tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md) if you have not already!
The purpose of PyTriton is to provide an API for integrating Triton-C kernels into PyTorch and Tensorflow. The good thing about PyTriton is that it is framework agnostic, in the sense that any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required. Consider for example the following piece of code:
The purpose of PyTriton is to provide an API for easily executing Triton-C kernels from PyTorch and Tensorflow. One of the main advantages of PyTriton is that it is framework agnostic: any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required, as will be shown in this tutorial.
Consider for example the following piece of code:
```python
import numpy as np
@@ -46,25 +56,44 @@ def run_torch():
# run_torch()
```
Here, the triton module detects which frameworks are imported when executiong a `triton.op` for the first time, and generates the appropriate framework bindings code accordingly. Specifically, when a Triton custom op is executed for the first time, the following chain of events takes place:
- The imported frameworks are detected
- The C++ code for a Tensorflow or PyTorch generic custom operation -- with the same signature as the provided Triton-C kernel -- is generated, compiled and cached
- The Tensorflow or PyTorch op is dynamically loaded using the generated .so file, and a framework-agnostic wrapper is returned
- The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
PyTriton works by detecting which frameworks are imported and automatically generating and just-in-time compiling C++ binding code for them. Specifically, the following chain of events is triggered when a Triton operation is executed:
1. The imported frameworks are detected
2. C++ binding code for Tensorflow or PyTorch is generated, compiled and cached.
3. The corresponding custom-op is automatically loaded from the generated .so file, and a framework-agnostic wrapper is created.
4. The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
## <span style="color:darkred"> Writing your own custom operation </span> <a name="custom-operation"></a>
The remainder of this tutorial will show you how to re-implement the above `triton.ops.dot` operation from scratch.
In this section we will reimplement the above `dot` function, whose full source-code can be found [here](https://github.com/ptillet/triton/blob/master/python/triton/ops/dot.py).
## <span style="color:darkred"> PyTriton Functions </span> <a name="pytriton-function"></a>
The PyTriton API provides a `triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
The first thing to do to create a custom op is to declare a class which inherits from `triton.function`.
```python
import triton
# Entry point
class _dot(triton.function):
src = """
@staticmethod
# Forward Pass
def forward(ctx, *args):
#...
@staticmethod
# Backward Pass
def backward(ctx, dy):
#...
```
### <span style="color:darkblue">Creation of Triton Kernels </span> <a name="creation-triton-kernel"></a>
PyTriton also provides a `triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton-C code derived at the end of the [previous tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
```
src = """
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
@@ -102,9 +131,16 @@ __global__ void dot(TYPE * A, TYPE * B, TYPE * C,
kernel = triton.kernel(src, ['C'])
```
Here, `src` is the exact Triton-C source-code generated at the end of the aforementioned [tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md) , and `kernel = triton.kernel(src, ['C'])` creates a triton kernel from this source code which returns the tensor whose data points to `C`. At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors: `[tensor, tensor, tensor, int, int, int, int, int, int]`.
Note that the second argument to `triton.kernel` constructors indicates which of the operands our kernel function should return. Here, we only return `C`.
However, in practice only A, B and C are provided by the user, and all the other `int` arguments are deduced from them, hence we create a helper function that extracts shapes from the `A`, `B` and `C` tensor and calls ouer `kernel`:
At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors:
```
[tensor, tensor, tensor, int, int, int, int, int, int]
```
### <span style="color:darkblue">Usage of Triton Kernels </span> <a name="usage-triton-kernels"></a>
However, in practice only A, B are provided by the user, and all the other `int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the `A` and `B` tensors, and then returns the results of a call to `kernel`:
```python
@staticmethod
@@ -150,13 +186,22 @@ However, in practice only A, B and C are provided by the user, and all the other
```
There are a few things to note here:
While this code should be mostly self-explanatory, there are a few of noteworthy things worth pointing out
- `triton.shape` provides a framework-agnostic way to retrieve the shape of a tensor
- `triton.empty` creates an empty tensor of the specified dimensions
- `grid` corresponds to the grid with which our Triton kernel will be launched. Because in our case this grid depends on parametric tile variables, it is supplied as a function of compilation options `opt`, whose compile-time definition can be retrieved using `opt.d(name)`. Here, `opt.d('TM')` and `opt.d('TN')` retrieve the first and second tile dimension our kernel was compiled with. We also provide a helper `triton.cdiv` for ceil divisions.
- `macros` provides a list of preprocessor definitions to compile the kernel with. Alternatively, these can also be supplied as named argument to the `_dot.kernel`. We recall that lists can be supplied to the preprocessor, in which case an auto-tuning procedure will be triggered. Here, the value of `TM` and `TN` are both tuned between 32, 64 and 128.
## <span style="color:darkred"> Compatibility with Automatic Differentiation</span> <a name="autodiff"></a>
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
### <span style="color:darkblue"> Basics </span> <a name="autodiff-basics"></a>
PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework using a single, common API inspired by PyTorch. It consists of two static methods `forward` and `backward` that take a context as their first input:
```
@@ -188,9 +233,11 @@ PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework
return da, db, None, None, None, None, None, None, None
```
### <span style="color:darkblue">Convenience </span> <a name="autodiff-convenience"></a>
Still like for PyTorch, a callable operation can be created using the `apply` method of our `triton.function` class. We wrap it as a module variable for convenience:
```python
dot = _dot.apply
```
And that's it! Our custom op is now created and ready to be used with both PyTorch and Tensorflow.
And that's it! Our custom op is now created and ready to be used with both PyTorch and Tensorflow.

View File

@@ -2,8 +2,6 @@
## <span style="color:darkred"> Table of Contents </span>
1. [Motivations](#motivations)
1. [Issues of C/C++ for Linear Algebra](#issues-c-c++)
2. [Design Philosophy of Triton-C](#design-philosophy)
2. [Vector Addition](#vector-addition)
1. [Differences with CUDA](#differences-with-cuda)
2. [Advantages over CUDA](#advantages-over-cuda)
@@ -12,56 +10,54 @@
3. [Auto-Tuning](#auto-tuning)
3. [Matrix Transposition](#matrix-transposition)
1. [Compute Kernel](#trans-compute-kernel)
2. [Conditional Dereferencing](#conditional-dereferencing)
2. [The __multipleof Attribute](#trans-multipleof)
3. [Conditional Dereferencing](#conditional-dereferencing)
4. [Matrix Multiplication](#matrix-multiplication)
1. [Compute Kernel](#matmul-compute-kernel)
2. [Optimizations](#optimizations)
1. [Pre-Fetching](#pre-fetching)
1. [Rematerialization](#rematerialization)
3. [Fused Transpositions](#fused-trans)
3. [Fused Transpositions and Auto-Tuning](#fused-trans-autotuning)
## <span style="color:darkred"> Motivations </span> <a name="motivations"></a>
## <span style="color:darkblue"> Issues of C/C++ for Linear Algebra </span> <a name="issues-c-c++"></a>
In C and C++, arrays and pointers have similar semantics. Indeed, there is no way to manipulate statically shaped multi-dimensional arrays (beyond initialization) as a whole without resorting to third-party libraries:
In C and C++, arrays and pointers have similar semantics. Indeed, there is no native way to manipulate statically shaped multi-dimensional arrays (beyond initialization) as a whole:
```c
// C99
float x[16][8] = {3.14};
float y[16][8] = {5.17};
// z = x + y
float z[16][8];
for(int i = 0; i < 16; i ++)
#pragma unroll
for(int i = 0; i < 16; i++)
#pragma unroll
for(int j = 0; j < 8; j++)
z[i][j] = x[i][j] + y[i][j];
```
This issue can be somewhat mitigated using templates metaprogramming in C++:
While it does not seem like a big deal at first sight, there are two issues with this:
- **Ergonomics**: Of course, it is possible to simplify the above code using functions in C
```
float z[16][8];
add(z, x, y, 16, 8);
```
but this would be semantically different as the loops can no longer be unrolled due to their bounds being now dynamic arguments of the add function. This can be mitigated using templates metaprogramming (and operator overloads) in C++:
```c
// C++
template<typename T, int M, int N>
class matrix;
matrix<float, 16, 8> x = {3.14};
matrix<float, 16, 8> y = {5.17};
matrix<float, 16, 8> z = x + y;
```
This is better, but there are still some important issues with this approach:
- The syntax could be better, especially when it comes to broadcasting and reshaping.
- Data-flow information for array operations does not propagate beyond the program's AST, thereby making it difficult for compilers to optimize moderately complicated array programs (i.e., Matrix-Multiplication). This can be worked around using heavy metaprogramming techniques (see [CUTLASS](https://github.com/NVIDIA/cutlass)), but even then programmers still have to allocate and synchronize shared memory manually and endure prohibitively long compilation procedures not easily amenable to auto-tuning.
For these reasons, most Deep-Learning frameworks still rely heavily on highly optimized subroutines (e.g., BLAS), which makes the development of novel custom primitives time-consuming for experts and challenging for others. This is where Triton comes into play.
## <span style="color:darkblue"> Design Philosophy of Triton-C </span> <a name="design-philosophy"></a>
The purpose of Triton is to bring native support for efficient numerical multi-dimensional array operations into a standard procedural languages. We achieve this through:
* **Triton-C**: Syntactic and semantical extensions to the C language. In particular, native support for reshaping, broadcasting, matrix-multiplication, transposition, etc. This is the object of this tutorial.
* **Triton-IR**: An LLVM-like IR for array operations, as well as various (automatic memory coalescing, automatic vectorization, shared memory allocation/synchronization, tensor core instruction selection, etc.). Although our system generates Triton-IR programs from Triton-C source-code, this is beyond the scope of this tutorial. More information can be found [here](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf).
Anyway, the Triton-C code corresponding to the above matrix addition operation can be written and extended as follows:
While this is better and now equivalent to our initial code snippet, the syntax is not quite as ergonomically satisfying as what native syntactic support could provide:
```c
// Triton-C
float x[16, 8] = 3.14;
float y[16, 8] = 5.17;
// float z[8, 8] = x + y; // doesn't compile -- incompatible shapes!
@@ -69,10 +65,29 @@ float z[16, 8] = x + y;
float u[16] = z[:, +]; // sum along the second axis
float v[16, 32] = u[:, newaxis]; // broadcasting along the second axis
```
which is valid _Triton-C_.
- **Portability**: One other issue with our initial C program is that it is not portable. While it will run well on a single CPU thread, the operation `z = x + y` would underutilize a GPU Streaming Processor as it would execute on a single thread only. For this reason, it would have to be rewritten in CUDA as follows:
```
// CUDA
// Launch on a block of 16 x 8 threads
float x = 3.14;
float y = 5.17;
float z = x + y
```
In Triton-C, the same code can be used across many different platforms (only CPUs and GPUs are supported at the moment). Furthermore, Triton-C is single-threaded, hence easier to write than CUDA.
- **Performance**: Another issue with our initial C code snippet is its performance. Although the loops are unrolled, the program does not carry any data-flow information pertaining to array operations. This issue gets more and more problematic as programs get increasingly complex, eventually culminating in matrix multiplication being remarkably hard to optimize.
This can be worked around using heavy metaprogramming techniques (see [CUTLASS](https://github.com/NVIDIA/cutlass)), but even then programmers still have to allocate and synchronize shared memory manually and endure prohibitively long compilation procedures not easily amenable to auto-tuning. For these reasons, most Deep-Learning frameworks still rely heavily on highly optimized subroutines (e.g., BLAS), which makes the development of novel custom primitives time-consuming for experts and almost impossible for others.
Triton addresses this issue by relying on **Triton-IR**, an LLVM-like IR for array operations, and **Triton-JIT**, an optimizing compiler for Triton-IR. These two systems are, however, beyond the scope of this tutorial. More information can be found [here](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf).
_Note: You might be thinking that this is exactly what [MLIR](https://github.com/tensorflow/mlir) was made for... and you're right! You can conceptually think of Triton-IR as a dialect for MLIR, and Triton-C as a frontend for it. I would like to integrate Triton-IR into MLIR in the future; If you're interested in making this a thing, let me know._
Of course, we can do much more than additions, reduction and broadcasting. The purpose of this tutorial is to walk you through all the features of Triton-C, and eventually show you how it can be used to build auto-tuned matrix-multiplication kernels on par with state-of-the-art CUDA-C implementation in less than an afternoon.
_Note: You might be thinking that this is exactly what [MLIR](https://github.com/tensorflow/mlir) was made for... and you're right! You can think of Triton-IR as a dialect for MLIR, and Triton-C as a frontend for it. If you're interested in making this a thing, let me know._
## <span style="color:darkred"> Vector Addition </span> <a name="vector-addition"></a>
@@ -81,16 +96,18 @@ _Note: You might be thinking that this is exactly what [MLIR](https://github.com
Let's start it off by looking at a simple example. Vector addition, in its most trivial Triton-C implementation, can be written as follows:
```c
// Triton-C
// launched on a grid of (N / 32) programs of 1 thread each
__global__ void add(int N, float *a, float *b, float* c) {
int id = get_program_id(0);
int off[32] = id * 32 + (0 ... 32)
*(c + off) = *(a + off) + *(b + off)
*(c + off) = *(a + off) + *(b + off);
}
```
For reference, here is an equivalent CUDA kernel (nvcc will generate the same PTX code as triton-jit on the above code):
For reference, here is an equivalent CUDA kernel (NVCC will generate the same PTX code as Triton-JIT on the above code):
```c
// CUDA
// launched on a grid of (N / 32) programs of 32 threads each
__global__ void add(int N, float *a, float *b, float *c) {
int off = blockIdx.x * 32 + threadIdx.x;
@@ -98,10 +115,10 @@ __global__ void add(int N, float *a, float *b, float *c) {
}
```
As you can see, there are three main differences between our Triton-C kernel and the equivalent CUDA-C:
As you can see, there are three main differences between our Triton-C kernel and the equivalent CUDA:
- **The programming model is different**.
While Triton-C and CUDA-C both use a Single-Program, Multiple-Data (SPMD) programming model, each Triton-C kernel is single-threaded.
While Triton-C and CUDA both use a Single-Program, Multiple-Data (SPMD) programming model, each Triton-C kernel is single-threaded.
Therefore, `get_program_id({0, 1, 2})` is equivalent to `blockIdx.{x, y, z}`, but there is no such thing as `blockDim` and `threadIdx`.
- **The semantics of arrays is different**
@@ -113,7 +130,7 @@ float* c_ptr[32] = c_broadcast + off; // c_ptr = c + off
```
- **The semantics of the subscript operator is different**.
n C/CUDA-C, subscripting can be used to offset and dereference a pointer, but in Triton-C it can only be used to index and broadcast an array (think NumPy).
n C/CUDA, subscripting can be used to offset and dereference a pointer, but in Triton-C it can only be used to index and broadcast an array (think NumPy).
### <span style="color:darkblue"> Advantages over CUDA </span> <a name="advantages-over-cuda"></a>
@@ -123,6 +140,7 @@ At this point, the advantages of Triton-C over CUDA may not be obvious. But they
On some hardware architectures, vectorizing load/store operations can lead to better memory utilization and, in turn, noticeable performance gains. In general, 128-bit memory transactions are favored, leading to the following CUDA kernel:
```c
// CUDA
// launched on a grid of (N / 128) programs of 32 threads each
__global__ void add(int N, float4 *a, float4 *b, float4 *c) {
int off = blockIdx.x * 32 + threadIdx.x;
@@ -131,6 +149,7 @@ __global__ void add(int N, float4 *a, float4 *b, float4 *c) {
```
Or, for half-precision inputs:
```c
// CUDA
// launched on a grid of (N / 256) programs of 32 threads each
__global__ void add(int N, half8 *a, half8 *b, half8 *c) {
int off = blockIdx.x * 32 + threadIdx.x;
@@ -146,11 +165,12 @@ In Triton-C, this is not a problem as the compiler will figure out automatically
Specifically, the Triton compiler would refuse to 4-way vectorize our above compute kernel because it would require the array `int off[32]` to be distributed over 8 threads, which is less than a warp. Fortunately, it turns out that this problem can be easily solved using preprocessor directrives to _parameterize_ our kernel:
```c
// Triton-C
// launched on a grid of (N / SIZE) programs of 1 thread each
__global__ void add(int N, TYPE* a, TYPE* b, TYPE* c) {
int id = get_program_id(0);
int off[SIZE] = id * SIZE + (0 ... SIZE)
*(c + off) = *(a + off) + *(b + off)
int off[SIZE] = id * SIZE + (0 ... SIZE);
*(c + off) = *(a + off) + *(b + off);
}
// Not vectorized when compiled with -DSIZE=32 -DTYPE=float
// 4-Vectorized when compiled with -DSIZE=128 -DTYPE=float
@@ -169,11 +189,11 @@ _Note: Tuning our reference CUDA kernel would be much more cumbersome, as templa
## <span style="color:darkred"> Matrix Transposition </span> <a name="matrix-transposition"></a>
Transpositions are (relatively) hard to efficiently write in CUDA because a naive implementation would lead to _uncoalesced_ memory operations when writing back the transposed matrix to DRAM. Therefore, optimized CUDA implementations require the explicit use of shared memory, as shown [here](https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/).
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from _uncoalesced_ memory operations when writing back the transposed matrix to DRAM. Of course, this can be fixed by using shared memory as shown [here](https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/), but this comes at the cost of simplicity and -- more importantly -- interferes with auto-tuning.
### <span style="color:darkblue"> Compute Kernel </span> <a name="trans-compute-kernel"></a>
In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory. Therefore, an optimal Triton kernel for this operation would look like:
In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory in order to enable shared memory stores/loads. Therefore, an optimal Triton kernel for this operation would look like:
```c
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
@@ -192,20 +212,22 @@ __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
}
```
This kernel loads a `TM x TN` tile from the input matrix `X`, transposes it and write the resulting `TN x TM` tile to the output matrix `Y`. As a result, transposition of the full input matrix is achieved by launching a grid of `(M / TM) x (N / TN)` programs decomposed as follows:
At a high level, this kernel loads a `TM x TN` tile from the input matrix `X`, transposes it and writes the resulting `TN x TM` tile to the output matrix `Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of `(M / TM) x (N / TN)` programs decomposed as follows:
- Statements (1) and (2) extract the location of the program in the grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` will hold the values:
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
```
pidm = 2
pidn = 1
```
- Statements (3) and (4) construct the ranges of indices to read from the first and second axis of X:
- Statements (3) and (4) construct the ranges of indices:
```
rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)]
rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
```
which will be used in statements (5) and (6) to construct tiles of pointers
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
```
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
@@ -222,10 +244,22 @@ rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
```
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
### <span style="color:darkblue"> The __multipleof Attribute </span> <a name="trans-multipleof"></a>
The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible.
For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows:
```c
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx __multipleof(8), int ldy __multipleof(8)) {
// ...
}
```
### <span style="color:darkblue"> Conditional Dereferencing </span> <a name="conditional-dereferencing"></a>
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
```
```c
// launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
// ...
@@ -237,21 +271,22 @@ __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
}
```
Here, statements (7a) creates an array of booleans `checkx[TM, TN]` such that `checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Then, both `px` and `py` can be conditionally dereferenced using Triton-C's conditional dereferencing operator `*?(predicate) pointer`.
Here, statements (7a) creates an array of booleans `checkx[TM, TN]` such that `checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator `*?(predicate) pointer`.
## <span style="color:darkred"> Matrix Multiplication </span> <a name="matrix-multiplication"></a>
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA-C kernels (see [CUTLASS](https://github.com/NVIDIA/cutlass)). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see [CUTLASS](https://github.com/NVIDIA/cutlass)). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
_Note: Bounds-checking is ommitted for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops._
_Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops._
### <span style="color:darkblue"> Compute Kernel </span> <a name="matmul-compute-kernel"></a>
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
```c
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
@@ -306,7 +341,7 @@ Internally, the Triton compiler will perform quite a few optimizations that will
### <span style="color:darkblue"> Optimizations </span> <a name="optimizations"></a>
Nonetheless, there are two important optimizations that the Triton compiler does not do at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
#### <span style="color:purple"> Pre-Fetching </span> <a name="pre-fetching"></a>
@@ -342,11 +377,13 @@ TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
```
### <span style="color:darkblue"> Fused Transpositions </span> <a name="fused-trans"></a>
### <span style="color:darkblue"> Fused Transpositions and Auto-Tuning </span> <a name="fused-trans-autotuning"></a>
It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. This is also what is done in the [PyTriton](https://github.com/ptillet/triton/blob/master/python/triton/ops/dot.py) implementation of matrix-multiplication. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
```c
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
@@ -391,4 +428,9 @@ All matrix multiplications variants can then be retrieved using the following co
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb -DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
```
Auto-tuning can also be handled using pre-processor macros:
```c
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
```

View File

@@ -137,7 +137,7 @@ public:
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");

View File

@@ -667,21 +667,30 @@ public:
};
class reduce_inst: public builtin_inst {
private:
static type* get_res_type(value *arg, unsigned axis);
public:
enum op_t{
ADD, SUB, MAX, MIN,
FADD, FSUB, FMAX, FMIN
};
private:
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
static type* get_res_type(value *arg, unsigned axis);
static std::string to_str(op_t op);
private:
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; }
_TRITON_DEFINE_CLONE(reduce_inst)
_TRITON_DEFINE_ACCEPT(reduce_inst)
public:
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
private:
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {

View File

@@ -418,22 +418,25 @@ class UnaryOp : public Expr {
friend class LValAssigner;
public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr);
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
::Type *Convert();
static int encodeRed(int ax, int tag);
static void decodeRed(int info, int& ax, int& tag);
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void TransOpTypeChecking();
void UnaryArithmOpTypeChecking();
void CastOpTypeChecking();
protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr)
: Expr(operand->Tok(), type), op_(op) {
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
: Expr(operand->Tok(), type), op_(op), info_(info) {
operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand);
@@ -441,6 +444,7 @@ protected:
}
int op_;
int info_;
Expr* operand_;
};

View File

@@ -131,6 +131,8 @@ public:
// TILE ARITHMETICS BEGIN
NEWAXIS,
MAX,
MIN,
// TILE ARITHMETICS END
ALIGNAS, // _Alignas
@@ -142,7 +144,7 @@ public:
STATIC,
THREAD, // _Thread_local
AUTO,
REGISTER,
GLOBAL,
// STORAGE CLASS SPECIFIER END
BREAK,
@@ -180,6 +182,7 @@ public:
PLUS,
MINUS,
CAST,
REDUCE,
// For preprocessor
PP_IF,
@@ -236,7 +239,7 @@ public:
bool IsIdentifier() const { return IDENTIFIER == tag_; }
bool IsEOF() const { return tag_ == Token::END; }
bool IsTypeSpecQual() const { return CONST <= tag_ && tag_ <= ENUM; }
bool IsDecl() const { return CONST <= tag_ && tag_ <= REGISTER; }
bool IsDecl() const { return CONST <= tag_ && tag_ <= GLOBAL; }
static const char* Lexeme(int tag) {
auto iter = tagLexemeMap_.find(tag);
if (iter == tagLexemeMap_.end())

View File

@@ -40,7 +40,7 @@ enum {
S_STATIC = 0x04,
S_THREAD = 0x08,
S_AUTO = 0x10,
S_REGISTER = 0x20,
S_GLOBAL = 0x20,
// Type specifier
T_SIGNED = 0x40,

View File

@@ -59,7 +59,7 @@ public:
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<size_t> num_warps;
std::vector<int> num_warps;
};
struct options_t {

View File

@@ -3,6 +3,7 @@
#include "triton/codegen/selection/machine_layout.h"
#include "triton/codegen/selection/machine_value.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h"
@@ -748,8 +749,97 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) {
});
}
void generator::visit_reduce_inst(ir::reduce_inst*) {
void generator::visit_reduce_inst(ir::reduce_inst* x) {
throw std::runtime_error("not implemented");
// std::map<indices_t, Value*> partial;
// ir::value *arg = x->get_operand(0);
// distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
// ir::reduce_inst::op_t op = x->get_op();
// auto accumulate = [&](Value* x, Value *y) -> Value* {
// switch(op) {
// case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y);
// case ir::reduce_inst::SUB: return builder_->CreateSub(x, y);
// case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y);
// case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y);
// case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
// case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
// case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y);
// case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y);
// default: break;
// }
// assert(false);
// return nullptr;
// };
// unsigned axis = x->get_axis();
// // reduce within thread
// arg_tile->for_each([&](indices_t idx) {
// indices_t pidx = idx;
// pidx[axis] = builder_->getInt32(0);
// Value *current = arg_tile->get_value(idx);
// // current partial result is not initialized -- create
// if(partial.find(pidx) == partial.end())
// partial[pidx] = current;
// // current partial result is initialized -- accumulate
// else
// partial[pidx] = accumulate(partial[pidx], current);
// });
// // depth
// unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis];
// unsigned per_thread = arg_tile->axis(axis).values.size();
// unsigned depth = shape_ax / per_thread;
// // shapes
// auto shared_shapes = arg_tile->get_shapes();
// shared_shapes[axis] = depth;
// // reduce within blocks
// unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
// Type *res_ty = builder_->getFloatTy();
// Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
// for(auto& x: partial) {
// // current element being computed
// Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
// Value *&result = x.second;
// indices_t write_idx = x.first;
// write_idx[axis] = lane;
// // shared memory write pointer
// Value *write_offset = shared_tile::shared_offset(*builder_, shared_shapes, write_idx);
// Value *write_ptr = builder_->CreateGEP(base_ptr, write_offset);
// // initialize shared memory
// tgt_->add_barrier(*mod_, *builder_);
// builder_->CreateStore(result, write_ptr);
// // build result
// for(unsigned i = depth/2; i > 0; i >>= 1){
// // current indices
// indices_t current(write_idx.size(), builder_->getInt32(0));
// current[axis] = builder_->getInt32(i);
// // shared memory offset
// Value *read_offset = shared_tile::shared_offset(*builder_, shared_shapes, current);
// Value *is_active = builder_->CreateICmpULT(lane, builder_->getInt32(i));
// read_offset = builder_->CreateSelect(is_active, read_offset, builder_->getInt32(0));
// // shared memory read pointer
// Value *read_ptr = builder_->CreateGEP(write_ptr, read_offset);
// tgt_->add_barrier(*mod_, *builder_);
// Value *next = builder_->CreateLoad(read_ptr);
// // accumulate
// result = accumulate(result, next);
// // write back
// builder_->CreateStore(result, write_ptr);
// }
// }
// tgt_->add_barrier(*mod_, *builder_);
// distributed_tile* x_tile = (distributed_tile*)tmap_.at(x);
// x_tile->for_each([&](indices_t idx) {
// indices_t red_idx = idx;
// red_idx.insert(red_idx.begin() + axis, builder_->getInt32(0));
// Value *read_offset = shared_tile::shared_offset(*builder_, shared_shapes, red_idx);
// Value *read_ptr = builder_->CreateGEP(base_ptr, read_offset);
// x_tile->set_value(idx, builder_->CreateLoad(read_ptr));
// });
}
void generator::visit_select_inst(ir::select_inst* select) {

View File

@@ -330,8 +330,8 @@ value *builder::create_sqrt(value *A, const std::string &name) {
return insert(sqrt_inst::create(A, name));
}
value *builder::create_reduce(value *A, unsigned axis, const std::string &name) {
return insert(reduce_inst::create(A, axis, name));
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) {
return insert(reduce_inst::create(A, op, axis, name));
}
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){

View File

@@ -630,6 +630,23 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
//===----------------------------------------------------------------------===//
// reduce instructions
//===----------------------------------------------------------------------===//
std::string reduce_inst::to_str(op_t op) {
switch (op) {
case ADD: return "+";
case SUB: return "-";
case MAX: return "imax";
case MIN: return "imin";
case FADD: return "+";
case FSUB: return "-";
case FMAX: return "fmax";
case FMIN: return "fmin";
default: break;
}
assert(false);
return "";
}
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
@@ -640,14 +657,15 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) {
return tile_type::get(scalar_ty, shapes);
}
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next),
op_(op),
axis_(axis){
set_operand(0, arg);
}
instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) {
return new reduce_inst(arg, axis, name, next);
instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
return new reduce_inst(arg, op, axis, name, next);
}

View File

@@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() {
}
void BinaryOp::MaskedDerefOpTypeChecking() {
// auto lhsTileType = lhs_->Type()->ToTile();
// auto rhsTileType = rhs_->Type()->ToTile();
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsType = lhsScalType->ToArithm();
@@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() {
* Unary Operators
*/
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type);
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
ret->pool_ = &unaryOpPool;
ret->TypeChecking();
@@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
}
int UnaryOp::encodeRed(int ax, int tag) {
int result = 0;
result |= ax;
result |= tag << 16;
return result;
}
void UnaryOp::decodeRed(int info, int& ax, int& tag) {
ax = info & 0x0000FFFF;
tag = (info & 0xFFFF0000) >> 16;
}
bool UnaryOp::IsLVal() {
// Only deref('*') could be lvalue;
return op_ == Token::DEREF;
@@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() {
case '^':
return TransOpTypeChecking();
case Token::REDUCE:
return ReduceOpTypeChecking();
default:
assert(false);
}
@@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() {
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
}
void UnaryOp::ReduceOpTypeChecking() {
int ax, tag;
decodeRed(info_, ax, tag);
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "array expected for reduction operation");
auto shape = tileType->Shape();
shape.erase(shape.begin() + ax);
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile();

View File

@@ -154,12 +154,24 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
error_not_implemented();
}
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
using ir::reduce_inst;
switch(tag){
case Token::ADD: return is_float ? reduce_inst::FADD : reduce_inst::ADD;
case Token::SUB: return is_float ? reduce_inst::FSUB : reduce_inst::SUB;
case Token::MAX: return is_float ? reduce_inst::FMAX : reduce_inst::MAX;
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
default: break;
}
should_not_happen();
return reduce_inst::op_t();
}
void Generator::VisitUnaryOp(UnaryOp* unary) {
// recursion
Visit(unary->operand_);
ir::value* op = ret_;
ir::value* arg = ret_;
ir::type *arg_ty = arg->get_type();
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
// return
switch (unary->op_) {
case Token::PREFIX_INC: return error_not_implemented();
@@ -167,13 +179,20 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
case Token::POSTFIX_INC: return error_not_implemented();
case Token::POSTFIX_DEC: return error_not_implemented();
case Token::ADDR: return error_not_implemented();
case Token::DEREF: return set_ret(bld_->create_load(op));
case Token::DEREF: return set_ret(bld_->create_load(arg));
case Token::PLUS: return error_not_implemented();
case Token::MINUS: return error_not_implemented();
case '~': return set_ret(bld_->create_neg(op));
case '!': return set_ret(bld_->create_not(op));
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
case '^': return set_ret(bld_->create_trans(op));
case '~': return set_ret(bld_->create_neg(arg));
case '!': return set_ret(bld_->create_not(arg));
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case '^': return set_ret(bld_->create_trans(arg));
case Token::REDUCE: {
int ax, tag;
UnaryOp::decodeRed(unary->info_, ax, tag);
bool is_float = arg_scal_ty->is_floating_point_ty();
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
return set_ret(bld_->create_reduce(arg, op, ax));
}
default: error_not_implemented();
}
return error_not_implemented();
@@ -412,16 +431,41 @@ void Generator::Gen(ir::module *mod) {
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
if(src->get_type() == dst_ty)
return src;
if(dst_ty->is_tile_ty()) {
ir::type *src_ty = src->get_type();
auto dst_shapes = dst_ty->get_tile_shapes();
if(!src_ty->is_tile_ty())
return bld_->create_splat(src, dst_shapes);
auto src_shapes = src_ty->get_tile_shapes();
if(src_shapes.size() != dst_shapes.size())
return bld_->create_reshape(src, dst_shapes);
else
if(src_shapes.size() != dst_shapes.size()){
unsigned src_numel = 1;
for(unsigned s: src_shapes)
src_numel *= s;
unsigned dst_numel = 1;
for(unsigned s: dst_shapes)
dst_numel *= s;
if(src_numel == dst_numel)
return bld_->create_reshape(src, dst_shapes);
else {
auto padded_shapes = src_shapes;
while(padded_shapes.size() != dst_shapes.size())
padded_shapes.insert(padded_shapes.begin(), 1);
// check that broadcast is legal
for(size_t d = 0; d < padded_shapes.size(); d++){
if(dst_shapes[d] != padded_shapes[d] &&
padded_shapes[d] != 1)
should_not_happen();
}
// pad and broadcast
ir::value *padded = bld_->create_reshape(src, padded_shapes);
return bld_->create_broadcast(padded, dst_shapes);
}
}
else{
return bld_->create_broadcast(src, dst_shapes);
}
}
return src;
}

View File

@@ -453,21 +453,52 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
TileType::ShapeInt shape;
size_t i = 0;
const Token* tok;
std::vector<std::pair<int, int>> redInfo;
do {
tok = ts_.Next();
if(tok->tag_ == ':')
shape.push_back(lhsShape[i++]);
else if(tok->tag_ == Token::NEWAXIS)
shape.push_back(1);
else
Error(tok, "only ':' and newaxis are supported in subscripts");
switch(tok->tag_) {
case ':':
shape.push_back(lhsShape[i++]);
break;
case Token::NEWAXIS:
shape.push_back(1);
break;
case Token::ADD:
case Token::SUB:
case Token::MAX:
case Token::MIN:{
int info = UnaryOp::encodeRed(i, tok->tag_);
redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
}
default:
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
break;
}
}while(ts_.Try(','));
ts_.Expect(']');
if(lhsShape.size() > i)
Error(tok, "broadcasting not using all operand axes");
// create ret tile
TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType);
Expr* res = lhs;
for(auto r: redInfo){
shape.erase(shape.begin() + r.first);
Type *retType;
if(shape.empty())
retType = lhsQual.GetPtr();
else
retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::REDUCE, res, retType, r.second);
}
if(!shape.empty()){
TileType *retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::CAST, res, retType);
}
return res;
}
@@ -1000,6 +1031,10 @@ QualType Parser::ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec)
EnsureAndSetStorageSpec(tok, storageSpec, S_EXTERN);
break;
case Token::GLOBAL:
EnsureAndSetStorageSpec(tok, storageSpec, S_GLOBAL);
break;
case Token::STATIC:
if (!storageSpec)
Error(tok, ERR_FUNC_SPEC);
@@ -1020,10 +1055,6 @@ QualType Parser::ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec)
EnsureAndSetStorageSpec(tok, storageSpec, S_AUTO);
break;
case Token::REGISTER:
EnsureAndSetStorageSpec(tok, storageSpec, S_REGISTER);
break;
// Type qualifier
case Token::CONST: qualSpec |= Qualifier::CONST; break;
case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break;

View File

@@ -7,6 +7,7 @@
static MemPoolImp<Token> tokenPool;
const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "__global__", Token::GLOBAL },
{ "auto", Token::AUTO },
{ "break", Token::BREAK },
{ "case", Token::CASE },
@@ -30,7 +31,6 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "newaxis", Token::NEWAXIS },
{ "signed", Token::SIGNED },
{ "unsigned", Token::UNSIGNED },
{ "register", Token::REGISTER },
{ "restrict", Token::RESTRICT },
{ "return", Token::RETURN },
{ "short", Token::SHORT },
@@ -54,6 +54,8 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "_Noreturn", Token::NORETURN },
{ "_Static_assert", Token::STATIC_ASSERT },
{ "_Thread_local", Token::THREAD },
{ "max", Token::MAX },
{ "min", Token::MIN },
};
const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
@@ -121,6 +123,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
{ Token::EXTERN, "extern" },
{ Token::FLOAT, "float" },
{ Token::FOR, "for" },
{ Token::GLOBAL, "global" },
{ Token::GOTO, "goto" },
{ Token::IF, "if" },
{ Token::INLINE, "inline" },
@@ -129,7 +132,6 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
{ Token::NEWAXIS, "newaxis" },
{ Token::SIGNED, "signed" },
{ Token::UNSIGNED, "unsigned" },
{ Token::REGISTER, "register" },
{ Token::RESTRICT, "restrict" },
{ Token::RETURN, "return" },
{ Token::SHORT, "short" },

View File

@@ -175,11 +175,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
auto ir = make_ir(parser);
// binary code-gen
std::unique_ptr<driver::module> bin;
try{
bin = make_bin(*ir, stream->context(), opt);
}catch(const std::runtime_error& e) {
return;
}
bin = make_bin(*ir, stream->context(), opt);
// kernel uses too much resources
if(!bin)
return;

View File

@@ -66,7 +66,7 @@ template<> struct to_string<double>{
template<class T>
bool triton_dot(drv::stream* stream, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
int32_t TM, int32_t TN, int32_t TK, size_t nwarp,
int32_t TM, int32_t TN, int32_t TK, int32_t nwarp,
const std::vector<int>& a_order, const std::vector<int>& b_order,
run_mode_t mode, std::vector<double>& bench, bool &test){
std::string ty = to_string<T>::value;

27
tests/common/src/reduce.h Normal file
View File

@@ -0,0 +1,27 @@
namespace src {
const char *reduce1d =
R"(
void reduce1d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __readonly __aligned(16),
int N) {
}
)";
const char *reduce2d =
R"(
void reduce2d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int M, int N, int ldx) {
int ridm = get_program_id(0);
int ridn = get_program_id(1);
int rm[TM] = ridm * TM + 0 ... TM;
int rn[TN] = ridn * TN + 0 ... TN;
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
TYPE* py[TY] = Y + RY;
*py = (*px)[RED];
}
)";
}

View File

@@ -10,9 +10,13 @@
namespace drv = triton::driver;
namespace rt = triton::runtime;
/* ------------------------
* Launch Grid
* ------------------------ */
inline size_t ceil(size_t x, size_t y) {
return (x + y - 1) / y;
};
}
inline rt::function::grid_fn_ty grid1d(size_t N) {
return [N](const rt::function::options_t& x) {
@@ -51,12 +55,120 @@ inline std::vector<std::vector<std::string>> tile_nd(size_t rank) {
return {};
}
/* ------------------------
* Tensor Initialization
* ------------------------ */
template<class T>
void init_rand(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = static_cast<T>((double)rand()/RAND_MAX);
}
template<class T>
void init_zeros(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = 0;
}
/* ------------------------
* Loop Nests
* ------------------------ */
void _loop_nest(std::vector<int> const & ranges,
std::function<void(std::vector<int> const &)> const & f){
int D = ranges.size();
std::vector<int> values(D, 0);
// Start with innermost loop
int i = D - 1;
while(true){
// Execute function
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}
/* -----------------------
* TENSOR INDEXING
* ----------------------- */
enum order_t {
ROWMAJOR,
COLMAJOR
};
int offset(const std::vector<int>& idx, const std::vector<int>& shapes) {
int result = idx[0];
for(int i = 1; i < idx.size(); i++)
result += idx[i]*shapes[i-1];
return result;
}
/* -----------------------
* REDUCTION HELPERS
* ----------------------- */
enum reduce_op_t {
ADD,
MAX,
MIN
};
std::string to_str(reduce_op_t op) {
switch (op) {
case ADD: return "+";
case MAX: return "max";
case MIN: return "min";
default: break;
}
assert(false);
return "";
}
template<class T>
std::function<T(T,T)> get_accumulator(reduce_op_t op) {
switch (op) {
case ADD: return [](T x, T y) { return x + y; };
case MAX: return [](T x, T y) { return std::max(x, y); };
case MIN: return [](T x, T y) { return std::min(x, y); };
default: break;
}
assert(false);
return std::function<T(T,T)>();
}
/* -----------------------
* TENSOR COMPARISON
* ----------------------- */
namespace testing {
template<class T>
bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
if(hc.size() != rc.size())
return false;
for(size_t i = 0; i < hc.size(); i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
return false;
}
return true;
}
}
/* -----------------------
* PRETTY PRINTING
* ----------------------- */
namespace aux{
template<std::size_t...> struct seq{};
@@ -83,34 +195,22 @@ auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
}
template<class Ch, class Tr, class T>
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::vector<T> const& t)
-> std::basic_ostream<Ch, Tr>&
{
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, const std::vector<T>& vec) {
os << "{";
for(size_t i = 0; i < t.size(); i++) {
for(size_t i = 0; i < vec.size(); i++){
if(i > 0)
os << ", ";
os << t[i];
os << vec[i];
}
return os << "}";
os << "}";
return os;
}
template<class Ch, class Tr>
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, reduce_op_t op) {
return os << to_str(op);
}
namespace testing {
template<class T>
bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
if(hc.size() != rc.size())
return false;
for(size_t i = 0; i < hc.size(); i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
return false;
}
return true;
}
}
#endif

View File

@@ -1,4 +1,4 @@
foreach(PROG dot copy1d copy2d copy3d)
foreach(PROG dot copy1d copy2d copy3d reduce)
set(TARGET unit_${PROG})
add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})

106
tests/unit/reduce.cc Normal file
View File

@@ -0,0 +1,106 @@
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include <functional>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/reduce.h"
#include "cuda/cublas.h"
#include "util.h"
namespace drv = triton::driver;
namespace rt = triton::runtime;
template<class T>
void reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
assert(axis <= shapes.size() - 1);
// remove shape at index axis to get outer dimensions
std::vector<int> outer = shapes;
outer.erase(outer.begin() + axis);
// retrieve shape at index axis to get inner dimension
int inner = shapes[axis];
// accumualtion function
auto acc = get_accumulator<T>(op);
// iterate over outer dimensions
_loop_nest(outer, [&](const std::vector<int>& y_idx) {
T ret = 0;
auto x_idx = y_idx;
x_idx.insert(x_idx.begin() + axis, 0);
// accumulate over inner dimensions
for(int z = 0; z < inner; z++){
x_idx[axis] = z;
ret = acc(ret, x[offset(x_idx, shapes)]);
}
y[offset(y_idx, outer)] = ret;
});
}
bool do_test(drv::stream* stream, std::vector<int> shape, int axis, reduce_op_t op, int nwarp){
typedef float NumericT;
std::string ty = "float";
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
size_t axy = (axis == 0) ? 1 : 0;
std::string RY = (axis == 0) ? "rn" : "rm";
std::vector<NumericT> hy(shape[axy]);
std::vector<NumericT> ry(shape[axy]);
std::vector<NumericT> hx(shape[0]*shape[1]);
srand(0);
init_zeros(hy);
init_rand(hx);
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
stream->write(&*dy, true, 0, hy);
stream->write(&*dx, true, 0, hx);
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"TM", {std::to_string(shape[0])}});
opt.defines.push_back({"TN", {std::to_string(shape[1])}});
opt.defines.push_back({"TY", {std::to_string(shape[axy])}});
opt.defines.push_back({"RY", {RY}});
std::string RED = "";
for(int n = 0; n < 2; n++){
if(n > 0)
RED += ", ";
RED += (n==axis) ? to_str(op) : ":";
}
opt.defines.push_back({"RED", {RED}});
opt.num_warps = {nwarp};
rt::function function(src::reduce2d, opt);
function({&*dx, &*dy, shape[0], shape[1], shape[0]}, grid2d(shape[0], shape[1]), stream);
stream->synchronize();
stream->read(&*dy, true, 0, hy);
reduce_nd(ry, hx, op, axis, shape);
return testing::diff(hy, ry);
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t;
std::vector<config_t> configs = {
config_t{{32, 32}, 0, MAX},
config_t{{32, 32}, 1, ADD},
config_t{{32, 64}, 0, ADD},
config_t{{64, 32}, 1, ADD}
};
// does the work
int axis;
std::vector<int> shape;
reduce_op_t op;
for(const auto& c: configs){
std::tie(shape, axis, op) = c;
std::cout << "Testing " << c << " ... " << std::flush;
if(do_test(stream, shape, axis, op, 1))
std::cout << " Pass! " << std::endl;
else
std::cout << " Fail! " << std::endl;
}
}