[GH-PAGES] Updated website
@@ -15,7 +15,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.\nYou will learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- Reduction operators in Triton.\n"
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation that is significantly faster\nthan PyTorch's native op for a particular class of matrices: those whose rows can fit in\nthe GPU's SRAM.\nYou will learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- Reduction operators in Triton.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -33,21 +33,21 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n\n\n# Compute the row-wise softmax of x\n@torch.jit.script\ndef naive_softmax(x):\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read 2MN elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read 2MN elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 7MN elements ; wrote 3MN + 2M elements\n return ret"
|
||||
"import torch\n\n\n@torch.jit.script\ndef naive_softmax(x):\n \"\"\"Compute row-wise softmax of X using native pytorch\n\n We subtract the maximum element in order to avoid overflows. Softmax is invariant to\n this shift.\n \"\"\"\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read 2MN elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read 2MN elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 7MN elements ; wrote 3MN + 2M elements\n return ret"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads X once and does all the necessary computations on-chip.\nDoing so would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nThe `torch.jit.script` flags aims to perform this kind of \"kernel fusion\" automatically but, as we will see later, it is still far from ideal.\n\n"
|
||||
"When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$\nrequires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads\nX once and does all the necessary computations on-chip.\nDoing so would require reading and writing back only $MN$ bytes, so we could\nexpect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nThe `torch.jit.script` flags aims to perform this kind of \"kernel fusion\" automatically\nbut, as we will see later, it is still far from ideal.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a power-of-two number of elements,\nso we need to internally \"pad\" each row and guard the memory operations properly if we want to handle any possible input shapes:\n\n"
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X,\nnormalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a\npower-of-two number of elements, so we need to internally \"pad\" each row and guard the\nmemory operations properly if we want to handle any possible input shapes:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -58,7 +58,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):\n # row index\n m = tl.program_id(0)\n # col indices\n # here BLOCK is the smallest power of two greater than `N`\n n = tl.arange(0, meta['BLOCK'])\n # the memory address of all the elements\n # that we want to load can be computed as follows\n X = X + m * stride_xm + n\n x = tl.load(X, mask=n < N, other=-float('inf'))\n # Substract maximum for numerical stability\n z = x - tl.max(x, axis=0)\n # Note that exponentials in Triton are fast\n # but approximate (i.e., think __expf in CUDA)\n num = tl.exp(z)\n denom = tl.sum(num, axis=0)\n y = num / denom\n # Write back to Y\n Y = Y + m * stride_ym + n\n tl.store(Y, y, mask=n < N)"
|
||||
"import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef softmax_kernel(\n output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta\n):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n BLOCK_SIZE = meta['BLOCK_SIZE']\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Substract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -76,7 +76,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef softmax(x):\n M, N = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK = next_power_of_2(N)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK >= 2048: num_warps = 8\n if BLOCK >= 4096: num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix\n _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)\n return y"
|
||||
"def next_power_of_2(n):\n \"\"\"Return the smallest power of 2 greater than or equal to n\"\"\"\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef softmax(x):\n n_rows, n_cols = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = next_power_of_2(n_cols)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o\n # f the input matrix\n softmax_kernel[(n_rows,)](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return y"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -101,7 +101,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nx = torch.randn(1823, 781, device='cuda')\ny_tri = softmax(x)\ny_ref = torch.softmax(x, axis=1)\nprint(torch.allclose(y_tri, y_ref))"
|
||||
"torch.manual_seed(0)\nx = torch.randn(1823, 781, device='cuda')\ny_triton = softmax(x)\ny_torch = torch.softmax(x, axis=1)\nprint(torch.allclose(y_triton, y_torch))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -126,14 +126,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'], # argument names to use as an x-axis for the plot\n x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``\n line_names=[\"Triton\", \"Torch (native)\", \"Torch (jit)\"], # label name for the lines\n styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"softmax-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={'M': 4096} # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(M, N, provider):\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n if provider == 'torch-native':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))\n if provider == 'torch-jit':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))\n gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbenchmark.run(show_plots=True, print_data=True)"
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'], # argument names to use as an x-axis for the plot\n x_vals=[\n 128 * i for i in range(2, 100)\n ], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=[\n 'triton',\n 'torch-native',\n 'torch-jit',\n ], # possible values for `line_arg``\n line_names=[\n \"Triton\",\n \"Torch (native)\",\n \"Torch (jit)\",\n ], # label name for the lines\n styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"softmax-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(M, N, provider):\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n if provider == 'torch-native':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))\n if provider == 'torch-jit':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))\n gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbenchmark.run(show_plots=True, print_data=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In the above plot, we can see that:\n\n - Triton is 2-3x faster than the Torch JIT.\n - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.\n This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.\n Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.\n"
|
||||
"In the above plot, we can see that:\n\n - Triton is 2-3x faster than the Torch JIT.\n - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.\n This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.\n Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.\n\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@@ -13,31 +13,37 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
# --------------------------
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(
|
||||
X, # *Pointer* to first input vector
|
||||
Y, # *Pointer* to second input vector
|
||||
Z, # *Pointer* to output vector
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
**meta, # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = tl.load(X + offsets, mask=mask)
|
||||
y = tl.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
tl.store(Z + offsets, z)
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extar elements in case the input is not a
|
||||
# multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output)
|
||||
|
||||
|
||||
# %%
|
||||
@@ -45,20 +51,23 @@ def _add(
|
||||
# and (2) enqueue the above kernel with appropriate grid/block sizes.
|
||||
|
||||
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# We need to preallocate the output
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda and y.is_cuda and output.is_cuda
|
||||
n_elements = output.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# In this case, we use a 1D grid where the size is the number of blocks
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# NOTE:
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
return z
|
||||
return output
|
||||
|
||||
|
||||
# %%
|
||||
@@ -68,11 +77,14 @@ torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
print(zb)
|
||||
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||
output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
# %%
|
||||
# Seems like we're good to go!
|
||||
@@ -88,15 +100,17 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
||||
x_vals=[
|
||||
2 ** i for i in range(12, 28, 1)
|
||||
], # different possible values for `x_name`
|
||||
x_log=True, # x axis is logarithmic
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['triton', 'torch'], # possible values for `line_arg`
|
||||
line_names=["Triton", "Torch"], # label name for the lines
|
||||
line_names=['Triton', 'Torch'], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={} # values for function arguments not in `x_names` and `y_name`
|
||||
ylabel='GB/s', # label name for the y-axis
|
||||
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(size, provider):
|
||||
|
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
|
||||
kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:
|
||||
|
||||
- Block-level matrix multiplications
|
||||
@@ -14,24 +15,28 @@ You will specifically learn about:
|
||||
# Motivations
|
||||
# -------------
|
||||
# Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
# They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
# In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
|
||||
# They are notoriously hard to optimize, hence their implementation is generally done by
|
||||
# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized
|
||||
# to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
# In this tutorial, you will learn how to implement efficient matrix multiplications by
|
||||
# yourself with Triton, in a way that is easy to customize and extend.
|
||||
#
|
||||
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
# Roughly speaking, the kernel that we will write will implement the following blocked
|
||||
# algorithm to multiply a (MxK) by a (KxN) matrix:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# # do in parallel
|
||||
# for m in range(0, M, BLOCK_M):
|
||||
# for m in range(0, M, BLOCK_SIZE_M):
|
||||
# # do in parallel
|
||||
# for n in range(0, N, BLOCK_N):
|
||||
# acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
|
||||
# for k in range(0, K, BLOCK_K):
|
||||
# a = A[m : m+BLOCK_M, k : k+BLOCK_K]
|
||||
# b = B[k : k+BLOCK_K, n : n+BLOCK_N]
|
||||
# for n in range(0, N, BLOCK_SIZE_N):
|
||||
# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
|
||||
# for k in range(0, K, BLOCK_SIZE_K):
|
||||
# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
|
||||
# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
|
||||
# acc += dot(a, b)
|
||||
# C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
|
||||
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
||||
#
|
||||
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
||||
|
||||
@@ -40,18 +45,22 @@ You will specifically learn about:
|
||||
# ----------------
|
||||
#
|
||||
# The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
|
||||
# The main difficulty comes from the computation of the memory locations at which blocks
|
||||
# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
|
||||
# multi-dimensional pointer arithmetics.
|
||||
#
|
||||
# Pointer Arithmetics
|
||||
# ~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
||||
# y :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
||||
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
|
||||
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
|
||||
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
||||
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
||||
#
|
||||
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||
#
|
||||
@@ -59,9 +68,9 @@ You will specifically learn about:
|
||||
#
|
||||
# pid_m = triton.program_id(0)
|
||||
# pid_n = triton.program_id(1)
|
||||
# rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
# rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
# rk = triton.arange(0, BLOCK_K)
|
||||
# rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)
|
||||
# rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)
|
||||
# rk = triton.arange(0, BLOCK_SIZE_K)
|
||||
# // pointer for A operand
|
||||
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
||||
# // pointer for B operand
|
||||
@@ -71,41 +80,51 @@ You will specifically learn about:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# pa += BLOCK_K * stride_a_1;
|
||||
# pb += BLOCK_K * stride_b_0;
|
||||
# pa += BLOCK_SIZE_K * stride_a_1;
|
||||
# pb += BLOCK_SIZE_K * stride_b_0;
|
||||
#
|
||||
#
|
||||
# L2 Cache Optimizations
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
# It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
|
||||
# And unfortunately, a simple row-major ordering
|
||||
# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
|
||||
# block of :code:`C`.
|
||||
# It is important to remember that the order in which these blocks are computed does
|
||||
# matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||||
# a simple row-major ordering
|
||||
#
|
||||
# .. code-block:: Python
|
||||
#
|
||||
# pid = triton.program_id(0);
|
||||
# grid_m = (M + BLOCK_M - 1) // BLOCK_M;
|
||||
# grid_n = (N + BLOCK_N - 1) // BLOCK_N;
|
||||
# grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
|
||||
# grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
|
||||
# pid_m = pid / grid_n;
|
||||
# pid_n = pid % grid_n;
|
||||
#
|
||||
# is just not going to cut it.
|
||||
#
|
||||
# One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
|
||||
# switching to the next column:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# pid = triton.program_id(0);
|
||||
# width = GROUP_M * grid_n;
|
||||
# group_id = pid // width;
|
||||
# # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
||||
# # we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0
|
||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
# pid_n = (pid % width) // (group_size);
|
||||
|
||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
# ordering, we only need to load 54 blocks.
|
||||
# .. image:: grouped_vs_row_major_ordering.png
|
||||
#
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by
|
||||
# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
#
|
||||
|
||||
# %%
|
||||
@@ -118,96 +137,165 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
|
||||
# decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of
|
||||
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
|
||||
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
||||
# provided configs
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
|
||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
|
||||
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
# %
|
||||
# We can now define our kernel as normal, using all the techniques presented above
|
||||
@triton.jit
|
||||
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
**meta,
|
||||
):
|
||||
"""Kernel for computing the matmul AB = C
|
||||
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# extract meta-parameters
|
||||
BLOCK_M = META['BLOCK_M']
|
||||
BLOCK_N = META['BLOCK_N']
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = 8
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
# triton can accept arbitrary activation function
|
||||
# via metaparameters!
|
||||
if META['ACTIVATION']:
|
||||
acc = META['ACTIVATION'](acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
tl.store(C, acc, mask=mask)
|
||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
||||
GROUP_SIZE_M = 8
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block
|
||||
# Note that this will lead to some quantization in performance where time-taken jumps
|
||||
# when you need to add a new block
|
||||
n_blocks_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
||||
n_blocks_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
||||
|
||||
# Map PIDs to the block they should compute. This is done in a grouped ordering
|
||||
# to promote L2 cache reuse.
|
||||
n_output_blocks_in_group = GROUP_SIZE_M * n_blocks_n
|
||||
group_id = pid // n_output_blocks_in_group
|
||||
first_m_block_in_group = group_id * GROUP_SIZE_M
|
||||
|
||||
# If the number of blocks is not divisible by the group size, the last group is smaller
|
||||
group_size_m = min(n_blocks_m - first_m_block_in_group, GROUP_SIZE_M)
|
||||
|
||||
# Within a group, we compute in col-major ordering, block_m and block_n are the
|
||||
# output row and col that this program is computing in terms of blocks
|
||||
block_m = first_m_block_in_group + (pid % group_size_m)
|
||||
block_n = (pid % n_output_blocks_in_group) // group_size_m
|
||||
|
||||
# Convert from block indices back to element indices
|
||||
m_start = block_m * BLOCK_SIZE_M
|
||||
n_start = block_n * BLOCK_SIZE_N
|
||||
|
||||
# Expand out to all the offsets for each of the elements in this block.
|
||||
m_offsets_a = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
||||
n_offsets_b = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
# Get the pointers for the first block of each. We will advance this pointer
|
||||
# as we move in the K direction and accumulate.
|
||||
# a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers
|
||||
a_ptrs = a_ptr + (stride_am * m_offsets_a + stride_ak * k_offsets[None, :])
|
||||
# b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers
|
||||
b_ptrs = b_ptr + (stride_bk * k_offsets[:, None] + stride_bn * n_offsets_b)
|
||||
# We accumulate internally in fp32, but the output is written out in the dtype
|
||||
# of the tensor when it is stored
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
# Note that for simplicity, we don't apply a mask here. This means that if K is
|
||||
# not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and
|
||||
# accumulate it incorrectly.
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
# We accumulate along the K dimension
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
# Advance the ptrs to the next K block
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# triton can accept arbitrary activation function via metaparameters!
|
||||
if meta['ACTIVATION']:
|
||||
accumulator = meta['ACTIVATION'](accumulator)
|
||||
|
||||
m_offsets_c = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
||||
n_offsets_c = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
||||
c_ptrs = c_ptr + stride_cm * m_offsets_c + stride_cn * n_offsets_c
|
||||
mask = (m_offsets_c < M) & (n_offsets_c < N)
|
||||
tl.store(c_ptrs, accumulator, mask=mask)
|
||||
|
||||
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
return tl.where(x >= 0, x, 0.01*x)
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
# %%
|
||||
# We can now create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
K, N = b.shape
|
||||
assert (
|
||||
K % 32 == 0
|
||||
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
pgm = _matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
# done; return the output tensor
|
||||
return c
|
||||
|
||||
|
||||
@@ -220,11 +308,14 @@ def matmul(a, b, activation=None):
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=None)
|
||||
c_1 = torch.matmul(a, b)
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"{triton_output=}")
|
||||
print(f"{torch_output=}")
|
||||
if triton.testing.allclose(triton_output, torch_output):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
@@ -238,14 +329,19 @@ print(triton.testing.allclose(c_0, c_1))
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
|
||||
x_vals=[
|
||||
128 * i for i in range(1, 33)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
@@ -257,9 +353,13 @@ def benchmark(M, N, K, provider):
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch_relu(torch.matmul(a, b))
|
||||
)
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, activation=leaky_relu)
|
||||
)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster
|
||||
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
|
||||
the GPU's SRAM.
|
||||
You will learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
@@ -17,9 +19,13 @@ You will learn about:
|
||||
import torch
|
||||
|
||||
|
||||
# Compute the row-wise softmax of x
|
||||
@torch.jit.script
|
||||
def naive_softmax(x):
|
||||
"""Compute row-wise softmax of X using native pytorch
|
||||
|
||||
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
|
||||
this shift.
|
||||
"""
|
||||
# read MN elements ; write M elements
|
||||
x_max = x.max(dim=1)[0]
|
||||
# read 2MN elements ; write MN elements
|
||||
@@ -35,43 +41,54 @@ def naive_softmax(x):
|
||||
|
||||
|
||||
# %%
|
||||
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
# Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
|
||||
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
|
||||
# requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
|
||||
# X once and does all the necessary computations on-chip.
|
||||
# Doing so would require reading and writing back only :math:`MN` bytes, so we could
|
||||
# expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
|
||||
# but, as we will see later, it is still far from ideal.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
# so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
# Our softmax kernel works as follows: each program loads a row of the input matrix X,
|
||||
# normalizes it and writes back the result to the output Y.
|
||||
# Note that one important limitation of Triton is that each block must have a
|
||||
# power-of-two number of elements, so we need to internally "pad" each row and guard the
|
||||
# memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
# row index
|
||||
m = tl.program_id(0)
|
||||
# col indices
|
||||
# here BLOCK is the smallest power of two greater than `N`
|
||||
n = tl.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
X = X + m * stride_xm + n
|
||||
x = tl.load(X, mask=n < N, other=-float('inf'))
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
|
||||
):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
row_idx = tl.program_id(0)
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
|
||||
# The block size is the next power of two greater than n_cols, so we can fit each
|
||||
# row in a single block
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
z = x - tl.max(x, axis=0)
|
||||
# Note that exponentials in Triton are fast
|
||||
# but approximate (i.e., think __expf in CUDA)
|
||||
num = tl.exp(z)
|
||||
denom = tl.sum(num, axis=0)
|
||||
y = num / denom
|
||||
# Write back to Y
|
||||
Y = Y + m * stride_ym + n
|
||||
tl.store(Y, y, mask=n < N)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
# %%
|
||||
@@ -79,6 +96,7 @@ def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
@@ -90,20 +108,31 @@ def next_power_of_2(n):
|
||||
|
||||
|
||||
def softmax(x):
|
||||
M, N = x.shape
|
||||
n_rows, n_cols = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
BLOCK_SIZE = next_power_of_2(n_cols)
|
||||
# Another trick we can use is to ask the compiler to use more threads per row by
|
||||
# increasing the number of warps (`num_warps`) over which each row is distributed.
|
||||
# You will see in the next tutorial how to auto-tune this value in a more natural
|
||||
# way so you don't have to come up with manual heuristics yourself.
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)
|
||||
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
|
||||
# f the input matrix
|
||||
softmax_kernel[(n_rows,)](
|
||||
y,
|
||||
x,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
n_cols,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
@@ -117,9 +146,9 @@ def softmax(x):
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(1823, 781, device='cuda')
|
||||
y_tri = softmax(x)
|
||||
y_ref = torch.softmax(x, axis=1)
|
||||
print(torch.allclose(y_tri, y_ref))
|
||||
y_triton = softmax(x)
|
||||
y_torch = torch.softmax(x, axis=1)
|
||||
print(torch.allclose(y_triton, y_torch))
|
||||
|
||||
#%%
|
||||
# As expected, the results are identical.
|
||||
@@ -134,14 +163,24 @@ print(torch.allclose(y_tri, y_ref))
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 100)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
|
||||
line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
|
||||
line_vals=[
|
||||
'triton',
|
||||
'torch-native',
|
||||
'torch-jit',
|
||||
], # possible values for `line_arg``
|
||||
line_names=[
|
||||
"Triton",
|
||||
"Torch (native)",
|
||||
"Torch (jit)",
|
||||
], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
|
@@ -33,7 +33,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport triton.language as tl\nimport triton\n\n\n@triton.jit\ndef _add(\n X, # *Pointer* to first input vector\n Y, # *Pointer* to second input vector\n Z, # *Pointer* to output vector\n N, # Size of the vector\n **meta # Optional meta-parameters for the kernel\n):\n pid = tl.program_id(0)\n # Create an offset for the blocks of pointers to be\n # processed by this program instance\n offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])\n # Create a mask to guard memory operations against\n # out-of-bounds accesses\n mask = offsets < N\n # Load x\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n # Write back x + y\n z = x + y\n tl.store(Z + offsets, z)"
|
||||
"import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n **meta, # Optional meta-parameters for the kernel\n):\n BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process\n # There are multiple 'program's processing different data. We identify which program\n # we are here\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n # This program will process inputs that are offset from the initial data.\n # for instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extar elements in case the input is not a\n # multiple of the block size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM\n tl.store(output_ptr + offsets, output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -51,7 +51,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def add(x, y):\n z = torch.empty_like(x)\n N = z.shape[0]\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )\n # NOTE:\n # - each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n _add[grid](x, y, z, N, BLOCK=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return z"
|
||||
"def add(x: torch.Tensor, y: torch.Tensor):\n # We need to preallocate the output\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.shape[0]\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n # In this case, we use a 1D grid where the size is the number of blocks\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n # NOTE:\n # - each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return output"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -69,7 +69,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')"
|
||||
"torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -94,7 +94,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['size'], # argument names to use as an x-axis for the plot\n x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`\n x_log=True, # x axis is logarithmic\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['triton', 'torch'], # possible values for `line_arg`\n line_names=[\"Triton\", \"Torch\"], # label name for the lines\n styles=[('blue', '-'), ('green', '-')], # line styles\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"vector-add-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={} # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(size, provider):\n x = torch.rand(size, device='cuda', dtype=torch.float32)\n y = torch.rand(size, device='cuda', dtype=torch.float32)\n if provider == 'torch':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))\n gbps = lambda ms: 12 * size / ms * 1e-6\n return gbps(ms), gbps(max_ms), gbps(min_ms)"
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['size'], # argument names to use as an x-axis for the plot\n x_vals=[\n 2 ** i for i in range(12, 28, 1)\n ], # different possible values for `x_name`\n x_log=True, # x axis is logarithmic\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['triton', 'torch'], # possible values for `line_arg`\n line_names=['Triton', 'Torch'], # label name for the lines\n styles=[('blue', '-'), ('green', '-')], # line styles\n ylabel='GB/s', # label name for the y-axis\n plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.\n args={}, # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(size, provider):\n x = torch.rand(size, device='cuda', dtype=torch.float32)\n y = torch.rand(size, device='cuda', dtype=torch.float32)\n if provider == 'torch':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))\n gbps = lambda ms: 12 * size / ms * 1e-6\n return gbps(ms), gbps(max_ms), gbps(min_ms)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 37 KiB After Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 55 KiB After Width: | Height: | Size: 55 KiB |
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 32 KiB |
@@ -31,37 +31,43 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
Compute Kernel
|
||||
--------------------------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-43
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-49
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(
|
||||
X, # *Pointer* to first input vector
|
||||
Y, # *Pointer* to second input vector
|
||||
Z, # *Pointer* to output vector
|
||||
N, # Size of the vector
|
||||
**meta # Optional meta-parameters for the kernel
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
**meta, # Optional meta-parameters for the kernel
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
# Create an offset for the blocks of pointers to be
|
||||
# processed by this program instance
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
# Create a mask to guard memory operations against
|
||||
# out-of-bounds accesses
|
||||
mask = offsets < N
|
||||
# Load x
|
||||
x = tl.load(X + offsets, mask=mask)
|
||||
y = tl.load(Y + offsets, mask=mask)
|
||||
# Write back x + y
|
||||
z = x + y
|
||||
tl.store(Z + offsets, z)
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extar elements in case the input is not a
|
||||
# multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output)
|
||||
|
||||
|
||||
|
||||
@@ -71,31 +77,34 @@ Compute Kernel
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-46
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 50-52
|
||||
|
||||
Let's also declare a helper function to (1) allocate the `z` tensor
|
||||
and (2) enqueue the above kernel with appropriate grid/block sizes.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 46-64
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 52-73
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# We need to preallocate the output
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda and y.is_cuda and output.is_cuda
|
||||
n_elements = output.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# In this case, we use a 1D grid where the size is the number of blocks
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# NOTE:
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
return z
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@@ -105,11 +114,11 @@ and (2) enqueue the above kernel with appropriate grid/block sizes.
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 65-66
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 74-75
|
||||
|
||||
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 66-77
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 75-89
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -118,11 +127,14 @@ We can now use the above function to compute the element-wise sum of two `torch.
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
print(zb)
|
||||
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||
output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -141,11 +153,11 @@ We can now use the above function to compute the element-wise sum of two `torch.
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 78-79
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 90-91
|
||||
|
||||
Seems like we're good to go!
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 81-86
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 93-98
|
||||
|
||||
Benchmark
|
||||
-----------
|
||||
@@ -153,7 +165,7 @@ We can now benchmark our custom op on vectors of increasing sizes to get a sense
|
||||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
||||
for different problem sizes.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 86-113
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 98-127
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -162,15 +174,17 @@ for different problem sizes.
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
||||
x_vals=[
|
||||
2 ** i for i in range(12, 28, 1)
|
||||
], # different possible values for `x_name`
|
||||
x_log=True, # x axis is logarithmic
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['triton', 'torch'], # possible values for `line_arg`
|
||||
line_names=["Triton", "Torch"], # label name for the lines
|
||||
line_names=['Triton', 'Torch'], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={} # values for function arguments not in `x_names` and `y_name`
|
||||
ylabel='GB/s', # label name for the y-axis
|
||||
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(size, provider):
|
||||
@@ -191,18 +205,19 @@ for different problem sizes.
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 114-116
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 128-130
|
||||
|
||||
We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
|
||||
`save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 116-116
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 130-131
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
|
||||
:alt: 01 vector add
|
||||
:class: sphx-glr-single-img
|
||||
@@ -218,16 +233,16 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
size Triton Torch
|
||||
0 4096.0 9.600000 9.600000
|
||||
1 8192.0 19.200000 19.200000
|
||||
2 16384.0 38.400001 38.400001
|
||||
3 32768.0 76.800002 76.800002
|
||||
2 16384.0 31.999999 31.999999
|
||||
3 32768.0 63.999998 76.800002
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 384.000001 384.000001
|
||||
6 262144.0 341.333321 384.000001
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 722.823517
|
||||
10 4194304.0 780.190482 780.190482
|
||||
11 8388608.0 819.200021 812.429770
|
||||
11 8388608.0 812.429770 812.429770
|
||||
12 16777216.0 833.084721 833.084721
|
||||
13 33554432.0 843.811163 843.811163
|
||||
14 67108864.0 849.278610 848.362445
|
||||
@@ -239,7 +254,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 10.996 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 11.055 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
@@ -20,20 +20,22 @@
|
||||
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster
|
||||
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
|
||||
the GPU's SRAM.
|
||||
You will learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- Reduction operators in Triton.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-16
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-18
|
||||
|
||||
Motivations
|
||||
------------
|
||||
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 16-37
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 18-43
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -41,9 +43,13 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
import torch
|
||||
|
||||
|
||||
# Compute the row-wise softmax of x
|
||||
@torch.jit.script
|
||||
def naive_softmax(x):
|
||||
"""Compute row-wise softmax of X using native pytorch
|
||||
|
||||
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
|
||||
this shift.
|
||||
"""
|
||||
# read MN elements ; write M elements
|
||||
x_max = x.max(dim=1)[0]
|
||||
# read 2MN elements ; write MN elements
|
||||
@@ -65,22 +71,28 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 38-42
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-52
|
||||
|
||||
When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
|
||||
When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
|
||||
requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
|
||||
X once and does all the necessary computations on-chip.
|
||||
Doing so would require reading and writing back only :math:`MN` bytes, so we could
|
||||
expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
|
||||
but, as we will see later, it is still far from ideal.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-49
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 54-61
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
Our softmax kernel works as follows: each program loads a row of the input matrix X,
|
||||
normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a
|
||||
power-of-two number of elements, so we need to internally "pad" each row and guard the
|
||||
memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 49-77
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 61-94
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -90,26 +102,31 @@ so we need to internally "pad" each row and guard the memory operations properly
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
# row index
|
||||
m = tl.program_id(0)
|
||||
# col indices
|
||||
# here BLOCK is the smallest power of two greater than `N`
|
||||
n = tl.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
X = X + m * stride_xm + n
|
||||
x = tl.load(X, mask=n < N, other=-float('inf'))
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
|
||||
):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
row_idx = tl.program_id(0)
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
|
||||
# The block size is the next power of two greater than n_cols, so we can fit each
|
||||
# row in a single block
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
z = x - tl.max(x, axis=0)
|
||||
# Note that exponentials in Triton are fast
|
||||
# but approximate (i.e., think __expf in CUDA)
|
||||
num = tl.exp(z)
|
||||
denom = tl.sum(num, axis=0)
|
||||
y = num / denom
|
||||
# Write back to Y
|
||||
Y = Y + m * stride_ym + n
|
||||
tl.store(Y, y, mask=n < N)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
|
||||
@@ -119,17 +136,18 @@ so we need to internally "pad" each row and guard the memory operations properly
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 78-79
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 95-96
|
||||
|
||||
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 79-110
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 96-139
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
@@ -141,20 +159,31 @@ We can create a helper function that enqueues the kernel and its (meta-)argument
|
||||
|
||||
|
||||
def softmax(x):
|
||||
M, N = x.shape
|
||||
n_rows, n_cols = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
BLOCK_SIZE = next_power_of_2(n_cols)
|
||||
# Another trick we can use is to ask the compiler to use more threads per row by
|
||||
# increasing the number of warps (`num_warps`) over which each row is distributed.
|
||||
# You will see in the next tutorial how to auto-tune this value in a more natural
|
||||
# way so you don't have to come up with manual heuristics yourself.
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
|
||||
_softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)
|
||||
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
|
||||
# f the input matrix
|
||||
softmax_kernel[(n_rows,)](
|
||||
y,
|
||||
x,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
n_cols,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
@@ -165,26 +194,26 @@ We can create a helper function that enqueues the kernel and its (meta-)argument
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 111-113
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 140-142
|
||||
|
||||
Unit Test
|
||||
----------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 115-117
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 144-146
|
||||
|
||||
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
|
||||
This will allow us to verify that our padding mechanism works.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 117-124
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 146-153
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(1823, 781, device='cuda')
|
||||
y_tri = softmax(x)
|
||||
y_ref = torch.softmax(x, axis=1)
|
||||
print(torch.allclose(y_tri, y_ref))
|
||||
y_triton = softmax(x)
|
||||
y_torch = torch.softmax(x, axis=1)
|
||||
print(torch.allclose(y_triton, y_torch))
|
||||
|
||||
|
||||
|
||||
@@ -201,18 +230,18 @@ This will allow us to verify that our padding mechanism works.
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 125-126
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 154-155
|
||||
|
||||
As expected, the results are identical.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 128-132
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 157-161
|
||||
|
||||
Benchmark
|
||||
-------------
|
||||
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 132-161
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 161-200
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -221,14 +250,24 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 100)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
|
||||
line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
|
||||
line_vals=[
|
||||
'triton',
|
||||
'torch-native',
|
||||
'torch-jit',
|
||||
], # possible values for `line_arg``
|
||||
line_names=[
|
||||
"Triton",
|
||||
"Torch (native)",
|
||||
"Torch (jit)",
|
||||
], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
@@ -263,22 +302,22 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 512.000001 546.133347 186.181817
|
||||
1 384.0 585.142862 585.142862 153.600004
|
||||
2 512.0 630.153853 585.142849 154.566038
|
||||
3 640.0 682.666684 640.000002 160.000000
|
||||
2 512.0 630.153853 606.814814 154.566038
|
||||
3 640.0 660.645170 640.000002 160.000000
|
||||
4 768.0 702.171410 664.216187 163.839992
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 406.179533 199.140227
|
||||
94 12288.0 812.429770 415.661740 199.399583
|
||||
95 12416.0 810.840807 412.149375 199.054102
|
||||
96 12544.0 810.925276 412.971190 199.308841
|
||||
97 12672.0 811.007961 412.097543 199.264875
|
||||
93 12160.0 812.359066 406.179533 199.038365
|
||||
94 12288.0 812.429770 415.222812 199.298541
|
||||
95 12416.0 810.840807 412.149375 198.854847
|
||||
96 12544.0 810.925276 412.971190 199.209928
|
||||
97 12672.0 809.389265 412.097543 199.167004
|
||||
|
||||
[98 rows x 4 columns]
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 162-167
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 201-207
|
||||
|
||||
In the above plot, we can see that:
|
||||
|
||||
@@ -290,7 +329,7 @@ In the above plot, we can see that:
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 12.626 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 13.186 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
@@ -20,7 +20,8 @@
|
||||
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
|
||||
kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:
|
||||
|
||||
- Block-level matrix multiplications
|
||||
@@ -28,50 +29,58 @@ You will specifically learn about:
|
||||
- Program re-ordering for improved L2 cache hit rate
|
||||
- Automatic performance tuning
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-37
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 15-42
|
||||
|
||||
Motivations
|
||||
-------------
|
||||
Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
|
||||
They are notoriously hard to optimize, hence their implementation is generally done by
|
||||
hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized
|
||||
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
In this tutorial, you will learn how to implement efficient matrix multiplications by
|
||||
yourself with Triton, in a way that is easy to customize and extend.
|
||||
|
||||
Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
Roughly speaking, the kernel that we will write will implement the following blocked
|
||||
algorithm to multiply a (MxK) by a (KxN) matrix:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# do in parallel
|
||||
for m in range(0, M, BLOCK_M):
|
||||
for m in range(0, M, BLOCK_SIZE_M):
|
||||
# do in parallel
|
||||
for n in range(0, N, BLOCK_N):
|
||||
acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = A[m : m+BLOCK_M, k : k+BLOCK_K]
|
||||
b = B[k : k+BLOCK_K, n : n+BLOCK_N]
|
||||
for n in range(0, N, BLOCK_SIZE_N):
|
||||
acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
|
||||
b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
|
||||
acc += dot(a, b)
|
||||
C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
|
||||
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
||||
|
||||
where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 39-110
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-119
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
|
||||
The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
|
||||
The main difficulty comes from the computation of the memory locations at which blocks
|
||||
of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
|
||||
multi-dimensional pointer arithmetics.
|
||||
|
||||
Pointer Arithmetics
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:
|
||||
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
||||
y :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
||||
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
||||
:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
&A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
|
||||
&B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
|
||||
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
||||
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
||||
|
||||
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||
|
||||
@@ -79,9 +88,9 @@ Which means that pointers for blocks of A and B can be initialized (i.e., :code:
|
||||
|
||||
pid_m = triton.program_id(0)
|
||||
pid_n = triton.program_id(1)
|
||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
||||
rk = triton.arange(0, BLOCK_K)
|
||||
rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)
|
||||
rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)
|
||||
rk = triton.arange(0, BLOCK_SIZE_K)
|
||||
// pointer for A operand
|
||||
pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
||||
// pointer for B operand
|
||||
@@ -91,50 +100,72 @@ And then updated in the inner loop as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pa += BLOCK_K * stride_a_1;
|
||||
pb += BLOCK_K * stride_b_0;
|
||||
pa += BLOCK_SIZE_K * stride_a_1;
|
||||
pb += BLOCK_SIZE_K * stride_b_0;
|
||||
|
||||
|
||||
L2 Cache Optimizations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
|
||||
And unfortunately, a simple row-major ordering
|
||||
As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
|
||||
block of :code:`C`.
|
||||
It is important to remember that the order in which these blocks are computed does
|
||||
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||||
a simple row-major ordering
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
pid = triton.program_id(0);
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M;
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N;
|
||||
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
|
||||
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
|
||||
pid_m = pid / grid_n;
|
||||
pid_n = pid % grid_n;
|
||||
|
||||
is just not going to cut it.
|
||||
|
||||
One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
|
||||
switching to the next column:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pid = triton.program_id(0);
|
||||
width = GROUP_M * grid_n;
|
||||
group_id = pid // width;
|
||||
# we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
|
||||
# we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
||||
pid_m = group_id * GROUP_M + (pid % group_size);
|
||||
pid_n = (pid % width) // (group_size);
|
||||
|
||||
In practice, this can improve the performance of our matrix multiplication kernel by >10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 119-130
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 112-115
|
||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
# ordering, we only need to load 54 blocks.
|
||||
# .. image:: grouped_vs_row_major_ordering.png
|
||||
#
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by
|
||||
# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
#
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 131-134
|
||||
|
||||
Final Result
|
||||
-------------
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 115-190
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 134-263
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -144,74 +175,127 @@ Final Result
|
||||
import triton.language as tl
|
||||
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
|
||||
# decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of
|
||||
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
|
||||
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
||||
# provided configs
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
|
||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
|
||||
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
# %
|
||||
# We can now define our kernel as normal, using all the techniques presented above
|
||||
@triton.jit
|
||||
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
**meta,
|
||||
):
|
||||
"""Kernel for computing the matmul AB = C
|
||||
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# extract meta-parameters
|
||||
BLOCK_M = META['BLOCK_M']
|
||||
BLOCK_N = META['BLOCK_N']
|
||||
BLOCK_K = META['BLOCK_K']
|
||||
GROUP_M = 8
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
# triton can accept arbitrary activation function
|
||||
# via metaparameters!
|
||||
if META['ACTIVATION']:
|
||||
acc = META['ACTIVATION'](acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
tl.store(C, acc, mask=mask)
|
||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
||||
GROUP_SIZE_M = 8
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block
|
||||
# Note that this will lead to some quantization in performance where time-taken jumps
|
||||
# when you need to add a new block
|
||||
n_blocks_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
||||
n_blocks_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
||||
|
||||
# Map PIDs to the block they should compute. This is done in a grouped ordering
|
||||
# to promote L2 cache reuse.
|
||||
n_output_blocks_in_group = GROUP_SIZE_M * n_blocks_n
|
||||
group_id = pid // n_output_blocks_in_group
|
||||
first_m_block_in_group = group_id * GROUP_SIZE_M
|
||||
|
||||
# If the number of blocks is not divisible by the group size, the last group is smaller
|
||||
group_size_m = min(n_blocks_m - first_m_block_in_group, GROUP_SIZE_M)
|
||||
|
||||
# Within a group, we compute in col-major ordering, block_m and block_n are the
|
||||
# output row and col that this program is computing in terms of blocks
|
||||
block_m = first_m_block_in_group + (pid % group_size_m)
|
||||
block_n = (pid % n_output_blocks_in_group) // group_size_m
|
||||
|
||||
# Convert from block indices back to element indices
|
||||
m_start = block_m * BLOCK_SIZE_M
|
||||
n_start = block_n * BLOCK_SIZE_N
|
||||
|
||||
# Expand out to all the offsets for each of the elements in this block.
|
||||
m_offsets_a = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
||||
n_offsets_b = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
# Get the pointers for the first block of each. We will advance this pointer
|
||||
# as we move in the K direction and accumulate.
|
||||
# a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers
|
||||
a_ptrs = a_ptr + (stride_am * m_offsets_a + stride_ak * k_offsets[None, :])
|
||||
# b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers
|
||||
b_ptrs = b_ptr + (stride_bk * k_offsets[:, None] + stride_bn * n_offsets_b)
|
||||
# We accumulate internally in fp32, but the output is written out in the dtype
|
||||
# of the tensor when it is stored
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
# Note that for simplicity, we don't apply a mask here. This means that if K is
|
||||
# not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and
|
||||
# accumulate it incorrectly.
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
# We accumulate along the K dimension
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
# Advance the ptrs to the next K block
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# triton can accept arbitrary activation function via metaparameters!
|
||||
if meta['ACTIVATION']:
|
||||
accumulator = meta['ACTIVATION'](accumulator)
|
||||
|
||||
m_offsets_c = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
||||
n_offsets_c = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
||||
c_ptrs = c_ptr + stride_cm * m_offsets_c + stride_cn * n_offsets_c
|
||||
mask = (m_offsets_c < M) & (n_offsets_c < N)
|
||||
tl.store(c_ptrs, accumulator, mask=mask)
|
||||
|
||||
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
return tl.where(x >= 0, x, 0.01*x)
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
|
||||
@@ -220,33 +304,49 @@ Final Result
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 191-193
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 264-266
|
||||
|
||||
We can now create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 193-214
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 266-302
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
K, N = b.shape
|
||||
assert (
|
||||
K % 32 == 0
|
||||
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
pgm = _matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
# done; return the output tensor
|
||||
return c
|
||||
|
||||
|
||||
@@ -257,14 +357,14 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 215-219
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 303-307
|
||||
|
||||
Unit Test
|
||||
-----------
|
||||
|
||||
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 219-229
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 307-320
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -272,11 +372,14 @@ We can test our custom matrix multiplication operation against a native torch im
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=None)
|
||||
c_1 = torch.matmul(a, b)
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"{triton_output=}")
|
||||
print(f"{torch_output=}")
|
||||
if triton.testing.allclose(triton_output, torch_output):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
|
||||
|
||||
@@ -288,7 +391,7 @@ We can test our custom matrix multiplication operation against a native torch im
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||||
triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||||
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
|
||||
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
|
||||
...,
|
||||
@@ -296,7 +399,7 @@ We can test our custom matrix multiplication operation against a native torch im
|
||||
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
|
||||
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
|
||||
...,
|
||||
@@ -304,12 +407,12 @@ We can test our custom matrix multiplication operation against a native torch im
|
||||
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor(True, device='cuda:0')
|
||||
✅ Triton and Torch match
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 230-236
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 321-327
|
||||
|
||||
Benchmark
|
||||
--------------
|
||||
@@ -318,7 +421,7 @@ Square Matrix Performance
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 236-268
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 327-368
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -327,14 +430,19 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
|
||||
x_vals=[
|
||||
128 * i for i in range(1, 33)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
@@ -346,9 +454,13 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch_relu(torch.matmul(a, b))
|
||||
)
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, activation=leaky_relu)
|
||||
)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
@@ -371,37 +483,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
matmul-performance:
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 128.0 0.455111 ... 0.512000 0.512000
|
||||
1 256.0 2.730667 ... 2.978909 2.978909
|
||||
2 384.0 7.372800 ... 7.899428 8.507077
|
||||
3 512.0 14.563555 ... 16.384000 16.384000
|
||||
1 256.0 2.978909 ... 2.978909 2.978909
|
||||
2 384.0 7.372800 ... 8.507077 7.899428
|
||||
3 512.0 14.563555 ... 16.384000 15.420235
|
||||
4 640.0 22.260869 ... 24.380953 24.380953
|
||||
5 768.0 32.768000 ... 34.028308 34.028308
|
||||
6 896.0 39.025776 ... 39.025776 39.025776
|
||||
6 896.0 39.025776 ... 39.025776 35.123201
|
||||
7 1024.0 49.932191 ... 52.428801 52.428801
|
||||
8 1152.0 45.242181 ... 46.656000 45.938215
|
||||
9 1280.0 51.200001 ... 56.109587 56.109587
|
||||
10 1408.0 64.138541 ... 65.684049 58.640951
|
||||
11 1536.0 79.526831 ... 75.296679 75.296679
|
||||
12 1664.0 63.372618 ... 61.636381 62.061463
|
||||
13 1792.0 72.983276 ... 69.379162 68.953520
|
||||
14 1920.0 69.467336 ... 67.434145 70.172588
|
||||
15 2048.0 73.584279 ... 75.573044 74.898285
|
||||
16 2176.0 83.155572 ... 80.817862 77.398646
|
||||
17 2304.0 68.251065 ... 72.828879 73.051599
|
||||
18 2432.0 71.305746 ... 80.963875 80.963875
|
||||
19 2560.0 77.649287 ... 75.676673 74.983980
|
||||
20 2688.0 83.186525 ... 84.671999 82.823267
|
||||
21 2816.0 82.916747 ... 76.115547 79.733474
|
||||
22 2944.0 82.237674 ... 80.771529 78.358539
|
||||
23 3072.0 82.062468 ... 84.892208 82.782312
|
||||
24 3200.0 84.544253 ... 88.397792 89.385477
|
||||
25 3328.0 79.812967 ... 80.617354 81.071278
|
||||
26 3456.0 81.518272 ... 86.970406 81.600781
|
||||
27 3584.0 87.042978 ... 96.372338 90.640517
|
||||
28 3712.0 84.230479 ... 82.764991 82.423549
|
||||
29 3840.0 80.255442 ... 81.377484 80.783056
|
||||
30 3968.0 89.329379 ... 85.932350 87.347124
|
||||
31 4096.0 93.531519 ... 85.816960 91.056800
|
||||
8 1152.0 44.566925 ... 46.656000 46.656000
|
||||
9 1280.0 51.200001 ... 56.888887 56.109587
|
||||
10 1408.0 64.138541 ... 64.902096 64.902096
|
||||
11 1536.0 78.643199 ... 76.106321 76.106321
|
||||
12 1664.0 62.929456 ... 62.061463 62.061463
|
||||
13 1792.0 72.983276 ... 69.810085 69.379162
|
||||
14 1920.0 67.764707 ... 70.530615 70.530615
|
||||
15 2048.0 73.908442 ... 75.234154 74.898285
|
||||
16 2176.0 83.500614 ... 81.143743 81.143743
|
||||
17 2304.0 68.446623 ... 73.501144 73.501144
|
||||
18 2432.0 71.305746 ... 82.147552 82.147552
|
||||
19 2560.0 77.833728 ... 77.283019 77.101175
|
||||
20 2688.0 81.053536 ... 81.928846 83.922689
|
||||
21 2816.0 81.981598 ... 79.443003 80.320825
|
||||
22 2944.0 82.373605 ... 77.385141 78.112900
|
||||
23 3072.0 81.472093 ... 83.761985 79.638683
|
||||
24 3200.0 84.768213 ... 88.888888 85.561498
|
||||
25 3328.0 83.905938 ... 87.794262 87.156532
|
||||
26 3456.0 80.220468 ... 85.676480 84.068369
|
||||
27 3584.0 86.707226 ... 95.553020 94.847460
|
||||
28 3712.0 83.247783 ... 84.303780 85.309435
|
||||
29 3840.0 80.255442 ... 83.339866 85.005380
|
||||
30 3968.0 88.938731 ... 87.409694 87.159957
|
||||
31 4096.0 91.616198 ... 89.597949 89.538177
|
||||
|
||||
[32 rows x 5 columns]
|
||||
|
||||
@@ -411,7 +523,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 2 minutes 14.738 seconds)
|
||||
**Total running time of the script:** ( 2 minutes 30.425 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
@@ -5,12 +5,12 @@
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**03:38.360** total execution time for **getting-started_tutorials** files:
|
||||
**03:54.665** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:14.738 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:30.425 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.626 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:13.186 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.996 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:11.055 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
@@ -201,49 +201,58 @@ to download the full example code</p>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_add</span><span class="p">(</span>
|
||||
<span class="n">X</span><span class="p">,</span> <span class="c1"># *Pointer* to first input vector</span>
|
||||
<span class="n">Y</span><span class="p">,</span> <span class="c1"># *Pointer* to second input vector</span>
|
||||
<span class="n">Z</span><span class="p">,</span> <span class="c1"># *Pointer* to output vector</span>
|
||||
<span class="n">N</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
|
||||
<span class="o">**</span><span class="n">meta</span> <span class="c1"># Optional meta-parameters for the kernel</span>
|
||||
<span class="k">def</span> <span class="nf">add_kernel</span><span class="p">(</span>
|
||||
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to first input vector</span>
|
||||
<span class="n">y_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to second input vector</span>
|
||||
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to output vector</span>
|
||||
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
|
||||
<span class="o">**</span><span class="n">meta</span><span class="p">,</span> <span class="c1"># Optional meta-parameters for the kernel</span>
|
||||
<span class="p">):</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Create an offset for the blocks of pointers to be</span>
|
||||
<span class="c1"># processed by this program instance</span>
|
||||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">]</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># Create a mask to guard memory operations against</span>
|
||||
<span class="c1"># out-of-bounds accesses</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">N</span>
|
||||
<span class="c1"># Load x</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="c1"># Write back x + y</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Z</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]</span> <span class="c1"># How many inputs each program should process</span>
|
||||
<span class="c1"># There are multiple 'program's processing different data. We identify which program</span>
|
||||
<span class="c1"># we are here</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># We use a 1D launch grid so axis is 0</span>
|
||||
<span class="c1"># This program will process inputs that are offset from the initial data.</span>
|
||||
<span class="c1"># for instance, if you had a vector of length 256 and block_size of 64, the programs</span>
|
||||
<span class="c1"># would each access the elements [0:64, 64:128, 128:192, 192:256].</span>
|
||||
<span class="c1"># Note that offsets is a list of pointers</span>
|
||||
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
|
||||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
|
||||
<span class="c1"># Create a mask to guard memory operations against out-of-bounds accesses</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">n_elements</span>
|
||||
<span class="c1"># Load x and y from DRAM, masking out any extar elements in case the input is not a</span>
|
||||
<span class="c1"># multiple of the block size</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">output</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="c1"># Write x + y back to DRAM</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Let’s also declare a helper function to (1) allocate the <cite>z</cite> tensor
|
||||
and (2) enqueue the above kernel with appropriate grid/block sizes.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||
<span class="c1"># We need to preallocate the output</span>
|
||||
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">y</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">output</span><span class="o">.</span><span class="n">is_cuda</span>
|
||||
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># The SPMD launch grid denotes the number of kernel instances that run in parallel.</span>
|
||||
<span class="c1"># It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">]),</span> <span class="p">)</span>
|
||||
<span class="c1"># In this case, we use a 1D grid where the size is the number of blocks</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]),)</span>
|
||||
<span class="c1"># NOTE:</span>
|
||||
<span class="c1"># - each torch.tensor object is implicitly converted into a pointer to its first element.</span>
|
||||
<span class="c1"># - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel</span>
|
||||
<span class="c1"># - don't forget to pass meta-parameters as keywords arguments</span>
|
||||
<span class="n">_add</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||||
<span class="n">add_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||||
<span class="c1"># We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still</span>
|
||||
<span class="c1"># running asynchronously at this point.</span>
|
||||
<span class="k">return</span> <span class="n">z</span>
|
||||
<span class="k">return</span> <span class="n">output</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can now use the above function to compute the element-wise sum of two <cite>torch.tensor</cite> objects and test its correctness:</p>
|
||||
@@ -251,11 +260,14 @@ and (2) enqueue the above kernel with appropriate grid/block sizes.</p>
|
||||
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
<span class="n">output_torch</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">output_triton</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">output_torch</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">output_triton</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span>
|
||||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span>
|
||||
<span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
@@ -274,15 +286,17 @@ for different problem sizes.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'size'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="mi">2</span> <span class="o">**</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">x_log</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># x axis is logarithmic</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'torch'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg`</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Torch"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'Triton'</span><span class="p">,</span> <span class="s1">'Torch'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"GB/s"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"vector-add-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{}</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s1">'GB/s'</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s1">'vector-add-performance'</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{},</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||||
@@ -307,23 +321,23 @@ for different problem sizes.</p>
|
||||
size Triton Torch
|
||||
0 4096.0 9.600000 9.600000
|
||||
1 8192.0 19.200000 19.200000
|
||||
2 16384.0 38.400001 38.400001
|
||||
3 32768.0 76.800002 76.800002
|
||||
2 16384.0 31.999999 31.999999
|
||||
3 32768.0 63.999998 76.800002
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 384.000001 384.000001
|
||||
6 262144.0 341.333321 384.000001
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 722.823517
|
||||
10 4194304.0 780.190482 780.190482
|
||||
11 8388608.0 819.200021 812.429770
|
||||
11 8388608.0 812.429770 812.429770
|
||||
12 16777216.0 833.084721 833.084721
|
||||
13 33554432.0 843.811163 843.811163
|
||||
14 67108864.0 849.278610 848.362445
|
||||
15 134217728.0 851.577704 850.656574
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 10.996 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 11.055 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||
|
@@ -195,7 +195,9 @@ to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="fused-softmax">
|
||||
<span id="sphx-glr-getting-started-tutorials-02-fused-softmax-py"></span><h1>Fused Softmax<a class="headerlink" href="#fused-softmax" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
|
||||
<p>In this tutorial, you will write a fused softmax operation that is significantly faster
|
||||
than PyTorch’s native op for a particular class of matrices: those whose rows can fit in
|
||||
the GPU’s SRAM.
|
||||
You will learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The benefits of kernel fusion for bandwidth-bound operations.</p></li>
|
||||
@@ -208,9 +210,13 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
|
||||
<span class="c1"># Compute the row-wise softmax of x</span>
|
||||
<span class="nd">@torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">script</span>
|
||||
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="sd">"""Compute row-wise softmax of X using native pytorch</span>
|
||||
|
||||
<span class="sd"> We subtract the maximum element in order to avoid overflows. Softmax is invariant to</span>
|
||||
<span class="sd"> this shift.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
@@ -225,45 +231,57 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
<span class="k">return</span> <span class="n">ret</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>When implemented naively in pytorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
|
||||
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip.
|
||||
Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.</p>
|
||||
<p>When implemented naively in PyTorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span>
|
||||
requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
|
||||
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
|
||||
X once and does all the necessary computations on-chip.
|
||||
Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could
|
||||
expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically
|
||||
but, as we will see later, it is still far from ideal.</p>
|
||||
</div>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of the input matrix X,
|
||||
normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a
|
||||
power-of-two number of elements, so we need to internally “pad” each row and guard the
|
||||
memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_softmax</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="o">**</span><span class="n">meta</span><span class="p">):</span>
|
||||
<span class="c1"># row index</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># col indices</span>
|
||||
<span class="c1"># here BLOCK is the smallest power of two greater than `N`</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># the memory address of all the elements</span>
|
||||
<span class="c1"># that we want to load can be computed as follows</span>
|
||||
<span class="n">X</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">'inf'</span><span class="p">))</span>
|
||||
<span class="k">def</span> <span class="nf">softmax_kernel</span><span class="p">(</span>
|
||||
<span class="n">output_ptr</span><span class="p">,</span> <span class="n">input_ptr</span><span class="p">,</span> <span class="n">input_row_stride</span><span class="p">,</span> <span class="n">output_row_stride</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="o">**</span><span class="n">meta</span>
|
||||
<span class="p">):</span>
|
||||
<span class="c1"># The rows of the softmax are independent, so we parallelize across those</span>
|
||||
<span class="n">row_idx</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE'</span><span class="p">]</span>
|
||||
<span class="c1"># The stride represents how much we need to increase the pointer to advance 1 row</span>
|
||||
<span class="n">row_start_ptr</span> <span class="o">=</span> <span class="n">input_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">input_row_stride</span>
|
||||
|
||||
<span class="c1"># The block size is the next power of two greater than n_cols, so we can fit each</span>
|
||||
<span class="c1"># row in a single block</span>
|
||||
<span class="n">col_offsets</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
|
||||
<span class="n">input_ptrs</span> <span class="o">=</span> <span class="n">row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
|
||||
<span class="c1"># Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols</span>
|
||||
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">input_ptrs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o"><</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">'inf'</span><span class="p">))</span>
|
||||
<span class="c1"># Substract maximum for numerical stability</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Note that exponentials in Triton are fast</span>
|
||||
<span class="c1"># but approximate (i.e., think __expf in CUDA)</span>
|
||||
<span class="n">num</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">denom</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">num</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span>
|
||||
<span class="c1"># Write back to Y</span>
|
||||
<span class="n">Y</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">row_minus_max</span> <span class="o">=</span> <span class="n">row</span> <span class="o">-</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">row</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)</span>
|
||||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">row_minus_max</span><span class="p">)</span>
|
||||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">softmax_output</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span>
|
||||
<span class="c1"># Write back output to DRAM</span>
|
||||
<span class="n">output_row_start_ptr</span> <span class="o">=</span> <span class="n">output_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">output_row_stride</span>
|
||||
<span class="n">output_ptrs</span> <span class="o">=</span> <span class="n">output_row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptrs</span><span class="p">,</span> <span class="n">softmax_output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o"><</span> <span class="n">n_cols</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<span class="sd">"""Return the smallest power of 2 greater than or equal to n"""</span>
|
||||
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">2</span>
|
||||
@@ -275,20 +293,31 @@ so we need to internally “pad” each row and guard the memory operations prop
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="c1"># The block size is the smallest power of two greater than the number of columns in `x`</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">n_cols</span><span class="p">)</span>
|
||||
<span class="c1"># Another trick we can use is to ask the compiler to use more threads per row by</span>
|
||||
<span class="c1"># increasing the number of warps (`num_warps`) over which each row is distributed.</span>
|
||||
<span class="c1"># You will see in the next tutorial how to auto-tune this value in a more natural</span>
|
||||
<span class="c1"># way so you don't have to come up with manual heuristics yourself.</span>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">2048</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">>=</span> <span class="mi">2048</span><span class="p">:</span>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">>=</span> <span class="mi">4096</span><span class="p">:</span>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
<span class="c1"># Allocate output</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix</span>
|
||||
<span class="n">_softmax</span><span class="p">[(</span><span class="n">M</span><span class="p">,</span> <span class="p">)](</span><span class="n">y</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">,</span> <span class="n">BLOCK</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">)</span>
|
||||
<span class="c1"># Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o</span>
|
||||
<span class="c1"># f the input matrix</span>
|
||||
<span class="n">softmax_kernel</span><span class="p">[(</span><span class="n">n_rows</span><span class="p">,)](</span>
|
||||
<span class="n">y</span><span class="p">,</span>
|
||||
<span class="n">x</span><span class="p">,</span>
|
||||
<span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">n_cols</span><span class="p">,</span>
|
||||
<span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">,</span>
|
||||
<span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -299,9 +328,9 @@ so we need to internally “pad” each row and guard the memory operations prop
|
||||
This will allow us to verify that our padding mechanism works.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">))</span>
|
||||
<span class="n">y_triton</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">y_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
@@ -317,14 +346,24 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'N'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">100</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
|
||||
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'torch-native'</span><span class="p">,</span> <span class="s1">'torch-jit'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Torch (native)"</span><span class="p">,</span> <span class="s2">"Torch (jit)"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="s1">'triton'</span><span class="p">,</span>
|
||||
<span class="s1">'torch-native'</span><span class="p">,</span>
|
||||
<span class="s1">'torch-jit'</span><span class="p">,</span>
|
||||
<span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="s2">"Triton"</span><span class="p">,</span>
|
||||
<span class="s2">"Torch (native)"</span><span class="p">,</span>
|
||||
<span class="s2">"Torch (jit)"</span><span class="p">,</span>
|
||||
<span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"GB/s"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"softmax-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">'M'</span><span class="p">:</span> <span class="mi">4096</span><span class="p">}</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">'M'</span><span class="p">:</span> <span class="mi">4096</span><span class="p">},</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||||
@@ -348,15 +387,15 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 512.000001 546.133347 186.181817
|
||||
1 384.0 585.142862 585.142862 153.600004
|
||||
2 512.0 630.153853 585.142849 154.566038
|
||||
3 640.0 682.666684 640.000002 160.000000
|
||||
2 512.0 630.153853 606.814814 154.566038
|
||||
3 640.0 660.645170 640.000002 160.000000
|
||||
4 768.0 702.171410 664.216187 163.839992
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 406.179533 199.140227
|
||||
94 12288.0 812.429770 415.661740 199.399583
|
||||
95 12416.0 810.840807 412.149375 199.054102
|
||||
96 12544.0 810.925276 412.971190 199.308841
|
||||
97 12672.0 811.007961 412.097543 199.264875
|
||||
93 12160.0 812.359066 406.179533 199.038365
|
||||
94 12288.0 812.429770 415.222812 199.298541
|
||||
95 12416.0 810.840807 412.149375 198.854847
|
||||
96 12544.0 810.925276 412.971190 199.209928
|
||||
97 12672.0 809.389265 412.097543 199.167004
|
||||
|
||||
[98 rows x 4 columns]
|
||||
</pre></div>
|
||||
@@ -370,7 +409,7 @@ This means that – when temporary data is too large to fit entirely in the GPU
|
||||
Note that our Triton kernel is not only faster than PyTorch’s CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
|
||||
</ul>
|
||||
</div></blockquote>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.626 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 13.186 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
|
@@ -202,7 +202,8 @@ to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="matrix-multiplication">
|
||||
<span id="sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"></span><h1>Matrix Multiplication<a class="headerlink" href="#matrix-multiplication" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
|
||||
<p>In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
|
||||
kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>Block-level matrix multiplications</p></li>
|
||||
@@ -213,21 +214,25 @@ You will specifically learn about:</p>
|
||||
<div class="section" id="motivations">
|
||||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.</p>
|
||||
<p>Roughly speaking, the kernel that we will write will implement the following blocked algorithm:</p>
|
||||
They are notoriously hard to optimize, hence their implementation is generally done by
|
||||
hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized
|
||||
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
In this tutorial, you will learn how to implement efficient matrix multiplications by
|
||||
yourself with Triton, in a way that is easy to customize and extend.</p>
|
||||
<p>Roughly speaking, the kernel that we will write will implement the following blocked
|
||||
algorithm to multiply a (MxK) by a (KxN) matrix:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
|
||||
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
|
||||
<span class="c1"># do in parallel</span>
|
||||
<span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">):</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">):</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">]</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span>
|
||||
<span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">):</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span>
|
||||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="n">C</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
|
||||
<span class="n">C</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
@@ -236,14 +241,20 @@ In this tutorial, you will learn how to implement efficient matrix multiplicatio
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the computation of the memory locations at which blocks of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.</p>
|
||||
The main difficulty comes from the computation of the memory locations at which blocks</p>
|
||||
<blockquote>
|
||||
<div><p>of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need</p>
|
||||
</div></blockquote>
|
||||
<p>multi-dimensional pointer arithmetics.</p>
|
||||
<div class="section" id="pointer-arithmetics">
|
||||
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
||||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given by <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_x_0</span> <span class="pre">+</span> <span class="pre">j*stride_x_1</span></code>.
|
||||
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_M,</span> <span class="pre">k:k+BLOCK_K]</span></code> and <code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_N]</span></code> can be defined in pseudo-code as:</p>
|
||||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given b
|
||||
y <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_x_0</span> <span class="pre">+</span> <span class="pre">j*stride_x_1</span></code>.
|
||||
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_SIZE_M,</span> <span class="pre">k:k+BLOCK_SIZE_K]</span></code> and
|
||||
<code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_SIZE_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_SIZE_N]</span></code> can be defined in pseudo-code as:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
@@ -251,9 +262,9 @@ Therefore, blocks of pointers for <code class="code docutils literal notranslate
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pid_m</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">rk</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||||
<span class="n">rk</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
||||
<span class="o">//</span> <span class="n">pointer</span> <span class="k">for</span> <span class="n">A</span> <span class="n">operand</span>
|
||||
<span class="n">pa</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_a_0</span> <span class="o">+</span> <span class="n">rk</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">);</span>
|
||||
<span class="o">//</span> <span class="n">pointer</span> <span class="k">for</span> <span class="n">B</span> <span class="n">operand</span>
|
||||
@@ -263,21 +274,25 @@ Therefore, blocks of pointers for <code class="code docutils literal notranslate
|
||||
</div></blockquote>
|
||||
<p>And then updated in the inner loop as follows:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">;</span>
|
||||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_b_0</span><span class="p">;</span>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">;</span>
|
||||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_b_0</span><span class="p">;</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
</div>
|
||||
<div class="section" id="l2-cache-optimizations">
|
||||
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline">¶</a></h3>
|
||||
<p>As mentioned above, each program instance computes an <code class="code docutils literal notranslate"><span class="pre">[BLOCK_M,</span> <span class="pre">BLOCK_N]</span></code> block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
|
||||
It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
|
||||
And unfortunately, a simple row-major ordering</p>
|
||||
<dl class="simple">
|
||||
<dt>As mentioned above, each program instance computes a <code class="code docutils literal notranslate"><span class="pre">[BLOCK_SIZE_M,</span> <span class="pre">BLOCK_SIZE_N]</span></code></dt><dd><p>block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.</p>
|
||||
</dd>
|
||||
</dl>
|
||||
<p>It is important to remember that the order in which these blocks are computed does
|
||||
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||||
a simple row-major ordering</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-Python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_M</span><span class="p">;</span>
|
||||
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_N</span><span class="p">;</span>
|
||||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_M</span><span class="p">;</span>
|
||||
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_N</span><span class="p">;</span>
|
||||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">grid_n</span><span class="p">;</span>
|
||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">%</span> <span class="n">grid_n</span><span class="p">;</span>
|
||||
</pre></div>
|
||||
@@ -285,19 +300,30 @@ And unfortunately, a simple row-major ordering</p>
|
||||
</div></blockquote>
|
||||
<p>is just not going to cut it.</p>
|
||||
<p>One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before switching to the next column:</p>
|
||||
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before
|
||||
switching to the next column:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="n">width</span> <span class="o">=</span> <span class="n">GROUP_M</span> <span class="o">*</span> <span class="n">grid_n</span><span class="p">;</span>
|
||||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">width</span><span class="p">;</span>
|
||||
<span class="c1"># we need to handle the case where M % (GROUP_M*BLOCK_M) != 0</span>
|
||||
<span class="c1"># we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0</span>
|
||||
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span><span class="p">,</span> <span class="n">GROUP_M</span><span class="p">);</span>
|
||||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">);</span>
|
||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">group_size</span><span class="p">);</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
<p>In practice, this can improve the performance of our matrix multiplication kernel by >10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># For example, in the following matmul where each matrix is 9 blocks by 9 blocks,</span>
|
||||
<span class="c1"># we can see that if we compute the output in row-major ordering, we need to load 90</span>
|
||||
<span class="c1"># blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped</span>
|
||||
<span class="c1"># ordering, we only need to load 54 blocks.</span>
|
||||
<span class="c1"># .. image:: grouped_vs_row_major_ordering.png</span>
|
||||
<span class="c1">#</span>
|
||||
<span class="c1"># In practice, this can improve the performance of our matrix multiplication kernel by</span>
|
||||
<span class="c1"># more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</span>
|
||||
<span class="c1">#</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="final-result">
|
||||
@@ -307,74 +333,127 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
<span class="c1"># %</span>
|
||||
<span class="c1"># :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:</span>
|
||||
<span class="c1"># - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try</span>
|
||||
<span class="c1"># - A autotuning *key* whose change in values will trigger evaluation of all the provided configs</span>
|
||||
<span class="c1"># :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`</span>
|
||||
<span class="c1"># decorator, which consumes:</span>
|
||||
<span class="c1"># - A list of :code:`triton.Config` objects that define different configurations of</span>
|
||||
<span class="c1"># meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try</span>
|
||||
<span class="c1"># - An autotuning *key* whose change in values will trigger evaluation of all the</span>
|
||||
<span class="c1"># provided configs</span>
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span>
|
||||
<span class="n">configs</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||||
<span class="c1">#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_SIZE_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_SIZE_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||||
<span class="p">],</span>
|
||||
<span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span>
|
||||
<span class="p">)</span>
|
||||
<span class="c1"># %</span>
|
||||
<span class="c1"># We can now define our kernel as normal, using all the techniques presented above</span>
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_matmul</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">stride_am</span><span class="p">,</span> <span class="n">stride_ak</span><span class="p">,</span> <span class="n">stride_bk</span><span class="p">,</span> <span class="n">stride_bn</span><span class="p">,</span> <span class="n">stride_cm</span><span class="p">,</span> <span class="n">stride_cn</span><span class="p">,</span> <span class="o">**</span><span class="n">META</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">matmul_kernel</span><span class="p">(</span>
|
||||
<span class="c1"># Pointers to matrices</span>
|
||||
<span class="n">a_ptr</span><span class="p">,</span>
|
||||
<span class="n">b_ptr</span><span class="p">,</span>
|
||||
<span class="n">c_ptr</span><span class="p">,</span>
|
||||
<span class="c1"># Matrix dimensions</span>
|
||||
<span class="n">M</span><span class="p">,</span>
|
||||
<span class="n">N</span><span class="p">,</span>
|
||||
<span class="n">K</span><span class="p">,</span>
|
||||
<span class="c1"># The stride variables represent how much to increase the ptr by when moving by 1</span>
|
||||
<span class="c1"># element in a particular dimension. E.g. stride_am is how much to increase a_ptr</span>
|
||||
<span class="c1"># by to get the element one row down (A has M rows)</span>
|
||||
<span class="n">stride_am</span><span class="p">,</span>
|
||||
<span class="n">stride_ak</span><span class="p">,</span>
|
||||
<span class="n">stride_bk</span><span class="p">,</span>
|
||||
<span class="n">stride_bn</span><span class="p">,</span>
|
||||
<span class="n">stride_cm</span><span class="p">,</span>
|
||||
<span class="n">stride_cn</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">meta</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="sd">"""Kernel for computing the matmul AB = C</span>
|
||||
|
||||
<span class="sd"> A has shape (M, K), B has shape (K, N) and C has shape (M, N)</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="c1"># extract meta-parameters</span>
|
||||
<span class="n">BLOCK_M</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_M'</span><span class="p">]</span>
|
||||
<span class="n">BLOCK_N</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_N'</span><span class="p">]</span>
|
||||
<span class="n">BLOCK_K</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_K'</span><span class="p">]</span>
|
||||
<span class="n">GROUP_M</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="c1"># matrix multiplication</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_M</span>
|
||||
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_N</span>
|
||||
<span class="c1"># re-order program ID for better L2 performance</span>
|
||||
<span class="n">width</span> <span class="o">=</span> <span class="n">GROUP_M</span> <span class="o">*</span> <span class="n">grid_n</span>
|
||||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">width</span>
|
||||
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span><span class="p">,</span> <span class="n">GROUP_M</span><span class="p">)</span>
|
||||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">)</span>
|
||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">group_size</span><span class="p">)</span>
|
||||
<span class="c1"># do matrix multiplication</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">rk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span>
|
||||
<span class="n">A</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_am</span> <span class="o">+</span> <span class="n">rk</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">)</span>
|
||||
<span class="n">B</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_bk</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_bn</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="n">BLOCK_K</span><span class="p">):</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="n">A</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
|
||||
<span class="n">B</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
||||
<span class="c1"># triton can accept arbitrary activation function</span>
|
||||
<span class="c1"># via metaparameters!</span>
|
||||
<span class="k">if</span> <span class="n">META</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">]:</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">](</span><span class="n">acc</span><span class="p">)</span>
|
||||
<span class="c1"># rematerialize rm and rn to save registers</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">C</span> <span class="o">=</span> <span class="n">C</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_cm</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_cn</span><span class="p">)</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="n">acc</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">BLOCK_SIZE_M</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">]</span>
|
||||
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">]</span>
|
||||
<span class="n">BLOCK_SIZE_K</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_K'</span><span class="p">]</span>
|
||||
<span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block</span>
|
||||
<span class="c1"># Note that this will lead to some quantization in performance where time-taken jumps</span>
|
||||
<span class="c1"># when you need to add a new block</span>
|
||||
<span class="n">n_blocks_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_M</span>
|
||||
<span class="n">n_blocks_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_N</span>
|
||||
|
||||
<span class="c1"># Map PIDs to the block they should compute. This is done in a grouped ordering</span>
|
||||
<span class="c1"># to promote L2 cache reuse.</span>
|
||||
<span class="n">n_output_blocks_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">n_blocks_n</span>
|
||||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">n_output_blocks_in_group</span>
|
||||
<span class="n">first_m_block_in_group</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
|
||||
|
||||
<span class="c1"># If the number of blocks is not divisible by the group size, the last group is smaller</span>
|
||||
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">n_blocks_m</span> <span class="o">-</span> <span class="n">first_m_block_in_group</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Within a group, we compute in col-major ordering, block_m and block_n are the</span>
|
||||
<span class="c1"># output row and col that this program is computing in terms of blocks</span>
|
||||
<span class="n">block_m</span> <span class="o">=</span> <span class="n">first_m_block_in_group</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
|
||||
<span class="n">block_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">n_output_blocks_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
|
||||
|
||||
<span class="c1"># Convert from block indices back to element indices</span>
|
||||
<span class="n">m_start</span> <span class="o">=</span> <span class="n">block_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span>
|
||||
<span class="n">n_start</span> <span class="o">=</span> <span class="n">block_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span>
|
||||
|
||||
<span class="c1"># Expand out to all the offsets for each of the elements in this block.</span>
|
||||
<span class="n">m_offsets_a</span> <span class="o">=</span> <span class="p">(</span><span class="n">m_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">))[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="n">n_offsets_b</span> <span class="o">=</span> <span class="p">(</span><span class="n">n_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">))[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
||||
<span class="n">k_offsets</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Get the pointers for the first block of each. We will advance this pointer</span>
|
||||
<span class="c1"># as we move in the K direction and accumulate.</span>
|
||||
<span class="c1"># a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers</span>
|
||||
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">stride_am</span> <span class="o">*</span> <span class="n">m_offsets_a</span> <span class="o">+</span> <span class="n">stride_ak</span> <span class="o">*</span> <span class="n">k_offsets</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span>
|
||||
<span class="c1"># b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers</span>
|
||||
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">stride_bk</span> <span class="o">*</span> <span class="n">k_offsets</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">stride_bn</span> <span class="o">*</span> <span class="n">n_offsets_b</span><span class="p">)</span>
|
||||
<span class="c1"># We accumulate internally in fp32, but the output is written out in the dtype</span>
|
||||
<span class="c1"># of the tensor when it is stored</span>
|
||||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
|
||||
<span class="c1"># Note that for simplicity, we don't apply a mask here. This means that if K is</span>
|
||||
<span class="c1"># not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and</span>
|
||||
<span class="c1"># accumulate it incorrectly.</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">a_ptrs</span><span class="p">)</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">b_ptrs</span><span class="p">)</span>
|
||||
<span class="c1"># We accumulate along the K dimension</span>
|
||||
<span class="n">accumulator</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Advance the ptrs to the next K block</span>
|
||||
<span class="n">a_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
|
||||
<span class="n">b_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
||||
<span class="c1"># triton can accept arbitrary activation function via metaparameters!</span>
|
||||
<span class="k">if</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">]:</span>
|
||||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">](</span><span class="n">accumulator</span><span class="p">)</span>
|
||||
|
||||
<span class="n">m_offsets_c</span> <span class="o">=</span> <span class="p">(</span><span class="n">m_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">))[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="n">n_offsets_c</span> <span class="o">=</span> <span class="p">(</span><span class="n">n_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">))[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
||||
<span class="n">c_ptrs</span> <span class="o">=</span> <span class="n">c_ptr</span> <span class="o">+</span> <span class="n">stride_cm</span> <span class="o">*</span> <span class="n">m_offsets_c</span> <span class="o">+</span> <span class="n">stride_cn</span> <span class="o">*</span> <span class="n">n_offsets_c</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">m_offsets_c</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">n_offsets_c</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">c_ptrs</span><span class="p">,</span> <span class="n">accumulator</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="c1"># we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`</span>
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mf">0.01</span><span class="o">*</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mf">0.01</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can now create a convenience wrapper function that only takes two input tensors
|
||||
@@ -385,17 +464,31 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"matrix A must be contiguous"</span>
|
||||
<span class="k">assert</span> <span class="n">b</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"matrix B must be contiguous"</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">_</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">K</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="k">assert</span> <span class="p">(</span>
|
||||
<span class="n">K</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">==</span> <span class="mi">0</span>
|
||||
<span class="p">),</span> <span class="s2">"We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"</span>
|
||||
<span class="c1"># allocates output</span>
|
||||
<span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="c1"># launch kernel</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">META</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_M'</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_N'</span><span class="p">]),</span> <span class="p">)</span>
|
||||
<span class="n">pgm</span> <span class="o">=</span> <span class="n">_matmul</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> \
|
||||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>\
|
||||
<span class="n">ACTIVATION</span> <span class="o">=</span> <span class="n">activation</span>
|
||||
<span class="c1"># 1D launch kernel where each block gets its own program.</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">META</span><span class="p">:</span> <span class="p">(</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">]),</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">matmul_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||||
<span class="n">a</span><span class="p">,</span>
|
||||
<span class="n">b</span><span class="p">,</span>
|
||||
<span class="n">c</span><span class="p">,</span>
|
||||
<span class="n">M</span><span class="p">,</span>
|
||||
<span class="n">N</span><span class="p">,</span>
|
||||
<span class="n">K</span><span class="p">,</span>
|
||||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">ACTIVATION</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="c1"># done; return the output tensor</span>
|
||||
<span class="k">return</span> <span class="n">c</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -406,15 +499,18 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="n">c_0</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||||
<span class="n">c_1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">c_0</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">c_1</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">c_0</span><span class="p">,</span> <span class="n">c_1</span><span class="p">))</span>
|
||||
<span class="n">triton_output</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||||
<span class="n">torch_output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">triton_output</span><span class="si">=}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">torch_output</span><span class="si">=}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">triton_output</span><span class="p">,</span> <span class="n">torch_output</span><span class="p">):</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"✅ Triton and Torch match"</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="s2">"❌ Triton and Torch differ"</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||||
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
|
||||
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
|
||||
...,
|
||||
@@ -422,7 +518,7 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
|
||||
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
|
||||
...,
|
||||
@@ -430,7 +526,7 @@ tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor(True, device='cuda:0')
|
||||
✅ Triton and Torch match
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -442,14 +538,19 @@ tensor(True, device='cuda:0')
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">33</span><span class="p">)</span>
|
||||
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'cublas + relu'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'triton + relu'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"cuBLAS (+ torch.nn.LeakyReLU)"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Triton (+ LeakyReLU)"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||||
<span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'cublas + relu'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'triton + relu'</span><span class="p">],</span>
|
||||
<span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"cuBLAS (+ torch.nn.LeakyReLU)"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Triton (+ LeakyReLU)"</span><span class="p">],</span>
|
||||
<span class="c1"># line styles</span>
|
||||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"TFLOPS"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"matmul-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{}</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{},</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||||
@@ -461,9 +562,13 @@ tensor(True, device='cuda:0')
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas + relu'</span><span class="p">:</span>
|
||||
<span class="n">torch_relu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch_relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)))</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span>
|
||||
<span class="k">lambda</span><span class="p">:</span> <span class="n">torch_relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton + relu'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">leaky_relu</span><span class="p">))</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span>
|
||||
<span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">leaky_relu</span><span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||||
|
||||
@@ -476,42 +581,42 @@ tensor(True, device='cuda:0')
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 128.0 0.455111 ... 0.512000 0.512000
|
||||
1 256.0 2.730667 ... 2.978909 2.978909
|
||||
2 384.0 7.372800 ... 7.899428 8.507077
|
||||
3 512.0 14.563555 ... 16.384000 16.384000
|
||||
1 256.0 2.978909 ... 2.978909 2.978909
|
||||
2 384.0 7.372800 ... 8.507077 7.899428
|
||||
3 512.0 14.563555 ... 16.384000 15.420235
|
||||
4 640.0 22.260869 ... 24.380953 24.380953
|
||||
5 768.0 32.768000 ... 34.028308 34.028308
|
||||
6 896.0 39.025776 ... 39.025776 39.025776
|
||||
6 896.0 39.025776 ... 39.025776 35.123201
|
||||
7 1024.0 49.932191 ... 52.428801 52.428801
|
||||
8 1152.0 45.242181 ... 46.656000 45.938215
|
||||
9 1280.0 51.200001 ... 56.109587 56.109587
|
||||
10 1408.0 64.138541 ... 65.684049 58.640951
|
||||
11 1536.0 79.526831 ... 75.296679 75.296679
|
||||
12 1664.0 63.372618 ... 61.636381 62.061463
|
||||
13 1792.0 72.983276 ... 69.379162 68.953520
|
||||
14 1920.0 69.467336 ... 67.434145 70.172588
|
||||
15 2048.0 73.584279 ... 75.573044 74.898285
|
||||
16 2176.0 83.155572 ... 80.817862 77.398646
|
||||
17 2304.0 68.251065 ... 72.828879 73.051599
|
||||
18 2432.0 71.305746 ... 80.963875 80.963875
|
||||
19 2560.0 77.649287 ... 75.676673 74.983980
|
||||
20 2688.0 83.186525 ... 84.671999 82.823267
|
||||
21 2816.0 82.916747 ... 76.115547 79.733474
|
||||
22 2944.0 82.237674 ... 80.771529 78.358539
|
||||
23 3072.0 82.062468 ... 84.892208 82.782312
|
||||
24 3200.0 84.544253 ... 88.397792 89.385477
|
||||
25 3328.0 79.812967 ... 80.617354 81.071278
|
||||
26 3456.0 81.518272 ... 86.970406 81.600781
|
||||
27 3584.0 87.042978 ... 96.372338 90.640517
|
||||
28 3712.0 84.230479 ... 82.764991 82.423549
|
||||
29 3840.0 80.255442 ... 81.377484 80.783056
|
||||
30 3968.0 89.329379 ... 85.932350 87.347124
|
||||
31 4096.0 93.531519 ... 85.816960 91.056800
|
||||
8 1152.0 44.566925 ... 46.656000 46.656000
|
||||
9 1280.0 51.200001 ... 56.888887 56.109587
|
||||
10 1408.0 64.138541 ... 64.902096 64.902096
|
||||
11 1536.0 78.643199 ... 76.106321 76.106321
|
||||
12 1664.0 62.929456 ... 62.061463 62.061463
|
||||
13 1792.0 72.983276 ... 69.810085 69.379162
|
||||
14 1920.0 67.764707 ... 70.530615 70.530615
|
||||
15 2048.0 73.908442 ... 75.234154 74.898285
|
||||
16 2176.0 83.500614 ... 81.143743 81.143743
|
||||
17 2304.0 68.446623 ... 73.501144 73.501144
|
||||
18 2432.0 71.305746 ... 82.147552 82.147552
|
||||
19 2560.0 77.833728 ... 77.283019 77.101175
|
||||
20 2688.0 81.053536 ... 81.928846 83.922689
|
||||
21 2816.0 81.981598 ... 79.443003 80.320825
|
||||
22 2944.0 82.373605 ... 77.385141 78.112900
|
||||
23 3072.0 81.472093 ... 83.761985 79.638683
|
||||
24 3200.0 84.768213 ... 88.888888 85.561498
|
||||
25 3328.0 83.905938 ... 87.794262 87.156532
|
||||
26 3456.0 80.220468 ... 85.676480 84.068369
|
||||
27 3584.0 86.707226 ... 95.553020 94.847460
|
||||
28 3712.0 83.247783 ... 84.303780 85.309435
|
||||
29 3840.0 80.255442 ... 83.339866 85.005380
|
||||
30 3968.0 88.938731 ... 87.409694 87.159957
|
||||
31 4096.0 91.616198 ... 89.597949 89.538177
|
||||
|
||||
[32 rows x 5 columns]
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 14.738 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 30.425 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||||
|
@@ -174,7 +174,7 @@
|
||||
|
||||
<div class="section" id="computation-times">
|
||||
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
||||
<p><strong>03:38.360</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<p><strong>03:54.665</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<table class="docutils align-default">
|
||||
<colgroup>
|
||||
<col style="width: 85%" />
|
||||
@@ -183,15 +183,15 @@
|
||||
</colgroup>
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
|
||||
<td><p>02:14.738</p></td>
|
||||
<td><p>02:30.425</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>01:12.626</p></td>
|
||||
<td><p>01:13.186</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>00:10.996</p></td>
|
||||
<td><p>00:11.055</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
|