[GH-PAGES] Updated website
4
master/.buildinfo
Normal file
@@ -0,0 +1,4 @@
|
||||
# Sphinx build info version 1
|
||||
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
|
||||
config: 9fdb9e4639876051a20b8f90b2128cca
|
||||
tags: 645f666f9bcd5a90fca523b33c5a78b7
|
BIN
master/.doctrees/environment.pickle
Normal file
BIN
master/.doctrees/getting-started/installation.doctree
Normal file
BIN
master/.doctrees/getting-started/tutorials/01-vector-add.doctree
Normal file
BIN
master/.doctrees/getting-started/tutorials/05-layer-norm.doctree
Normal file
BIN
master/.doctrees/getting-started/tutorials/index.doctree
Normal file
BIN
master/.doctrees/index.doctree
Normal file
BIN
master/.doctrees/python-api/generated/triton.Config.doctree
Normal file
BIN
master/.doctrees/python-api/generated/triton.autotune.doctree
Normal file
BIN
master/.doctrees/python-api/generated/triton.heuristics.doctree
Normal file
BIN
master/.doctrees/python-api/generated/triton.jit.doctree
Normal file
BIN
master/.doctrees/python-api/triton.doctree
Normal file
BIN
master/.doctrees/python-api/triton.language.doctree
Normal file
BIN
master/.doctrees/python-api/triton.testing.doctree
Normal file
@@ -0,0 +1,161 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation that is significantly faster\nthan PyTorch's native op for a particular class of matrices: those whose rows can fit in\nthe GPU's SRAM.\nYou will learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- Reduction operators in Triton.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Motivations\nCustom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.\nLet us consider instead the case of a simple (numerically stabilized) softmax operation:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@torch.jit.script\ndef naive_softmax(x):\n \"\"\"Compute row-wise softmax of X using native pytorch\n\n We subtract the maximum element in order to avoid overflows. Softmax is invariant to\n this shift.\n \"\"\"\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read MN + M elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read MN + M elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements\n return ret"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$\nrequires reading $5MN + 2M$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads\nX once and does all the necessary computations on-chip.\nDoing so would require reading and writing back only $MN$ bytes, so we could\nexpect a theoretical speed-up of ~4x (i.e., $(8MN + 4M) / 2MN$).\nThe `torch.jit.script` flags aims to perform this kind of \"kernel fusion\" automatically\nbut, as we will see later, it is still far from ideal.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X,\nnormalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a\npower-of-two number of elements, so we need to internally \"pad\" each row and guard the\nmemory operations properly if we want to handle any possible input shapes:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.jit\ndef softmax_kernel(\n output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,\n BLOCK_SIZE: tl.constexpr\n):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Substract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def softmax(x):\n n_rows, n_cols = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o\n # f the input matrix\n softmax_kernel[(n_rows,)](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Unit Test\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We make sure that we test our kernel on a matrix with an irregular number of rows and columns.\nThis will allow us to verify that our padding mechanism works.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nx = torch.randn(1823, 781, device='cuda')\ny_triton = softmax(x)\ny_torch = torch.softmax(x, axis=1)\nassert torch.allclose(y_triton, y_torch), (y_triton, y_torch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As expected, the results are identical.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmark\nHere we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.\nWe will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'], # argument names to use as an x-axis for the plot\n x_vals=[\n 128 * i for i in range(2, 100)\n ], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=[\n 'triton',\n 'torch-native',\n 'torch-jit',\n ], # possible values for `line_arg``\n line_names=[\n \"Triton\",\n \"Torch (native)\",\n \"Torch (jit)\",\n ], # label name for the lines\n styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"softmax-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(M, N, provider):\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n if provider == 'torch-native':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))\n if provider == 'torch-jit':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))\n gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbenchmark.run(show_plots=True, print_data=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In the above plot, we can see that:\n\n - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.\n - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.\n Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.\n\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic programming model of Triton
|
||||
- The `triton.jit` decorator, which is used to define Triton kernels.
|
||||
- The best practices for validating and benchmarking your custom ops against native reference implementations
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extra elements in case the input is not a
|
||||
# multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
# %%
|
||||
# Let's also declare a helper function to (1) allocate the `z` tensor
|
||||
# and (2) enqueue the above kernel with appropriate grid/block sizes.
|
||||
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# We need to preallocate the output
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda and y.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
# In this case, we use a 1D grid where the size is the number of blocks
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# NOTE:
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
return output
|
||||
|
||||
|
||||
# %%
|
||||
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
# %%
|
||||
# Seems like we're good to go!
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
# -----------
|
||||
# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
||||
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
||||
# for different problem sizes.
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
2 ** i for i in range(12, 28, 1)
|
||||
], # different possible values for `x_name`
|
||||
x_log=True, # x axis is logarithmic
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['triton', 'torch'], # possible values for `line_arg`
|
||||
line_names=['Triton', 'Torch'], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-')], # line styles
|
||||
ylabel='GB/s', # label name for the y-axis
|
||||
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(size, provider):
|
||||
x = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
y = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
|
||||
gbps = lambda ms: 12 * size / ms * 1e-6
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# %%
|
||||
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
|
||||
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||
benchmark.run(print_data=True, show_plots=True)
|
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Layer Normalization
|
||||
====================
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
try:
|
||||
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
|
||||
# should not be added to extras_require in setup.py.
|
||||
import apex
|
||||
HAS_APEX = True
|
||||
except ModuleNotFoundError:
|
||||
HAS_APEX = False
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
Out,
|
||||
A,
|
||||
Weight,
|
||||
Bias,
|
||||
Mean, Rstd,
|
||||
stride, N, eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
Out += row * stride
|
||||
A += row * stride
|
||||
# compute mean
|
||||
mean = 0
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
a = tl.where(cols < N, a - mean, 0.)
|
||||
_var += a * a
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# write-back mean/rstd
|
||||
tl.store(Mean + row, mean)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# multiply by weight and add bias
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
weight = tl.load(Weight + cols, mask=mask)
|
||||
bias = tl.load(Bias + cols, mask=mask)
|
||||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
out = a_hat * weight + bias
|
||||
# # write-back
|
||||
tl.store(Out + cols, out, mask=mask)
|
||||
|
||||
# Backward pass (DA + partial DW + partial DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(
|
||||
_DA,
|
||||
_DOut,
|
||||
_A,
|
||||
Weight,
|
||||
Mean, Rstd,
|
||||
stride, NumRows, NumCols, eps,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
# position of elements processed by this program
|
||||
pid = tl.program_id(0)
|
||||
row = pid
|
||||
A = _A + row * stride
|
||||
DOut = _DOut + row * stride
|
||||
DA = _DA + row * stride
|
||||
mean = tl.load(Mean + row)
|
||||
rstd = tl.load(Rstd + row)
|
||||
# load data to SRAM
|
||||
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
_mean1 += a_hat * wdout
|
||||
_mean2 += wdout
|
||||
mean1 = tl.sum(_mean1, axis=0) / NumCols
|
||||
mean2 = 0.
|
||||
mean2 = tl.sum(_mean2, axis=0) / NumCols
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DA + cols, da, mask=mask)
|
||||
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(
|
||||
A, DOut,
|
||||
Mean, Var,
|
||||
DW,
|
||||
DB,
|
||||
M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(DB + cols, sum_db, mask=cols < N)
|
||||
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, normalized_shape, weight, bias, eps):
|
||||
# allocate output
|
||||
out = torch.empty_like(a)
|
||||
# reshape input data into 2D tensor
|
||||
a_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = a_arg.shape
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // a.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_layer_norm_fwd_fused[(M,)](
|
||||
out,
|
||||
a_arg,
|
||||
weight,
|
||||
bias,
|
||||
mean, rstd,
|
||||
a_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
a, weight, bias, mean, rstd,
|
||||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
if hasattr(bias, "config"):
|
||||
assert bias.config.grad_scale_name == weight.config.grad_scale_name
|
||||
grad_scale_name = bias.config.grad_scale_name
|
||||
else:
|
||||
grad_scale_name = None
|
||||
ctx.grad_scale_gain_bias_name = grad_scale_name
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
assert dout.is_contiguous()
|
||||
a, weight, bias, mean, var = ctx.saved_tensors
|
||||
# heuristics for amount of parallel reduction stream for DG/DB
|
||||
N = weight.shape[0]
|
||||
# allocate output
|
||||
da = torch.empty_like(dout)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
x_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
_layer_norm_bwd_dx_fused[(M,)](
|
||||
da,
|
||||
dout,
|
||||
a,
|
||||
weight,
|
||||
mean, var,
|
||||
x_arg.stride(0), M, N,
|
||||
ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
mean, var,
|
||||
dweight,
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128,
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1 * torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# forward pass
|
||||
y_tri = layer_norm(x, w_shape, weight, bias, eps)
|
||||
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
||||
# backward pass (triton)
|
||||
y_tri.backward(dy, retain_graph=True)
|
||||
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
|
||||
x.grad, weight.grad, bias.grad = None, None, None
|
||||
# backward pass (torch)
|
||||
y_ref.backward(dy, retain_graph=True)
|
||||
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(y_tri, y_ref)
|
||||
triton.testing.assert_almost_equal(dx_tri, dx_ref)
|
||||
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
||||
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'],
|
||||
x_vals=[512 * i for i in range(2, 32)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
|
||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
plot_name='layer-norm',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
|
||||
)
|
||||
)
|
||||
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1 * torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# utility functions
|
||||
if provider == 'triton':
|
||||
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
|
||||
if provider == 'torch':
|
||||
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
||||
if provider == 'apex':
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
||||
y_fwd = lambda: apex_layer_norm(x)
|
||||
# forward pass
|
||||
if mode == 'forward':
|
||||
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
|
||||
# backward pass
|
||||
if mode == 'backward':
|
||||
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
|
||||
y = y_fwd()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# test_layer_norm(1151, 8192, torch.float16)
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
@@ -0,0 +1,100 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Low-Memory Dropout\n\nIn this tutorial, you will write a memory-efficient implementation of dropout whose state\nwill be composed of a single int32 seed. This differs from more traditional implementations of dropout,\nwhose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:\n\n- The limitations of naive implementations of Dropout with PyTorch\n- Parallel pseudo-random number generation in Triton\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Baseline\nThe *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance\nof deep neural networks in low-data regime (i.e. regularization).\n\nIt takes a vector as input and produces a vector of the same shape as output. Each scalar in the\noutput has a probability $p$ of being changed to zero and otherwise it is copied from the input.\nThis forces the network to perform well even when only $1 - p$ scalars from the input are available.\n\nAt evaluation time we want to use the full power of the network so we set $p=0$. Naively this would\nincrease the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease\nin the output softmax temperature). To prevent this we multiply the output by $\\frac{1}{1 - p}$, which\nkeeps the norm consistent regardless of the dropout probability.\n\nLet's first take a look at the baseline implementation.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tabulate\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n# Input tensor\nx = torch.randn(size=(10,)).cuda()\n# Dropout mask\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\n#\noutput = dropout(x, x_keep=x_keep, p=p)\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"keep mask\"] + x_keep.tolist(),\n [\"output\"] + output.tolist()\n]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Seeded dropout\nAbove implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly\nwe need to store the dropout mask for backpropagation. Secondly, dropout state management can get\nvery tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in\nhttps://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation\nthat (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management\nof persisting randomness across multiple invocations of the kernel.\n\nPseudorandom number generation in Triton is simple! In this tutorial we will use the\n:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`\nvalues in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides\nother `random number generation strategies <Random Number Generation>`.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).</p></div>\n\nLet's put it all together.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\nx = torch.randn(size=(10,)).cuda()\n# Compare this to the baseline - dropout mask is never instantiated!\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\n\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"output (seed = 123)\"] + output.tolist(),\n [\"output (seed = 123)\"] + output2.tolist(),\n [\"output (seed = 512)\"] + output3.tolist()\n]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Et Voil\u00e0! We have a triton kernel that applies the same dropout mask provided the seed is the same!\nIf you'd like explore further applications of pseudorandomness in GPU programming, we encourage you\nto explore the `triton/language/random` folder!\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Exercises\n1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.\n2. Add support for striding.\n3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## References\n\n.. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, \"Parallel Random Numbers: As Easy as 1, 2, 3\", 2011\n.. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, \"Dropout: A Simple Way to Prevent Neural Networks from Overfitting\", JMLR 2014\n\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Low-Memory Dropout
|
||||
=================
|
||||
|
||||
In this tutorial, you will write a memory-efficient implementation of dropout whose state
|
||||
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
|
||||
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:
|
||||
|
||||
- The limitations of naive implementations of Dropout with PyTorch
|
||||
- Parallel pseudo-random number generation in Triton
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Baseline
|
||||
# -------------
|
||||
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
|
||||
# of deep neural networks in low-data regime (i.e. regularization).
|
||||
#
|
||||
# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
|
||||
# output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input.
|
||||
# This forces the network to perform well even when only :math:`1 - p` scalars from the input are available.
|
||||
#
|
||||
# At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would
|
||||
# increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
|
||||
# in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which
|
||||
# keeps the norm consistent regardless of the dropout probability.
|
||||
#
|
||||
# Let's first take a look at the baseline implementation.
|
||||
|
||||
|
||||
import tabulate
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dropout(
|
||||
x_ptr, # pointer to the input
|
||||
x_keep_ptr, # pointer to a mask of 0s and 1s
|
||||
output_ptr, # pointer to the output
|
||||
n_elements, # number of elements in the `x` tensor
|
||||
p, # probability that an element of `x` is changed to zero
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
# Load data
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
|
||||
# The line below is the crucial part, described in the paragraph above!
|
||||
output = tl.where(x_keep, x / (1 - p), 0.0)
|
||||
# Write-back output
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
def dropout(x, x_keep, p):
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_contiguous()
|
||||
n_elements = x.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
# Input tensor
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
# Dropout mask
|
||||
p = 0.5
|
||||
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
|
||||
#
|
||||
output = dropout(x, x_keep=x_keep, p=p)
|
||||
print(tabulate.tabulate([
|
||||
["input"] + x.tolist(),
|
||||
["keep mask"] + x_keep.tolist(),
|
||||
["output"] + output.tolist()
|
||||
]))
|
||||
|
||||
# %%
|
||||
# Seeded dropout
|
||||
# -------------
|
||||
# Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
|
||||
# we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
|
||||
# very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in
|
||||
# https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation
|
||||
# that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
|
||||
# of persisting randomness across multiple invocations of the kernel.
|
||||
#
|
||||
# Pseudorandom number generation in Triton is simple! In this tutorial we will use the
|
||||
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
|
||||
# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
|
||||
# other :ref:`random number generation strategies <Random Number Generation>`.
|
||||
#
|
||||
# .. note::
|
||||
# Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).
|
||||
#
|
||||
# Let's put it all together.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _seeded_dropout(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
p,
|
||||
seed,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# compute memory offsets of elements handled by this instance
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# load data from x
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
# randomly prune it
|
||||
random = tl.rand(seed, offsets)
|
||||
x_keep = random > p
|
||||
# write-back
|
||||
output = tl.where(x_keep, x / (1 - p), 0.0)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
def seeded_dropout(x, p, seed):
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_contiguous()
|
||||
n_elements = x.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
# Compare this to the baseline - dropout mask is never instantiated!
|
||||
output = seeded_dropout(x, p=0.5, seed=123)
|
||||
output2 = seeded_dropout(x, p=0.5, seed=123)
|
||||
output3 = seeded_dropout(x, p=0.5, seed=512)
|
||||
|
||||
print(tabulate.tabulate([
|
||||
["input"] + x.tolist(),
|
||||
["output (seed = 123)"] + output.tolist(),
|
||||
["output (seed = 123)"] + output2.tolist(),
|
||||
["output (seed = 512)"] + output3.tolist()
|
||||
]))
|
||||
|
||||
# %%
|
||||
# Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
|
||||
# If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
|
||||
# to explore the `triton/language/random` folder!
|
||||
|
||||
# %%
|
||||
# Exercises
|
||||
# -------------
|
||||
# 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.
|
||||
# 2. Add support for striding.
|
||||
# 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
|
||||
|
||||
# %%
|
||||
# References
|
||||
# --------------
|
||||
#
|
||||
# .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
|
||||
# .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014
|
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
|
||||
kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:
|
||||
|
||||
- Block-level matrix multiplications
|
||||
- Multi-dimensional pointer arithmetic
|
||||
- Program re-ordering for improved L2 cache hit rate
|
||||
- Automatic performance tuning
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Motivations
|
||||
# -------------
|
||||
# Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
# They are notoriously hard to optimize, hence their implementation is generally done by
|
||||
# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized
|
||||
# to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
# In this tutorial, you will learn how to implement efficient matrix multiplications by
|
||||
# yourself with Triton, in a way that is easy to customize and extend.
|
||||
#
|
||||
# Roughly speaking, the kernel that we will write will implement the following blocked
|
||||
# algorithm to multiply a (M, K) by a (K, N) matrix:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# # do in parallel
|
||||
# for m in range(0, M, BLOCK_SIZE_M):
|
||||
# # do in parallel
|
||||
# for n in range(0, N, BLOCK_SIZE_N):
|
||||
# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
|
||||
# for k in range(0, K, BLOCK_SIZE_K):
|
||||
# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
|
||||
# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
|
||||
# acc += dot(a, b)
|
||||
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
||||
#
|
||||
# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
#
|
||||
# The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the computation of the memory locations at which blocks
|
||||
# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
|
||||
# multi-dimensional pointer arithmetics.
|
||||
#
|
||||
# Pointer Arithmetics
|
||||
# ~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
||||
# y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
|
||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
||||
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
||||
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
||||
#
|
||||
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||
# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||
#
|
||||
# And then updated in the inner loop as follows:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# pa += BLOCK_SIZE_K * stride_ak;
|
||||
# pb += BLOCK_SIZE_K * stride_bk;
|
||||
#
|
||||
#
|
||||
# L2 Cache Optimizations
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
|
||||
# block of :code:`C`.
|
||||
# It is important to remember that the order in which these blocks are computed does
|
||||
# matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||||
# a simple row-major ordering
|
||||
#
|
||||
# .. code-block:: Python
|
||||
#
|
||||
# pid = triton.program_id(0);
|
||||
# grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
|
||||
# grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
|
||||
# pid_m = pid / grid_n;
|
||||
# pid_n = pid % grid_n;
|
||||
#
|
||||
# is just not going to cut it.
|
||||
#
|
||||
# One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
|
||||
# switching to the next column:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# # program ID
|
||||
# pid = tl.program_id(axis=0)
|
||||
# # number of program ids along the M axis
|
||||
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
# # number of programs ids along the N axis
|
||||
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
# # number of programs in group
|
||||
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
# # id of the group this program is in
|
||||
# group_id = pid // num_pid_in_group
|
||||
# # row-id of the first program in the group
|
||||
# first_pid_m = group_id * GROUP_SIZE_M
|
||||
# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
|
||||
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
# # *within groups*, programs are ordered in a column-major order
|
||||
# # row-id of the program in the *launch grid*
|
||||
# pid_m = first_pid_m + (pid % group_size_m)
|
||||
# # col-id of the program in the *launch grid*
|
||||
# pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
#
|
||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
# ordering, we only need to load 54 blocks.
|
||||
# .. image:: grouped_vs_row_major_ordering.png
|
||||
#
|
||||
# In practice, this can improve the performance of our matrix multiplication kernel by
|
||||
# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
#
|
||||
|
||||
# %%
|
||||
# Final Result
|
||||
# -------------
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
|
||||
# decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of
|
||||
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
|
||||
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
||||
# provided configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
# Matrix dimensions
|
||||
M, N, K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse
|
||||
# See above `L2 Cache Optimizations` section for details
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||
# see above `Pointer Arithmetics` section for details
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
# Note that for simplicity, we don't apply a mask here.
|
||||
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||
# this will access out-of-bounds memory and produce an
|
||||
# error or (worse!) incorrect results.
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
# We accumulate along the K dimension
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# you can fuse arbitrary activation functions here
|
||||
# while the accumulator is still in FP32!
|
||||
if ACTIVATION:
|
||||
accumulator = ACTIVATION(accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
x = x + 1
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
# %%
|
||||
# We can now create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
assert (
|
||||
K % 32 == 0
|
||||
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
||||
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
if triton.testing.allclose(triton_output, torch_output):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
# --------------
|
||||
#
|
||||
# Square Matrix Performance
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 33)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch_relu(torch.matmul(a, b))
|
||||
)
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, activation=leaky_relu)
|
||||
)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True, print_data=True)
|
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster
|
||||
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
|
||||
the GPU's SRAM.
|
||||
You will learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- Reduction operators in Triton.
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Motivations
|
||||
# ------------
|
||||
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def naive_softmax(x):
|
||||
"""Compute row-wise softmax of X using native pytorch
|
||||
|
||||
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
|
||||
this shift.
|
||||
"""
|
||||
# read MN elements ; write M elements
|
||||
x_max = x.max(dim=1)[0]
|
||||
# read MN + M elements ; write MN elements
|
||||
z = x - x_max[:, None]
|
||||
# read MN elements ; write MN elements
|
||||
numerator = torch.exp(z)
|
||||
# read MN elements ; write M elements
|
||||
denominator = numerator.sum(dim=1)
|
||||
# read MN + M elements ; write MN elements
|
||||
ret = numerator / denominator[:, None]
|
||||
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
|
||||
return ret
|
||||
|
||||
|
||||
# %%
|
||||
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
|
||||
# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
|
||||
# X once and does all the necessary computations on-chip.
|
||||
# Doing so would require reading and writing back only :math:`MN` bytes, so we could
|
||||
# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
|
||||
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
|
||||
# but, as we will see later, it is still far from ideal.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
# Our softmax kernel works as follows: each program loads a row of the input matrix X,
|
||||
# normalizes it and writes back the result to the output Y.
|
||||
# Note that one important limitation of Triton is that each block must have a
|
||||
# power-of-two number of elements, so we need to internally "pad" each row and guard the
|
||||
# memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
row_idx = tl.program_id(0)
|
||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
# The block size is the next power of two greater than n_cols, so we can fit each
|
||||
# row in a single block
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
# %%
|
||||
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
def softmax(x):
|
||||
n_rows, n_cols = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK_SIZE = triton.next_power_of_2(n_cols)
|
||||
# Another trick we can use is to ask the compiler to use more threads per row by
|
||||
# increasing the number of warps (`num_warps`) over which each row is distributed.
|
||||
# You will see in the next tutorial how to auto-tune this value in a more natural
|
||||
# way so you don't have to come up with manual heuristics yourself.
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
|
||||
# f the input matrix
|
||||
softmax_kernel[(n_rows,)](
|
||||
y,
|
||||
x,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
n_cols,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# ----------
|
||||
|
||||
# %%
|
||||
# We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
|
||||
# This will allow us to verify that our padding mechanism works.
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(1823, 781, device='cuda')
|
||||
y_triton = softmax(x)
|
||||
y_torch = torch.softmax(x, axis=1)
|
||||
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
|
||||
|
||||
# %%
|
||||
# As expected, the results are identical.
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
# -------------
|
||||
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 100)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=[
|
||||
'triton',
|
||||
'torch-native',
|
||||
'torch-jit',
|
||||
], # possible values for `line_arg``
|
||||
line_names=[
|
||||
"Triton",
|
||||
"Torch (native)",
|
||||
"Torch (jit)",
|
||||
], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
if provider == 'torch-native':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
|
||||
if provider == 'torch-jit':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
|
||||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
|
||||
# %%
|
||||
# In the above plot, we can see that:
|
||||
#
|
||||
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
@@ -0,0 +1,140 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Vector Addition\nIn this tutorial, you will write a simple vector addition using Triton and learn about:\n\n- The basic programming model of Triton\n- The `triton.jit` decorator, which is used to define Triton kernels.\n- The best practices for validating and benchmarking your custom ops against native reference implementations\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process\n # NOTE: `constexpr` so it can be used as a shape value\n):\n # There are multiple 'program's processing different data. We identify which program\n # we are here\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n # This program will process inputs that are offset from the initial data.\n # for instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extra elements in case the input is not a\n # multiple of the block size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM\n tl.store(output_ptr + offsets, output, mask=mask)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's also declare a helper function to (1) allocate the `z` tensor\nand (2) enqueue the above kernel with appropriate grid/block sizes.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def add(x: torch.Tensor, y: torch.Tensor):\n # We need to preallocate the output\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n # In this case, we use a 1D grid where the size is the number of blocks\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n # NOTE:\n # - each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return output"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems like we're good to go!\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmark\nWe can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.\nTo make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops\nfor different problem sizes.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['size'], # argument names to use as an x-axis for the plot\n x_vals=[\n 2 ** i for i in range(12, 28, 1)\n ], # different possible values for `x_name`\n x_log=True, # x axis is logarithmic\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['triton', 'torch'], # possible values for `line_arg`\n line_names=['Triton', 'Torch'], # label name for the lines\n styles=[('blue', '-'), ('green', '-')], # line styles\n ylabel='GB/s', # label name for the y-axis\n plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.\n args={}, # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(size, provider):\n x = torch.rand(size, device='cuda', dtype=torch.float32)\n y = torch.rand(size, device='cuda', dtype=torch.float32)\n if provider == 'torch':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))\n gbps = lambda ms: 12 * size / ms * 1e-6\n return gbps(ms), gbps(max_ms), gbps(min_ms)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or\n`save_path='/path/to/results/' to save them to disk along with raw CSV data\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"benchmark.run(print_data=True, show_plots=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
BIN
master/_images/cuda-parallel-matmul.png
Normal file
After Width: | Height: | Size: 9.5 KiB |
BIN
master/_images/grouped_vs_row_major_ordering.png
Normal file
After Width: | Height: | Size: 465 KiB |
BIN
master/_images/halide-iteration.png
Normal file
After Width: | Height: | Size: 12 KiB |
BIN
master/_images/polyhedral-iteration.png
Normal file
After Width: | Height: | Size: 59 KiB |
BIN
master/_images/sphx_glr_01-vector-add_001.png
Normal file
After Width: | Height: | Size: 23 KiB |
BIN
master/_images/sphx_glr_01-vector-add_thumb.png
Normal file
After Width: | Height: | Size: 15 KiB |
BIN
master/_images/sphx_glr_02-fused-softmax_001.png
Normal file
After Width: | Height: | Size: 37 KiB |
BIN
master/_images/sphx_glr_02-fused-softmax_thumb.png
Normal file
After Width: | Height: | Size: 23 KiB |
BIN
master/_images/sphx_glr_03-matrix-multiplication_001.png
Normal file
After Width: | Height: | Size: 58 KiB |
BIN
master/_images/sphx_glr_03-matrix-multiplication_thumb.png
Normal file
After Width: | Height: | Size: 33 KiB |
BIN
master/_images/sphx_glr_04-low-memory-dropout_thumb.png
Normal file
After Width: | Height: | Size: 26 KiB |
BIN
master/_images/sphx_glr_05-layer-norm_001.png
Normal file
After Width: | Height: | Size: 36 KiB |
BIN
master/_images/sphx_glr_05-layer-norm_thumb.png
Normal file
After Width: | Height: | Size: 22 KiB |
BIN
master/_images/triton-parallel-matmul.png
Normal file
After Width: | Height: | Size: 3.0 KiB |
55
master/_sources/getting-started/installation.rst.txt
Normal file
@@ -0,0 +1,55 @@
|
||||
==============
|
||||
Installation
|
||||
==============
|
||||
|
||||
---------------------
|
||||
Binary Distributions
|
||||
---------------------
|
||||
|
||||
You can install the latest stable release of Triton from pip:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install triton
|
||||
|
||||
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
|
||||
|
||||
And the latest nightly release:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U --pre triton
|
||||
|
||||
|
||||
--------------
|
||||
From Source
|
||||
--------------
|
||||
|
||||
+++++++++++++++
|
||||
Python Package
|
||||
+++++++++++++++
|
||||
|
||||
You can install the Python package from source by running the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/openai/triton.git;
|
||||
cd triton/python;
|
||||
pip install cmake; # build time dependency
|
||||
pip install -e .
|
||||
|
||||
Note that, if llvm-11 is not present on your system, the setup.py script will download the official LLVM11 static libraries link against that.
|
||||
|
||||
You can then test your installation by running the unit tests:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -e '.[tests]'
|
||||
pytest -vs test/unit/
|
||||
|
||||
and the benchmarks
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd bench/
|
||||
python -m run --with-plots --result-dir /tmp/triton-bench
|
286
master/_sources/getting-started/tutorials/01-vector-add.rst.txt
Normal file
@@ -0,0 +1,286 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/01-vector-add.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_01-vector-add.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_01-vector-add.py:
|
||||
|
||||
|
||||
Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic programming model of Triton
|
||||
- The `triton.jit` decorator, which is used to define Triton kernels.
|
||||
- The best practices for validating and benchmarking your custom ops against native reference implementations
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-14
|
||||
|
||||
Compute Kernel
|
||||
--------------------------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-50
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extra elements in case the input is not a
|
||||
# multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 51-53
|
||||
|
||||
Let's also declare a helper function to (1) allocate the `z` tensor
|
||||
and (2) enqueue the above kernel with appropriate grid/block sizes.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 53-74
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
# We need to preallocate the output
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_cuda and y.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
# In this case, we use a 1D grid where the size is the number of blocks
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# NOTE:
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously at this point.
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 75-76
|
||||
|
||||
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 76-90
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
|
||||
The maximum difference between torch and triton is 0.0
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 91-92
|
||||
|
||||
Seems like we're good to go!
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 94-99
|
||||
|
||||
Benchmark
|
||||
-----------
|
||||
We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
||||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
||||
for different problem sizes.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 99-128
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
2 ** i for i in range(12, 28, 1)
|
||||
], # different possible values for `x_name`
|
||||
x_log=True, # x axis is logarithmic
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['triton', 'torch'], # possible values for `line_arg`
|
||||
line_names=['Triton', 'Torch'], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-')], # line styles
|
||||
ylabel='GB/s', # label name for the y-axis
|
||||
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(size, provider):
|
||||
x = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
y = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
|
||||
gbps = lambda ms: 12 * size / ms * 1e-6
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 129-131
|
||||
|
||||
We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
|
||||
`save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 131-132
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
|
||||
:alt: 01 vector add
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
vector-add-performance:
|
||||
size Triton Torch
|
||||
0 4096.0 9.600000 9.600000
|
||||
1 8192.0 19.200000 19.200000
|
||||
2 16384.0 38.400001 38.400001
|
||||
3 32768.0 63.999998 63.999998
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 341.333321
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 702.171410
|
||||
10 4194304.0 780.190482 780.190482
|
||||
11 8388608.0 812.429770 812.429770
|
||||
12 16777216.0 833.084721 833.084721
|
||||
13 33554432.0 842.004273 842.004273
|
||||
14 67108864.0 847.448255 848.362445
|
||||
15 134217728.0 849.737435 850.656574
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 39.514 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 01-vector-add.py <01-vector-add.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 01-vector-add.ipynb <01-vector-add.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -0,0 +1,337 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/02-fused-softmax.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_02-fused-softmax.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
||||
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster
|
||||
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
|
||||
the GPU's SRAM.
|
||||
You will learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- Reduction operators in Triton.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-18
|
||||
|
||||
Motivations
|
||||
------------
|
||||
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 18-46
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def naive_softmax(x):
|
||||
"""Compute row-wise softmax of X using native pytorch
|
||||
|
||||
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
|
||||
this shift.
|
||||
"""
|
||||
# read MN elements ; write M elements
|
||||
x_max = x.max(dim=1)[0]
|
||||
# read MN + M elements ; write MN elements
|
||||
z = x - x_max[:, None]
|
||||
# read MN elements ; write MN elements
|
||||
numerator = torch.exp(z)
|
||||
# read MN elements ; write M elements
|
||||
denominator = numerator.sum(dim=1)
|
||||
# read MN + M elements ; write MN elements
|
||||
ret = numerator / denominator[:, None]
|
||||
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 47-55
|
||||
|
||||
When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
|
||||
requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
|
||||
X once and does all the necessary computations on-chip.
|
||||
Doing so would require reading and writing back only :math:`MN` bytes, so we could
|
||||
expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
|
||||
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
|
||||
but, as we will see later, it is still far from ideal.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 57-64
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
Our softmax kernel works as follows: each program loads a row of the input matrix X,
|
||||
normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a
|
||||
power-of-two number of elements, so we need to internally "pad" each row and guard the
|
||||
memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 64-93
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
row_idx = tl.program_id(0)
|
||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
# The block size is the next power of two greater than n_cols, so we can fit each
|
||||
# row in a single block
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||
# Substract maximum for numerical stability
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 94-95
|
||||
|
||||
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 95-125
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
def softmax(x):
|
||||
n_rows, n_cols = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK_SIZE = triton.next_power_of_2(n_cols)
|
||||
# Another trick we can use is to ask the compiler to use more threads per row by
|
||||
# increasing the number of warps (`num_warps`) over which each row is distributed.
|
||||
# You will see in the next tutorial how to auto-tune this value in a more natural
|
||||
# way so you don't have to come up with manual heuristics yourself.
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
# Allocate output
|
||||
y = torch.empty_like(x)
|
||||
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
|
||||
# f the input matrix
|
||||
softmax_kernel[(n_rows,)](
|
||||
y,
|
||||
x,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
n_cols,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 126-128
|
||||
|
||||
Unit Test
|
||||
----------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 130-132
|
||||
|
||||
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
|
||||
This will allow us to verify that our padding mechanism works.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 132-139
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(1823, 781, device='cuda')
|
||||
y_triton = softmax(x)
|
||||
y_torch = torch.softmax(x, axis=1)
|
||||
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 140-141
|
||||
|
||||
As expected, the results are identical.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 143-147
|
||||
|
||||
Benchmark
|
||||
-------------
|
||||
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 147-186
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 100)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=[
|
||||
'triton',
|
||||
'torch-native',
|
||||
'torch-jit',
|
||||
], # possible values for `line_arg``
|
||||
line_names=[
|
||||
"Triton",
|
||||
"Torch (native)",
|
||||
"Torch (jit)",
|
||||
], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
if provider == 'torch-native':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
|
||||
if provider == 'torch-jit':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
|
||||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
|
||||
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png
|
||||
:alt: 02 fused softmax
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
softmax-performance:
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 512.000001 512.000001 190.511628
|
||||
1 384.0 614.400016 585.142862 153.600004
|
||||
2 512.0 655.360017 585.142849 154.566038
|
||||
3 640.0 706.206879 640.000002 158.759699
|
||||
4 768.0 722.823517 664.216187 162.754967
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 406.179533 198.631953
|
||||
94 12288.0 812.429770 415.881552 198.995960
|
||||
95 12416.0 812.498981 412.149375 198.556711
|
||||
96 12544.0 812.566838 412.546756 198.815254
|
||||
97 12672.0 811.007961 412.097543 198.971549
|
||||
|
||||
[98 rows x 4 columns]
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 187-192
|
||||
|
||||
In the above plot, we can see that:
|
||||
|
||||
- Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||
- Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 3 minutes 22.699 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -0,0 +1,530 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/03-matrix-multiplication.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
||||
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
|
||||
kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:
|
||||
|
||||
- Block-level matrix multiplications
|
||||
- Multi-dimensional pointer arithmetic
|
||||
- Program re-ordering for improved L2 cache hit rate
|
||||
- Automatic performance tuning
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 15-42
|
||||
|
||||
Motivations
|
||||
-------------
|
||||
Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
They are notoriously hard to optimize, hence their implementation is generally done by
|
||||
hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized
|
||||
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
In this tutorial, you will learn how to implement efficient matrix multiplications by
|
||||
yourself with Triton, in a way that is easy to customize and extend.
|
||||
|
||||
Roughly speaking, the kernel that we will write will implement the following blocked
|
||||
algorithm to multiply a (M, K) by a (K, N) matrix:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# do in parallel
|
||||
for m in range(0, M, BLOCK_SIZE_M):
|
||||
# do in parallel
|
||||
for n in range(0, N, BLOCK_SIZE_N):
|
||||
acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
|
||||
b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
|
||||
acc += dot(a, b)
|
||||
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
||||
|
||||
where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-137
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
|
||||
The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the computation of the memory locations at which blocks
|
||||
of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
|
||||
multi-dimensional pointer arithmetics.
|
||||
|
||||
Pointer Arithmetics
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
||||
y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
|
||||
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
||||
:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
||||
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
||||
|
||||
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||
|
||||
And then updated in the inner loop as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pa += BLOCK_SIZE_K * stride_ak;
|
||||
pb += BLOCK_SIZE_K * stride_bk;
|
||||
|
||||
|
||||
L2 Cache Optimizations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
|
||||
block of :code:`C`.
|
||||
It is important to remember that the order in which these blocks are computed does
|
||||
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
||||
a simple row-major ordering
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
pid = triton.program_id(0);
|
||||
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
|
||||
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
|
||||
pid_m = pid / grid_n;
|
||||
pid_n = pid % grid_n;
|
||||
|
||||
is just not going to cut it.
|
||||
|
||||
One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
|
||||
switching to the next column:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# program ID
|
||||
pid = tl.program_id(axis=0)
|
||||
# number of program ids along the M axis
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
# number of programs ids along the N axis
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
# number of programs in group
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
# id of the group this program is in
|
||||
group_id = pid // num_pid_in_group
|
||||
# row-id of the first program in the group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
# *within groups*, programs are ordered in a column-major order
|
||||
# row-id of the program in the *launch grid*
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
# col-id of the program in the *launch grid*
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||
we can see that if we compute the output in row-major ordering, we need to load 90
|
||||
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
||||
ordering, we only need to load 54 blocks.
|
||||
.. image:: grouped_vs_row_major_ordering.png
|
||||
|
||||
In practice, this can improve the performance of our matrix multiplication kernel by
|
||||
more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 139-142
|
||||
|
||||
Final Result
|
||||
-------------
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 142-259
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
|
||||
# decorator, which consumes:
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of
|
||||
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
|
||||
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
||||
# provided configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
# Matrix dimensions
|
||||
M, N, K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse
|
||||
# See above `L2 Cache Optimizations` section for details
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||
# see above `Pointer Arithmetics` section for details
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
# Note that for simplicity, we don't apply a mask here.
|
||||
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||
# this will access out-of-bounds memory and produce an
|
||||
# error or (worse!) incorrect results.
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
# We accumulate along the K dimension
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# you can fuse arbitrary activation functions here
|
||||
# while the accumulator is still in FP32!
|
||||
if ACTIVATION:
|
||||
accumulator = ACTIVATION(accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
x = x + 1
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 260-262
|
||||
|
||||
We can now create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 262-291
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
assert b.is_contiguous(), "matrix B must be contiguous"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
assert (
|
||||
K % 32 == 0
|
||||
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 292-296
|
||||
|
||||
Unit Test
|
||||
-----------
|
||||
|
||||
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 296-309
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
if triton.testing.allclose(triton_output, torch_output):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||||
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
|
||||
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
|
||||
...,
|
||||
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||||
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
|
||||
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
|
||||
...,
|
||||
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||||
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
✅ Triton and Torch match
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 310-316
|
||||
|
||||
Benchmark
|
||||
--------------
|
||||
|
||||
Square Matrix Performance
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 316-357
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 33)
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch_relu(torch.matmul(a, b))
|
||||
)
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, activation=leaky_relu)
|
||||
)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
|
||||
:alt: 03 matrix multiplication
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
matmul-performance:
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 256.0 2.730667 ... 2.978909 2.978909
|
||||
1 384.0 7.372800 ... 8.507077 8.507077
|
||||
2 512.0 14.563555 ... 15.420235 15.420235
|
||||
3 640.0 22.260869 ... 24.380953 24.380953
|
||||
4 768.0 32.768000 ... 35.389441 34.028308
|
||||
5 896.0 39.025776 ... 41.321411 39.025776
|
||||
6 1024.0 49.932191 ... 53.773130 52.428801
|
||||
7 1152.0 45.242181 ... 48.161033 47.396572
|
||||
8 1280.0 51.200001 ... 57.690139 57.690139
|
||||
9 1408.0 64.138541 ... 68.147202 67.305878
|
||||
10 1536.0 80.430545 ... 80.430545 78.643199
|
||||
11 1664.0 63.372618 ... 63.372618 62.492442
|
||||
12 1792.0 72.983276 ... 63.499573 63.142831
|
||||
13 1920.0 69.120002 ... 71.626943 71.257735
|
||||
14 2048.0 73.908442 ... 78.033565 76.959706
|
||||
15 2176.0 83.155572 ... 87.115360 85.632545
|
||||
16 2304.0 68.251065 ... 78.064941 76.809875
|
||||
17 2432.0 71.125224 ... 75.522751 74.521127
|
||||
18 2560.0 77.833728 ... 82.331658 80.908642
|
||||
19 2688.0 84.108772 ... 90.966561 89.254248
|
||||
20 2816.0 83.552120 ... 83.712490 83.552120
|
||||
21 2944.0 82.237674 ... 84.324925 83.899046
|
||||
22 3072.0 81.825298 ... 89.735509 89.170242
|
||||
23 3200.0 84.432717 ... 96.240602 94.674553
|
||||
24 3328.0 82.939284 ... 86.736504 86.113988
|
||||
25 3456.0 82.688790 ... 88.014813 81.269178
|
||||
26 3584.0 86.457107 ... 99.684470 99.025764
|
||||
27 3712.0 83.247783 ... 89.594031 85.675250
|
||||
28 3840.0 85.070769 ... 93.090912 87.980905
|
||||
29 3968.0 93.648452 ... 86.114283 87.284643
|
||||
30 4096.0 89.627865 ... 89.240508 93.792965
|
||||
|
||||
[31 rows x 5 columns]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 6 minutes 16.582 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 03-matrix-multiplication.py <03-matrix-multiplication.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 03-matrix-multiplication.ipynb <03-matrix-multiplication.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -0,0 +1,271 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/04-low-memory-dropout.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_04-low-memory-dropout.py:
|
||||
|
||||
|
||||
Low-Memory Dropout
|
||||
=================
|
||||
|
||||
In this tutorial, you will write a memory-efficient implementation of dropout whose state
|
||||
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
|
||||
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:
|
||||
|
||||
- The limitations of naive implementations of Dropout with PyTorch
|
||||
- Parallel pseudo-random number generation in Triton
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 14-29
|
||||
|
||||
Baseline
|
||||
-------------
|
||||
The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
|
||||
of deep neural networks in low-data regime (i.e. regularization).
|
||||
|
||||
It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
|
||||
output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input.
|
||||
This forces the network to perform well even when only :math:`1 - p` scalars from the input are available.
|
||||
|
||||
At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would
|
||||
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
|
||||
in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which
|
||||
keeps the norm consistent regardless of the dropout probability.
|
||||
|
||||
Let's first take a look at the baseline implementation.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 29-82
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
import tabulate
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dropout(
|
||||
x_ptr, # pointer to the input
|
||||
x_keep_ptr, # pointer to a mask of 0s and 1s
|
||||
output_ptr, # pointer to the output
|
||||
n_elements, # number of elements in the `x` tensor
|
||||
p, # probability that an element of `x` is changed to zero
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
# Load data
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
|
||||
# The line below is the crucial part, described in the paragraph above!
|
||||
output = tl.where(x_keep, x / (1 - p), 0.0)
|
||||
# Write-back output
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
def dropout(x, x_keep, p):
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_contiguous()
|
||||
n_elements = x.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
# Input tensor
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
# Dropout mask
|
||||
p = 0.5
|
||||
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
|
||||
#
|
||||
output = dropout(x, x_keep=x_keep, p=p)
|
||||
print(tabulate.tabulate([
|
||||
["input"] + x.tolist(),
|
||||
["keep mask"] + x_keep.tolist(),
|
||||
["output"] + output.tolist()
|
||||
]))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
|
||||
input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344
|
||||
keep mask 1 1 0 1 0 1 1 0 0 0
|
||||
output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0
|
||||
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 83-101
|
||||
|
||||
Seeded dropout
|
||||
-------------
|
||||
Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
|
||||
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
|
||||
very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in
|
||||
https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation
|
||||
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
|
||||
of persisting randomness across multiple invocations of the kernel.
|
||||
|
||||
Pseudorandom number generation in Triton is simple! In this tutorial we will use the
|
||||
:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
|
||||
values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
|
||||
other :ref:`random number generation strategies <Random Number Generation>`.
|
||||
|
||||
.. note::
|
||||
Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).
|
||||
|
||||
Let's put it all together.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 101-149
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _seeded_dropout(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
p,
|
||||
seed,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# compute memory offsets of elements handled by this instance
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# load data from x
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
# randomly prune it
|
||||
random = tl.rand(seed, offsets)
|
||||
x_keep = random > p
|
||||
# write-back
|
||||
output = tl.where(x_keep, x / (1 - p), 0.0)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
def seeded_dropout(x, p, seed):
|
||||
output = torch.empty_like(x)
|
||||
assert x.is_contiguous()
|
||||
n_elements = x.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
# Compare this to the baseline - dropout mask is never instantiated!
|
||||
output = seeded_dropout(x, p=0.5, seed=123)
|
||||
output2 = seeded_dropout(x, p=0.5, seed=123)
|
||||
output3 = seeded_dropout(x, p=0.5, seed=512)
|
||||
|
||||
print(tabulate.tabulate([
|
||||
["input"] + x.tolist(),
|
||||
["output (seed = 123)"] + output.tolist(),
|
||||
["output (seed = 123)"] + output2.tolist(),
|
||||
["output (seed = 512)"] + output3.tolist()
|
||||
]))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
|
||||
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
|
||||
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
|
||||
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
|
||||
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
|
||||
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 150-153
|
||||
|
||||
Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
|
||||
If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
|
||||
to explore the `triton/language/random` folder!
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 155-160
|
||||
|
||||
Exercises
|
||||
-------------
|
||||
1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.
|
||||
2. Add support for striding.
|
||||
3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 162-167
|
||||
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
|
||||
.. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 0.484 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 04-low-memory-dropout.py <04-low-memory-dropout.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 04-low-memory-dropout.ipynb <04-low-memory-dropout.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
420
master/_sources/getting-started/tutorials/05-layer-norm.rst.txt
Normal file
@@ -0,0 +1,420 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/05-layer-norm.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_05-layer-norm.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_05-layer-norm.py:
|
||||
|
||||
|
||||
Layer Normalization
|
||||
====================
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 5-312
|
||||
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png
|
||||
:alt: 05 layer norm
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
layer-norm:
|
||||
N Triton Torch Apex
|
||||
0 1024.0 585.142849 277.694907 468.114273
|
||||
1 1536.0 630.153868 323.368435 511.999982
|
||||
2 2048.0 682.666643 334.367358 520.126988
|
||||
3 2560.0 694.237267 362.477870 512.000013
|
||||
4 3072.0 712.347810 375.206126 501.551037
|
||||
5 3584.0 725.873439 384.859062 458.751978
|
||||
6 4096.0 728.177767 381.023256 458.293714
|
||||
7 4608.0 670.254540 394.267384 426.173427
|
||||
8 5120.0 688.403381 397.669909 426.666652
|
||||
9 5632.0 704.000002 395.228063 413.357796
|
||||
10 6144.0 702.171410 402.885254 409.600010
|
||||
11 6656.0 705.271522 398.861429 400.360920
|
||||
12 7168.0 690.891575 396.844306 387.459443
|
||||
13 7680.0 686.480466 392.587863 387.634072
|
||||
14 8192.0 636.271854 393.609605 371.308771
|
||||
15 8704.0 630.153861 389.005597 380.502740
|
||||
16 9216.0 609.322328 407.337026 383.999986
|
||||
17 9728.0 589.575753 409.599987 383.369452
|
||||
18 10240.0 568.888869 408.578556 382.803739
|
||||
19 10752.0 551.384634 411.559798 381.445676
|
||||
20 11264.0 536.380957 406.826188 373.134567
|
||||
21 11776.0 523.377770 409.599991 377.587162
|
||||
22 12288.0 517.389457 413.911572 383.251457
|
||||
23 12800.0 505.679014 410.420828 376.470582
|
||||
24 13312.0 494.180982 405.699062 376.310952
|
||||
25 13824.0 482.934503 411.888257 379.389355
|
||||
26 14336.0 471.967074 406.695045 374.185964
|
||||
27 14848.0 461.297068 408.192434 374.712936
|
||||
28 15360.0 454.269882 406.214870 378.092307
|
||||
29 15872.0 447.887117 406.974373 376.225175
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
try:
|
||||
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
|
||||
# should not be added to extras_require in setup.py.
|
||||
import apex
|
||||
HAS_APEX = True
|
||||
except ModuleNotFoundError:
|
||||
HAS_APEX = False
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
Out,
|
||||
A,
|
||||
Weight,
|
||||
Bias,
|
||||
Mean, Rstd,
|
||||
stride, N, eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
Out += row * stride
|
||||
A += row * stride
|
||||
# compute mean
|
||||
mean = 0
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
a = tl.where(cols < N, a - mean, 0.)
|
||||
_var += a * a
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# write-back mean/rstd
|
||||
tl.store(Mean + row, mean)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# multiply by weight and add bias
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
weight = tl.load(Weight + cols, mask=mask)
|
||||
bias = tl.load(Bias + cols, mask=mask)
|
||||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
out = a_hat * weight + bias
|
||||
# # write-back
|
||||
tl.store(Out + cols, out, mask=mask)
|
||||
|
||||
# Backward pass (DA + partial DW + partial DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(
|
||||
_DA,
|
||||
_DOut,
|
||||
_A,
|
||||
Weight,
|
||||
Mean, Rstd,
|
||||
stride, NumRows, NumCols, eps,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
# position of elements processed by this program
|
||||
pid = tl.program_id(0)
|
||||
row = pid
|
||||
A = _A + row * stride
|
||||
DOut = _DOut + row * stride
|
||||
DA = _DA + row * stride
|
||||
mean = tl.load(Mean + row)
|
||||
rstd = tl.load(Rstd + row)
|
||||
# load data to SRAM
|
||||
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
_mean1 += a_hat * wdout
|
||||
_mean2 += wdout
|
||||
mean1 = tl.sum(_mean1, axis=0) / NumCols
|
||||
mean2 = 0.
|
||||
mean2 = tl.sum(_mean2, axis=0) / NumCols
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DA + cols, da, mask=mask)
|
||||
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(
|
||||
A, DOut,
|
||||
Mean, Var,
|
||||
DW,
|
||||
DB,
|
||||
M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(DB + cols, sum_db, mask=cols < N)
|
||||
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, normalized_shape, weight, bias, eps):
|
||||
# allocate output
|
||||
out = torch.empty_like(a)
|
||||
# reshape input data into 2D tensor
|
||||
a_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = a_arg.shape
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // a.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_layer_norm_fwd_fused[(M,)](
|
||||
out,
|
||||
a_arg,
|
||||
weight,
|
||||
bias,
|
||||
mean, rstd,
|
||||
a_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
a, weight, bias, mean, rstd,
|
||||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
if hasattr(bias, "config"):
|
||||
assert bias.config.grad_scale_name == weight.config.grad_scale_name
|
||||
grad_scale_name = bias.config.grad_scale_name
|
||||
else:
|
||||
grad_scale_name = None
|
||||
ctx.grad_scale_gain_bias_name = grad_scale_name
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
assert dout.is_contiguous()
|
||||
a, weight, bias, mean, var = ctx.saved_tensors
|
||||
# heuristics for amount of parallel reduction stream for DG/DB
|
||||
N = weight.shape[0]
|
||||
# allocate output
|
||||
da = torch.empty_like(dout)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
x_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
_layer_norm_bwd_dx_fused[(M,)](
|
||||
da,
|
||||
dout,
|
||||
a,
|
||||
weight,
|
||||
mean, var,
|
||||
x_arg.stride(0), M, N,
|
||||
ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
mean, var,
|
||||
dweight,
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128,
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1 * torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# forward pass
|
||||
y_tri = layer_norm(x, w_shape, weight, bias, eps)
|
||||
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
||||
# backward pass (triton)
|
||||
y_tri.backward(dy, retain_graph=True)
|
||||
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
|
||||
x.grad, weight.grad, bias.grad = None, None, None
|
||||
# backward pass (torch)
|
||||
y_ref.backward(dy, retain_graph=True)
|
||||
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(y_tri, y_ref)
|
||||
triton.testing.assert_almost_equal(dx_tri, dx_ref)
|
||||
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
||||
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'],
|
||||
x_vals=[512 * i for i in range(2, 32)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
|
||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
plot_name='layer-norm',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
|
||||
)
|
||||
)
|
||||
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1 * torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# utility functions
|
||||
if provider == 'triton':
|
||||
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
|
||||
if provider == 'torch':
|
||||
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
||||
if provider == 'apex':
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
||||
y_fwd = lambda: apex_layer_norm(x)
|
||||
# forward pass
|
||||
if mode == 'forward':
|
||||
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
|
||||
# backward pass
|
||||
if mode == 'backward':
|
||||
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
|
||||
y = y_fwd()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# test_layer_norm(1151, 8192, torch.float16)
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 5 minutes 26.747 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
152
master/_sources/getting-started/tutorials/index.rst.txt
Normal file
@@ -0,0 +1,152 @@
|
||||
:orphan:
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials:
|
||||
|
||||
Tutorials
|
||||
==================
|
||||
|
||||
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
|
||||
|
||||
To install the dependencies for the tutorials:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd triton
|
||||
pip install -e './python[tutorials]'
|
||||
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model of Triton - The triton.jit decorator, which is used to define Tri...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_01-vector-add_thumb.png
|
||||
:alt: Vector Addition
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_01-vector-add.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/01-vector-add
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - Reduction operators in Triton...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_02-fused-softmax_thumb.png
|
||||
:alt: Fused Softmax
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/02-fused-softmax
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- Block-level matrix multiplications - Multi-dimensional pointer arithmetic - Program re-orderi...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_03-matrix-multiplication_thumb.png
|
||||
:alt: Matrix Multiplication
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/03-matrix-multiplication
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, you will write a memory-efficient implementation of dropout whose state will ...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_04-low-memory-dropout_thumb.png
|
||||
:alt: Low-Memory Dropout
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/04-low-memory-dropout
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Layer Normalization">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_05-layer-norm_thumb.png
|
||||
:alt: Layer Normalization
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/05-layer-norm
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-gallery
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download all examples in Python source code: tutorials_python.zip </getting-started/tutorials/tutorials_python.zip>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download all examples in Jupyter notebooks: tutorials_jupyter.zip </getting-started/tutorials/tutorials_jupyter.zip>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -0,0 +1,20 @@
|
||||
|
||||
:orphan:
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_sg_execution_times:
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**16:46.026** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 06:16.582 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:26.747 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:22.699 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:39.514 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.484 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
52
master/_sources/index.rst.txt
Normal file
@@ -0,0 +1,52 @@
|
||||
Welcome to Triton's documentation!
|
||||
==================================
|
||||
|
||||
Triton is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
- Follow the :doc:`installation instructions <getting-started/installation>` for your platform of choice.
|
||||
- Take a look at the :doc:`tutorials <getting-started/tutorials/index>` to learn how to write your first Triton program.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Getting Started
|
||||
:hidden:
|
||||
|
||||
getting-started/installation
|
||||
getting-started/tutorials/index
|
||||
|
||||
Python API
|
||||
-------------------
|
||||
|
||||
- :doc:`triton <python-api/triton>`
|
||||
- :doc:`triton.language <python-api/triton.language>`
|
||||
- :doc:`triton.testing <python-api/triton.testing>`
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Python API
|
||||
:hidden:
|
||||
|
||||
python-api/triton
|
||||
python-api/triton.language
|
||||
python-api/triton.testing
|
||||
|
||||
|
||||
Going Further
|
||||
------------------
|
||||
|
||||
Check out the following documents to learn more about Triton and how it compares against other DSLs for DNNs:
|
||||
|
||||
- Chapter 1: :doc:`Introduction <programming-guide/chapter-1/introduction>`
|
||||
- Chapter 2: :doc:`Related Work <programming-guide/chapter-2/related-work>`
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Programming Guide
|
||||
:hidden:
|
||||
|
||||
programming-guide/chapter-1/introduction
|
||||
programming-guide/chapter-2/related-work
|
@@ -0,0 +1,69 @@
|
||||
==============
|
||||
Introduction
|
||||
==============
|
||||
|
||||
--------------
|
||||
Motivations
|
||||
--------------
|
||||
|
||||
Over the past decade, Deep Neural Networks (DNNs) have emerged as an important class of Machine Learning (ML) models, capable of achieving state-of-the-art performance across many domains ranging from natural language processing [SUTSKEVER2014]_ to computer vision [REDMON2016]_ to computational neuroscience [LEE2017]_. The strength of these models lies in their hierarchical structure, composed of a sequence of parametric (e.g., convolutional) and non-parametric (e.g., rectified linearity) *layers*. This pattern, though notoriously computationally expensive, also generates a large amount of highly parallelizable work particularly well suited for multi- and many- core processors.
|
||||
|
||||
As a consequence, Graphics Processing Units (GPUs) have become a cheap and accessible resource for exploring and/or deploying novel research ideas in the field. This trend has been accelerated by the release of several frameworks for General-Purpose GPU (GPGPU) computing, such as CUDA and OpenCL, which have made the development of high-performance programs easier. Yet, GPUs remain incredibly challenging to optimize for locality and parallelism, especially for computations that cannot be efficiently implemented using a combination of pre-existing optimized primitives. To make matters worse, GPU architectures are also rapidly evolving and specializing, as evidenced by the addition of tensor cores to NVIDIA (and more recently AMD) micro-architectures.
|
||||
|
||||
This tension between the computational opportunities offered by DNNs and the practical difficulty of GPU programming has created substantial academic and industrial interest for Domain-Specific Languages (DSLs) and compilers. Regrettably, these systems -- whether they be based on polyhedral machinery (*e.g.*, Tiramisu [BAGHDADI2021]_, Tensor Comprehensions [VASILACHE2018]_) or scheduling languages (*e.g.*, Halide [JRK2013]_, TVM [CHEN2018]_) -- remain less flexible and (for the same algorithm) markedly slower than the best handwritten compute kernels available in libraries like `cuBLAS <https://docs.nvidia.com/cuda/cublas/index.html>`_, `cuDNN <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>`_ or `TensorRT <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html>`_.
|
||||
|
||||
The main premise of this project is the following: programming paradigms based on blocked algorithms [LAM1991]_ can facilitate the construction of high-performance compute kernels for neural networks. We specifically revisit traditional "Single Program, Multiple Data" (SPMD [AUGUIN1983]_) execution models for GPUs, and propose a variant in which programs -- rather than threads -- are blocked. For example, in the case of matrix multiplication, CUDA and Triton differ as follows:
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| CUDA Programming Model | Triton Programming Model |
|
||||
| | |
|
||||
| (Scalar Program, Blocked Threads) | (Blocked Program, Scalar Threads) |
|
||||
+=====================================================+=====================================================+
|
||||
| | |
|
||||
|.. code-block:: C |.. code-block:: C |
|
||||
| | :force: |
|
||||
| | |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int m = 0; i < M; m++) | for(int m = 0; m < M; m += MB) |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int n = 0; j < N; n++){ | for(int n = 0; n < N; n += NB){ |
|
||||
| float acc = 0; | float acc[MB, NB] = 0; |
|
||||
| for(int k = 0; k < K;k ++) | for(int k = 0; k < K; k += KB) |
|
||||
| acc += A[i, k]* B[k, j]; | acc += A[m:m+MB, k:k+KB] |
|
||||
| | @ B[k:k+KB, n:n+NB]; |
|
||||
| C[i, j] = acc; | C[m:m+MB, n:n+NB] = acc; |
|
||||
| } | } |
|
||||
| | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| |pic1| | |pic2| |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
|
||||
.. |pic1| image:: cuda-parallel-matmul.png
|
||||
|
||||
.. |pic2| image:: triton-parallel-matmul.png
|
||||
|
||||
A key benefit of this approach is that it leads to block-structured iteration spaces that offer programmers more flexibility than existing DSLs when implementing sparse operations, all while allowing compilers to aggressively optimize programs for data locality and parallelism.
|
||||
|
||||
--------------
|
||||
Challenges
|
||||
--------------
|
||||
|
||||
The main challenge posed by our proposed paradigm is that of work scheduling, i.e., how the work done by each program instance should be partitioned for efficient execution on modern GPUs. To address this issue, the Triton compiler makes heavy use of *block-level data-flow analysis*, a technique for scheduling iteration blocks statically based on the control- and data-flow structure of the target program. The resulting system actually works surprisingly well: our compiler manages to apply a broad range of interesting optimization automatically (e.g., automatic coalescing, thread swizzling, pre-fetching, automatic vectorization, tensor core-aware instruction selection, shared memory allocation/synchronization, asynchronous copy scheduling). Of course doing all this is not trivial; one of the purposes of this guide is to give you a sense of how it works.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [SUTSKEVER2014] I. Sutskever et al., "Sequence to Sequence Learning with Neural Networks", NIPS 2014
|
||||
.. [REDMON2016] J. Redmon et al., "You Only Look Once: Unified, Real-Time Object Detection", CVPR 2016
|
||||
.. [LEE2017] K. Lee et al., "Superhuman Accuracy on the SNEMI3D Connectomics Challenge", ArXiV 2017
|
||||
.. [BAGHDADI2021] R. Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [VASILACHE2018] N. Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [JRK2013] J. Ragan-Kelley et al., "Halide: A Language and Compiler for Optimizing Parallelism, Locality, and Recomputation in Image Processing Pipelines", PLDI 2013
|
||||
.. [CHEN2018] T. Chen et al., "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning", OSDI 2018
|
||||
.. [LAM1991] M. Lam et al., "The Cache Performance and Optimizations of Blocked Algorithms", ASPLOS 1991
|
||||
.. [AUGUIN1983] M. Auguin et al., "Opsila: an advanced SIMD for numerical analysis and signal processing", EUROMICRO 1983
|
210
master/_sources/programming-guide/chapter-2/related-work.rst.txt
Normal file
@@ -0,0 +1,210 @@
|
||||
==============
|
||||
Related Work
|
||||
==============
|
||||
|
||||
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlight its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
|
||||
|
||||
-----------------------
|
||||
Polyhedral Compilation
|
||||
-----------------------
|
||||
|
||||
Traditional compilers typically rely on intermediate representations, such as LLVM-IR [LATTNER2004]_, that encode control flow information using (un)conditional branches. This relatively low-level format makes it difficult to statically analyze the runtime behavior (e.g., cache misses) of input programs, and to automatically optimize loops accordingly through the use of tiling [WOLFE1989]_, fusion [DARTE1999]_ and interchange [ALLEN1984]_. To solve this issue, polyhedral compilers [ANCOURT1991]_ rely on program representations that have statically predictable control flow, thereby enabling aggressive compile-time program transformations for data locality and parallelism. Though this strategy has been adopted by many languages and compilers for DNNs such as Tiramisu [BAGHDADI2021]_, Tensor Comprehensions [VASILACHE2018]_, Diesel [ELANGO2018]_ and the Affine dialect in MLIR [LATTNER2019]_, it also comes with a number of limitations that will be described later in this section.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Program Representation
|
||||
+++++++++++++++++++++++
|
||||
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic1| |
|
||||
| | |
|
||||
| for(int i = 0; i < 3; i++) | |
|
||||
| for(int j = i; j < 5; j++) | |
|
||||
| A[i][j] = 0; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic1| image:: polyhedral-iteration.png
|
||||
:width: 300
|
||||
|
||||
Polyhedral compilers focus on a class of programs commonly known as **Static Control Parts** (SCoP), *i.e.*, maximal sets of consecutive statements in which conditionals and loop bounds are affine functions of surrounding loop indices and global invariant parameters. As shown above, programs in this format always lead to iteration domains that are bounded by affine inequalities, i.e., polyhedral. These polyhedra can also be defined algebraically; for the above example:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathcal{P} = \{ i, j \in \mathbb{Z}^2
|
||||
~|~
|
||||
\begin{pmatrix}
|
||||
1 & 0 \\
|
||||
-1 & 0 \\
|
||||
-1 & 1 \\
|
||||
0 & -1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i \\
|
||||
j
|
||||
\end{pmatrix}
|
||||
+
|
||||
\begin{pmatrix}
|
||||
0 \\
|
||||
2 \\
|
||||
0 \\
|
||||
4
|
||||
\end{pmatrix}
|
||||
\geq
|
||||
0
|
||||
\}
|
||||
|
||||
|
||||
Each point :math:`(i, j)` in :math:`\mathcal{P}` represents a *polyhedral statement*, that is a program statement which (1) does not induce control-flow side effects (e.g., :code:`for`, :code:`if`, :code:`break`) and (2) contains only affine functions of loop indices and global parameters in array accesses. To facilitate alias analysis, array accesses are also mathematically abstracted, using so-called *access function*. In other words, :code:`A[i][j]` is simply :code:`A[f(i,j)]` where the access function :math:`f` is defined by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(i, j) = \begin{pmatrix}
|
||||
1 & 0\\
|
||||
0 & 1\\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i\\
|
||||
j
|
||||
\end{pmatrix}
|
||||
=
|
||||
(i, j)
|
||||
|
||||
|
||||
Note that the iteration domains of an SCoP does not specify the order in which its statements shall execute. In fact, this iteration domain may be traversed in many different possible legal orders, i.e. *schedules*. Formally, a schedule is defined as a p-dimensional affine transformation :math:`\Theta` of loop indices :math:`\mathbf{x}` and global invariant parameters :math:`\mathbf{g}`:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = T_S \begin{pmatrix}
|
||||
\vec{x}\\
|
||||
\vec{g}\\
|
||||
1
|
||||
\end{pmatrix}
|
||||
\qquad
|
||||
T_S \in \mathbb{Z} ^{p \times (\text{dim}(\mathbf{x}) + \text{dim}(\mathbf{g}) + 1)}
|
||||
|
||||
|
||||
Where :math:`\Theta_S(\mathbf{x})` is a p-dimensional vector representing the slowest to fastest growing indices (from left to right) when traversing the loop nest surrounding :math:`S`. For the code shown above, the original schedule defined by the loop nest in C can be retrieved by using:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = \begin{pmatrix}
|
||||
1 & 0 \\
|
||||
0 & 1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
=
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
|
||||
|
||||
where :math:`i` and :math:`j` are respectively the slowest and fastest growing loop indices in the nest. If :math:`T_S` is a vector (resp. tensor), then :math:`\Theta_S` is a said to be one-dimensional (resp. multi-dimensional).
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
Programs amenable to polyhedral compilation can be aggressively transformed and optimized. Most of these transformations actually boil down to the production of schedules and iteration domains that enable loop transformations promoting parallelism and spatial/temporal data locality (e.g., fusion, interchange, tiling, parallelization).
|
||||
|
||||
Polyhedral compilers can also automatically go through complex verification processes to ensure that the semantics of their input program is preserved throughout this optimization phase. Note that polyhedral optimizers are not incompatible with more standard optimization techniques. In fact, it is not uncommon for these systems to be implemented as a set of LLVM passes that can be run ahead of more traditional compilation techniques [GROSSER2012]_.
|
||||
|
||||
All in all, polyhedral machinery is extremely powerful, when applicable. It has been shown to support most common loop transformations, and has indeed achieved performance comparable to state-of-the-art GPU libraries for dense matrix multiplication [ELANGO2018]_. Additionally, it is also fully automatic and doesn't require any hint from programmers apart from source-code in a C-like format.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
Unfortunately, polyhedral compilers suffer from two major limitations that have prevented its adoption as a universal method for code generation in neural networks.
|
||||
|
||||
First, the set of possible program transformations :math:`\Omega = \{ \Theta_S ~|~ S \in \text{program} \}` is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [SATO2019]_.
|
||||
|
||||
Second, the polyhedral framework is not very generally applicable; SCoPs are relatively common [GIRBAL2006]_ but require loop bounds and array subscripts to be affine functions of loop indices, which typically only occurs in regular, dense computations. For this reason, this framework still has to be successfully applied to sparse -- or even structured-sparse -- neural networks, whose importance has been rapidly rising over the past few years.
|
||||
|
||||
On the other hand, blocked program representations advocated by this dissertation are less restricted in scope and can achieve close to peak performance using standard dataflow analysis.
|
||||
|
||||
-----------------------
|
||||
Scheduling Languages
|
||||
-----------------------
|
||||
|
||||
Separation of concerns [DIJKSTRA82]_ is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
|
||||
|
||||
.. code-block:: python
|
||||
:linenos:
|
||||
|
||||
// algorithm
|
||||
Var x("x"), y("y");
|
||||
Func matmul("matmul");
|
||||
RDom k(0, matrix_size);
|
||||
RVar ki;
|
||||
matmul(x, y) = 0.0f;
|
||||
matmul(x, y) += A(k, y) * B(x, k);
|
||||
// schedule
|
||||
Var xi("xi"), xo("xo"), yo("yo"), yi("yo"), yii("yii"), xii("xii");
|
||||
matmul.vectorize(x, 8);
|
||||
matmul.update(0)
|
||||
.split(x, x, xi, block_size).split(xi, xi, xii, 8)
|
||||
.split(y, y, yi, block_size).split(yi, yi, yii, 4)
|
||||
.split(k, k, ki, block_size)
|
||||
.reorder(xii, yii, xi, ki, yi, k, x, y)
|
||||
.parallel(y).vectorize(xii).unroll(xi).unroll(yii);
|
||||
|
||||
|
||||
The resulting code may however not be completely portable, as schedules can sometimes rely on execution models (e.g., SPMD) or hardware intrinsics (e.g., matrix-multiply-accumulate) that are not widely available. This issue can be mitigated by auto-scheduling mechanisms [MULLAPUDI2016]_.
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
The main advantage of this approach is that it allows programmers to write an algorithm *only once*, and focus on performance optimization separately. It makes it possible to manually specify optimizations that a polyhedral compiler wouldn't be able to figure out automatically using static data-flow analysis.
|
||||
|
||||
Scheduling languages are, without a doubt, one of the most popular approaches for neural network code generation. The most popular system for this purpose is probably TVM, which provides good performance across a wide range of platforms as well as built-in automatic scheduling mechanisms.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic2| |
|
||||
| | |
|
||||
| for(int i = 0; i < 4; i++) | |
|
||||
| for(int j = 0; j < 4; j++) | |
|
||||
| float acc = 0; | |
|
||||
| for(int k = 0; k < K[i]; k++) | |
|
||||
| acc += A[i][col[i,k]]*B[k][j] | |
|
||||
| C[i][j] = acc; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic2| image:: halide-iteration.png
|
||||
:width: 300
|
||||
|
||||
On the other hand, the block-based program representation that we advocate for through this work allows for block-structured iteration spaces and allows programmers to manually handle load-balancing as they wish.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [LATTNER2004] C. Lattner et al., "LLVM: a compilation framework for lifelong program analysis transformation", CGO 2004
|
||||
.. [WOLFE1989] M. Wolfe, "More Iteration Space Tiling", SC 1989
|
||||
.. [DARTE1999] A. Darte, "On the Complexity of Loop Fusion", PACT 1999
|
||||
.. [ALLEN1984] J. Allen et al., "Automatic Loop Interchange", SIGPLAN Notices 1984
|
||||
.. [ANCOURT1991] C. Ancourt et al., "Scanning Polyhedra with DO Loops", PPoPP 1991
|
||||
.. [BAGHDADI2021] R. Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [VASILACHE2018] N. Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [ELANGO2018] V. Elango et al. "Diesel: DSL for Linear Algebra and Neural Net Computations on GPUs", MAPL 2018
|
||||
.. [LATTNER2019] C. Lattner et al., "MLIR Primer: A Compiler Infrastructure for the End of Moore’s Law", Arxiv 2019
|
||||
.. [GROSSER2012] T. Grosser et al., "Polly - Performing Polyhedral Optimizations on a Low-Level Intermediate Representation", Parallel Processing Letters 2012
|
||||
.. [SATO2019] Y. Sato et al., "An Autotuning Framework for Scalable Execution of Tiled Code via Iterative Polyhedral Compilation", TACO 2019
|
||||
.. [GIRBAL2006] S. Girbal et al., "Semi-Automatic Composition of Loop Transformations for Deep Parallelism and Memory Hierarchies", International Journal of Parallel Programming 2006
|
||||
.. [DIJKSTRA82] E. W. Dijkstra et al., "On the role of scientific thought", Selected writings on computing: a personal perspective 1982
|
||||
.. [MULLAPUDI2016] R. Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016
|
22
master/_sources/python-api/generated/triton.Config.rst.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
triton.Config
|
||||
=============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autoclass:: Config
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Config.__init__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@@ -0,0 +1,6 @@
|
||||
triton.autotune
|
||||
===============
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: autotune
|
@@ -0,0 +1,6 @@
|
||||
triton.heuristics
|
||||
=================
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: heuristics
|
6
master/_sources/python-api/generated/triton.jit.rst.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
triton.jit
|
||||
==========
|
||||
|
||||
.. currentmodule:: triton
|
||||
|
||||
.. autofunction:: jit
|
@@ -0,0 +1,6 @@
|
||||
triton.language.arange
|
||||
======================
|
||||
|
||||
.. currentmodule:: triton.language
|
||||
|
||||
.. autofunction:: arange
|
@@ -0,0 +1,6 @@
|
||||
triton.language.atomic\_add
|
||||
===========================
|
||||
|
||||
.. currentmodule:: triton.language
|
||||
|
||||
.. autofunction:: atomic_add
|
@@ -0,0 +1,6 @@
|
||||
triton.language.atomic\_cas
|
||||
===========================
|
||||
|
||||
.. currentmodule:: triton.language
|
||||
|
||||
.. autofunction:: atomic_cas
|