Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -1,139 +1,71 @@
|
||||
import torch
|
||||
import triton
|
||||
"""
|
||||
Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic syntax of the Triton programming language
|
||||
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
|
||||
- The basic programming model used by Triton
|
||||
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
#
|
||||
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
|
||||
# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
|
||||
# programming model for more details).
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void add(float* z, float* x, float* y, int N){
|
||||
# // The `get_program_id(i)` returns the i-th coordinate
|
||||
# // of the program in the overaching SPMD context
|
||||
# // (a.k.a launch grid). This is what allows us to process
|
||||
# // different chunks of data in parallel.
|
||||
# // For those similar with CUDA, `get_program_id({0,1,2})`
|
||||
# // is similar to blockIdx.{x,y,z}
|
||||
# int pid = get_program_id(0);
|
||||
# // In Triton, arrays are first-class citizen. In other words,
|
||||
# // they are primitives data-types and are -- contrary to C and
|
||||
# // CUDA -- not implemented as pointers to contiguous chunks of
|
||||
# // memory.
|
||||
# // In the few lines below, we create an array of `BLOCK` pointers
|
||||
# // whose memory values are, e.g.:
|
||||
# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
|
||||
# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
|
||||
# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
# float* pz [BLOCK] = z + offset;
|
||||
# float* px [BLOCK] = x + offset;
|
||||
# float* py [BLOCK] = y + offset;
|
||||
# // Simple element-wise control-flow for load/store operations can
|
||||
# // be achieved using the the ternary operator `cond ? val_true : val_false`
|
||||
# // or the conditional dereferencing operator `*?(cond)ptr
|
||||
# // Here, we make sure that we do not access memory out-of-bounds when we
|
||||
# // write-back `z`
|
||||
# bool check[BLOCK] = offset < N;
|
||||
# *?(check)pz = *?(check)px + *?(check)py;
|
||||
# }
|
||||
#
|
||||
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(
|
||||
X, # *Pointer* to first input vector
|
||||
Y, # *Pointer* to second input vector
|
||||
Z, # *Pointer* to output vector
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = triton.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = triton.load(X + offsets, mask=mask)
|
||||
y = triton.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
triton.store(Z + offsets, z)
|
||||
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# --------------------------
|
||||
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
|
||||
#
|
||||
# - :code:`source: string`: the source-code of the kernel you want to create
|
||||
# - :code:`device: torch.device`: the device you want to compile this code for
|
||||
# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# source-code for Triton compute kernel
|
||||
# here we just copy-paste the above code without the extensive comments.
|
||||
# you may prefer to store it in a .c file and load it from there instead.
|
||||
_src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz[BLOCK] = z + offset;
|
||||
float* px[BLOCK] = x + offset;
|
||||
float* py[BLOCK] = y + offset;
|
||||
// bounds checking
|
||||
bool check[BLOCK] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
# We can also declara a helper function that handles allocating the output vector
|
||||
# and enqueueing the kernel.
|
||||
|
||||
|
||||
# This function returns a callable `triton.kernel` object created from the above source code.
|
||||
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||
# We compile the kernel with -DBLOCK=1024
|
||||
def make_add_kernel(device):
|
||||
cache = make_add_kernel.cache
|
||||
if device not in cache:
|
||||
defines = {'BLOCK': 1024}
|
||||
cache[device] = triton.kernel(_src, device=device, defines=defines)
|
||||
return cache[device]
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# NOTE:
|
||||
# - torch.tensor objects are implicitly converted to pointers to their first element.
|
||||
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously.
|
||||
return z
|
||||
|
||||
|
||||
make_add_kernel.cache = dict()
|
||||
|
||||
|
||||
# This is a standard torch custom autograd Function;
|
||||
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
|
||||
class _add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
# *allocate output*
|
||||
z = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# this is a function which takes compilation parameters `opt`
|
||||
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||
# triton.cdiv is a shortcut for ceil division:
|
||||
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||
N = z.shape[0]
|
||||
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
|
||||
# *launch kernel*:
|
||||
# pointer to the data of torch tensors can be retrieved with
|
||||
# the `.data_ptr()` method
|
||||
kernel = make_add_kernel(z.device)
|
||||
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
|
||||
return z
|
||||
|
||||
|
||||
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
|
||||
add = _add.apply
|
||||
|
||||
# %%
|
||||
# We can now use the above function to compute the sum of two `torch.tensor` objects:
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:
|
||||
# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432, device='cuda')
|
||||
y = torch.rand(98432, device='cuda')
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
|
@@ -4,8 +4,7 @@ Fused Softmax
|
||||
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The syntax and usage of reduction operators in Triton.
|
||||
- The automatic vectorization capabilities of the Triton compiler.
|
||||
- The reduction operators in Triton.
|
||||
"""
|
||||
|
||||
# %%
|
||||
@@ -36,79 +35,45 @@ def naive_softmax(x):
|
||||
# %%
|
||||
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
# In this case, we would be reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
# Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
|
||||
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
|
||||
# // row index
|
||||
# int m = get_program_id(0);
|
||||
# // column indices
|
||||
# int n [BLOCK] = 0 ... BLOCK;
|
||||
# // the memory address of all the elements
|
||||
# // that we want to load can be computed as follows
|
||||
# float* px [BLOCK] = X + m*stride_xm + n;
|
||||
# // because BLOCK has to be a power of two
|
||||
# // (per Triton-C specs), it is important
|
||||
# // to guard each memory operation with predicates
|
||||
# // or we will read out of bounds
|
||||
# bool check[BLOCK] = n < N;
|
||||
# float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
# // syntax for reduction in Triton is:
|
||||
# // x[:, :, OPERATOR, :, :]
|
||||
# // ^
|
||||
# // index
|
||||
# // where operator is in {min, max, +}
|
||||
# // for 1D vectors, this is just x[OPERATOR].
|
||||
# float z [BLOCK] = x - x[max];
|
||||
# // Note that exponentials in Triton are fast
|
||||
# // but approximate (i.e., think __expf in CUDA)
|
||||
# float num [BLOCK] = exp(z);
|
||||
# float denom = num[+];
|
||||
# // The result of the reduction is now stored in y
|
||||
# float y [BLOCK] = num / denom;
|
||||
# // We write it back
|
||||
# float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
# *?(check)py = y;
|
||||
# }
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# ---------------
|
||||
# Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
|
||||
# We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
|
||||
# This means that different values of BLOCK will result in different kernels
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# Source code for the Triton kernel
|
||||
_src = """
|
||||
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
|
||||
int m = get_program_id(0);
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
float z [BLOCK] = x - x[max];
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
float y [BLOCK] = num / denom;
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
# row index
|
||||
m = triton.program_id(0)
|
||||
# col indices
|
||||
n = triton.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
X = X + m * stride_xm + n
|
||||
x = triton.load(X, mask=n < N, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
z = x - triton.max(x, axis=0)
|
||||
# Note that exponentials in Triton are fast
|
||||
# but approximate (i.e., think __expf in CUDA)
|
||||
num = triton.exp(z)
|
||||
denom = triton.sum(num, axis=0)
|
||||
y = num / denom
|
||||
# Write back to Y
|
||||
Y = Y + m * stride_ym + n
|
||||
triton.store(Y, y, mask=n < N)
|
||||
|
||||
|
||||
# %%
|
||||
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
|
||||
# helper function to get the smaller power-of-two larger than a given number
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
@@ -120,11 +85,9 @@ def next_power_of_2(n):
|
||||
return n
|
||||
|
||||
|
||||
# kernel caching mechanism
|
||||
def make_kernel(N, device):
|
||||
cache = make_kernel.cache
|
||||
# Now are kernels are indexed not only by the provided device but also
|
||||
# by the rounded number of columns in the input matrix
|
||||
def softmax(x):
|
||||
M, N = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
# Another trick we can use is to ask the compiler to parallelize each
|
||||
# row-normalization more aggressively -- i.e., with more warps -- vectors
|
||||
@@ -134,37 +97,13 @@ def make_kernel(N, device):
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
# Each (BLOCK, num_warps, device) results in a different kernel
|
||||
key = (BLOCK, num_warps, device)
|
||||
if key not in cache:
|
||||
defines = {'BLOCK': BLOCK}
|
||||
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
|
||||
return cache[key]
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)
|
||||
return y
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
y = torch.empty_like(x)
|
||||
# The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
M, N = y.shape
|
||||
grid = lambda opt: (M, )
|
||||
# Launch kernel
|
||||
kernel = make_kernel(N, y.device)
|
||||
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
|
||||
return y
|
||||
|
||||
|
||||
softmax = _softmax.apply
|
||||
|
||||
# %%
|
||||
# We can use the above softmax function to compute the row-wise softmax of a given matrix.
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# ----------
|
||||
|
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that outperforms CUTLASS and falls just short of matching cuBLAS's performance.
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
|
||||
You will specifically learn about:
|
||||
|
||||
- The block-level matrix multiplication operator `@`
|
||||
- Block-level matrix multiplications
|
||||
- Multi-dimensional pointer arithmetic
|
||||
- Program re-ordering for improved L2 cache hit rate
|
||||
- Automatic performance tuning
|
||||
@@ -15,7 +15,7 @@ You will specifically learn about:
|
||||
# -------------
|
||||
# Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
|
||||
#
|
||||
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
@@ -23,322 +23,212 @@ You will specifically learn about:
|
||||
# .. code-block:: python
|
||||
#
|
||||
# # do in parallel
|
||||
# for m in range(0, M, MB):
|
||||
# for m in range(0, M, BLOCK_M):
|
||||
# # do in parallel
|
||||
# for n in range(0, N, NB):
|
||||
# acc = zeros((MB, NB), dtype=float32)
|
||||
# for k in range(0, K, KB):
|
||||
# acc += A[m : m+MB, k : k+KB] @ B[k : k+KB, n : n+NB]
|
||||
# C[m : m+MB, n : n+NB] = acc;
|
||||
# for n in range(0, N, BLOCK_N):
|
||||
# acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
|
||||
# for k in range(0, K, BLOCK_K):
|
||||
# a = A[m : m+BLOCK_M, k : k+BLOCK_K]
|
||||
# b = B[k : k+BLOCK_K, n : n+BLOCK_N]
|
||||
# acc += dot(a, b)
|
||||
# C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
|
||||
#
|
||||
# where each iteration of the doubly-nested for-loops corresponds to a Triton program instance.
|
||||
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
#
|
||||
# The above algorithm is actually fairly straightforward to implement in Triton, as we can simply use the :code:`@` operator for block-level matrix multiplication.
|
||||
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations of the tiles of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
# The above algorithm is actually fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
#
|
||||
# Pointer Arithmetics
|
||||
# ~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = i + X.stride(0) + j`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+MB, k:k+KB]` and :code:`B[k : k+KB, n : n+NB]` can be defined in pseudo-code as:
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# &A[m : m+MB, k:k+KB] = A + (m : m+MB)[:, newaxis]*A.stride(0) + (k : k+KB)[newaxis, :];
|
||||
# &B[k : k+KB, n:n+NB] = B + (k : k+KB)[:, newaxis]*B.stride(0) + (n : n+NB)[newaxis, :];
|
||||
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
|
||||
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
|
||||
#
|
||||
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: python
|
||||
# :force:
|
||||
#
|
||||
# int rm[MB] = program_id_m * MB + 0 ... MB;
|
||||
# int rn[NB] = program_id_n * NB + 0 ... NB;
|
||||
# int rk[KB] = 0 ... KB;
|
||||
# TYPE *pa[MB, KB] = A + (rm[:, newaxis] * stride_a_0 + rk [newaxis, :] * 1);
|
||||
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
# pid_m = triton.program_id(0)
|
||||
# pid_n = triton.program_id(1)
|
||||
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
# rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
# rk = triton.arange(0, BLOCK_K)
|
||||
# // pointer for A operand
|
||||
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
||||
# // pointer for B operand
|
||||
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
||||
#
|
||||
# These pointers can then be updated in the inner loop as:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: python
|
||||
#
|
||||
# pa += KB * 1;
|
||||
# pb += KB * ldb;
|
||||
# pa += BLOCK_K * stride_a_1;
|
||||
# pb += BLOCK_K * stride_b_0;
|
||||
#
|
||||
#
|
||||
# L2 Cache Optimizations
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# As mentioned above, each program instance computes an :code:`[MB, NB]` block of :code:`C`.
|
||||
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||||
# This means that a naive row-major ordering:
|
||||
#
|
||||
# .. code-block:: C
|
||||
# .. code-block:: Python
|
||||
#
|
||||
# int program_id = get_program_id(0);
|
||||
# int grid_m = (M + MB - 1) / MB;
|
||||
# int grid_n = (N + NB - 1) / NB;
|
||||
# int program_id_m = program_id / grid_n;
|
||||
# int program_id_n = program_id % grid_n;
|
||||
# pid = triton.program_id(0);
|
||||
# grid_m = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
# grid_n = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
# pid_m = pid / grid_n;
|
||||
# pid_n = pid % grid_n;
|
||||
#
|
||||
# is unlikely to result in optimal performance.
|
||||
#
|
||||
# One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE` before switching to the next column:
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# int program_id = get_program_id(0);
|
||||
# int width = GROUP_SIZE * grid_n;
|
||||
# int group_id = pid / width;
|
||||
# // we need to handle the case where M % (GROUP_SIZE*BM) != 0
|
||||
# int group_size = min(grid_m - group_id * GROUP_SIZE, GROUP_SIZE);
|
||||
# int pid_m = group_id * GROUP_SIZE + (pid % group_size);
|
||||
# int pid_n = (pid % width) / (group_size);
|
||||
# pid = triton.program_id(0);
|
||||
# width = GROUP_M * grid_n;
|
||||
# group_id = pid / width;
|
||||
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
# pid_n = (pid % width) / (group_size);
|
||||
#
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
#
|
||||
# Final Result
|
||||
# ~~~~~~~~~~~~~~
|
||||
#
|
||||
# We are now ready to put all these pieces together and write our Triton kernel for matrix multiplication.
|
||||
# Note that we rematerialize :code:`rm` and :code:`rn:` after the inner loop to decrease register pressure.
|
||||
# This is an optimization that provides an additional 5% performance improvement and cannot be currently done by the Triton compiler.
|
||||
#
|
||||
# .. code-block:: C
|
||||
# :force:
|
||||
#
|
||||
# #define MAX_GROUP_SIZE 8
|
||||
#
|
||||
# __global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
# int M, int N, int K,
|
||||
# int stride_a_0, int stride_b_0, int stride_c_0) {
|
||||
# // prologue
|
||||
# int pid = get_program_id(0);
|
||||
# int grid_m = (M + MB - 1) / MB;
|
||||
# int grid_n = (N + NB - 1) / NB;
|
||||
# // re-order program ID for better L2 performance
|
||||
# int width = MAX_GROUP_SIZE * grid_n;
|
||||
# int group_id = pid / width;
|
||||
# int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
# int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
# int pid_n = (pid % width) / (group_size);
|
||||
# // pointers to operands
|
||||
# // note the parentheses here; they force the offset
|
||||
# // computation to happen in typeof(stride_a_0) = int32 rather than
|
||||
# // typeof(A) = int64
|
||||
# int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
# int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
# int rk[KB] = 0 ... KB;
|
||||
# TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * stride_a_0);
|
||||
# TYPE *pb[KB, NB] = B + (rk[:, newaxis] * stride_b_0 + rn [newaxis, :] * 1);
|
||||
# // reduction loop
|
||||
# float acc[MB, NB] = 0;
|
||||
# for (int k = K; k > 0; k -= KB) {
|
||||
# acc += (*pa) @ (*pb);
|
||||
# pa += KB * 1;
|
||||
# pb += KB * stride_b_0;
|
||||
# }
|
||||
# // pointers to output
|
||||
# // here we rematerialize `rm` and `rn` so that they are not live through
|
||||
# // the above reduction loop. In the future, the compiler should be able to
|
||||
# // do this automatically.
|
||||
# rm = pid_m * MB + 0 ... MB;
|
||||
# rn = pid_n * NB + 0 ... NB;
|
||||
# TYPE *pc[MB, NB] = C + (rm[:, newaxis] * stride_c_0 + rn[newaxis, :]);
|
||||
# // we write back using *?() operator. `acc` gets casted to `float32` implicitly.
|
||||
# *? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
# }
|
||||
#
|
||||
# Where :code:`TYPE` is the data-type of the input matrices and :code:`MB`, :code:`NB`, :code:`KB` are the block sizes defined in the above pseudo-code.
|
||||
# Good values for these block sizes are hard to find, hence we will introduce the auto-tuner in the next section of this tutorial.
|
||||
# If :code:`TYPE` is :code:`half`, then tensor cores will be used automatically provided that :code:`MB`, :code:`NB` and :code:`KB` are multiples of 16.
|
||||
#
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# ----------------
|
||||
# Final Result
|
||||
# -------------
|
||||
#
|
||||
# Auto-Tuning
|
||||
# ~~~~~~~~~~~~~~
|
||||
#
|
||||
# In order to use Triton's built-in auto-tuner in the above kernel, we need to define a list of :code:`triton.config` objects. that can be constructed as follows:
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
autotune_configs = [
|
||||
triton.config(defines={"MB": "128", "NB": "128", "KB": "32"}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '128', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '64', 'KB': '32'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '64', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '32', 'NB': '128', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '128', 'NB': '32', 'KB': '64'}, num_warps=4),
|
||||
triton.config(defines={'MB': '64', 'NB': '32', 'KB': '64'}, num_warps=2),
|
||||
triton.config(defines={'MB': '32', 'NB': '64', 'KB': '64'}, num_warps=2)
|
||||
]
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
ret_true = 1 / (1 + triton.exp(-x))
|
||||
ret_false = triton.exp(x) / (1 + triton.exp(x))
|
||||
return triton.where(x >= 0, ret_true, ret_false)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swish(x):
|
||||
return x * sigmoid(x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
# %
|
||||
# We can now define our kernel as normal, using all the techniques presented above
|
||||
@triton.jit
|
||||
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
|
||||
# extract meta-parameters
|
||||
BLOCK_M = META['BLOCK_M']
|
||||
BLOCK_N = META['BLOCK_N']
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = 8
|
||||
# matrix multiplication
|
||||
pid = triton.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
a = triton.load(A)
|
||||
b = triton.load(B)
|
||||
acc += triton.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
# triton can accept arbitrary activation function
|
||||
# via metaparameters!
|
||||
if META['ACTIVATION']:
|
||||
acc = META['ACTIVATION'](acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
triton.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
# %%
|
||||
# we also need to define a list of :code:`string` (i.e., "autotuning key") that specifies the set of argument names whose change in value will trigger the auto-tuner to kick in.
|
||||
# Here, we want to re-tune our kernel only when the shape of input matrices changes.
|
||||
|
||||
autotune_key = ["M", "N", "K"]
|
||||
|
||||
# %%
|
||||
# We can now create an auto-tuned kernel by passing the `autotune_configs` and `autotune_key` lists to the constructor of the :code:`triton.kernel` class.
|
||||
|
||||
src = """
|
||||
#define MAX_GROUP_SIZE 8
|
||||
|
||||
__global__ void dot(TYPE* A, TYPE* B, TYPE* C,
|
||||
int M, int N, int K,
|
||||
int lda, int ldb, int ldc) {
|
||||
int pid = get_program_id(0);
|
||||
int grid_m = (M + MB - 1) / MB;
|
||||
int grid_n = (N + NB - 1) / NB;
|
||||
int width = MAX_GROUP_SIZE * grid_n;
|
||||
int group_id = pid / width;
|
||||
int group_size = min(grid_m - group_id * MAX_GROUP_SIZE, MAX_GROUP_SIZE);
|
||||
int pid_m = group_id * MAX_GROUP_SIZE + (pid % group_size);
|
||||
int pid_n = (pid % width) / (group_size);
|
||||
int rm[MB] = pid_m * MB + 0 ... MB;
|
||||
int rn[NB] = pid_n * NB + 0 ... NB;
|
||||
int rk[KB] = 0 ... KB;
|
||||
TYPE *pa[MB, KB] = A + (rk [newaxis, :] * 1 + rm[:, newaxis] * lda);
|
||||
TYPE *pb[KB, NB] = B + (rk[:, newaxis] * ldb + rn [newaxis, :] * 1);
|
||||
float acc[MB, NB] = 0;
|
||||
for (int k = K; k > 0; k -= KB) {
|
||||
acc += (*pa) @ (*pb);
|
||||
pa += KB * 1;
|
||||
pb += KB * ldb;
|
||||
}
|
||||
rm = pid_m * MB + 0 ... MB;
|
||||
rn = pid_n * NB + 0 ... NB;
|
||||
TYPE *pc[MB, NB] = C + (rm[:, newaxis] * ldc + rn[newaxis, :]);
|
||||
*? (rm[:, newaxis] < M && rn [newaxis, :] < N) pc = acc;
|
||||
}
|
||||
"""
|
||||
# We can also create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
|
||||
|
||||
|
||||
def make_kernel(device, dtype):
|
||||
key = (device, dtype)
|
||||
cache = make_kernel.cache
|
||||
if key not in cache:
|
||||
defines = {'TYPE': dtype}
|
||||
cache[key] = triton.kernel(
|
||||
src,
|
||||
device=device,
|
||||
defines=defines,
|
||||
autotune_configs=autotune_configs,
|
||||
autotune_key=autotune_key,
|
||||
)
|
||||
return cache[key]
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
# return output
|
||||
return c
|
||||
|
||||
|
||||
make_kernel.cache = dict()
|
||||
|
||||
# %%
|
||||
# Autograd Function
|
||||
# ~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# Now we are ready to expose our auto-tuned kernel as a `torch.autograd.Function`.
|
||||
# To do so, we just need to define a `forward` function that takes a two tensors as input and returns a tensor as output.
|
||||
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
M, Ka = a.shape
|
||||
Kb, N = b.shape
|
||||
assert Ka == Kb, "incompatible dimensions"
|
||||
assert a.is_contiguous() and b.is_contiguous(), "inputs must be contiguous"
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
kernel = make_kernel(a.device, a.dtype)
|
||||
grid = lambda opt: (triton.cdiv(M, opt.MB) * triton.cdiv(N, opt.NB), )
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
||||
M, N, Ka, \
|
||||
a.stride(0), b.stride(0), c.stride(0), \
|
||||
grid=grid)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# We can test our custom matrix multiplication operation against cuBLAS (i.e., :code:`torch.matmul`).
|
||||
# Note that we need to modify the :code`atol` and :code:`rtol` parameters of `torch.allclose` to account for the fact that we are comparing FP16 tensors.
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
|
||||
a = torch.rand((512, 768), device='cuda', dtype=torch.float16)
|
||||
b = torch.rand((768, 896), device='cuda', dtype=torch.float16)
|
||||
c_0 = dot(a, b)
|
||||
c_1 = torch.matmul(a, b)
|
||||
#torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
# --------------
|
||||
#
|
||||
# Installing The CUTLASS Bindings
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# The cuBLAS library (used by :code:`torch.matmul`) uses handwritten assembly-level optimizations that cannot be replicated using publicly available tools.
|
||||
# For this reason, we will instead compare the performance of our kernel against `CUTLASS <https://github.com/NVIDIA/cutlass/>`_ , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
|
||||
# To install CUTLASS, you need a recent version of cmake:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# cd /path/to/cutlass/
|
||||
# git clone https://github.com/NVIDIA/cutlass.git
|
||||
# cd cutlass
|
||||
# mkdir build
|
||||
# cd build
|
||||
# wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4-Linux-x86_64.tar.gz
|
||||
# tar xzvf *.tar.gz
|
||||
#
|
||||
# You can then install CUTLASS as follows for V100
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=70 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s884gemm_f16_*_align8
|
||||
# make -j8 install
|
||||
#
|
||||
# Or as follows for A100:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# ./cmake-3.19.4-Linux-x86_64/bin/cmake ../ -DCUTLASS_NVCC_ARCHS_ENABLED=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s16816gemm_*align8
|
||||
# make -j8 install
|
||||
#
|
||||
# Where you can change CUTLASS_LIBRARY_KERNELS as you desire. Here, we are only interested in FP16 tensor core performance.
|
||||
# Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables :code:`CUTLASS_INCLUDE_DIR` and :code:`CUTLASS_LIBRARY_DIR` are set during the installation process.
|
||||
# To re-install Triton with the updated CUTLASS bindings, run the following command:
|
||||
#
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
||||
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
||||
# pip uninstall -y triton
|
||||
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
#
|
||||
# Which we can test as follows:
|
||||
|
||||
import triton
|
||||
c_2 = triton.testing.cutlass_matmul(a, b)
|
||||
print(c_2)
|
||||
print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
|
||||
# %%
|
||||
# Note that this wrapper for CUTLASS was written for benchmarking purposes and is probably not production-ready.
|
||||
#
|
||||
# Square Matrix Performance
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
||||
@@ -347,29 +237,25 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
x_vals=[8192], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
|
||||
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
silu = torch.nn.SiLU()
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
|
||||
if provider == 'cutlass':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.testing.cutlass_matmul(a, b))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
|
||||
# %%
|
||||
# As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.
|
||||
benchmark.run(print_data=True)
|
Reference in New Issue
Block a user