[DOCS] Improved documentation and integration in CI (#139)

This commit is contained in:
Philippe Tillet
2021-07-22 22:45:19 -07:00
committed by Philippe Tillet
parent 76c6f24fb6
commit b253b77c71
5 changed files with 124 additions and 84 deletions

28
.ci/build-website.yml Normal file
View File

@@ -0,0 +1,28 @@
trigger: none
pr: none
jobs:
- job: linux
workspace:
clean: all
pool: default
steps:
- bash: |
set -o errexit
cd python
python setup.py develop
displayName: Install dependencies
- bash: |
cd docs
make html
displayName: Build docs
- bash: |
git checkout gh-pages
sh ./update-website.sh
git remote set-url origin git@github.com:ptillet/triton.git
git push
displayName: Publish docs

View File

@@ -168,6 +168,8 @@ class Benchmark:
ylabel='', ylabel='',
x_log=False, x_log=False,
y_log=False, y_log=False,
color=None,
styles=None,
): ):
""" """
Constructor Constructor
@@ -202,6 +204,7 @@ class Benchmark:
self.line_vals = line_vals self.line_vals = line_vals
self.line_names = line_names self.line_names = line_names
self.y_log = y_log self.y_log = y_log
self.styles = styles
# plot info # plot info
self.xlabel = xlabel self.xlabel = xlabel
self.ylabel = ylabel self.ylabel = ylabel
@@ -239,11 +242,13 @@ class Mark:
plt.figure() plt.figure()
ax = plt.subplot() ax = plt.subplot()
x = bench.x_names[0] x = bench.x_names[0]
for y in bench.line_names: for i, y in enumerate(bench.line_names):
y_min, y_max = df[y + '-min'], df[y + '-max'] y_min, y_max = df[y + '-min'], df[y + '-max']
ax.plot(df[x], df[y], label=y) col = bench.styles[i][0] if bench.styles else None
sty = bench.styles[i][1] if bench.styles else None
ax.plot(df[x], df[y], label=y, color=col, ls=sty)
if y_min is not None and y_max is not None: if y_min is not None and y_max is not None:
ax.fill_between(df[x], y_min, y_max, alpha=0.5) ax.fill_between(df[x], y_min, y_max, alpha=0.15, color=col)
ax.legend() ax.legend()
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names) xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
ax.set_xlabel(xlabel) ax.set_xlabel(xlabel)

View File

