The purpose of PyTriton is to provide an API for easily executing Triton-C kernels from PyTorch and Tensorflow. One of the main advantages of PyTriton is that it is framework agnostic: any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required, as will be shown in this tutorial.
PyTriton works by detecting which frameworks are imported and automatically generating and just-in-time compiling C++ binding code for them. Specifically, the following chain of events is triggered when a Triton operation is executed:
1. The imported frameworks are detected
2. C++ binding code for Tensorflow or PyTorch is generated, compiled and cached.
3. The corresponding custom-op is automatically loaded from the generated .so file, and a framework-agnostic wrapper is created.
4. The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
The PyTriton API provides a `triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
### <span style="color:darkblue">Creation of Triton Kernels </span> <a name="creation-triton-kernel"></a>
PyTriton also provides a `triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton-C code derived at the end of the [previous tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
Note that the second argument to `triton.kernel` constructors indicates which of the operands our kernel function should return. Here, we only return `C`.
At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors:
### <span style="color:darkblue">Usage of Triton Kernels </span> <a name="usage-triton-kernels"></a>
However, in practice only A, B are provided by the user, and all the other `int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the `A` and `B` tensors, and then returns the results of a call to `kernel`:
-`grid` corresponds to the grid with which our Triton kernel will be launched. Because in our case this grid depends on parametric tile variables, it is supplied as a function of compilation options `opt`, whose compile-time definition can be retrieved using `opt.d(name)`. Here, `opt.d('TM')` and `opt.d('TN')` retrieve the first and second tile dimension our kernel was compiled with. We also provide a helper `triton.cdiv` for ceil divisions.
-`macros` provides a list of preprocessor definitions to compile the kernel with. Alternatively, these can also be supplied as named argument to the `_dot.kernel`. We recall that lists can be supplied to the preprocessor, in which case an auto-tuning procedure will be triggered. Here, the value of `TM` and `TN` are both tuned between 32, 64 and 128.
## <span style="color:darkred"> Compatibility with Automatic Differentiation</span> <a name="autodiff"></a>
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework using a single, common API inspired by PyTorch. It consists of two static methods `forward` and `backward` that take a context as their first input:
```
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
return _dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b
if not t_a and not t_b:
da = _dot._call(dy, b, False, True)
db = _dot._call(a, dy, True, False)
elif not t_a and t_b:
da = _dot._call(dy, b, False, False)
db = _dot._call(dy, a, True, False)
elif t_a and not t_b:
da = _dot._call(b, dy, False, True)
db = _dot._call(a, dy, False, False)
elif t_a and t_b:
da = _dot._call(b, dy, True, True)
db = _dot._call(dy, a, True, True)
else:
assert False
return da, db, None, None, None, None, None, None, None
Still like for PyTorch, a callable operation can be created using the `apply` method of our `triton.function` class. We wrap it as a module variable for convenience: