[DOCS] Improved documentation and integration in CI (#139)

This commit is contained in:
Philippe Tillet
2021-07-22 22:45:19 -07:00
committed by Philippe Tillet
parent 76c6f24fb6
commit b253b77c71
5 changed files with 124 additions and 84 deletions

View File

@@ -1,7 +1,7 @@
"""
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
You will specifically learn about:
- Block-level matrix multiplications
@@ -14,9 +14,9 @@ You will specifically learn about:
# Motivations
# -------------
# Matrix multiplications are a key building block of most modern high-performance computing systems.
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
# They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
# In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
#
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
#
@@ -39,8 +39,8 @@ You will specifically learn about:
# Compute Kernel
# ----------------
#
# The above algorithm is actually fairly straightforward to implement in Triton.
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
# The above algorithm is, actually, fairly straightforward to implement in Triton.
# The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
#
# Pointer Arithmetics
# ~~~~~~~~~~~~~~~~~~~~
@@ -50,10 +50,10 @@ You will specifically learn about:
#
# .. code-block:: python
#
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
#
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
#
# .. code-block:: python
#
@@ -67,7 +67,7 @@ You will specifically learn about:
# // pointer for B operand
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
#
# These pointers can then be updated in the inner loop as:
# And then updated in the inner loop as follows:
#
# .. code-block:: python
#
@@ -79,8 +79,8 @@ You will specifically learn about:
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
# This means that a naive row-major ordering:
# It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
# And unfortunately, a simple row-major ordering
#
# .. code-block:: Python
#
@@ -90,7 +90,7 @@ You will specifically learn about:
# pid_m = pid / grid_n;
# pid_n = pid % grid_n;
#
# is unlikely to result in optimal performance.
# is just not going to cut it.
#
# One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
@@ -122,22 +122,18 @@ import triton.language as tl
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
@triton.jit
def sigmoid(x):
ret_true = 1 / (1 + tl.exp(-x))
ret_false = tl.exp(x) / (1 + tl.exp(x))
return tl.where(x >= 0, ret_true, ret_false)
@triton.jit
def swish(x):
return x * sigmoid(x)
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
],
key=['M', 'N', 'K'],
@@ -186,10 +182,14 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
tl.store(C, acc, mask=mask)
# %%
# We can also create a convenience wrapper function that only takes two input tensors
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01*x)
# %%
# We can now create a convenience wrapper function that only takes two input tensors
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
def matmul(a, b, activation=None):
# checks constraints
@@ -207,8 +207,7 @@ def matmul(a, b, activation=None):
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation
)
#print(pgm.asm('ttir'))
# return output
# done; return the output tensor
return c
@@ -216,17 +215,16 @@ def matmul(a, b, activation=None):
# Unit Test
# -----------
#
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
#torch.manual_seed(0)
# a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# c_0 = matmul(a, b, activation=None)
# c_1 = torch.matmul(a, b)
# print(c_0)
# print(c_1)
# print(triton.testing.allclose(c_0, c_1))
# exit()
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c_0 = matmul(a, b, activation=None)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(triton.testing.allclose(c_0, c_1))
# %%
# Benchmark
@@ -234,31 +232,36 @@ def matmul(a, b, activation=None):
#
# Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
# We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[8192], # different possible values for `x_name`
x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['cublas', 'triton'], # possible values for `line_arg``
line_names=["cuBLAS", "Triton"], # label name for the lines
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
)
)
def benchmark(M, N, K, provider):
silu = torch.nn.SiLU()
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True)
benchmark.run(show_plots=True, print_data=True)