@@ -3,9 +3,9 @@ Vector Addition
================= =================
In this tutorial, you will write a simple vector addition using Triton and learn about: In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic programming model used by Triton - The basic programming model of Triton
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels. - The `triton.jit` decorator, which is used to define Triton kernels.
- The best practices for validating and benchmarking custom ops against native reference implementations - The best practices for validating and benchmarking your custom ops against native reference implementations
""" """
# %% # %%
@@ -41,28 +41,28 @@ def _add(
# %% # %%
# We can also declara a helper function that handles allocating the output vector # Let's also declare a helper function that to (1) allocate the output vector
# and enqueueing the kernel. # and (2) enqueueing the above kernel.
def add(x, y): def add(x, y):
z = torch.empty_like(x) z = torch.empty_like(x)
N = z.shape[0] N = z.shape[0]
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel. # The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int] # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), ) grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
# NOTE: # NOTE:
# - torch.tensor objects are implicitly converted to pointers to their first element. # - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments # - don't forget to pass meta-parameters as keywords arguments
_add[grid](x, y, z, N, BLOCK=1024) _add[grid](x, y, z, N, BLOCK=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously. # running asynchronously at this point.
return z return z
# %% # %%
# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results: # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
torch.manual_seed(0) torch.manual_seed(0)
size = 98432 size = 98432
@@ -81,7 +81,7 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
# Benchmark # Benchmark
# ----------- # -----------
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch. # We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op. # To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
# for different problem sizes. # for different problem sizes.
@@ -91,8 +91,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name` x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
x_log=True, # x axis is logarithmic x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton'], # possible values for `line_arg` line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=["Torch", "Triton"], # label name for the lines line_names=["Triton", "Torch"], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis ylabel="GB/s", # label name for the y-axis
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot. plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
args={} # values for function arguments not in `x_names` and `y_name` args={} # values for function arguments not in `x_names` and `y_name`
@@ -112,4 +113,4 @@ def benchmark(size, provider):
# %% # %%
# We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or # We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data # `save_path='/path/to/results/' to save them to disk along with raw CSV data
benchmark.run(show_plots=True) benchmark.run(print_data=True, show_plots=True)

View File

@@ -1,10 +1,11 @@
""" """
Fused Softmax Fused Softmax
================= =================
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about: In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
You will learn about:
- The benefits of kernel fusion for bandwidth-bound operations. - The benefits of kernel fusion for bandwidth-bound operations.
- The reduction operators in Triton. - Reduction operators in Triton.
""" """
# %% # %%
@@ -17,15 +18,16 @@ import torch
# Compute the row-wise softmax of x # Compute the row-wise softmax of x
@torch.jit.script
def naive_softmax(x): def naive_softmax(x):
# read MN elements ; write M elements # read MN elements ; write M elements
x_max = torch.max(x, axis=1)[0] x_max = x.max(dim=1)[0]
# read 2MN elements ; write MN elements # read 2MN elements ; write MN elements
z = x - x_max[:, None] z = x - x_max[:, None]
# read MN elements ; write MN elements # read MN elements ; write MN elements
numerator = torch.exp(x) numerator = torch.exp(x)
# read MN elements ; write M elements # read MN elements ; write M elements
denominator = torch.sum(numerator, axis=1) denominator = numerator.sum(dim=1)
# read 2MN elements ; write MN elements # read 2MN elements ; write MN elements
ret = numerator / denominator[:, None] ret = numerator / denominator[:, None]
# in total: read 7MN elements ; wrote 3MN + 2M elements # in total: read 7MN elements ; wrote 3MN + 2M elements
@@ -35,15 +37,15 @@ def naive_softmax(x):
# %% # %%
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements. # When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. # This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`). # Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory. # The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
# %% # %%
# Compute Kernel # Compute Kernel
# ---------------- # ----------------
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y. # Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements, # Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes: # so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
import triton import triton
import triton.language as tl import triton.language as tl
@@ -54,6 +56,7 @@ def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
# row index # row index
m = tl.program_id(0) m = tl.program_id(0)
# col indices # col indices
# here BLOCK is the smallest power of two greater than `N`
n = tl.arange(0, meta['BLOCK']) n = tl.arange(0, meta['BLOCK'])
# the memory address of all the elements # the memory address of all the elements
# that we want to load can be computed as follows # that we want to load can be computed as follows
@@ -90,11 +93,10 @@ def softmax(x):
M, N = x.shape M, N = x.shape
# The block size is the smallest power of two greater than the number of columns in `x` # The block size is the smallest power of two greater than the number of columns in `x`
BLOCK = next_power_of_2(N) BLOCK = next_power_of_2(N)
# Another trick we can use is to ask the compiler to parallelize each # Another trick we can use is to ask the compiler to use more threads per row by
# row-normalization more aggressively -- i.e., with more warps -- vectors # increasing the number of warps (`num_warps`) over which each row is distributed.
# that are longer
# You will see in the next tutorial how to auto-tune this value in a more natural # You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself # way so you don't have to come up with manual heuristics yourself.
num_warps = 4 num_warps = 4
if BLOCK >= 2048: num_warps = 8 if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16 if BLOCK >= 4096: num_warps = 16
@@ -132,10 +134,11 @@ print(torch.allclose(y_tri, y_ref))
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name` x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg`` line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
line_names=["Torch", "Triton", 'Naive'], # label name for the lines line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
ylabel="GB/s", # label name for the y-axis ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name` args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
@@ -143,11 +146,11 @@ print(torch.allclose(y_tri, y_ref))
) )
def benchmark(M, N, provider): def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32) x = torch.randn(M, N, device='cuda', dtype=torch.float32)
if provider == 'torch': if provider == 'torch-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton': if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x)) ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive': if provider == 'torch-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x)) ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms) return gbps(ms), gbps(max_ms), gbps(min_ms)
@@ -158,7 +161,7 @@ benchmark.run(show_plots=True, print_data=True)
# %% # %%
# In the above plot, we can see that: # In the above plot, we can see that:
# #
# - Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions. # - Triton is 2-3x faster than the Torch JIT.
# - Triton is significantly faster than :code:`torch.softmax` for very large input matrices. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax. # - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of data necessary. # This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**. # Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.

View File

