[documentation] added pytriton tutorial
This commit is contained in:
@@ -29,8 +29,8 @@ python dot.py
|
||||
|
||||
## Tutorials
|
||||
|
||||
- The PyTriton API (coming soon...)
|
||||
- [The Triton-C language](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
|
||||
- [The PyTriton API](https://github.com/ptillet/triton/blob/master/docs/pytriton.md)
|
||||
- The Triton-IR representation (coming soon...)
|
||||
- The Triton-JIT compiler (coming soon...)
|
||||
|
||||
|
@@ -1,13 +1,24 @@
|
||||
#The PyTriton API
|
||||
|
||||
|
||||
## <span style="color:darkred"> Table of Contents </span>
|
||||
|
||||
This tutorial is the continuation of the [Triton-C tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md), so check it out if you have not already!
|
||||
|
||||
1. [Motivations](#motivations)
|
||||
2. [Triton Functions](#pytriton-function)
|
||||
1. [Creation of Triton Kernels](#creation-triton-kernels)
|
||||
2. [Usage of Triton Kernels](#usage-triton-kernels)
|
||||
3. [Integration with Automatic Differentiation](#autodiff)
|
||||
1. [Basics](#autodiff:basics)
|
||||
2. [Convenience](#autodiff:convenience)
|
||||
|
||||
|
||||
## <span style="color:darkred"> Motivations </span> <a name="motivations"></a>
|
||||
|
||||
In this tutorial we assume some basic knowledge of Triton-C, so check out the corresponding [tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md) if you have not already!
|
||||
|
||||
The purpose of PyTriton is to provide an API for integrating Triton-C kernels into PyTorch and Tensorflow. The good thing about PyTriton is that it is framework agnostic, in the sense that any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required. Consider for example the following piece of code:
|
||||
The purpose of PyTriton is to provide an API for easily executing Triton-C kernels from PyTorch and Tensorflow. One of the main advantages of PyTriton is that it is framework agnostic: any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required, as will be shown in this tutorial.
|
||||
|
||||
Consider for example the following piece of code:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
@@ -46,25 +57,34 @@ def run_torch():
|
||||
# run_torch()
|
||||
```
|
||||
|
||||
Here, the triton module detects which frameworks are imported when executiong a `triton.op` for the first time, and generates the appropriate framework bindings code accordingly. Specifically, when a Triton custom op is executed for the first time, the following chain of events takes place:
|
||||
- The imported frameworks are detected
|
||||
- The C++ code for a Tensorflow or PyTorch generic custom operation -- with the same signature as the provided Triton-C kernel -- is generated, compiled and cached
|
||||
- The Tensorflow or PyTorch op is dynamically loaded using the generated .so file, and a framework-agnostic wrapper is returned
|
||||
- The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
|
||||
PyTriton works by detecting which frameworks are imported and automatically generating and just-in-time compiling C++ binding code for them. Specifically, the following chain of events is triggered when a Triton operation is executed:
|
||||
|
||||
1. The imported frameworks are detected
|
||||
2. C++ binding code for Tensorflow or PyTorch is generated, compiled and cached.
|
||||
3. The corresponding custom-op is automatically loaded from the generated .so file, and a framework-agnostic wrapper is created.
|
||||
4. The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
|
||||
|
||||
|
||||
## <span style="color:darkred"> Writing your own custom operation </span> <a name="custom-operation"></a>
|
||||
The remainder of this tutorial will show you how to re-implement the above `triton.ops.dot` operation from scratch.
|
||||
|
||||
In this section we will reimplement the above `dot` function, whose full source-code can be found [here](https://github.com/ptillet/triton/blob/master/python/triton/ops/dot.py).
|
||||
## <span style="color:darkred"> PyTriton Functions </span> <a name="pytriton-function"></a>
|
||||
|
||||
The PyTriton API provides a `triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
|
||||
|
||||
The first thing to do to create a custom op is to declare a class which inherits from `triton.function`.
|
||||
```python
|
||||
import triton
|
||||
|
||||
class _dot(triton.function):
|
||||
|
||||
src = """
|
||||
```
|
||||
|
||||
### <span style="color:darkblue">Creation of Triton Kernels </span> <a name="creation-triton-kernel"></a>
|
||||
|
||||
|
||||
PyTriton also provides a `triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton-C code derived at the end of the [previous tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
|
||||
|
||||
```
|
||||
src = """
|
||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
@@ -102,9 +122,16 @@ __global__ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
kernel = triton.kernel(src, ['C'])
|
||||
```
|
||||
|
||||
Here, `src` is the exact Triton-C source-code generated at the end of the aforementioned [tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md) , and `kernel = triton.kernel(src, ['C'])` creates a triton kernel from this source code which returns the tensor whose data points to `C`. At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors: `[tensor, tensor, tensor, int, int, int, int, int, int]`.
|
||||
Note that the second argument to `triton.kernel` constructors indicates which of the operands our kernel function should return. Here, we only return `C`.
|
||||
|
||||
However, in practice only A, B and C are provided by the user, and all the other `int` arguments are deduced from them, hence we create a helper function that extracts shapes from the `A`, `B` and `C` tensor and calls ouer `kernel`:
|
||||
At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors:
|
||||
```
|
||||
[tensor, tensor, tensor, int, int, int, int, int, int]
|
||||
```
|
||||
|
||||
### <span style="color:darkblue">Usage of Triton Kernels </span> <a name="usage-triton-kernels"></a>
|
||||
|
||||
However, in practice only A, B are provided by the user, and all the other `int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the `A` and `B` tensors, and then returns the results of a call to `kernel`:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
@@ -150,13 +177,22 @@ However, in practice only A, B and C are provided by the user, and all the other
|
||||
|
||||
```
|
||||
|
||||
There are a few things to note here:
|
||||
While this code should be mostly self-explanatory, there are a few of noteworthy things worth pointing out
|
||||
|
||||
- `triton.shape` provides a framework-agnostic way to retrieve the shape of a tensor
|
||||
|
||||
- `triton.empty` creates an empty tensor of the specified dimensions
|
||||
|
||||
- `grid` corresponds to the grid with which our Triton kernel will be launched. Because in our case this grid depends on parametric tile variables, it is supplied as a function of compilation options `opt`, whose compile-time definition can be retrieved using `opt.d(name)`. Here, `opt.d('TM')` and `opt.d('TN')` retrieve the first and second tile dimension our kernel was compiled with. We also provide a helper `triton.cdiv` for ceil divisions.
|
||||
|
||||
- `macros` provides a list of preprocessor definitions to compile the kernel with. Alternatively, these can also be supplied as named argument to the `_dot.kernel`. We recall that lists can be supplied to the preprocessor, in which case an auto-tuning procedure will be triggered. Here, the value of `TM` and `TN` are both tuned between 32, 64 and 128.
|
||||
|
||||
## <span style="color:darkred"> Compatibility with Automatic Differentiation</span> <a name="autodiff"></a>
|
||||
|
||||
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
|
||||
|
||||
### <span style="color:darkblue"> Basics </span> <a name="autodiff:basics"></a>
|
||||
|
||||
PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework using a single, common API inspired by PyTorch. It consists of two static methods `forward` and `backward` that take a context as their first input:
|
||||
|
||||
```
|
||||
@@ -188,6 +224,8 @@ PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework
|
||||
return da, db, None, None, None, None, None, None, None
|
||||
```
|
||||
|
||||
### <span style="color:darkblue">Convenience </span> <a name="autodiff:convenience"></a>
|
||||
|
||||
Still like for PyTorch, a callable operation can be created using the `apply` method of our `triton.function` class. We wrap it as a module variable for convenience:
|
||||
|
||||
```python
|
||||
|
Reference in New Issue
Block a user