[GH-PAGES] Updated website
0
.cmake/api/v1/query/cache-v2
Normal file
0
.cmake/api/v1/query/cmakeFiles-v1
Normal file
0
.cmake/api/v1/query/codemodel-v2
Normal file
@@ -15,7 +15,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- The syntax and usage of reduction operators in Triton.\n- The automatic vectorization capabilities of the Triton compiler.\n"
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- The reduction operators in Triton.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -40,21 +40,14 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads X once and does all the necessary computations on-chip.\nIn this case, we would be reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nIn practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.\n\n"
|
||||
"When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads X once and does all the necessary computations on-chip.\nThis solution would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nIn practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a power-of-two number of elements,\nso we need to internally \"pad\" tiles and guard the memory operations properly if we want to handle any possible input shapes:\n\n .. code-block:: C\n\n __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){\n // row index\n int m = get_program_id(0);\n // column indices\n int n [BLOCK] = 0 ... BLOCK;\n // the memory address of all the elements\n // that we want to load can be computed as follows\n float* px [BLOCK] = X + m*stride_xm + n;\n // because BLOCK has to be a power of two\n // (per Triton-C specs), it is important\n // to guard each memory operation with predicates\n // or we will read out of bounds\n bool check[BLOCK] = n < N;\n float x [BLOCK] = check ? *px : -F32_INFINITY;\n // syntax for reduction in Triton is:\n // x[:, :, OPERATOR, :, :]\n // ^\n // index\n // where operator is in {min, max, +}\n // for 1D vectors, this is just x[OPERATOR].\n float z [BLOCK] = x - x[max];\n // Note that exponentials in Triton are fast\n // but approximate (i.e., think __expf in CUDA)\n float num [BLOCK] = exp(z);\n float denom = num[+];\n // The result of the reduction is now stored in y\n float y [BLOCK] = num / denom;\n // We write it back\n float* py [BLOCK] = Y + m*stride_ym + n;\n *?(check)py = y;\n }\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Torch Bindings\nHere our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.\nWe just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.\nThis means that different values of BLOCK will result in different kernels\n\n"
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a power-of-two number of elements,\nso we need to internally \"pad\" tiles and guard the memory operations properly if we want to handle any possible input shapes:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -65,14 +58,25 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport triton\n\n# Source code for the Triton kernel\n_src = \"\"\"\n__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n int m = get_program_id(0);\n int n [BLOCK] = 0 ... BLOCK;\n float* px [BLOCK] = X + m*stride_xm + n;\n bool check[BLOCK] = n < N;\n float x [BLOCK] = check ? *px : -F32_INFINITY;\n float z [BLOCK] = x - x[max];\n float num [BLOCK] = exp(z);\n float denom = num[+];\n float y [BLOCK] = num / denom;\n float* py [BLOCK] = Y + m*stride_ym + n;\n *?(check)py = y; \n}\n\"\"\"\n\n\n# helper function to get the smaller power-of-two larger than a given number\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\n# kernel caching mechanism\ndef make_kernel(N, device):\n cache = make_kernel.cache\n # Now are kernels are indexed not only by the provided device but also\n # by the rounded number of columns in the input matrix\n BLOCK = next_power_of_2(N)\n # Another trick we can use is to ask the compiler to parallelize each\n # row-normalization more aggressively -- i.e., with more warps -- vectors\n # that are longer\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself\n num_warps = 4\n if BLOCK >= 2048: num_warps = 8\n if BLOCK >= 4096: num_warps = 16\n # Each (BLOCK, num_warps, device) results in a different kernel\n key = (BLOCK, num_warps, device)\n if key not in cache:\n defines = {'BLOCK': BLOCK}\n cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)\n return cache[key]\n\n\nmake_kernel.cache = dict()\n\n\nclass _softmax(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x):\n # constraints of the op\n assert x.dtype == torch.float32\n y = torch.empty_like(x)\n # The launch grid is simple: we have one kernel instance per row of the input matrix\n M, N = y.shape\n grid = lambda opt: (M, )\n # Launch kernel\n kernel = make_kernel(N, y.device)\n kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)\n return y\n\n\nsoftmax = _softmax.apply"
|
||||
"import triton\n\n\n@triton.jit\ndef _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):\n # row index\n m = triton.program_id(0)\n # col indices\n n = triton.arange(0, meta['BLOCK'])\n # the memory address of all the elements\n # that we want to load can be computed as follows\n X = X + m * stride_xm + n\n x = triton.load(X, mask=n < N, other=-float('inf'))\n # Substract maximum for numerical stability\n z = x - triton.max(x, axis=0)\n # Note that exponentials in Triton are fast\n # but approximate (i.e., think __expf in CUDA)\n num = triton.exp(z)\n denom = triton.sum(num, axis=0)\n y = num / denom\n # Write back to Y\n Y = Y + m * stride_ym + n\n triton.store(Y, y, mask=n < N)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can use the above softmax function to compute the row-wise softmax of a given matrix.\n\n"
|
||||
"We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef softmax(x):\n M, N = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK = next_power_of_2(N)\n # Another trick we can use is to ask the compiler to parallelize each\n # row-normalization more aggressively -- i.e., with more warps -- vectors\n # that are longer\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself\n num_warps = 4\n if BLOCK >= 2048: num_warps = 8\n if BLOCK >= 4096: num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix\n _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)\n return y"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@@ -3,137 +3,70 @@ Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic syntax of the Triton programming language
|
||||
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
|
||||
- The basic programming model used by Triton
|
||||
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
#
|
||||
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
|
||||
# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
|
||||
# programming model for more details).
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void add(float* z, float* x, float* y, int N){
|
||||
# // The `get_program_id(i)` returns the i-th coordinate
|
||||
# // of the program in the overaching SPMD context
|
||||
# // (a.k.a launch grid). This is what allows us to process
|
||||
# // different chunks of data in parallel.
|
||||
# // For those similar with CUDA, `get_program_id({0,1,2})`
|
||||
# // is similar to blockIdx.{x,y,z}
|
||||
# int pid = get_program_id(0);
|
||||
# // In Triton, arrays are first-class citizen. In other words,
|
||||
# // they are primitives data-types and are -- contrary to C and
|
||||
# // CUDA -- not implemented as pointers to contiguous chunks of
|
||||
# // memory.
|
||||
# // In the few lines below, we create an array of `BLOCK` pointers
|
||||
# // whose memory values are, e.g.:
|
||||
# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
|
||||
# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
|
||||
# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
# float* pz [BLOCK] = z + offset;
|
||||
# float* px [BLOCK] = x + offset;
|
||||
# float* py [BLOCK] = y + offset;
|
||||
# // Simple element-wise control-flow for load/store operations can
|
||||
# // be achieved using the the ternary operator `cond ? val_true : val_false`
|
||||
# // or the conditional dereferencing operator `*?(cond)ptr
|
||||
# // Here, we make sure that we do not access memory out-of-bounds when we
|
||||
# // write-back `z`
|
||||
# bool check[BLOCK] = offset < N;
|
||||
# *?(check)pz = *?(check)px + *?(check)py;
|
||||
# }
|
||||
#
|
||||
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# --------------------------
|
||||
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
|
||||
#
|
||||
# - :code:`source: string`: the source-code of the kernel you want to create
|
||||
# - :code:`device: torch.device`: the device you want to compile this code for
|
||||
# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# source-code for Triton compute kernel
|
||||
# here we just copy-paste the above code without the extensive comments.
|
||||
# you may prefer to store it in a .c file and load it from there instead.
|
||||
_src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz[BLOCK] = z + offset;
|
||||
float* px[BLOCK] = x + offset;
|
||||
float* py[BLOCK] = y + offset;
|
||||
// bounds checking
|
||||
bool check[BLOCK] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _add(
|
||||
X, # *Pointer* to first input vector
|
||||
Y, # *Pointer* to second input vector
|
||||
Z, # *Pointer* to output vector
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = triton.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = triton.load(X + offsets, mask=mask)
|
||||
y = triton.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
triton.store(Z + offsets, z)
|
||||
|
||||
# This function returns a callable `triton.kernel` object created from the above source code.
|
||||
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||
# We compile the kernel with -DBLOCK=1024
|
||||
def make_add_kernel(device):
|
||||
cache = make_add_kernel.cache
|
||||
if device not in cache:
|
||||
defines = {'BLOCK': 1024}
|
||||
cache[device] = triton.kernel(_src, device=device, defines=defines)
|
||||
return cache[device]
|
||||
|
||||
|
||||
make_add_kernel.cache = dict()
|
||||
|
||||
|
||||
# This is a standard torch custom autograd Function;
|
||||
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
|
||||
class _add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
# *allocate output*
|
||||
z = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# this is a function which takes compilation parameters `opt`
|
||||
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||
# triton.cdiv is a shortcut for ceil division:
|
||||
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||
N = z.shape[0]
|
||||
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
|
||||
# *launch kernel*:
|
||||
# pointer to the data of torch tensors can be retrieved with
|
||||
# the `.data_ptr()` method
|
||||
kernel = make_add_kernel(z.device)
|
||||
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
|
||||
return z
|
||||
|
||||
|
||||
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
|
||||
add = _add.apply
|
||||
|
||||
# %%
|
||||
# We can now use the above function to compute the sum of two `torch.tensor` objects:
|
||||
# We can also declara a helper function that handles allocating the output vector
|
||||
# and enqueueing the kernel.
|
||||
|
||||
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# NOTE:
|
||||
# - torch.tensor objects are implicitly converted to pointers to their first element.
|
||||
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously.
|
||||
return z
|
||||
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
|
||||
# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432, device='cuda')
|
||||
y = torch.rand(98432, device='cuda')
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
|
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
|
||||
You will specifically learn about:
|
||||
|
||||
- The block-level matrix multiplication operator `@`
|
||||
- Block-level matrix multiplications
|
||||
- Multi-dimensional pointer arithmetic
|
||||
- Program re-ordering for improved L2 cache hit rate
|
||||
- Automatic performance tuning
|
||||
@@ -15,7 +15,7 @@ You will specifically learn about:
|
||||
# -------------
|
||||
# Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
|
||||
#
|
||||
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
@@ -23,322 +23,212 @@ You will specifically learn about:
|
||||
# .. code-block:: python
|
||||
#
|
||||
# # do in parallel
|
||||
# for m in range(0, M, MB):
|
||||
# for m in range(0, M, BLOCK_M):
|
||||
# # do in parallel
|
||||
# for n in range(0, N, NB):
|
||||
# acc = zeros((MB, NB), dtype=float32)
|
||||
# for k in range(0, K, KB):
|
||||
# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
|
||||
# C[m : m+MB, n : n+NB] = acc;
|
||||
# for n in range(0, N, BLOCK_N):
|
||||
# acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
|
||||
# for k in range(0, K, BLOCK_K):
|
||||
# a = A[m : m+BLOCK_M, k : k+BLOCK_K]
|
||||
# b = B[k : k+BLOCK_K, n : n+BLOCK_N]
|
||||
# acc += dot(a, b)
|
||||
# C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
|
||||
#
|
||||
# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
|
||||
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
#
|
||||
# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
|
||||
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
# The above algorithm is actually fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
#
|
||||
# Pointer Arithmetics
|
||||
# ~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# &A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
|
||||
# &B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
|
||||
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
|
||||
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
|
||||
#
|
||||
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# :force:
|
||||
# .. code-block:: python
|
||||
#
|
||||
# int rm[MB] = program_id_m * MB + 0 ... MB;
|
||||
# int rn[NB] = program_id_n * NB + 0 ... NB;
|
||||
# int rk[KB] = 0 ... KB;
|
||||
# TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
|
||||
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
# pid_m = triton.program_id(0)
|
||||
# pid_n = triton.program_id(1)
|
||||
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
# rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
# rk = triton.arange(0, BLOCK_K)
|
||||
# // pointer for A operand
|
||||
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
||||
# // pointer for B operand
|
||||
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
||||
#
|
||||
# These pointers can then be updated in the inner loop as:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: python
|
||||
#
|
||||
# pa += KB * 1;
|
||||
# pb += KB * ldb;
|
||||
# pa += BLOCK_K * stride_a_1;
|
||||
# pb += BLOCK_K * stride_b_0;
|
||||
#
|
||||
#
|
||||
# L2 Cache Optimizations
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
|
||||
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||||
# This means that a naive row-major ordering:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: Python
|
||||
#
|
||||
# int program_id = get_program_id(0);
|
||||
# int grid_m = (M + MB - 1) / MB;
|
||||
# int grid_n = (N + NB - 1) / NB;
|
||||
# int program_id_m = program_id / grid_n;
|
||||
# int program_id_n = program_id % grid_n;
|
||||
# pid = triton.program_id(0);
|
||||
# grid_m = (M + BLOCK_M - 1) // BLOCK_M;
|
||||
# grid_n = (N + BLOCK_N - 1) // BLOCK_N;
|
||||
# pid_m = pid / grid_n;
|
||||
# pid_n = pid % grid_n;
|
||||
#
|
||||
# is unlikely to result in optimal performance.
|
||||
#
|
||||
# One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: python
|
||||
#
|
||||
# int program_id = get_program_id(0);
|
||||
# int width = GROUP_SIZE * grid_n;
|
||||
# int group_id = pid / width;
|
||||
# // we need to handle the case where M % (GROUP_SIZE*BM) != 0
|
||||
# int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
|
||||
# int pid_m = group_id * GROUP_SIZE + (pid % group_size);
|
||||
# int pid_n = (pid % width) / (group_size);
|
||||
# pid = triton.program_id(0);
|
||||
# width = GROUP_M * grid_n;
|
||||
# group_id = pid // width;
|
||||
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
# pid_n = (pid % width) // (group_size);
|
||||
#
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
#
|
||||
# Final Result
|
||||
# ~~~~~~~~~~~~~~
|
||||
#
|
||||
# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
|
||||
# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
|
||||
# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
|
||||
#
|
||||
# .. code-block:: C
|
||||
# :force:
|
||||
#
|
||||
# #define MAX_GROUP_SIZE 8
|
||||
#
|
||||
# __global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
# int M, int N, int K,
|
||||
# int stride_a_0, int stride_b_0, int stride_c_0) {
|
||||
# // prologue
|
||||
# int pid = get_program_id(0);
|
||||
# int grid_m = (M + MB - 1) / MB;
|
||||
# int grid_n = (N + NB - 1) / NB;
|
||||
# // re-order program ID for better L2 performance
|
||||
# int width = MAX_GROUP_SIZE * grid_n;
|
||||
# int group_id = pid / width;
|
||||
# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
# int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
# int pid_n = (pid % width) / (group_size);
|
||||
# // pointers to operands
|
||||
# // note the parentheses here; they force the offset
|
||||
# // computation to happen in typeof(stride_a_0) = int32 rather than
|
||||
# // typeof(A) = int64
|
||||
# int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
# int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
# int rk[KB] = 0 ... KB;
|
||||
# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
|
||||
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
# // reduction loop
|
||||
# float acc[MB, NB] = 0;
|
||||
# for (int k = K; k > 0; k -= KB) {
|
||||
# acc += (*pa) @ (*pb);
|
||||
# pa += KB * 1;
|
||||
# pb += KB * stride_b_0;
|
||||
# }
|
||||
# // pointers to output
|
||||
# // here we rematerialize `rm` and `rn` so that they are not live through
|
||||
# // the above reduction loop. In the future, the compiler should be able to
|
||||
# // do this automatically.
|
||||
# rm = pid_m * MB + 0 ... MB;
|
||||
# rn = pid_n * NB + 0 ... NB;
|
||||
# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
|
||||
# // we write back using *?() operator. `acc` gets casted to `float32` implicitly.
|
||||
# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
# }
|
||||
#
|
||||
# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
|
||||
# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
|
||||
# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
|
||||
#
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# ----------------
|
||||
# Final Result
|
||||
# -------------
|
||||
#
|
||||
# Auto-Tuning
|
||||
# ~~~~~~~~~~~~~~
|
||||
#
|
||||
# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
autotune_configs = [
|
||||
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
|
||||
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
|
||||
]
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
ret_true = 1 / (1 + triton.exp(-x))
|
||||
ret_false = triton.exp(x) / (1 + triton.exp(x))
|
||||
return triton.where(x >= 0, ret_true, ret_false)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swish(x):
|
||||
return x * sigmoid(x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
# %
|
||||
# We can now define our kernel as normal, using all the techniques presented above
|
||||
@triton.jit
|
||||
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
|
||||
# extract meta-parameters
|
||||
BLOCK_M = META['BLOCK_M']
|
||||
BLOCK_N = META['BLOCK_N']
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = 8
|
||||
# matrix multiplication
|
||||
pid = triton.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
a = triton.load(A)
|
||||
b = triton.load(B)
|
||||
acc += triton.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
# triton can accept arbitrary activation function
|
||||
# via metaparameters!
|
||||
if META['ACTIVATION']:
|
||||
acc = META['ACTIVATION'](acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
triton.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
# %%
|
||||
# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
|
||||
# Here, we want to re-tune our kernel only when the shape of input matrices changes.
|
||||
|
||||
autotune_key = ["M", "N", "K"]
|
||||
|
||||
# %%
|
||||
# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
|
||||
|
||||
src = """
|
||||
#define MAX_GROUP_SIZE 8
|
||||
|
||||
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
int M, int N, int K,
|
||||
int lda, int ldb, int ldc) {
|
||||
int pid = get_program_id(0);
|
||||
int grid_m = (M + MB - 1) / MB;
|
||||
int grid_n = (N + NB - 1) / NB;
|
||||
int width = MAX_GROUP_SIZE * grid_n;
|
||||
int group_id = pid / width;
|
||||
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
int pid_n = (pid % width) / (group_size);
|
||||
int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
int rk[KB] = 0 ... KB;
|
||||
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
|
||||
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
|
||||
float acc[MB, NB] = 0;
|
||||
for (int k = K; k > 0; k -= KB) {
|
||||
acc += (*pa) @ (*pb);
|
||||
pa += KB * 1;
|
||||
pb += KB * ldb;
|
||||
}
|
||||
rm = pid_m * MB + 0 ... MB;
|
||||
rn = pid_n * NB + 0 ... NB;
|
||||
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
|
||||
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
}
|
||||
"""
|
||||
# We can also create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
|
||||
|
||||
|
||||
def make_kernel(device, dtype):
|
||||
key = (device, dtype)
|
||||
cache = make_kernel.cache
|
||||
if key not in cache:
|
||||
defines = {'TYPE': dtype}
|
||||
cache[key] = triton.kernel(
|
||||
src,
|
||||
device=device,
|
||||
defines=defines,
|
||||
autotune_configs=autotune_configs,
|
||||
autotune_key=autotune_key,
|
||||
)
|
||||
return cache[key]
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
# return output
|
||||
return c
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
# %%
|
||||
# Autograd Function
|
||||
# ~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
|
||||
# To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
|
||||
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
M, Ka = a.shape
|
||||
Kb, N = b.shape
|
||||
assert Ka == Kb, "incompatible dimensions"
|
||||
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
kernel = make_kernel(a.device, a.dtype)
|
||||
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
||||
M, N, Ka, \
|
||||
a.stride(0), b.stride(0), c.stride(0), \
|
||||
grid=grid)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
|
||||
# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
|
||||
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
|
||||
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
|
||||
c_0 = dot(a, b)
|
||||
c_1 = torch.matmul(a, b)
|
||||
#torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
# --------------
|
||||
#
|
||||
# Installing The CUTLASS Bindings
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
|
||||
# For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
|
||||
# To install CUTLASS, you need a recent version of cmake:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# cd /path/to/cutlass/
|
||||
# git clone https://github.com/NVIDIA/cutlass.git
|
||||
# cd cutlass
|
||||
# mkdir build
|
||||
# cd build
|
||||
# wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
|
||||
# tar xzvf *.tar.gz
|
||||
#
|
||||
# You can then install CUTLASS as follows for V100
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
|
||||
# make -j8 install
|
||||
#
|
||||
# Or as follows for A100:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
|
||||
# make -j8 install
|
||||
#
|
||||
# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
|
||||
# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
|
||||
# To re-install Triton with the updated CUTLASS bindings, run the following command:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
||||
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
||||
# pip uninstall -y triton
|
||||
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
#
|
||||
# Which we can test as follows:
|
||||
|
||||
import triton
|
||||
c_2 = triton.testing.cutlass_matmul(a, b)
|
||||
print(c_2)
|
||||
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
|
||||
# %%
|
||||
# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
|
||||
#
|
||||
# Square Matrix Performance
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
||||
@@ -349,27 +239,23 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
|
||||
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
silu = torch.nn.SiLU()
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
|
||||
if provider == 'cutlass':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
|
||||
# %%
|
||||
# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
|
||||
benchmark.run(print_data=True)
|
@@ -4,8 +4,7 @@ Fused Softmax
|
||||
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The syntax and usage of reduction operators in Triton.
|
||||
- The automatic vectorization capabilities of the Triton compiler.
|
||||
- The reduction operators in Triton.
|
||||
"""
|
||||
|
||||
# %%
|
||||
@@ -36,79 +35,45 @@ def naive_softmax(x):
|
||||
# %%
|
||||
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
# In this case, we would be reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
# Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
|
||||
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
|
||||
# // row index
|
||||
# int m = get_program_id(0);
|
||||
# // column indices
|
||||
# int n [BLOCK] = 0 ... BLOCK;
|
||||
# // the memory address of all the elements
|
||||
# // that we want to load can be computed as follows
|
||||
# float* px [BLOCK] = X + m*stride_xm + n;
|
||||
# // because BLOCK has to be a power of two
|
||||
# // (per Triton-C specs), it is important
|
||||
# // to guard each memory operation with predicates
|
||||
# // or we will read out of bounds
|
||||
# bool check[BLOCK] = n < N;
|
||||
# float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
# // syntax for reduction in Triton is:
|
||||
# // x[:, :, OPERATOR, :, :]
|
||||
# // ^
|
||||
# // index
|
||||
# // where operator is in {min, max, +}
|
||||
# // for 1D vectors, this is just x[OPERATOR].
|
||||
# float z [BLOCK] = x - x[max];
|
||||
# // Note that exponentials in Triton are fast
|
||||
# // but approximate (i.e., think __expf in CUDA)
|
||||
# float num [BLOCK] = exp(z);
|
||||
# float denom = num[+];
|
||||
# // The result of the reduction is now stored in y
|
||||
# float y [BLOCK] = num / denom;
|
||||
# // We write it back
|
||||
# float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
# *?(check)py = y;
|
||||
# }
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# ---------------
|
||||
# Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
|
||||
# We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
|
||||
# This means that different values of BLOCK will result in different kernels
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# Source code for the Triton kernel
|
||||
_src = """
|
||||
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
|
||||
int m = get_program_id(0);
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
float z [BLOCK] = x - x[max];
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
float y [BLOCK] = num / denom;
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
# row index
|
||||
m = triton.program_id(0)
|
||||
# col indices
|
||||
n = triton.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
X = X + m * stride_xm + n
|
||||
x = triton.load(X, mask=n < N, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
z = x - triton.max(x, axis=0)
|
||||
# Note that exponentials in Triton are fast
|
||||
# but approximate (i.e., think __expf in CUDA)
|
||||
num = triton.exp(z)
|
||||
denom = triton.sum(num, axis=0)
|
||||
y = num / denom
|
||||
# Write back to Y
|
||||
Y = Y + m * stride_ym + n
|
||||
triton.store(Y, y, mask=n < N)
|
||||
|
||||
|
||||
# %%
|
||||
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
|
||||
# helper function to get the smaller power-of-two larger than a given number
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
@@ -120,11 +85,9 @@ def next_power_of_2(n):
|
||||
return n
|
||||
|
||||
|
||||
# kernel caching mechanism
|
||||
def make_kernel(N, device):
|
||||
cache = make_kernel.cache
|
||||
# Now are kernels are indexed not only by the provided device but also
|
||||
# by the rounded number of columns in the input matrix
|
||||
def softmax(x):
|
||||
M, N = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
# Another trick we can use is to ask the compiler to parallelize each
|
||||
# row-normalization more aggressively -- i.e., with more warps -- vectors
|
||||
@@ -134,37 +97,13 @@ def make_kernel(N, device):
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
# Each (BLOCK, num_warps, device) results in a different kernel
|
||||
key = (BLOCK, num_warps, device)
|
||||
if key not in cache:
|
||||
defines = {'BLOCK': BLOCK}
|
||||
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
|
||||
return cache[key]
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
|
||||
return y
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
y = torch.empty_like(x)
|
||||
# The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
M, N = y.shape
|
||||
grid = lambda opt: (M, )
|
||||
# Launch kernel
|
||||
kernel = make_kernel(N, y.device)
|
||||
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
|
||||
return y
|
||||
|
||||
|
||||
softmax = _softmax.apply
|
||||
|
||||
# %%
|
||||
# We can use the above softmax function to compute the row-wise softmax of a given matrix.
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# ----------
|
||||
|
@@ -15,21 +15,14 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Vector Addition\nIn this tutorial, you will write a simple vector addition using Triton and learn about:\n\n- The basic syntax of the Triton programming language\n- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API\n- The best practices for validating and benchmarking custom ops against native reference implementations\n"
|
||||
"\n# Vector Addition\nIn this tutorial, you will write a simple vector addition using Triton and learn about:\n\n- The basic programming model used by Triton\n- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.\n- The best practices for validating and benchmarking custom ops against native reference implementations\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\n\nEach compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel\non different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)\nprogramming model for more details).\n\n .. code-block:: C\n\n __global__ void add(float* z, float* x, float* y, int N){\n // The `get_program_id(i)` returns the i-th coordinate\n // of the program in the overaching SPMD context\n // (a.k.a launch grid). This is what allows us to process\n // different chunks of data in parallel.\n // For those similar with CUDA, `get_program_id({0,1,2})`\n // is similar to blockIdx.{x,y,z}\n int pid = get_program_id(0);\n // In Triton, arrays are first-class citizen. In other words,\n // they are primitives data-types and are -- contrary to C and\n // CUDA -- not implemented as pointers to contiguous chunks of\n // memory.\n // In the few lines below, we create an array of `BLOCK` pointers\n // whose memory values are, e.g.:\n // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz [BLOCK] = z + offset;\n float* px [BLOCK] = x + offset;\n float* py [BLOCK] = y + offset;\n // Simple element-wise control-flow for load/store operations can\n // be achieved using the the ternary operator `cond ? val_true : val_false`\n // or the conditional dereferencing operator `*?(cond)ptr\n // Here, we make sure that we do not access memory out-of-bounds when we\n // write-back `z`\n bool check[BLOCK] = offset < N;\n *?(check)pz = *?(check)px + *?(check)py;\n }\n\nThe existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Torch Bindings\nThe only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:\n\n- :code:`source: string`: the source-code of the kernel you want to create\n- :code:`device: torch.device`: the device you want to compile this code for\n- :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n\n"
|
||||
"## Compute Kernel\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -40,21 +33,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport triton\n\n# source-code for Triton compute kernel\n# here we just copy-paste the above code without the extensive comments.\n# you may prefer to store it in a .c file and load it from there instead.\n_src = \"\"\"\n__global__ void add(float* z, float* x, float* y, int N){\n // program id\n int pid = get_program_id(0);\n // create arrays of pointers\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz[BLOCK] = z + offset;\n float* px[BLOCK] = x + offset;\n float* py[BLOCK] = y + offset;\n // bounds checking\n bool check[BLOCK] = offset < N;\n // write-back\n *?(check)pz = *?(check)px + *?(check)py;\n}\n \"\"\"\n\n\n# This function returns a callable `triton.kernel` object created from the above source code.\n# For portability, we maintain a cache of kernels for different `torch.device`\n# We compile the kernel with -DBLOCK=1024\ndef make_add_kernel(device):\n cache = make_add_kernel.cache\n if device not in cache:\n defines = {'BLOCK': 1024}\n cache[device] = triton.kernel(_src, device=device, defines=defines)\n return cache[device]\n\n\nmake_add_kernel.cache = dict()\n\n\n# This is a standard torch custom autograd Function;\n# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`\nclass _add(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, y):\n # constraints of the op\n assert x.dtype == torch.float32\n # *allocate output*\n z = torch.empty_like(x)\n # *create launch grid*:\n # this is a function which takes compilation parameters `opt`\n # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n # triton.cdiv is a shortcut for ceil division:\n # triton.cdiv(a, b) = (a + b - 1) // b\n N = z.shape[0]\n grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n # *launch kernel*:\n # pointer to the data of torch tensors can be retrieved with\n # the `.data_ptr()` method\n kernel = make_add_kernel(z.device)\n kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)\n return z\n\n\n# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function\nadd = _add.apply"
|
||||
"import torch\nimport triton\n\n\n@triton.jit\ndef _add(\n X, # *Pointer* to first input vector\n Y, # *Pointer* to second input vector\n Z, # *Pointer* to output vector\n N, # Size of the vector\n **meta # Optional meta-parameters for the kernel\n):\n pid = triton.program_id(0)\n # Create an offset for the blocks of pointers to be\n # processed by this program instance\n offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])\n # Create a mask to guard memory operations against\n # out-of-bounds accesses\n mask = offsets < N\n # Load x\n x = triton.load(X + offsets, mask=mask)\n y = triton.load(Y + offsets, mask=mask)\n # Write back x + y\n z = x + y\n triton.store(Z + offsets, z)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now use the above function to compute the sum of two `torch.tensor` objects:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Unit Test\n\nOf course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:\n\n"
|
||||
"We can also declara a helper function that handles allocating the output vector\nand enqueueing the kernel.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -65,7 +51,25 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nx = torch.rand(98432, device='cuda')\ny = torch.rand(98432, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')"
|
||||
"def add(x, y):\n z = torch.empty_like(x)\n N = z.shape[0]\n # The SPMD launch grid denotes the number of kernel instances that should execute in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )\n # NOTE:\n # - torch.tensor objects are implicitly converted to pointers to their first element.\n # - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n _add[grid](x, y, z, N, BLOCK=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously.\n return z"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Before Width: | Height: | Size: 2.9 KiB |
Before Width: | Height: | Size: 3.6 KiB |
Before Width: | Height: | Size: 9.5 KiB |
Before Width: | Height: | Size: 12 KiB |
Before Width: | Height: | Size: 59 KiB |
Before Width: | Height: | Size: 27 KiB After Width: | Height: | Size: 27 KiB |
Before Width: | Height: | Size: 17 KiB After Width: | Height: | Size: 17 KiB |
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 36 KiB |
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 23 KiB |
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 36 KiB |
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 3.0 KiB |
@@ -22,63 +22,16 @@ Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic syntax of the Triton programming language
|
||||
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
|
||||
- The basic programming model used by Triton
|
||||
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-51
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-14
|
||||
|
||||
Compute Kernel
|
||||
--------------------------
|
||||
|
||||
Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
|
||||
on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
|
||||
programming model for more details).
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// The `get_program_id(i)` returns the i-th coordinate
|
||||
// of the program in the overaching SPMD context
|
||||
// (a.k.a launch grid). This is what allows us to process
|
||||
// different chunks of data in parallel.
|
||||
// For those similar with CUDA, `get_program_id({0,1,2})`
|
||||
// is similar to blockIdx.{x,y,z}
|
||||
int pid = get_program_id(0);
|
||||
// In Triton, arrays are first-class citizen. In other words,
|
||||
// they are primitives data-types and are -- contrary to C and
|
||||
// CUDA -- not implemented as pointers to contiguous chunks of
|
||||
// memory.
|
||||
// In the few lines below, we create an array of `BLOCK` pointers
|
||||
// whose memory values are, e.g.:
|
||||
// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
|
||||
// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz [BLOCK] = z + offset;
|
||||
float* px [BLOCK] = x + offset;
|
||||
float* py [BLOCK] = y + offset;
|
||||
// Simple element-wise control-flow for load/store operations can
|
||||
// be achieved using the the ternary operator `cond ? val_true : val_false`
|
||||
// or the conditional dereferencing operator `*?(cond)ptr
|
||||
// Here, we make sure that we do not access memory out-of-bounds when we
|
||||
// write-back `z`
|
||||
bool check[BLOCK] = offset < N;
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
|
||||
The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 53-60
|
||||
|
||||
Torch Bindings
|
||||
--------------------------
|
||||
The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
|
||||
|
||||
- :code:`source: string`: the source-code of the kernel you want to create
|
||||
- :code:`device: torch.device`: the device you want to compile this code for
|
||||
- :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 60-125
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-42
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -86,66 +39,28 @@ The only thing that matters when it comes to Triton and Torch is the :code:`trit
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# source-code for Triton compute kernel
|
||||
# here we just copy-paste the above code without the extensive comments.
|
||||
# you may prefer to store it in a .c file and load it from there instead.
|
||||
_src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz[BLOCK] = z + offset;
|
||||
float* px[BLOCK] = x + offset;
|
||||
float* py[BLOCK] = y + offset;
|
||||
// bounds checking
|
||||
bool check[BLOCK] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# This function returns a callable `triton.kernel` object created from the above source code.
|
||||
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||
# We compile the kernel with -DBLOCK=1024
|
||||
def make_add_kernel(device):
|
||||
cache = make_add_kernel.cache
|
||||
if device not in cache:
|
||||
defines = {'BLOCK': 1024}
|
||||
cache[device] = triton.kernel(_src, device=device, defines=defines)
|
||||
return cache[device]
|
||||
|
||||
|
||||
make_add_kernel.cache = dict()
|
||||
|
||||
|
||||
# This is a standard torch custom autograd Function;
|
||||
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
|
||||
class _add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
# *allocate output*
|
||||
z = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# this is a function which takes compilation parameters `opt`
|
||||
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||
# triton.cdiv is a shortcut for ceil division:
|
||||
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||
N = z.shape[0]
|
||||
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
|
||||
# *launch kernel*:
|
||||
# pointer to the data of torch tensors can be retrieved with
|
||||
# the `.data_ptr()` method
|
||||
kernel = make_add_kernel(z.device)
|
||||
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
|
||||
return z
|
||||
|
||||
|
||||
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
|
||||
add = _add.apply
|
||||
@triton.jit
|
||||
def _add(
|
||||
X, # *Pointer* to first input vector
|
||||
Y, # *Pointer* to second input vector
|
||||
Z, # *Pointer* to output vector
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = triton.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = triton.load(X + offsets, mask=mask)
|
||||
y = triton.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
triton.store(Z + offsets, z)
|
||||
|
||||
|
||||
|
||||
@@ -154,25 +69,54 @@ The only thing that matters when it comes to Triton and Torch is the :code:`trit
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 126-127
|
||||
|
||||
We can now use the above function to compute the sum of two `torch.tensor` objects:
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 43-45
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 129-133
|
||||
We can also declara a helper function that handles allocating the output vector
|
||||
and enqueueing the kernel.
|
||||
|
||||
Unit Test
|
||||
-----------
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 45-63
|
||||
|
||||
Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
|
||||
.. code-block:: default
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 133-143
|
||||
|
||||
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# NOTE:
|
||||
# - torch.tensor objects are implicitly converted to pointers to their first element.
|
||||
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously.
|
||||
return z
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 64-65
|
||||
|
||||
We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 65-76
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432, device='cuda')
|
||||
y = torch.rand(98432, device='cuda')
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
@@ -196,11 +140,11 @@ Of course, the first thing that we should check is that whether kernel is correc
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 144-145
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 77-78
|
||||
|
||||
Seems like we're good to go!
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 147-152
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 80-85
|
||||
|
||||
Benchmark
|
||||
-----------
|
||||
@@ -208,7 +152,7 @@ We can now benchmark our custom op for vectors of increasing sizes to get a sens
|
||||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
|
||||
for different problem sizes.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 152-178
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 85-111
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -245,12 +189,12 @@ for different problem sizes.
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 179-181
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 112-114
|
||||
|
||||
We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or
|
||||
`save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 181-181
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 114-114
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -268,7 +212,7 @@ We can now run the decorated function above. Pass `show_plots=True` to see the p
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 9.497 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 5.812 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
@@ -23,17 +23,16 @@ Fused Softmax
|
||||
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The syntax and usage of reduction operators in Triton.
|
||||
- The automatic vectorization capabilities of the Triton compiler.
|
||||
- The reduction operators in Triton.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-16
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 11-15
|
||||
|
||||
Motivations
|
||||
------------
|
||||
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 16-36
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 15-35
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -64,90 +63,68 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 37-41
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 36-40
|
||||
|
||||
When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
In this case, we would be reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 43-82
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 42-47
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
|
||||
Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
|
||||
// row index
|
||||
int m = get_program_id(0);
|
||||
// column indices
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
// the memory address of all the elements
|
||||
// that we want to load can be computed as follows
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
// because BLOCK has to be a power of two
|
||||
// (per Triton-C specs), it is important
|
||||
// to guard each memory operation with predicates
|
||||
// or we will read out of bounds
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
// syntax for reduction in Triton is:
|
||||
// x[:, :, OPERATOR, :, :]
|
||||
// ^
|
||||
// index
|
||||
// where operator is in {min, max, +}
|
||||
// for 1D vectors, this is just x[OPERATOR].
|
||||
float z [BLOCK] = x - x[max];
|
||||
// Note that exponentials in Triton are fast
|
||||
// but approximate (i.e., think __expf in CUDA)
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
// The result of the reduction is now stored in y
|
||||
float y [BLOCK] = num / denom;
|
||||
// We write it back
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 84-89
|
||||
|
||||
Torch Bindings
|
||||
---------------
|
||||
Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
|
||||
We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
|
||||
This means that different values of BLOCK will result in different kernels
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 89-165
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 47-73
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# Source code for the Triton kernel
|
||||
_src = """
|
||||
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
|
||||
int m = get_program_id(0);
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
float z [BLOCK] = x - x[max];
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
float y [BLOCK] = num / denom;
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
# row index
|
||||
m = triton.program_id(0)
|
||||
# col indices
|
||||
n = triton.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
X = X + m * stride_xm + n
|
||||
x = triton.load(X, mask=n < N, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
z = x - triton.max(x, axis=0)
|
||||
# Note that exponentials in Triton are fast
|
||||
# but approximate (i.e., think __expf in CUDA)
|
||||
num = triton.exp(z)
|
||||
denom = triton.sum(num, axis=0)
|
||||
y = num / denom
|
||||
# Write back to Y
|
||||
Y = Y + m * stride_ym + n
|
||||
triton.store(Y, y, mask=n < N)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 74-75
|
||||
|
||||
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 75-107
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
# helper function to get the smaller power-of-two larger than a given number
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
@@ -159,11 +136,9 @@ This means that different values of BLOCK will result in different kernels
|
||||
return n
|
||||
|
||||
|
||||
# kernel caching mechanism
|
||||
def make_kernel(N, device):
|
||||
cache = make_kernel.cache
|
||||
# Now are kernels are indexed not only by the provided device but also
|
||||
# by the rounded number of columns in the input matrix
|
||||
def softmax(x):
|
||||
M, N = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
# Another trick we can use is to ask the compiler to parallelize each
|
||||
# row-normalization more aggressively -- i.e., with more warps -- vectors
|
||||
@@ -173,33 +148,11 @@ This means that different values of BLOCK will result in different kernels
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
# Each (BLOCK, num_warps, device) results in a different kernel
|
||||
key = (BLOCK, num_warps, device)
|
||||
if key not in cache:
|
||||
defines = {'BLOCK': BLOCK}
|
||||
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
|
||||
return cache[key]
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
y = torch.empty_like(x)
|
||||
# The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
M, N = y.shape
|
||||
grid = lambda opt: (M, )
|
||||
# Launch kernel
|
||||
kernel = make_kernel(N, y.device)
|
||||
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
|
||||
return y
|
||||
|
||||
|
||||
softmax = _softmax.apply
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
|
||||
return y
|
||||
|
||||
|
||||
|
||||
@@ -208,21 +161,18 @@ This means that different values of BLOCK will result in different kernels
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 166-167
|
||||
|
||||
We can use the above softmax function to compute the row-wise softmax of a given matrix.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 169-171
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 108-110
|
||||
|
||||
Unit Test
|
||||
----------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 173-175
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 112-114
|
||||
|
||||
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
|
||||
This will allow us to verify that our padding mechanism works.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 175-182
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 114-121
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -248,18 +198,18 @@ This will allow us to verify that our padding mechanism works.
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 183-184
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 122-123
|
||||
|
||||
As expected, the results are identical.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 186-190
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 125-129
|
||||
|
||||
Benchmark
|
||||
-------------
|
||||
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 190-218
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 129-157
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -302,7 +252,7 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 219-224
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 158-163
|
||||
|
||||
In the above plot, we can see that:
|
||||
|
||||
@@ -314,7 +264,7 @@ In the above plot, we can see that:
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 25.654 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 20.767 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
@@ -20,21 +20,21 @@
|
||||
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
|
||||
You will specifically learn about:
|
||||
|
||||
- The block-level matrix multiplication operator `@`
|
||||
- Block-level matrix multiplications
|
||||
- Multi-dimensional pointer arithmetic
|
||||
- Program re-ordering for improved L2 cache hit rate
|
||||
- Automatic performance tuning
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-35
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-37
|
||||
|
||||
Motivations
|
||||
-------------
|
||||
Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
|
||||
|
||||
Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
@@ -42,154 +42,99 @@ Roughly speaking, the kernel that we will write will implement the following blo
|
||||
.. code-block:: python
|
||||
|
||||
# do in parallel
|
||||
for m in range(0, M, MB):
|
||||
for m in range(0, M, BLOCK_M):
|
||||
# do in parallel
|
||||
for n in range(0, N, NB):
|
||||
acc = zeros((MB, NB), dtype=float32)
|
||||
for k in range(0, K, KB):
|
||||
acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
|
||||
C[m : m+MB, n : n+NB] = acc;
|
||||
for n in range(0, N, BLOCK_N):
|
||||
acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = A[m : m+BLOCK_M, k : k+BLOCK_K]
|
||||
b = B[k : k+BLOCK_K, n : n+BLOCK_N]
|
||||
acc += dot(a, b)
|
||||
C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
|
||||
|
||||
where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
|
||||
where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 37-161
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 39-110
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
|
||||
The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
|
||||
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
The above algorithm is actually fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
|
||||
Pointer Arithmetics
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
|
||||
Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
|
||||
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
&A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
|
||||
&B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
|
||||
&A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
|
||||
&B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
|
||||
|
||||
Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
||||
|
||||
.. code-block:: C
|
||||
:force:
|
||||
.. code-block:: python
|
||||
|
||||
int rm[MB] = program_id_m * MB + 0 ... MB;
|
||||
int rn[NB] = program_id_n * NB + 0 ... NB;
|
||||
int rk[KB] = 0 ... KB;
|
||||
TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
|
||||
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
pid_m = triton.program_id(0)
|
||||
pid_n = triton.program_id(1)
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
// pointer for A operand
|
||||
pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
||||
// pointer for B operand
|
||||
pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
||||
|
||||
These pointers can then be updated in the inner loop as:
|
||||
|
||||
.. code-block:: C
|
||||
.. code-block:: python
|
||||
|
||||
pa += KB * 1;
|
||||
pb += KB * ldb;
|
||||
pa += BLOCK_K * stride_a_1;
|
||||
pb += BLOCK_K * stride_b_0;
|
||||
|
||||
|
||||
L2 Cache Optimizations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
|
||||
As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||||
This means that a naive row-major ordering:
|
||||
|
||||
.. code-block:: C
|
||||
.. code-block:: Python
|
||||
|
||||
int program_id = get_program_id(0);
|
||||
int grid_m = (M + MB - 1) / MB;
|
||||
int grid_n = (N + NB - 1) / NB;
|
||||
int program_id_m = program_id / grid_n;
|
||||
int program_id_n = program_id % grid_n;
|
||||
pid = triton.program_id(0);
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M;
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N;
|
||||
pid_m = pid / grid_n;
|
||||
pid_n = pid % grid_n;
|
||||
|
||||
is unlikely to result in optimal performance.
|
||||
|
||||
One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
|
||||
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
|
||||
.. code-block:: C
|
||||
.. code-block:: python
|
||||
|
||||
int program_id = get_program_id(0);
|
||||
int width = GROUP_SIZE * grid_n;
|
||||
int group_id = pid / width;
|
||||
// we need to handle the case where M % (GROUP_SIZE*BM) != 0
|
||||
int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
|
||||
int pid_m = group_id * GROUP_SIZE + (pid % group_size);
|
||||
int pid_n = (pid % width) / (group_size);
|
||||
pid = triton.program_id(0);
|
||||
width = GROUP_M * grid_n;
|
||||
group_id = pid // width;
|
||||
# we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
pid_n = (pid % width) // (group_size);
|
||||
|
||||
In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 112-115
|
||||
|
||||
Final Result
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
|
||||
Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
|
||||
This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
|
||||
|
||||
.. code-block:: C
|
||||
:force:
|
||||
|
||||
#define MAX_GROUP_SIZE 8
|
||||
|
||||
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
int M, int N, int K,
|
||||
int stride_a_0, int stride_b_0, int stride_c_0) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int grid_m = (M + MB - 1) / MB;
|
||||
int grid_n = (N + NB - 1) / NB;
|
||||
// re-order program ID for better L2 performance
|
||||
int width = MAX_GROUP_SIZE * grid_n;
|
||||
int group_id = pid / width;
|
||||
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
int pid_n = (pid % width) / (group_size);
|
||||
// pointers to operands
|
||||
// note the parentheses here; they force the offset
|
||||
// computation to happen in typeof(stride_a_0) = int32 rather than
|
||||
// typeof(A) = int64
|
||||
int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
int rk[KB] = 0 ... KB;
|
||||
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
|
||||
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
// reduction loop
|
||||
float acc[MB, NB] = 0;
|
||||
for (int k = K; k > 0; k -= KB) {
|
||||
acc += (*pa) @ (*pb);
|
||||
pa += KB * 1;
|
||||
pb += KB * stride_b_0;
|
||||
}
|
||||
// pointers to output
|
||||
// here we rematerialize `rm` and `rn` so that they are not live through
|
||||
// the above reduction loop. In the future, the compiler should be able to
|
||||
// do this automatically.
|
||||
rm = pid_m * MB + 0 ... MB;
|
||||
rn = pid_n * NB + 0 ... NB;
|
||||
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
|
||||
// we write back using *?() operator. `acc` gets casted to `float32` implicitly.
|
||||
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
}
|
||||
|
||||
Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
|
||||
Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
|
||||
If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
|
||||
-------------
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 163-170
|
||||
|
||||
Torch Bindings
|
||||
----------------
|
||||
|
||||
Auto-Tuning
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 170-185
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 115-188
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -197,16 +142,73 @@ In order to use Triton's built-in auto-tuner in the above kernel, we need to def
|
||||
import torch
|
||||
import triton
|
||||
|
||||
autotune_configs = [
|
||||
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
|
||||
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
|
||||
]
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
ret_true = 1 / (1 + triton.exp(-x))
|
||||
ret_false = triton.exp(x) / (1 + triton.exp(x))
|
||||
return triton.where(x >= 0, ret_true, ret_false)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swish(x):
|
||||
return x * sigmoid(x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
# %
|
||||
# We can now define our kernel as normal, using all the techniques presented above
|
||||
@triton.jit
|
||||
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
|
||||
# extract meta-parameters
|
||||
BLOCK_M = META['BLOCK_M']
|
||||
BLOCK_N = META['BLOCK_N']
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = 8
|
||||
# matrix multiplication
|
||||
pid = triton.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
a = triton.load(A)
|
||||
b = triton.load(B)
|
||||
acc += triton.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
# triton can accept arbitrary activation function
|
||||
# via metaparameters!
|
||||
if META['ACTIVATION']:
|
||||
acc = META['ACTIVATION'](acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
triton.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
|
||||
@@ -215,123 +217,36 @@ In order to use Triton's built-in auto-tuner in the above kernel, we need to def
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 186-188
|
||||
|
||||
we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
|
||||
Here, we want to re-tune our kernel only when the shape of input matrices changes.
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 189-191
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 188-191
|
||||
We can also create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
autotune_key = ["M", "N", "K"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 192-193
|
||||
|
||||
We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 193-244
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
src = """
|
||||
#define MAX_GROUP_SIZE 8
|
||||
|
||||
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
int M, int N, int K,
|
||||
int lda, int ldb, int ldc) {
|
||||
int pid = get_program_id(0);
|
||||
int grid_m = (M + MB - 1) / MB;
|
||||
int grid_n = (N + NB - 1) / NB;
|
||||
int width = MAX_GROUP_SIZE * grid_n;
|
||||
int group_id = pid / width;
|
||||
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
int pid_n = (pid % width) / (group_size);
|
||||
int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
int rk[KB] = 0 ... KB;
|
||||
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
|
||||
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
|
||||
float acc[MB, NB] = 0;
|
||||
for (int k = K; k > 0; k -= KB) {
|
||||
acc += (*pa) @ (*pb);
|
||||
pa += KB * 1;
|
||||
pb += KB * ldb;
|
||||
}
|
||||
rm = pid_m * MB + 0 ... MB;
|
||||
rn = pid_n * NB + 0 ... NB;
|
||||
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
|
||||
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def make_kernel(device, dtype):
|
||||
key = (device, dtype)
|
||||
cache = make_kernel.cache
|
||||
if key not in cache:
|
||||
defines = {'TYPE': dtype}
|
||||
cache[key] = triton.kernel(
|
||||
src,
|
||||
device=device,
|
||||
defines=defines,
|
||||
autotune_configs=autotune_configs,
|
||||
autotune_key=autotune_key,
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 245-250
|
||||
|
||||
Autograd Function
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
|
||||
To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 250-271
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 191-213
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
M, Ka = a.shape
|
||||
Kb, N = b.shape
|
||||
assert Ka == Kb, "incompatible dimensions"
|
||||
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
kernel = make_kernel(a.device, a.dtype)
|
||||
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
||||
M, N, Ka, \
|
||||
a.stride(0), b.stride(0), c.stride(0), \
|
||||
grid=grid)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
# return output
|
||||
return c
|
||||
|
||||
|
||||
|
||||
@@ -340,26 +255,27 @@ To do so, we just need to define a `forward` function that takes a two tensors a
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 272-277
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 214-218
|
||||
|
||||
Unit Test
|
||||
-----------
|
||||
|
||||
We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
|
||||
Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
|
||||
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 277-286
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 218-228
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
|
||||
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
|
||||
c_0 = dot(a, b)
|
||||
c_1 = torch.matmul(a, b)
|
||||
#torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
|
||||
|
||||
|
||||
@@ -371,118 +287,47 @@ Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torc
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
||||
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
||||
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
||||
tensor([[-0.0000e+00, 2.9438e+01, -1.3113e-06, ..., 9.7266e+00,
|
||||
-3.4237e-04, -0.0000e+00],
|
||||
[-1.7615e-01, -0.0000e+00, 6.1914e+00, ..., 3.7562e+01,
|
||||
-0.0000e+00, -0.0000e+00],
|
||||
[ 9.9531e+00, 1.9078e+01, -0.0000e+00, ..., 3.6934e+00,
|
||||
1.6578e+01, 2.1031e+01],
|
||||
...,
|
||||
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
||||
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
||||
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
||||
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
||||
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
||||
[ 2.6547e+01, -1.1802e-05, 7.7852e+00, ..., 5.2156e+01,
|
||||
3.5469e+01, 1.5602e+01],
|
||||
[-0.0000e+00, -0.0000e+00, 1.6531e+01, ..., 2.1211e+00,
|
||||
1.7412e+00, 1.1422e+01],
|
||||
[-2.6550e-02, -1.1325e-05, 3.0344e+01, ..., -9.1248e-03,
|
||||
-1.5199e-05, 3.8164e+00]], device='cuda:0', dtype=torch.float16)
|
||||
tensor([[-0.0000e+00, 2.9438e+01, -1.3113e-06, ..., 9.7266e+00,
|
||||
-3.4261e-04, -0.0000e+00],
|
||||
[-1.7615e-01, -0.0000e+00, 6.1914e+00, ..., 3.7562e+01,
|
||||
-0.0000e+00, -0.0000e+00],
|
||||
[ 9.9531e+00, 1.9078e+01, -0.0000e+00, ..., 3.6934e+00,
|
||||
1.6578e+01, 2.1031e+01],
|
||||
...,
|
||||
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
||||
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
||||
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
True
|
||||
[ 2.6547e+01, -1.1802e-05, 7.7852e+00, ..., 5.2156e+01,
|
||||
3.5469e+01, 1.5602e+01],
|
||||
[-0.0000e+00, -0.0000e+00, 1.6531e+01, ..., 2.1211e+00,
|
||||
1.7412e+00, 1.1422e+01],
|
||||
[-2.6550e-02, -1.1325e-05, 3.0344e+01, ..., -9.1324e-03,
|
||||
-1.5199e-05, 3.8164e+00]], device='cuda:0', dtype=torch.float16)
|
||||
tensor(True, device='cuda:0')
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 287-333
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 229-235
|
||||
|
||||
Benchmark
|
||||
--------------
|
||||
|
||||
Installing The CUTLASS Bindings
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
|
||||
For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
|
||||
To install CUTLASS, you need a recent version of cmake:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /path/to/cutlass/
|
||||
git clone https://github.com/NVIDIA/cutlass.git
|
||||
cd cutlass
|
||||
mkdir build
|
||||
cd build
|
||||
wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
|
||||
tar xzvf *.tar.gz
|
||||
|
||||
You can then install CUTLASS as follows for V100
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
|
||||
make -j8 install
|
||||
|
||||
Or as follows for A100:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
|
||||
make -j8 install
|
||||
|
||||
Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
|
||||
Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
|
||||
To re-install Triton with the updated CUTLASS bindings, run the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
||||
export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
||||
pip uninstall -y triton
|
||||
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
|
||||
Which we can test as follows:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 333-339
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import triton
|
||||
c_2 = triton.testing.cutlass_matmul(a, b)
|
||||
print(c_2)
|
||||
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
||||
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
||||
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
||||
...,
|
||||
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
||||
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
||||
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
True
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 340-345
|
||||
|
||||
Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
|
||||
|
||||
Square Matrix Performance
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 345-374
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 235-261
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -493,29 +338,26 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
|
||||
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
silu = torch.nn.SiLU()
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
|
||||
if provider == 'cutlass':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
|
||||
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
|
||||
@@ -523,17 +365,52 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
M cuBLAS Triton
|
||||
0 512.0 20.164923 15.420235
|
||||
1 768.0 58.982401 42.130286
|
||||
2 1024.0 91.180520 72.315584
|
||||
3 1280.0 157.538463 117.028568
|
||||
4 1536.0 150.593357 147.455995
|
||||
5 1792.0 212.064605 193.783168
|
||||
6 2048.0 197.379013 151.146088
|
||||
7 2304.0 243.753804 179.608068
|
||||
8 2560.0 237.449270 217.006622
|
||||
9 2816.0 233.231062 200.987140
|
||||
10 3072.0 236.916752 221.184001
|
||||
11 3328.0 234.499328 210.500857
|
||||
12 3584.0 248.385067 230.552287
|
||||
13 3840.0 252.493157 223.418188
|
||||
14 4096.0 263.689066 244.922869
|
||||
15 4352.0 247.295210 231.639115
|
||||
16 4608.0 274.573240 254.803966
|
||||
17 4864.0 266.298229 245.366501
|
||||
18 5120.0 259.548513 238.312729
|
||||
19 5376.0 252.676487 237.081606
|
||||
20 5632.0 270.685535 249.046163
|
||||
21 5888.0 264.382140 242.069377
|
||||
22 6144.0 262.447761 240.565495
|
||||
23 6400.0 257.028108 235.078047
|
||||
24 6656.0 254.386204 232.699140
|
||||
25 6912.0 252.040861 232.926171
|
||||
26 7168.0 253.193644 231.815375
|
||||
27 7424.0 251.789150 232.860938
|
||||
28 7680.0 250.988932 231.727608
|
||||
29 7936.0 253.622108 232.094986
|
||||
30 8192.0 253.121589 231.859598
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 375-375
|
||||
|
||||
As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 5.861 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 36.230 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
@@ -12,7 +12,7 @@ Below is a gallery of tutorials for writing various basic operations with Triton
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic syntax of the Triton programming language - The best practices for creating PyTorch...">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model used by Triton - The triton.jit decorator, which constitutes the ...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
@@ -33,7 +33,7 @@ Below is a gallery of tutorials for writing various basic operations with Triton
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - The syntax and usage of reduc...">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - The reduction operators in Tr...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
@@ -54,7 +54,7 @@ Below is a gallery of tutorials for writing various basic operations with Triton
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The block-level matrix multiplication operator @ - Multi-dimensional pointer arithmetic - Pro...">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- Block-level matrix multiplications - Multi-dimensional pointer arithmetic - Program re-orderi...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
|
@@ -5,12 +5,12 @@
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**00:25.654** total execution time for **getting-started_tutorials** files:
|
||||
**00:36.230** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:25.654 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 00:36.230 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:00.000 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 00:00.000 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:00.000 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
@@ -17,15 +17,27 @@ Getting Started
|
||||
getting-started/installation
|
||||
getting-started/tutorials/index
|
||||
|
||||
Programming Guide
|
||||
Language Reference
|
||||
-------------------
|
||||
|
||||
- Checkout the :doc:`Python API Documentation <language-reference/python-api/index>`
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Language Reference
|
||||
:hidden:
|
||||
|
||||
language-reference/python-api/index
|
||||
|
||||
|
||||
Going Further
|
||||
------------------
|
||||
|
||||
Check out the following documents to learn more about Triton and how it compares against other DSLs for DNNs:
|
||||
|
||||
- Chapter 1: :doc:`Introduction <programming-guide/chapter-1/introduction>`
|
||||
- Chapter 2: :doc:`Related Work <programming-guide/chapter-2/related-work>`
|
||||
- Chapter 3: :doc:`The Triton-C Language <programming-guide/chapter-3/triton-c>`
|
||||
- Chapter 4: :doc:`The Triton-IR Intermediate Representation <programming-guide/chapter-4/triton-ir>`
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
@@ -33,6 +45,4 @@ Check out the following documents to learn more about Triton and how it compares
|
||||
:hidden:
|
||||
|
||||
programming-guide/chapter-1/introduction
|
||||
programming-guide/chapter-2/related-work
|
||||
programming-guide/chapter-3/triton-c
|
||||
programming-guide/chapter-4/triton-ir
|
||||
programming-guide/chapter-2/related-work
|
@@ -0,0 +1,6 @@
|
||||
triton.arange
|
||||
=============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: arange
|
@@ -0,0 +1,6 @@
|
||||
triton.atomic\_cas
|
||||
==================
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: atomic_cas
|
@@ -0,0 +1,6 @@
|
||||
triton.atomic\_xchg
|
||||
===================
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: atomic_xchg
|
@@ -0,0 +1,6 @@
|
||||
triton.broadcast\_to
|
||||
====================
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: broadcast_to
|
@@ -0,0 +1,6 @@
|
||||
triton.dot
|
||||
==========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: dot
|
@@ -0,0 +1,6 @@
|
||||
triton.exp
|
||||
==========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: exp
|
@@ -0,0 +1,6 @@
|
||||
triton.load
|
||||
===========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: load
|
@@ -0,0 +1,6 @@
|
||||
triton.log
|
||||
==========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: log
|
@@ -0,0 +1,6 @@
|
||||
triton.max
|
||||
==========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: max
|
@@ -0,0 +1,6 @@
|
||||
triton.maximum
|
||||
==============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autodata:: maximum
|
@@ -0,0 +1,6 @@
|
||||
triton.min
|
||||
==========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: min
|
@@ -0,0 +1,6 @@
|
||||
triton.minimum
|
||||
==============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autodata:: minimum
|
@@ -0,0 +1,6 @@
|
||||
triton.multiple\_of
|
||||
===================
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: multiple_of
|
@@ -0,0 +1,6 @@
|
||||
triton.num\_programs
|
||||
====================
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: num_programs
|
@@ -0,0 +1,6 @@
|
||||
triton.program\_id
|
||||
==================
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: program_id
|
@@ -0,0 +1,6 @@
|
||||
triton.ravel
|
||||
============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autodata:: ravel
|
@@ -0,0 +1,6 @@
|
||||
triton.reshape
|
||||
==============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: reshape
|
@@ -0,0 +1,6 @@
|
||||
triton.sigmoid
|
||||
==============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autodata:: sigmoid
|
@@ -0,0 +1,6 @@
|
||||
triton.softmax
|
||||
==============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autodata:: softmax
|
@@ -0,0 +1,6 @@
|
||||
triton.store
|
||||
============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: store
|
@@ -0,0 +1,6 @@
|
||||
triton.sum
|
||||
==========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: sum
|
@@ -0,0 +1,6 @@
|
||||
triton.where
|
||||
============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: where
|
@@ -0,0 +1,6 @@
|
||||
triton.zeros
|
||||
============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: zeros
|
117
_sources/language-reference/python-api/index.rst.txt
Normal file
@@ -0,0 +1,117 @@
|
||||
Python API
|
||||
===========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
|
||||
Programming Model
|
||||
-------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
program_id
|
||||
num_programs
|
||||
|
||||
|
||||
Creation Ops
|
||||
-------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
arange
|
||||
zeros
|
||||
|
||||
|
||||
Shape Manipulation Ops
|
||||
-----------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
broadcast_to
|
||||
reshape
|
||||
ravel
|
||||
|
||||
|
||||
|
||||
Linear Algebra Ops
|
||||
-------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
dot
|
||||
|
||||
Memory Ops
|
||||
--------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
load
|
||||
store
|
||||
atomic_cas
|
||||
atomic_xchg
|
||||
|
||||
|
||||
Indexing Ops
|
||||
--------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
where
|
||||
|
||||
|
||||
Math Ops
|
||||
----------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
exp
|
||||
log
|
||||
sigmoid
|
||||
softmax
|
||||
|
||||
|
||||
Reduction Ops
|
||||
---------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
max
|
||||
min
|
||||
sum
|
||||
|
||||
|
||||
Comparison ops
|
||||
---------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
minimum
|
||||
maximum
|
||||
|
||||
|
||||
Compiler Hint Ops
|
||||
-------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
multiple_of
|
@@ -1,84 +0,0 @@
|
||||
=======================
|
||||
The Triton-C Language
|
||||
=======================
|
||||
|
||||
In the introduction, we stressed the importance of blocked algorithms and described their core principles in pseudo-code. To facilitate their implementation on modern GPU hardware, we present Triton-C, a single-threaded imperative kernel language in which block variables are first-class citizen. This language may be used either directly by developers familiar with C, or as an intermediate language for existing (and future) transcompilers. In this chapter, we describe its differences with C, its Numpy-like semantics and its "Single-Program, Multiple-Data" (SPMD) programming model.
|
||||
|
||||
-------------------
|
||||
Differences with C
|
||||
-------------------
|
||||
|
||||
The syntax of Triton-C is based on that of ANSI C, but was modified and extended to accomodate the semantics and programming model described in the next two subsections. These changes fall into the following categories:
|
||||
|
||||
+++++++++++
|
||||
Extensions
|
||||
+++++++++++
|
||||
|
||||
**Variable declarations**: Triton adds special-purpose syntax for multi-dimensional array declarations (e.g., :code:`int block[16, 16]`), which purposely differs from that of nested arrays (i.e., arrays of pointers) found in ANSI C (e.g., :code:`int block[16][16]`). Block dimensions must be constant but can also be made parametric with the use of pre-processor macros. One-dimensional blocks of integers may be initialized using ellipses (e.g., :code:`int range[16] = 0 ... 16`).
|
||||
|
||||
**Primitive types**: Triton-C supports the following primitive data-types: :code:`bool`, :code:`uint8`, :code:`uint16`, :code:`uint32`, :code:`uint64`, :code:`int8`, :code:`int16`, :code:`int32`, :code:`int64`, :code:`half`, :code:`float`, :code:`double`.
|
||||
|
||||
**Operators and built-in function**: The usual C operators were extended to support element-wise array operations (:code:`+`, :code:`-`, :code:`&&`, :code:`*`, etc.) and complex array operations(:code:`@` for matrix multiplication). Additionally, some built-in functions were added for concurrency (:code:`get_program_id`, :code:`atomic_add`).
|
||||
|
||||
**Slicing and broadcasting**: Multi-dimensional blocks can be broadcast along any particular dimension using numpy-like slicing syntax (e.g., :code:`int array[8, 8] = range[:, newaxis]` for stacking columns). Note that, as of now, slicing blocks to retrieve sub-blocks (or scalars) is forbidden as it is incompatible with the automatic parallelization methods used by our JIT. Reductions can be achieved using a syntax similar to slicing (e.g., :code:`array[+]` for summing an array, or :code:`array[:, max]` for row-wise maximum). Currently supported reduction operators are :code:`+`, :code:`min`, :code:`max`.
|
||||
|
||||
**Masked pointer dereferencement**: Block-level operations in Triton-C are "atomic", in the sense that they execute either completely or not at all. Basic element-wise control-flow for block-level operations can nonetheless be achieved using ternary operators and the *masked pointer dereferencement* operator exemplified below:
|
||||
|
||||
.. code-block:: C
|
||||
:force:
|
||||
|
||||
// create mask
|
||||
bool mask[16, 16] = ...;
|
||||
// conditional addition
|
||||
float x[16, 16] = mask ? a + b : 0;
|
||||
// conditional load
|
||||
float y[16] 16] = mask ? *ptr : 0;
|
||||
// conditional store
|
||||
*?(mask)ptr = y;
|
||||
\end{lstlisting}
|
||||
|
||||
|
||||
+++++++++++++
|
||||
Restrictions
|
||||
+++++++++++++
|
||||
|
||||
The Triton project is still in its infancy. As such, there are quite a few features of ANSI C that are not supported:
|
||||
|
||||
**Non-kernel functions**: Right now, all function definitions must be kernels, i.e. be preceded with the :code:`__global__` attribute. We are aware that this is a severe limitations, and the reason why it exists is because our automatic parallelization engine would not be capable of handling array parameter arguments.
|
||||
|
||||
**Non-primitive types**: Non-primitive types defined with :code:`struct` and :code:`union` are currently not supported, again because it is unclear at this point how these constructs would hook into our block-level data-flow analysis passes.
|
||||
|
||||
**While loops**: We just haven't had time to implement those yet.
|
||||
|
||||
----------------
|
||||
Semantics
|
||||
----------------
|
||||
|
||||
The existence of built-in **blocked** types, variable and operations in Triton-C offers two main benefits. First, it simplifies the structure of blocked programs by hiding important details pertaining to concurrent programming such as memory coalescing, cache management and specialized tensor instrinsics. Second, it opens the door for compilers to perform these optimizations automatically. However, it also means that programs have some kind of *block-level semantics* that does not exist in C. Though some aspects of it (e.g., the :code:`@` operator) are pretty intuitive, one in particular might be puzzling to some GPU programmers: broadcasting semantics.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Broadcasting Semantics
|
||||
+++++++++++++++++++++++
|
||||
|
||||
|
||||
Block variables in Triton are strongly typed, meaning that certain instructions statically require their operands to satisfy strict shape constraints. For example, a scalar may not be added to an array unless it is first appropriately broadcast. *Broadcasting semantics* (first introduced in `Numpy <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_) provides two formal rules for performing these conversions automatically in the case of binary operators: (1) the shape of the lowest-dimension operand is left-padded with ones until both operands have the same dimensionality; and (2) the content of both operands is replicated as many times as needed until their shape is identical. An error is emitted if this cannot be done.
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
int a[16], b[32, 16], c[16, 1];
|
||||
// a is first reshaped to [1, 16]
|
||||
// and then broadcast to [32, 16]
|
||||
int x_1[32, 16] = a[newaxis, :] + b;
|
||||
// Same as above but implicitly
|
||||
int x_2[32, 16] = a + b;
|
||||
// a is first reshaped to [1, 16]
|
||||
// a is broadcast to [16, 16]
|
||||
// c is broadcast to [16, 16]
|
||||
int y[16, 16] = a + c;
|
||||
|
||||
------------------
|
||||
Programming Model
|
||||
------------------
|
||||
|
||||
As discussed in the `CUDA documentation <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_, The execution of CUDA code on GPUs is supported by an `SPMD <https://en.wikipedia.org/wiki/SPMD>`_ programming model in which each kernel instance is associated with an identifiable *thread-block*, itself decomposed into *warps* of 32 *threads*. The Triton programming model is similar, but each kernel is *single-threaded* -- though automatically parallelized -- and associated with a global :code:`program id` which varies from instance to instance. This approach leads to simpler kernels in which CUDA-like concurrency primitives (shared memory synchronization, inter-thread communication, etc.) do not exist. The global program ids associated with each kernel instance can be queried using the :code:`get_program_id(axis)` built-in function where :code:`0 <= axis <= 2`. This is, for example, useful to create e.g., blocks of pointers as shown in the tutorials.
|
||||
|
@@ -1,82 +0,0 @@
|
||||
==========================================
|
||||
The Triton-IR Intermediate Representation
|
||||
==========================================
|
||||
|
||||
Triton-IR is an LLVM-based Intermediate Representation (IR) whose purpose is to provide an environment suitable for block-level program analysis, transformation and optimization.
|
||||
In our implementation, Triton-IR programs are constructed directly from Triton-C after parsing, but they could also be formed directly by higher-level DSLs in the future.
|
||||
Triton-IR and LLVM-IR programs share the same high-level structure, but the former also includes a number of extensions necessary for block-level data-flow analysis.
|
||||
These extensions are crucial for carrying out the optimizations outlined in the next chapter of this document.
|
||||
|
||||
---------------------------------
|
||||
Structure of a Triton-IR Program
|
||||
---------------------------------
|
||||
|
||||
++++++++
|
||||
Modules
|
||||
++++++++
|
||||
|
||||
At the highest level, Triton-IR programs consist of one or multiple basic units of compilation known as *modules*. These modules are compiled independently from one another, and eventually aggregated by a linker whose role is to resolve forward declarations and adequately merge global definitions. Each module itself is composed of functions, global variables, constants and other miscellaneous symbols such as metadata and attributes.
|
||||
|
||||
++++++++++
|
||||
Functions
|
||||
++++++++++
|
||||
|
||||
Triton-IR function definitions consist of a return type, a name and a potentially empty arguments list. Additional visibility, alignment and linkage specifiers can be added if desired. Function attributes (such as inlining hints) and parameter attributes (such as "readonly", aliasing hints) can also be specified, allowing compiler backends to perform more aggressive optimizations by, for instance, making better use of non-coherent caches found on NVIDIA GPUs. This header is followed by a body composed of a list of basic blocks whose interdependencies form the Control Flow Graph (CFG) of the function.
|
||||
|
||||
+++++++++++++
|
||||
Basic Blocks
|
||||
+++++++++++++
|
||||
|
||||
Basic blocks are straight-line code sequences that may only contain so-called *terminator* instructions (i.e., branching, return) at their end. To simplify program analysis, Triton-IR uses the Static Single Assignment (SSA) form, meaning that each variable in each basic block must be (1) assigned to only once and (2) defined before being used. In so doing, each basic block implicitly defines a Data-Flow Graph (DFG). In our case, the SSA form is created directly from Triton-C's Abstract Syntax Trees (ASTs) using an algorithm from the literature [BRAUN13]_.
|
||||
|
||||
---------------------------------
|
||||
Block-Level Dataflow Analysis
|
||||
---------------------------------
|
||||
|
||||
+++++++
|
||||
Types
|
||||
+++++++
|
||||
|
||||
Multi-dimensional blocks are at the center of data-flow analysis in Triton-JIT. They can be declared using syntax similar to vector declarations in LLVM-IR. For example, :code:`i32<8, 8>` is the type corresponding to :math:`8 \times 8` blocks of 32-bit integers. Note that there is no preprocessor in Triton-IR, hence parametric shape values must be resolved before programs are generated. In our case, this is done by Triton-JIT's auto-tuner.
|
||||
|
||||
+++++++++++++
|
||||
Instructions
|
||||
+++++++++++++
|
||||
|
||||
Triton-IR introduces a set of *reblocking* instructions whose purpose is to support broadcasting semantics as described in the previous chapter. The :code:`reshape` instruction creates a block of the specified shape using the raw data from its input argument. This is particularly useful to re-interpret variables as higher-dimensional arrays by padding their input shapes with ones in preparation for broadcasting. The :code:`broadcast` instruction creates a block of the specified shapes by replicating its input argument as many times as necessary along dimensions of size 1 -- as shown below for the :code:`broadcast<3,3>` instruction.
|
||||
|
||||
|pic1| and |pic2|
|
||||
|
||||
.. |pic1| image:: broadcast-1.png
|
||||
:width: 40%
|
||||
|
||||
.. |pic2| image:: broadcast-2.png
|
||||
:width: 40%
|
||||
|
||||
Usual scalar instructions (:code:`cmp`, :code:`getelementptr`, :code:`add`, :code:`load`...) were preserved and extended to signify element-wise operations when applicable. Finally, Triton-IR also exposes specialized arithmetic instructions for reductions (:code:`reduce`) and matrix multiplications (:code:`dot`).
|
||||
|
||||
----------------------------------
|
||||
Block-Level Control Flow Analysis
|
||||
----------------------------------
|
||||
|
||||
In Triton-IR, operations on block variables are atomic: they execute either in full or not at all. As a result, traditional control flow structures (e.g., conditional, loops) are not applicable to individual block elements. This is problematic, since a program may need to e.g., partially guard blocked loads against memory access violations.
|
||||
|
||||
This could be potentially solved through the use of the Predicated SSA (PSSA) [CARTER99]_ [STOUTCHININ01]_ form for Triton-IR. However, this would create a lot of unnecessary complexity for GPUs, where the benefits of PSSA are close to none as divergent program paths within warps are serialized anyway. Therefore, recent versions of Triton handle intra-block control flow in a much simpler way, using conditional instructions such as :code:`select`, :code:`masked_load` and :code:`masked_store`:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// For all indices [idx], return cond[idx] ? true_value[idx] : false_value[idx];
|
||||
select TYPE<TS1, ..., TSN> cond, true_value, false_value;
|
||||
// For all indices [idx], return cond[idx] ? *true_addr[idx] : false_value[idx];
|
||||
masked_load TYPE<TS1, ..., TSN> cond, true_addr, false_value;
|
||||
// For all indices [idx], execute *true_addr[idx] = true_value[idx] if cond[idx]
|
||||
masked_store TYPE<TS1, ..., TSN> cond, true_addr, true_value;
|
||||
|
||||
|
||||
------------
|
||||
References
|
||||
------------
|
||||
|
||||
.. [BRAUN13] M. Braun et al., "Simple and Efficient Construction of Static Single Assignment Form", CC 2013
|
||||
.. [CARTER99] L. Carter et al., "Predicated Static Single Assignment", PACT 1999
|
||||
.. [STOUTCHININ01] A. Stoutchinin et al., "Efficient Static Single Assignment Form for Predication", MICRO 2001
|
@@ -1,69 +0,0 @@
|
||||
==============
|
||||
Introduction
|
||||
==============
|
||||
|
||||
--------------
|
||||
Motivations
|
||||
--------------
|
||||
|
||||
Over the past decade, Deep Neural Networks (DNNs) have emerged as an important class of Machine Learning (ML) models, capable of achieving state-of-the-art performance across many domains ranging from natural language processing [1]_ to computer vision [2]_ to computational neuroscience [3]_. The strength of these models lies in their hierarchical structure, composed of a sequence of parametric (e.g., convolutional) and non-parametric (e.g., rectified linearity) *layers*. This pattern, though notoriously computationally expensive, also generates a large amount of highly parallelizable work particularly well suited for multi- and many- core processors.
|
||||
|
||||
As a consequence, Graphics Processing Units (GPUs) have become a cheap and accessible resource for exploring and/or deploying novel research ideas in the field. This trend has been accelerated by the release of several frameworks for General-Purpose GPU (GPGPU) computing, such as CUDA and OpenCL, which have made the development of high-performance programs easier. Yet, GPUs remain incredibly challenging to optimize for locality and parallelism, especially for computations that cannot be efficiently implemented using a combination of pre-existing optimized primitives. To make matters worse, GPU architectures are also rapidly evolving and specializing, as evidenced by the addition of tensor cores to NVIDIA (and more recently AMD) micro-architectures.
|
||||
|
||||
This tension between the computational opportunities offered by DNNs and the practical difficulty of GPU programming has created substantial academic and industrial interest for Domain-Specific Languages (DSLs) and compilers. Regrettably, these systems -- whether they be based on polyhedral machinery (*e.g.*, Tiramisu [4]_, Tensor Comprehensions [5]_) or scheduling languages (*e.g.*, Halide [6]_, TVM [7]_) -- remain less flexible and (for the same algorithm) markedly slower than the best handwritten compute kernels available in libraries like `cuBLAS <https://docs.nvidia.com/cuda/cublas/index.html>`_, `cuDNN <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>`_ or `TensorRT <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html>`_.
|
||||
|
||||
The main premise of this project is the following: programming paradigms based on blocked algorithms [8]_ can facilitate the construction of high-performance compute kernels for neural networks. We specifically revisit traditional "Single Program, Multiple Data" (SPMD [9]_) execution models for GPUs, and propose a variant in which programs -- rather than threads -- are blocked. For example, in the case of matrix multiplication, CUDA and Triton differ as follows:
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| CUDA Programming Model | Triton Programming Model |
|
||||
| | |
|
||||
| (Scalar Program, Blocked Threads) | (Blocked Program, Scalar Threads) |
|
||||
+=====================================================+=====================================================+
|
||||
| | |
|
||||
|.. code-block:: C |.. code-block:: C |
|
||||
| | :force: |
|
||||
| | |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int m = 0; i < M; m++) | for(int m = 0; m < M; m += MB) |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int n = 0; j < N; n++){ | for(int n = 0; n < N; n += NB){ |
|
||||
| float acc = 0; | float acc[MB, NB] = 0; |
|
||||
| for(int k = 0; k < K;k ++) | for(int k = 0; k < K; k += KB) |
|
||||
| acc += A[i, k]* B[k, j]; | acc += A[m:m+MB, k:k+KB] |
|
||||
| | @ B[k:k+KB, n:n+NB]; |
|
||||
| C[i, j] = acc; | C[m:m+MB, n:n+NB] = acc; |
|
||||
| } | } |
|
||||
| | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| |pic1| | |pic2| |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
|
||||
.. |pic1| image:: cuda-parallel-matmul.png
|
||||
|
||||
.. |pic2| image:: triton-parallel-matmul.png
|
||||
|
||||
A key benefit of this approach is that it leads to block-structured iteration spaces that offer programmers more flexibility than existing DSLs when implementing sparse operations, all while allowing compilers to aggressively optimize programs for data locality and parallelism.
|
||||
|
||||
--------------
|
||||
Challenges
|
||||
--------------
|
||||
|
||||
The main challenge posed by our proposed paradigm is that of work scheduling, i.e., how the work done by each program instance should be partitioned for efficient execution on modern GPUs. To address this issue, the Triton compiler makes heavy use of *block-level data-flow analysis*, a technique for scheduling iteration blocks statically based on the control- and data-flow structure of the target program. The resulting system actually works surprisingly well: our compiler manages to apply a broad range of interesting optimization automatically (e.g., automatic coalescing, thread swizzling, pre-fetching, automatic vectorization, tensor core-aware instruction selection, shared memory allocation/synchronization, asynchronous copy scheduling). Of course doing all this is not trivial; one of the purposes of this guide is to give you a sense of how it works.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [1] Sutskever et al., "Sequence to Sequence Learning with Neural Networks", NIPS 2014
|
||||
.. [2] Redmon et al., "You Only Look Once: Unified, Real-Time Object Detection", CVPR 2016
|
||||
.. [3] Lee et al., "Superhuman Accuracy on the SNEMI3D Connectomics Challenge", ArXiV 2017
|
||||
.. [4] Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [5] Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [6] Ragan-Kelley et al., "Halide: A Language and Compiler for Optimizing Parallelism, Locality, and Recomputation in Image Processing Pipelines", PLDI 2013
|
||||
.. [7] Chen et al., "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning", OSDI 2018
|
||||
.. [8] Lam et al., "The Cache Performance and Optimizations of Blocked Algorithms", ASPLOS 1991
|
||||
.. [9] Auguin et al., "Opsila: an advanced SIMD for numerical analysis and signal processing", EUROMICRO 1983
|
@@ -1,209 +0,0 @@
|
||||
==============
|
||||
Related Work
|
||||
==============
|
||||
|
||||
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlights its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
|
||||
|
||||
-----------------------
|
||||
Polyhedral Compilation
|
||||
-----------------------
|
||||
|
||||
Traditional compilers typically rely on intermediate representations, such as LLVM-IR [1]_, that encode control flow information using (un)conditional branches. This relatively low-level format makes it difficult to statically analyze the runtime behavior (e.g., cache misses) of input programs, and to automatically optimize loops accordingly through the use of tiling [2]_, fusion [3]_ and interchange [4]_. To solve this issue, polyhedral compilers [5]_ rely on program representations that have statically predictable control flow, thereby enabling aggressive compile-time program transformations for data locality and parallelism. Though this strategy has been adopted by many languages and compilers for DNNs such as Tiramisu [6]_, Tensor Comprehensions [7]_, Diesel [8]_ and the Affine dialect in MLIR [9]_, it also comes with a number of limitations that will be described later.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Program Representation
|
||||
+++++++++++++++++++++++
|
||||
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic1| |
|
||||
| | |
|
||||
| for(int i = 0; i < 3; i++) | |
|
||||
| for(int j = i; j < 5; j++) | |
|
||||
| A[i][j] = 0; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic1| image:: polyhedral-iteration.png
|
||||
:width: 300
|
||||
|
||||
Polyhedral compilers focus on a class of programs commonly known as **Static Control Parts** (SCoP), *i.e.*, maximal sets of consecutive statements in which conditionals and loop bounds are affine functions of surrounding loop indices and global invariant parameters. As shown above, programs in this format always lead to iteration domains that are bounded by affine inequalities, i.e., polyhedral. These polyhedra can also be defined algebraically; for the above example:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathcal{P} = \{ i, j \in \mathbb{Z}^2
|
||||
~|~
|
||||
\begin{pmatrix}
|
||||
1 & 0 \\
|
||||
-1 & 0 \\
|
||||
-1 & 1 \\
|
||||
0 & -1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i \\
|
||||
j
|
||||
\end{pmatrix}
|
||||
+
|
||||
\begin{pmatrix}
|
||||
0 \\
|
||||
2 \\
|
||||
0 \\
|
||||
4
|
||||
\end{pmatrix}
|
||||
\geq
|
||||
0
|
||||
\}
|
||||
|
||||
|
||||
Each point :math:`(i, j)` in :math:`\mathcal{P}` represents a *polyhedral statement*, that is a program statement which (1) does not induce control-flow side effects (e.g., :code:`for`, :code:`if`, :code:`break`) and (2) contains only affine functions of loop indices and global parameters in array accesses. To facilitate alias analysis, array accesses are also mathematically abstracted, using so-called *access function*. In other words, :code:`A[i][j]` is simply :code:`A[f(i,j)]` where the access function :math:`f` is defined by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(i, j) = \begin{pmatrix}
|
||||
1 & 0\\
|
||||
0 & 1\\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i\\
|
||||
j
|
||||
\end{pmatrix}
|
||||
=
|
||||
(i, j)
|
||||
|
||||
|
||||
Note that the iteration domains of an SCoP does not specify the order in which its statements shall execute. In fact, this iteration domain may be traversed in many different possible legal orders, i.e. *schedules*. Formally, a schedule is defined as a p-dimensional affine transformation :math:`\Theta` of loop indices :math:`\mathbf{x}` and global invariant parameters :math:`\mathbf{g}`:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = T_S \begin{pmatrix}
|
||||
\vec{x}\\
|
||||
\vec{g}\\
|
||||
1
|
||||
\end{pmatrix}
|
||||
\qquad
|
||||
T_S \in \mathbb{Z} ^{p \times (\text{dim}(\mathbf{x}) + \text{dim}(\mathbf{g}) + 1)}
|
||||
|
||||
|
||||
Where :math:`\Theta_S(\mathbf{x})` is a p-dimensional vector representing the slowest to fastest growing indices (from left to right) when traversing the loop nest surrounding :math:`S`. For the code shown above, the original schedule defined by the loop nest in C can be retrieved by using:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = \begin{pmatrix}
|
||||
1 & 0 \\
|
||||
0 & 1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
=
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
|
||||
|
||||
where :math:`i` and :math:`j` are respectively the slowest and fastest growing loop indices in the nest. If :math:`T_S` is a vector (resp. tensor), then :math:`\Theta_S` is a said to be one-dimensional (resp. multi-dimensional).
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
Programs amenable to polyhedral compilation can be aggressively transformed and optimized. Most of these transformations actually boil down to the production of schedules and iteration domains that enable loop transformations promoting parallelism and spatial/temporal data locality (e.g., fusion, interchange, tiling, parallelization).
|
||||
|
||||
Polyhedral compilers can also automatically go through complex verification processes to ensure that the semantics of their input program is preserved throughout this optimization phase. Note that polyhedral optimizers are not incompatible with more standard optimization techniques. In fact, it is not uncommon for these systems to be implemented as a set of LLVM passes that can be run ahead of more traditional compilation techniques [10]_.
|
||||
|
||||
All in all, polyhedral machinery is extremely powerful, when applicable. It has been shown to support most common loop transformations, and has indeed achieved performance comparable to state-of-the-art GPU libraries for dense matrix multiplication [8]_. Additionally, it is also fully automatic and doesn't require any hint from programmers apart from source-code in a C-like format.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
Unfortunately, polyhedral compilers suffer from two major limitations that have prevented its adoption as a universal method for code generation in neural networks.
|
||||
|
||||
First, the set of possible program transformations $\Omega = \{ \Theta_S ~|~ S \in \text{program} \}$ is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [11]_.
|
||||
|
||||
Second, the polyhedral framework is not very generally applicable; SCoPs are relatively common [12]_ but require loop bounds and array subscripts to be affine functions of loop indices, which typically only occurs in regular, dense computations. For this reason, this framework still has to be successfully applied to sparse -- or even structured-sparse -- neural networks, whose importance has been rapidly rising over the past few years.
|
||||
|
||||
On the other hand, blocked program representations advocated by this dissertation are less restricted in scope and can achieve close to peak performance using standard dataflow analysis.
|
||||
|
||||
-----------------------
|
||||
Scheduling Languages
|
||||
-----------------------
|
||||
|
||||
Separation of concerns \cite{dijkstra82} is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
|
||||
|
||||
.. code-block:: python
|
||||
:linenos:
|
||||
|
||||
// algorithm
|
||||
Var x("x"), y("y");
|
||||
Func matmul("matmul");
|
||||
RDom k(0, matrix_size);
|
||||
RVar ki;
|
||||
matmul(x, y) = 0.0f;
|
||||
matmul(x, y) += A(k, y) * B(x, k);
|
||||
// schedule
|
||||
Var xi("xi"), xo("xo"), yo("yo"), yi("yo"), yii("yii"), xii("xii");
|
||||
matmul.vectorize(x, 8);
|
||||
matmul.update(0)
|
||||
.split(x, x, xi, block_size).split(xi, xi, xii, 8)
|
||||
.split(y, y, yi, block_size).split(yi, yi, yii, 4)
|
||||
.split(k, k, ki, block_size)
|
||||
.reorder(xii, yii, xi, ki, yi, k, x, y)
|
||||
.parallel(y).vectorize(xii).unroll(xi).unroll(yii);
|
||||
|
||||
|
||||
The resulting code may however not be completely portable, as schedules can sometimes rely on execution models (e.g., SPMD) or hardware intrinsics (e.g., matrix-multiply-accumulate) that are not widely available. This issue can be mitigated by auto-scheduling mechanisms [13]_.
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
The main advantage of this approach is that it allows programmers to write an algorithm *only once*, and focus on performance optimization separately. It makes it possible to manually specify optimizations that a polyhedral compiler wouldn't be able to figure out automatically using static data-flow analysis.
|
||||
|
||||
Scheduling languages are, without a doubt, one of the most popular approaches for neural network code generation. The most popular system for this purpose is probably TVM, which provides good performance across a wide range of platforms as well as built-in automatic scheduling mechanisms.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse com-putations, whose iteration spaces may be irregular.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic2| |
|
||||
| | |
|
||||
| for(int i = 0; i < 4; i++) | |
|
||||
| for(int j = 0; j < 4; j++) | |
|
||||
| float acc = 0; | |
|
||||
| for(int k = 0; k < K[i]; k++) | |
|
||||
| acc += A[i][col[i,k]]*B[k][j] | |
|
||||
| C[i][j] = acc; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic2| image:: halide-iteration.png
|
||||
:width: 300
|
||||
|
||||
On the other hand, the block-based program representation that we advocate for through this work allows for block-structured iteration spaces and allows programmers to manually handle load-balancing as they wish.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [1] Lattner et al., "LLVM: a compilation framework for lifelong program analysis transformation"
|
||||
.. [2] Wolfe, "More Iteration Space Tiling", SC 1989
|
||||
.. [3] Darte, "On the Complexity of Loop Fusion", PACT 1999
|
||||
.. [4] Allen et al., "Automatic Loop Interchange", SIGPLAN Notices 1984
|
||||
.. [5] Ancourt et al., "Scanning Polyhedra with DO Loops", PPoPP 1991
|
||||
.. [6] Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [7] Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [8] Elango et al. "Diesel: DSL for Linear Algebra and Neural Net Computations on GPUs", MAPL 2018
|
||||
.. [9] Lattner et al., "MLIR Primer: A Compiler Infrastructure for the End of Moore’s Law", Arxiv 2019
|
||||
.. [10] Grosser et al., "Polly - Performing Polyhedral Optimizations on a Low-Level Intermediate Representation", Parallel Processing Letters 2012
|
||||
.. [11] Sato et al., "An Autotuning Framework for Scalable Execution of Tiled Code via Iterative Polyhedral Compilation", TACO 2019
|
||||
.. [12] Girbal et al., "Semi-Automatic Composition of Loop Transformations for Deep Parallelism and Memory Hierarchies", International Journal of Parallel Programming 2006
|
||||
.. [13] Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016
|
@@ -1,83 +0,0 @@
|
||||
=======================
|
||||
The Triton-C Language
|
||||
=======================
|
||||
|
||||
In the introduction, we stressed the importance of blocked algorithms and described their core principles in pseudo-code. To facilitate their implementation on modern GPU hardware, we present Triton-C, a single-threaded imperative kernel language in which block variables are first-class citizen. This language may be used either directly by developers familiar with C, or as an intermediate language for existing (and future) transcompilers. In this chapter, we describe its differences with C, its Numpy-like semantics and its "Single-Program, Multiple-Data" (SPMD) programming model.
|
||||
|
||||
-------------------
|
||||
Differences with C
|
||||
-------------------
|
||||
|
||||
The syntax of Triton-C is based on that of ANSI C, but was modified and extended to accomodate the semantics and programming model described in the next two subsections. These changes fall into the following categories:
|
||||
|
||||
+++++++++++
|
||||
Extensions
|
||||
+++++++++++
|
||||
|
||||
**Variable declarations**: Triton adds special-purpose syntax for multi-dimensional array declarations (e.g., :code:`int block[16, 16]`), which purposely differs from that of nested arrays (i.e., arrays of pointers) found in ANSI C (e.g., :code:`int block[16][16]`). Block dimensions must be constant but can also be made parametric with the use of pre-processor macros. One-dimensional blocks of integers may be initialized using ellipses (e.g., :code:`int range[16] = 0 ... 16`).
|
||||
|
||||
**Primitive types**: Triton-C supports the following primitive data-types: :code:`bool`, :code:`uint8`, :code:`uint16`, :code:`uint32`, :code:`uint64`, :code:`int8`, :code:`int16`, :code:`int32`, :code:`int64`, :code:`half`, :code:`float`, :code:`double`.
|
||||
|
||||
**Operators and built-in function**: The usual C operators were extended to support element-wise array operations (:code:`+`, :code:`-`, :code:`&&`, :code:`*`, etc.) and complex array operations(:code:`@` for matrix multiplication). Additionally, some built-in functions were added for concurrency (:code:`get_program_id`, :code:`atomic_add`).
|
||||
|
||||
**Slicing and broadcasting**: Multi-dimensional blocks can be broadcast along any particular dimension using numpy-like slicing syntax (e.g., :code:`int array[8, 8] = range[:, newaxis]` for stacking columns). Note that, as of now, slicing blocks to retrieve sub-blocks (or scalars) is forbidden as it is incompatible with the automatic parallelization methods used by our JIT. Reductions can be achieved using a syntax similar to slicing (e.g., :code:`array[+]` for summing an array, or :code:`array[:, max]` for row-wise maximum). Currently supported reduction operators are :code:`+`, :code:`min`, :code:`max`.
|
||||
|
||||
**Masked pointer dereferencement**: Block-level operations in Triton-C are "atomic", in the sense that they execute either completely or not at all. Basic element-wise control-flow for block-level operations can nonetheless be achieved using ternary operators and the *masked pointer dereferencement* operator exemplified below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// create mask
|
||||
bool mask[16, 16] = ...;
|
||||
// conditional addition
|
||||
float x[16, 16] = mask ? a + b : 0;
|
||||
// conditional load
|
||||
float y[16] 16] = mask ? *ptr : 0;
|
||||
// conditional store
|
||||
*?(mask)ptr = y;
|
||||
\end{lstlisting}
|
||||
|
||||
|
||||
+++++++++++++
|
||||
Restrictions
|
||||
+++++++++++++
|
||||
|
||||
The Triton project is still in its infancy. As such, there are quite a few features of ANSI C that are not supported:
|
||||
|
||||
**Non-kernel functions**: Right now, all function definitions must be kernels, i.e. be preceded with the :code:`__global__` attribute. We are aware that this is a severe limitations, and the reason why it exists is because our automatic parallelization engine would not be capable of handling array parameter arguments.
|
||||
|
||||
**Non-primitive types**: Non-primitive types defined with :code:`struct` and :code:`union` are currently not supported, again because it is unclear at this point how these constructs would hook into our block-level data-flow analysis passes.
|
||||
|
||||
**While loops**: We just haven't had time to implement those yet.
|
||||
|
||||
----------------
|
||||
Semantics
|
||||
----------------
|
||||
|
||||
The existence of built-in **blocked** types, variable and operations in Triton-C offers two main benefits. First, it simplifies the structure of blocked programs by hiding important details pertaining to concurrent programming such as memory coalescing, cache management and specialized tensor instrinsics. Second, it opens the door for compilers to perform these optimizations automatically. However, it also means that programs have some kind of *block-level semantics* that does not exist in C. Though some aspects of it (e.g., the :code:`@` operator) are pretty intuitive, one in particular might be puzzling to some GPU programmers: broadcasting semantics.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Broadcasting Semantics
|
||||
+++++++++++++++++++++++
|
||||
|
||||
|
||||
Block variables in Triton are strongly typed, meaning that certain instructions statically require their operands to satisfy strict shape constraints. For example, a scalar may not be added to an array unless it is first appropriately broadcast. *Broadcasting semantics* (first introduced in `Numpy <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_) provides two formal rules for performing these conversions automatically in the case of binary operators: (1) the shape of the lowest-dimension operand is left-padded with ones until both operands have the same dimensionality; and (2) the content of both operands is replicated as many times as needed until their shape is identical. An error is emitted if this cannot be done.
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
int a[16], b[32, 16], c[16, 1];
|
||||
// a is first reshaped to [1, 16]
|
||||
// and then broadcast to [32, 16]
|
||||
int x_1[32, 16] = a[newaxis, :] + b;
|
||||
// Same as above but implicitly
|
||||
int x_2[32, 16] = a + b;
|
||||
// a is first reshaped to [1, 16]
|
||||
// a is broadcast to [16, 16]
|
||||
// c is broadcast to [16, 16]
|
||||
int y[16, 16] = a + c;
|
||||
|
||||
------------------
|
||||
Programming Model
|
||||
------------------
|
||||
|
||||
As discussed in the `CUDA documentation <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_, The execution of CUDA code on GPUs is supported by an `SPMD <https://en.wikipedia.org/wiki/SPMD>`_ programming model in which each kernel instance is associated with an identifiable *thread-block*, itself decomposed into *warps* of 32 *threads*. The Triton programming model is similar, but each kernel is *single-threaded* -- though automatically parallelized -- and associated with a global :code:`program id` which varies from instance to instance. This approach leads to simpler kernels in which CUDA-like concurrency primitives (shared memory synchronization, inter-thread communication, etc.) do not exist. The global program ids associated with each kernel instance can be queried using the :code:`get_program_id(axis)` built-in function where :code:`0 <= axis <= 2`. This is, for example, useful to create e.g., blocks of pointers as shown in the tutorials.
|
||||
|
@@ -277,25 +277,25 @@ p.rubric {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
img.align-left, .figure.align-left, object.align-left {
|
||||
img.align-left, figure.align-left, .figure.align-left, object.align-left {
|
||||
clear: left;
|
||||
float: left;
|
||||
margin-right: 1em;
|
||||
}
|
||||
|
||||
img.align-right, .figure.align-right, object.align-right {
|
||||
img.align-right, figure.align-right, .figure.align-right, object.align-right {
|
||||
clear: right;
|
||||
float: right;
|
||||
margin-left: 1em;
|
||||
}
|
||||
|
||||
img.align-center, .figure.align-center, object.align-center {
|
||||
img.align-center, figure.align-center, .figure.align-center, object.align-center {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
|
||||
img.align-default, .figure.align-default {
|
||||
img.align-default, figure.align-default, .figure.align-default {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
@@ -319,7 +319,8 @@ img.align-default, .figure.align-default {
|
||||
|
||||
/* -- sidebars -------------------------------------------------------------- */
|
||||
|
||||
div.sidebar {
|
||||
div.sidebar,
|
||||
aside.sidebar {
|
||||
margin: 0 0 0.5em 1em;
|
||||
border: 1px solid #ddb;
|
||||
padding: 7px;
|
||||
@@ -377,12 +378,14 @@ div.body p.centered {
|
||||
/* -- content of sidebars/topics/admonitions -------------------------------- */
|
||||
|
||||
div.sidebar > :last-child,
|
||||
aside.sidebar > :last-child,
|
||||
div.topic > :last-child,
|
||||
div.admonition > :last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
div.sidebar::after,
|
||||
aside.sidebar::after,
|
||||
div.topic::after,
|
||||
div.admonition::after,
|
||||
blockquote::after {
|
||||
@@ -455,20 +458,22 @@ td > :last-child {
|
||||
|
||||
/* -- figures --------------------------------------------------------------- */
|
||||
|
||||
div.figure {
|
||||
div.figure, figure {
|
||||
margin: 0.5em;
|
||||
padding: 0.5em;
|
||||
}
|
||||
|
||||
div.figure p.caption {
|
||||
div.figure p.caption, figcaption {
|
||||
padding: 0.3em;
|
||||
}
|
||||
|
||||
div.figure p.caption span.caption-number {
|
||||
div.figure p.caption span.caption-number,
|
||||
figcaption span.caption-number {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
div.figure p.caption span.caption-text {
|
||||
div.figure p.caption span.caption-text,
|
||||
figcaption span.caption-text {
|
||||
}
|
||||
|
||||
/* -- field list styles ----------------------------------------------------- */
|
||||
|
BIN
_static/fonts/Inconsolata-Bold.ttf
Normal file
BIN
_static/fonts/Inconsolata-Regular.ttf
Normal file
BIN
_static/fonts/Inconsolata.ttf
Normal file
BIN
_static/fonts/Lato-Bold.ttf
Normal file
BIN
_static/fonts/Lato-Regular.ttf
Normal file
BIN
_static/fonts/Lato/lato-bold.eot
Normal file
BIN
_static/fonts/Lato/lato-bold.ttf
Normal file
BIN
_static/fonts/Lato/lato-bold.woff
Normal file
BIN
_static/fonts/Lato/lato-bold.woff2
Normal file
BIN
_static/fonts/Lato/lato-bolditalic.eot
Normal file
BIN
_static/fonts/Lato/lato-bolditalic.ttf
Normal file
BIN
_static/fonts/Lato/lato-bolditalic.woff
Normal file
BIN
_static/fonts/Lato/lato-bolditalic.woff2
Normal file
BIN
_static/fonts/Lato/lato-italic.eot
Normal file
BIN
_static/fonts/Lato/lato-italic.ttf
Normal file
BIN
_static/fonts/Lato/lato-italic.woff
Normal file
BIN
_static/fonts/Lato/lato-italic.woff2
Normal file
BIN
_static/fonts/Lato/lato-regular.eot
Normal file
BIN
_static/fonts/Lato/lato-regular.ttf
Normal file
BIN
_static/fonts/Lato/lato-regular.woff
Normal file
BIN
_static/fonts/Lato/lato-regular.woff2
Normal file
BIN
_static/fonts/RobotoSlab-Bold.ttf
Normal file
BIN
_static/fonts/RobotoSlab-Regular.ttf
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff
Normal file
BIN
_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2
Normal file
BIN
_static/fonts/fontawesome-webfont.eot
Normal file
2671
_static/fonts/fontawesome-webfont.svg
Normal file
After Width: | Height: | Size: 434 KiB |
BIN
_static/fonts/fontawesome-webfont.ttf
Normal file
BIN
_static/fonts/fontawesome-webfont.woff
Normal file
BIN
_static/fonts/fontawesome-webfont.woff2
Normal file
4
_static/js/modernizr.min.js
vendored
Normal file
146
genindex.html
@@ -92,12 +92,14 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="language-reference/python-api/index.html">Python API</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -167,8 +169,148 @@
|
||||
<h1 id="index">Index</h1>
|
||||
|
||||
<div class="genindex-jumpbox">
|
||||
<a href="#A"><strong>A</strong></a>
|
||||
| <a href="#B"><strong>B</strong></a>
|
||||
| <a href="#D"><strong>D</strong></a>
|
||||
| <a href="#E"><strong>E</strong></a>
|
||||
| <a href="#L"><strong>L</strong></a>
|
||||
| <a href="#M"><strong>M</strong></a>
|
||||
| <a href="#N"><strong>N</strong></a>
|
||||
| <a href="#P"><strong>P</strong></a>
|
||||
| <a href="#R"><strong>R</strong></a>
|
||||
| <a href="#S"><strong>S</strong></a>
|
||||
| <a href="#W"><strong>W</strong></a>
|
||||
| <a href="#Z"><strong>Z</strong></a>
|
||||
|
||||
</div>
|
||||
<h2 id="A">A</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.arange.html#triton.arange">arange() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.atomic_cas.html#triton.atomic_cas">atomic_cas() (in module triton)</a>
|
||||
</li>
|
||||
<li><a href="language-reference/python-api/generated/triton.atomic_xchg.html#triton.atomic_xchg">atomic_xchg() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="B">B</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.broadcast_to.html#triton.broadcast_to">broadcast_to() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="D">D</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.dot.html#triton.dot">dot() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="E">E</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.exp.html#triton.exp">exp() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="L">L</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.load.html#triton.load">load() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.log.html#triton.log">log() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="M">M</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.max.html#triton.max">max() (in module triton)</a>
|
||||
</li>
|
||||
<li><a href="language-reference/python-api/generated/triton.maximum.html#triton.maximum">maximum (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.min.html#triton.min">min() (in module triton)</a>
|
||||
</li>
|
||||
<li><a href="language-reference/python-api/generated/triton.minimum.html#triton.minimum">minimum (in module triton)</a>
|
||||
</li>
|
||||
<li><a href="language-reference/python-api/generated/triton.multiple_of.html#triton.multiple_of">multiple_of() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="N">N</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.num_programs.html#triton.num_programs">num_programs() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="P">P</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.program_id.html#triton.program_id">program_id() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="R">R</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.ravel.html#triton.ravel">ravel (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.reshape.html#triton.reshape">reshape() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="S">S</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.sigmoid.html#triton.sigmoid">sigmoid (in module triton)</a>
|
||||
</li>
|
||||
<li><a href="language-reference/python-api/generated/triton.softmax.html#triton.softmax">softmax (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.store.html#triton.store">store() (in module triton)</a>
|
||||
</li>
|
||||
<li><a href="language-reference/python-api/generated/triton.sum.html#triton.sum">sum() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="W">W</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.where.html#triton.where">where() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
<h2 id="Z">Z</h2>
|
||||
<table style="width: 100%" class="indextable genindextable"><tr>
|
||||
<td style="width: 33%; vertical-align: top;"><ul>
|
||||
<li><a href="language-reference/python-api/generated/triton.zeros.html#triton.zeros">zeros() (in module triton)</a>
|
||||
</li>
|
||||
</ul></td>
|
||||
</tr></table>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
@@ -102,12 +102,14 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../language-reference/python-api/index.html">Python API</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
@@ -95,8 +95,6 @@
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
@@ -105,12 +103,14 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -189,130 +189,62 @@ to download the full example code</p>
|
||||
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a simple vector addition using Triton and learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The basic syntax of the Triton programming language</p></li>
|
||||
<li><p>The best practices for creating PyTorch custom operators using the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API</p></li>
|
||||
<li><p>The basic programming model used by Triton</p></li>
|
||||
<li><p>The <cite>triton.jit</cite> decorator, which constitutes the main entry point for writing Triton kernels.</p></li>
|
||||
<li><p>The best practices for validating and benchmarking custom ops against native reference implementations</p></li>
|
||||
</ul>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Each compute kernel is declared using the <code class="code docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel
|
||||
on different chunks of data (See the <a class="reference external" href="(https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a>)
|
||||
programming model for more details).</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
|
||||
<span class="c1">// of the program in the overaching SPMD context</span>
|
||||
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
|
||||
<span class="c1">// different chunks of data in parallel.</span>
|
||||
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
|
||||
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
|
||||
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
|
||||
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
|
||||
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
|
||||
<span class="c1">// memory.</span>
|
||||
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
|
||||
<span class="c1">// whose memory values are, e.g.:</span>
|
||||
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
|
||||
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
|
||||
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
|
||||
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
|
||||
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
|
||||
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
|
||||
<span class="c1">// write-back `z`</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL’2019 Triton paper</a>.</p>
|
||||
</div>
|
||||
<div class="section" id="torch-bindings">
|
||||
<h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The only thing that matters when it comes to Triton and Torch is the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="code docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects. To create a <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things:</p>
|
||||
<ul class="simple">
|
||||
<li><p><code class="code docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create</p></li>
|
||||
<li><p><code class="code docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for</p></li>
|
||||
<li><p><code class="code docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <cite>#define</cite> for you</p></li>
|
||||
</ul>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># source-code for Triton compute kernel</span>
|
||||
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
|
||||
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
|
||||
<span class="s2"> // program id</span>
|
||||
<span class="s2"> int pid = get_program_id(0);</span>
|
||||
<span class="s2"> // create arrays of pointers</span>
|
||||
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
|
||||
<span class="s2"> float* px[BLOCK] = x + offset;</span>
|
||||
<span class="s2"> float* py[BLOCK] = y + offset;</span>
|
||||
<span class="s2"> // bounds checking</span>
|
||||
<span class="s2"> bool check[BLOCK] = offset < N;</span>
|
||||
<span class="s2"> // write-back</span>
|
||||
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2"> """</span>
|
||||
|
||||
|
||||
<span class="c1"># This function returns a callable `triton.kernel` object created from the above source code.</span>
|
||||
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
|
||||
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
|
||||
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
|
||||
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span>
|
||||
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
|
||||
<span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
|
||||
|
||||
<span class="c1"># This is a standard torch custom autograd Function;</span>
|
||||
<span class="c1"># The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`</span>
|
||||
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="c1"># *allocate output*</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># *create launch grid*:</span>
|
||||
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
|
||||
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
|
||||
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
|
||||
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
|
||||
<span class="c1"># *launch kernel*:</span>
|
||||
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
|
||||
<span class="c1"># the `.data_ptr()` method</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">z</span>
|
||||
|
||||
|
||||
<span class="c1"># Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function</span>
|
||||
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_add</span><span class="p">(</span>
|
||||
<span class="n">X</span><span class="p">,</span> <span class="c1"># *Pointer* to first input vector</span>
|
||||
<span class="n">Y</span><span class="p">,</span> <span class="c1"># *Pointer* to second input vector</span>
|
||||
<span class="n">Z</span><span class="p">,</span> <span class="c1"># *Pointer* to output vector</span>
|
||||
<span class="n">N</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
|
||||
<span class="o">**</span><span class="n">meta</span> <span class="c1"># Optional meta-parameters for the kernel</span>
|
||||
<span class="p">):</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Create an offset for the blocks of pointers to be</span>
|
||||
<span class="c1"># processed by this program instance</span>
|
||||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">]</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># Create a mask to guard memory operations against</span>
|
||||
<span class="c1"># out-of-bounds accesses</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">N</span>
|
||||
<span class="c1"># Load x</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="c1"># Write back x + y</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Z</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can now use the above function to compute the sum of two <cite>torch.tensor</cite> objects:</p>
|
||||
<p>We can also declara a helper function that handles allocating the output vector
|
||||
and enqueueing the kernel.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># The SPMD launch grid denotes the number of kernel instances that should execute in parallel.</span>
|
||||
<span class="c1"># It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">]),</span> <span class="p">)</span>
|
||||
<span class="c1"># NOTE:</span>
|
||||
<span class="c1"># - torch.tensor objects are implicitly converted to pointers to their first element.</span>
|
||||
<span class="c1"># - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel</span>
|
||||
<span class="c1"># - don't forget to pass meta-parameters as keywords arguments</span>
|
||||
<span class="n">_add</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||||
<span class="c1"># We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still</span>
|
||||
<span class="c1"># running asynchronously.</span>
|
||||
<span class="k">return</span> <span class="n">z</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="section" id="unit-test">
|
||||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:</p>
|
||||
<p>We can now use the above function to compute the sum of two <cite>torch.tensor</cite> objects and test our results:</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
|
||||
@@ -363,7 +295,7 @@ for different problem sizes.</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
<img alt="01 vector add" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 9.497 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 5.812 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||
|
@@ -98,7 +98,6 @@
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
|
||||
</ul>
|
||||
@@ -107,12 +106,14 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -192,8 +193,7 @@ to download the full example code</p>
|
||||
<p>In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The benefits of kernel fusion for bandwidth-bound operations.</p></li>
|
||||
<li><p>The syntax and usage of reduction operators in Triton.</p></li>
|
||||
<li><p>The automatic vectorization capabilities of the Triton compiler.</p></li>
|
||||
<li><p>The reduction operators in Triton.</p></li>
|
||||
</ul>
|
||||
<div class="section" id="motivations">
|
||||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||||
@@ -220,78 +220,41 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
</div>
|
||||
<p>When implemented naively in pytorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
|
||||
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip.
|
||||
In this case, we would be reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
This solution would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.</p>
|
||||
</div>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
|
||||
<p>Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally “pad” tiles and guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">softmax</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">Y</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">X</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// row index</span>
|
||||
<span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// column indices</span>
|
||||
<span class="kt">int</span> <span class="n">n</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="c1">// the memory address of all the elements</span>
|
||||
<span class="c1">// that we want to load can be computed as follows</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="c1">// because BLOCK has to be a power of two</span>
|
||||
<span class="c1">// (per Triton-C specs), it is important</span>
|
||||
<span class="c1">// to guard each memory operation with predicates</span>
|
||||
<span class="c1">// or we will read out of bounds</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="kt">float</span> <span class="n">x</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">check</span> <span class="o">?</span> <span class="o">*</span><span class="nl">px</span> <span class="p">:</span> <span class="o">-</span><span class="n">F32_INFINITY</span><span class="p">;</span>
|
||||
<span class="c1">// syntax for reduction in Triton is:</span>
|
||||
<span class="c1">// x[:, :, OPERATOR, :, :]</span>
|
||||
<span class="c1">// ^</span>
|
||||
<span class="c1">// index</span>
|
||||
<span class="c1">// where operator is in {min, max, +}</span>
|
||||
<span class="c1">// for 1D vectors, this is just x[OPERATOR].</span>
|
||||
<span class="kt">float</span> <span class="n">z</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="p">[</span><span class="n">max</span><span class="p">];</span>
|
||||
<span class="c1">// Note that exponentials in Triton are fast</span>
|
||||
<span class="c1">// but approximate (i.e., think __expf in CUDA)</span>
|
||||
<span class="kt">float</span> <span class="n">num</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
|
||||
<span class="kt">float</span> <span class="n">denom</span> <span class="o">=</span> <span class="n">num</span><span class="p">[</span><span class="o">+</span><span class="p">];</span>
|
||||
<span class="c1">// The result of the reduction is now stored in y</span>
|
||||
<span class="kt">float</span> <span class="n">y</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">;</span>
|
||||
<span class="c1">// We write it back</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span> <span class="o">=</span> <span class="n">y</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_softmax</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="o">**</span><span class="n">meta</span><span class="p">):</span>
|
||||
<span class="c1"># row index</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># col indices</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># the memory address of all the elements</span>
|
||||
<span class="c1"># that we want to load can be computed as follows</span>
|
||||
<span class="n">X</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">'inf'</span><span class="p">))</span>
|
||||
<span class="c1"># Substract maximum for numerical stability</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">triton</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Note that exponentials in Triton are fast</span>
|
||||
<span class="c1"># but approximate (i.e., think __expf in CUDA)</span>
|
||||
<span class="n">num</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">denom</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">num</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span>
|
||||
<span class="c1"># Write back to Y</span>
|
||||
<span class="n">Y</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
</div>
|
||||
<div class="section" id="torch-bindings">
|
||||
<h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
|
||||
We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
|
||||
This means that different values of BLOCK will result in different kernels</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># Source code for the Triton kernel</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){</span>
|
||||
<span class="s2"> int m = get_program_id(0);</span>
|
||||
<span class="s2"> int n [BLOCK] = 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* px [BLOCK] = X + m*stride_xm + n;</span>
|
||||
<span class="s2"> bool check[BLOCK] = n < N;</span>
|
||||
<span class="s2"> float x [BLOCK] = check ? *px : -F32_INFINITY;</span>
|
||||
<span class="s2"> float z [BLOCK] = x - x[max];</span>
|
||||
<span class="s2"> float num [BLOCK] = exp(z);</span>
|
||||
<span class="s2"> float denom = num[+];</span>
|
||||
<span class="s2"> float y [BLOCK] = num / denom;</span>
|
||||
<span class="s2"> float* py [BLOCK] = Y + m*stride_ym + n;</span>
|
||||
<span class="s2"> *?(check)py = y;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
|
||||
<span class="c1"># helper function to get the smaller power-of-two larger than a given number</span>
|
||||
<span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<p>We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">2</span>
|
||||
@@ -302,11 +265,9 @@ This means that different values of BLOCK will result in different kernels</p>
|
||||
<span class="k">return</span> <span class="n">n</span>
|
||||
|
||||
|
||||
<span class="c1"># kernel caching mechanism</span>
|
||||
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
|
||||
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span>
|
||||
<span class="c1"># Now are kernels are indexed not only by the provided device but also</span>
|
||||
<span class="c1"># by the rounded number of columns in the input matrix</span>
|
||||
<span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="c1"># The block size is the smallest power of two greater than the number of columns in `x`</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||||
<span class="c1"># Another trick we can use is to ask the compiler to parallelize each</span>
|
||||
<span class="c1"># row-normalization more aggressively -- i.e., with more warps -- vectors</span>
|
||||
@@ -316,36 +277,13 @@ This means that different values of BLOCK will result in different kernels</p>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">2048</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
<span class="c1"># Each (BLOCK, num_warps, device) results in a different kernel</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">num_warps</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
|
||||
<span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">_softmax</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># The launch grid is simple: we have one kernel instance per row of the input matrix</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="p">)</span>
|
||||
<span class="c1"># Launch kernel</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
|
||||
|
||||
<span class="n">softmax</span> <span class="o">=</span> <span class="n">_softmax</span><span class="o">.</span><span class="n">apply</span>
|
||||
<span class="c1"># Allocate output</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix</span>
|
||||
<span class="n">_softmax</span><span class="p">[(</span><span class="n">M</span><span class="p">,</span> <span class="p">)](</span><span class="n">y</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can use the above softmax function to compute the row-wise softmax of a given matrix.</p>
|
||||
</div>
|
||||
<div class="section" id="unit-test">
|
||||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||||
@@ -405,7 +343,7 @@ This means that – when temporary data is too large to fit entirely in the GPU
|
||||
Note that our Triton kernel is not only faster than PyTorch’s CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
|
||||
</ul>
|
||||
</div></blockquote>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 25.654 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 20.767 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
|