@@ -1,7 +1,7 @@
""" """
Matrix Multiplication Matrix Multiplication
====================== ======================
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs. In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
You will specifically learn about: You will specifically learn about:
- Block-level matrix multiplications - Block-level matrix multiplications
@@ -14,9 +14,9 @@ You will specifically learn about:
# Motivations # Motivations
# ------------- # -------------
# Matrix multiplications are a key building block of most modern high-performance computing systems. # Matrix multiplications are a key building block of most modern high-performance computing systems.
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). # They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.). # Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend. # In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
# #
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm: # Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
# #
@@ -39,8 +39,8 @@ You will specifically learn about:
# Compute Kernel # Compute Kernel
# ---------------- # ----------------
# #
# The above algorithm is actually fairly straightforward to implement in Triton. # The above algorithm is, actually, fairly straightforward to implement in Triton.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop. # The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
# #
# Pointer Arithmetics # Pointer Arithmetics
# ~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~
@@ -50,10 +50,10 @@ You will specifically learn about:
# #
# .. code-block:: python # .. code-block:: python
# #
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]; # &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]; # &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
# #
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as: # Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
# #
# .. code-block:: python # .. code-block:: python
# #
@@ -67,7 +67,7 @@ You will specifically learn about:
# // pointer for B operand # // pointer for B operand
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1); # pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
# #
# These pointers can then be updated in the inner loop as: # And then updated in the inner loop as follows:
# #
# .. code-block:: python # .. code-block:: python
# #
@@ -79,8 +79,8 @@ You will specifically learn about:
# ~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~
# #
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`. # As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program. # It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
# This means that a naive row-major ordering: # And unfortunately, a simple row-major ordering
# #
# .. code-block:: Python # .. code-block:: Python
# #
@@ -90,7 +90,7 @@ You will specifically learn about:
# pid_m = pid / grid_n; # pid_m = pid / grid_n;
# pid_n = pid % grid_n; # pid_n = pid % grid_n;
# #
# is unlikely to result in optimal performance. # is just not going to cut it.
# #
# One possible solution is to launch blocks in an order that promotes data reuse. # One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column: # This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
@@ -122,22 +122,18 @@ import triton.language as tl
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try # - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs # - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
@triton.jit
def sigmoid(x):
ret_true = 1 / (1 + tl.exp(-x))
ret_false = tl.exp(x) / (1 + tl.exp(x))
return tl.where(x >= 0, ret_true, ret_false)
@triton.jit
def swish(x):
return x * sigmoid(x)
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4), #triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
], ],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
@@ -186,10 +182,14 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
tl.store(C, acc, mask=mask) tl.store(C, acc, mask=mask)
# %% # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
# We can also create a convenience wrapper function that only takes two input tensors @triton.jit
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel def leaky_relu(x):
return tl.where(x >= 0, x, 0.01*x)
# %%
# We can now create a convenience wrapper function that only takes two input tensors
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
def matmul(a, b, activation=None): def matmul(a, b, activation=None):
# checks constraints # checks constraints
@@ -207,8 +207,7 @@ def matmul(a, b, activation=None):
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\ a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation ACTIVATION = activation
) )
#print(pgm.asm('ttir')) # done; return the output tensor
# return output
return c return c
@@ -216,17 +215,16 @@ def matmul(a, b, activation=None):
# Unit Test # Unit Test
# ----------- # -----------
# #
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel) # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
#torch.manual_seed(0) torch.manual_seed(0)
# a = torch.randn((512, 512), device='cuda', dtype=torch.float16) a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# b = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# c_0 = matmul(a, b, activation=None) c_0 = matmul(a, b, activation=None)
# c_1 = torch.matmul(a, b) c_1 = torch.matmul(a, b)
# print(c_0) print(c_0)
# print(c_1) print(c_1)
# print(triton.testing.allclose(c_0, c_1)) print(triton.testing.allclose(c_0, c_1))
# exit()
# %% # %%
# Benchmark # Benchmark
@@ -234,29 +232,34 @@ def matmul(a, b, activation=None):
# #
# Square Matrix Performance # Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.# # We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[8192], # different possible values for `x_name` x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['cublas', 'triton'], # possible values for `line_arg`` line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
line_names=["cuBLAS", "Triton"], # label name for the lines line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
ylabel="TFLOPS", # label name for the y-axis ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={} args={}
) )
) )
def benchmark(M, N, K, provider): def benchmark(M, N, K, provider):
silu = torch.nn.SiLU()
a = torch.randn((M, K), device='cuda', dtype=torch.float16) a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas': if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton': if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms) return perf(ms), perf(max_ms), perf(min_ms)