Files
triton/docs/tutorials/putting-it-all-together.rst
2021-07-27 12:38:48 -07:00

181 lines
6.8 KiB
ReStructuredText

====================================================
Putting It All Together
====================================================
In the previous tutorial, we saw how to write tensor-core-friendly matrix multiplication code competitive with cuBLAS in 20 lines of Triton code. Here, we will see how to wrap it into an automatically differentiable PyTorch functions for easy integration in your Deep Learning pipeline.
-----------------
PyTriton Function
-----------------
The PyTriton API provides a :code:`triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
.. code-block:: python
import triton
# Entry point
class _dot(torch.autograd.Function):
@staticmethod
# Forward Pass
def forward(ctx, *args):
#...
@staticmethod
# Backward Pass
def backward(ctx, dy):
#...
-----------------
PyTriton Kernels
-----------------
PyTriton also provides a :code:`triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton code shown at the end of the previous tutorial.
.. code-block:: python
src = """
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0);
int pn = get_program_id(1);
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
}
"""
kernel = triton.kernel(src)
At this point, `kernel` is a callable object which takes the same signature as the :code:`dot` function in our source code, except that pointers are treated as tensors: :code:`[tensor, tensor, tensor, int, int, int, int, int, int]`.
-----------------------
Using PyTriton Kernels
-----------------------
However, in practice only A, B are provided by the user, and all the other :code:`int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the :code:`A` and :code:`B` tensors, and then returns the results of a call to :code:`kernel`:
.. code:: python
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = a.shape
shape_b = b.shape
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if transpose_a:
M, Ka = Ka, M
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# launch grid
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# pre-processor definitions
defines = {# tile sizes
'TYPE' : dtype,
'AT' : transpose_a,
'BT' : transpose_b,
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [8],
# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
grid=grid, num_warps=4, defines=defines)
--------------------------------------------
Automatic Differentiation
--------------------------------------------
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
Creating custom operations for Triton and PyTorch is very similar; programmers have to provide two static methods :code:`forward` and :code:`backward` that take a context as their first input:
.. code:: python
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
return _dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b
if not t_a and not t_b:
da = _dot._call(dy, b, False, True)
db = _dot._call(a, dy, True, False)
elif not t_a and t_b:
da = _dot._call(dy, b, False, False)
db = _dot._call(dy, a, True, False)
elif t_a and not t_b:
da = _dot._call(b, dy, False, True)
db = _dot._call(a, dy, False, False)
elif t_a and t_b:
da = _dot._call(b, dy, True, True)
db = _dot._call(dy, a, True, True)
else:
assert False
return da, db, None, None, None, None, None, None, None
A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class.
.. code:: python
dot = _dot.apply
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~