Files
triton/docs/tutorials/custom-operation.rst
2021-07-27 12:38:48 -07:00

103 lines
3.4 KiB
ReStructuredText

===========================
Writing a Custom Operation
===========================
--------------
Compute Kernel
--------------
Let us start with something simple, and see how Triton can be used to create a custom vector addition for PyTorch. The Triton compute kernel for this operation is the following:
.. code-block:: C
// Triton
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
// bounds checking
bool check[TILE] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
As you can see, arrays are first-class citizen in Triton. This has a number of important advantages that will be highlighted in the next tutorial. For now, let's keep it simple and see how to execute the above operation in PyTorch.
---------------
PyTorch Wrapper
---------------
As you will see, a wrapper for the above Triton function can be created in just a few lines of pure python code.
.. code-block:: python
import torch
import triton
class _add(torch.autograd.Function):
# source-code for Triton compute kernel
src = """
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
// bounds checking
bool check[TILE] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
"""
# create callable kernel for the source-code
# options: 4 warps and a -DTILE=1024
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
# Forward pass
@staticmethod
def forward(ctx, x, y):
# type checking
assert x.dtype == torch.float32
# allocate output
z = torch.empty_like(x).cuda()
# create launch grid
# this is a function of the launch parameters
# triton.cdiv indicates ceil division
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )
# launch kernel
_add.kernel(z, x, y, N, grid = grid)
# return output
return z
# get callable from Triton function
add = _add.apply
# test
torch.manual_seed(0)
x = torch.rand(98432).cuda()
y = torch.rand(98432).cuda()
za = x + y
zb = add(x, y)
diff = (za - zb).abs().max()
print(diff)
print(torch.allclose(za,zb))
Executing the above code will:
- Generate a .cpp file containing PyTorch bindings for the Triton function
- Compile this .cpp file using distutils
- Cache the resulting custom op
- Call the resulting custom op
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.