[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2022-02-09 07:15:50 +00:00
parent 796d260cbd
commit e3b4440dff
300 changed files with 15022 additions and 14 deletions

View File

@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 4dce8c5cbe57d2011293861b3d5fb4a0
config: bd537efaf126fb307d18a393c4eacf10
tags: 645f666f9bcd5a90fca523b33c5a78b7

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation that is significantly faster\nthan PyTorch's native op for a particular class of matrices: those whose rows can fit in\nthe GPU's SRAM.\nYou will learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- Reduction operators in Triton.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Motivations\nCustom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.\nLet us consider instead the case of a simple (numerically stabilized) softmax operation:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@torch.jit.script\ndef naive_softmax(x):\n \"\"\"Compute row-wise softmax of X using native pytorch\n\n We subtract the maximum element in order to avoid overflows. Softmax is invariant to\n this shift.\n \"\"\"\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read MN + M elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read MN + M elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements\n return ret"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$\nrequires reading $5MN + 2M$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads\nX once and does all the necessary computations on-chip.\nDoing so would require reading and writing back only $MN$ bytes, so we could\nexpect a theoretical speed-up of ~4x (i.e., $(8MN + 4M) / 2MN$).\nThe `torch.jit.script` flags aims to perform this kind of \"kernel fusion\" automatically\nbut, as we will see later, it is still far from ideal.\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X,\nnormalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a\npower-of-two number of elements, so we need to internally \"pad\" each row and guard the\nmemory operations properly if we want to handle any possible input shapes:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"@triton.jit\ndef softmax_kernel(\n output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,\n BLOCK_SIZE: tl.constexpr\n):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Substract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def softmax(x):\n n_rows, n_cols = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o\n # f the input matrix\n softmax_kernel[(n_rows,)](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Unit Test\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We make sure that we test our kernel on a matrix with an irregular number of rows and columns.\nThis will allow us to verify that our padding mechanism works.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"torch.manual_seed(0)\nx = torch.randn(1823, 781, device='cuda')\ny_triton = softmax(x)\ny_torch = torch.softmax(x, axis=1)\nassert torch.allclose(y_triton, y_torch), (y_triton, y_torch)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As expected, the results are identical.\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark\nHere we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.\nWe will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'], # argument names to use as an x-axis for the plot\n x_vals=[\n 128 * i for i in range(2, 100)\n ], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=[\n 'triton',\n 'torch-native',\n 'torch-jit',\n ], # possible values for `line_arg``\n line_names=[\n \"Triton\",\n \"Torch (native)\",\n \"Torch (jit)\",\n ], # label name for the lines\n styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"softmax-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(M, N, provider):\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n if provider == 'torch-native':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))\n if provider == 'torch-jit':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))\n gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbenchmark.run(show_plots=True, print_data=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the above plot, we can see that:\n\n - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.\n - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.\n Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -0,0 +1,131 @@
"""
Vector Addition
=================
In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic programming model of Triton
- The `triton.jit` decorator, which is used to define Triton kernels.
- The best practices for validating and benchmarking your custom ops against native reference implementations
"""
# %%
# Compute Kernel
# --------------------------
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
# %%
# Let's also declare a helper function to (1) allocate the `z` tensor
# and (2) enqueue the above kernel with appropriate grid/block sizes.
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
# In this case, we use a 1D grid where the size is the number of blocks
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
# %%
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
# %%
# Seems like we're good to go!
# %%
# Benchmark
# -----------
# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
# for different problem sizes.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[
2 ** i for i in range(12, 28, 1)
], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=['Triton', 'Torch'], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel='GB/s', # label name for the y-axis
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)
# %%
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
benchmark.run(print_data=True, show_plots=True)

View File

@@ -0,0 +1,261 @@
"""
Layer Normalization
====================
"""
import torch
import triton
import triton.language as tl
try:
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
# should not be added to extras_require in setup.py.
import apex
HAS_APEX = True
except ModuleNotFoundError:
HAS_APEX = False
# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
BLOCK_SIZE: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
Y += row * stride
# load data and cast to float32
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
# compute mean
mean = tl.sum(x, axis=0) / N
# compute std
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean * rstd
# write-back mean/rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
# multiply by weight and add bias
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
y = xhat * w + b
# write-back
tl.store(Y + cols, y, mask=mask)
# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
DY += row * stride
DX += row * stride
# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW and DB into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# first store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# release lock
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
x, w, b, m, v = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
return dx, None, dw, db, None
layer_norm = LayerNorm.apply
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[512 * i for i in range(2, 32)],
line_arg='provider',
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,100 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Low-Memory Dropout\n\nIn this tutorial, you will write a memory-efficient implementation of dropout whose state\nwill be composed of a single int32 seed. This differs from more traditional implementations of dropout,\nwhose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:\n\n- The limitations of naive implementations of Dropout with PyTorch\n- Parallel pseudo-random number generation in Triton\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Baseline\nThe *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance\nof deep neural networks in low-data regime (i.e. regularization).\n\nIt takes a vector as input and produces a vector of the same shape as output. Each scalar in the\noutput has a probability $p$ of being changed to zero and otherwise it is copied from the input.\nThis forces the network to perform well even when only $1 - p$ scalars from the input are available.\n\nAt evaluation time we want to use the full power of the network so we set $p=0$. Naively this would\nincrease the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease\nin the output softmax temperature). To prevent this we multiply the output by $\\frac{1}{1 - p}$, which\nkeeps the norm consistent regardless of the dropout probability.\n\nLet's first take a look at the baseline implementation.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import tabulate\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n# Input tensor\nx = torch.randn(size=(10,)).cuda()\n# Dropout mask\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\n#\noutput = dropout(x, x_keep=x_keep, p=p)\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"keep mask\"] + x_keep.tolist(),\n [\"output\"] + output.tolist()\n]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Seeded dropout\nAbove implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly\nwe need to store the dropout mask for backpropagation. Secondly, dropout state management can get\nvery tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in\nhttps://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation\nthat (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management\nof persisting randomness across multiple invocations of the kernel.\n\nPseudorandom number generation in Triton is simple! In this tutorial we will use the\n:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`\nvalues in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides\nother `random number generation strategies <Random Number Generation>`.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).</p></div>\n\nLet's put it all together.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\nx = torch.randn(size=(10,)).cuda()\n# Compare this to the baseline - dropout mask is never instantiated!\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\n\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"output (seed = 123)\"] + output.tolist(),\n [\"output (seed = 123)\"] + output2.tolist(),\n [\"output (seed = 512)\"] + output3.tolist()\n]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Et Voil\u00e0! We have a triton kernel that applies the same dropout mask provided the seed is the same!\nIf you'd like explore further applications of pseudorandomness in GPU programming, we encourage you\nto explore the `triton/language/random` folder!\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercises\n1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.\n2. Add support for striding.\n3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n\n.. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, \"Parallel Random Numbers: As Easy as 1, 2, 3\", 2011\n.. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, \"Dropout: A Simple Way to Prevent Neural Networks from Overfitting\", JMLR 2014\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -0,0 +1,166 @@
"""
Low-Memory Dropout
=================
In this tutorial, you will write a memory-efficient implementation of dropout whose state
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:
- The limitations of naive implementations of Dropout with PyTorch
- Parallel pseudo-random number generation in Triton
"""
# %%
# Baseline
# -------------
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
# of deep neural networks in low-data regime (i.e. regularization).
#
# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
# output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input.
# This forces the network to perform well even when only :math:`1 - p` scalars from the input are available.
#
# At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would
# increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
# in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which
# keeps the norm consistent regardless of the dropout probability.
#
# Let's first take a look at the baseline implementation.
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
# The line below is the crucial part, described in the paragraph above!
output = tl.where(x_keep, x / (1 - p), 0.0)
# Write-back output
tl.store(output_ptr + offsets, output, mask=mask)
def dropout(x, x_keep, p):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
["input"] + x.tolist(),
["keep mask"] + x_keep.tolist(),
["output"] + output.tolist()
]))
# %%
# Seeded dropout
# -------------
# Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
# we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
# very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in
# https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation
# that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
# of persisting randomness across multiple invocations of the kernel.
#
# Pseudorandom number generation in Triton is simple! In this tutorial we will use the
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
# other :ref:`random number generation strategies <Random Number Generation>`.
#
# .. note::
# Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).
#
# Let's put it all together.
@triton.jit
def _seeded_dropout(
x_ptr,
output_ptr,
n_elements,
p,
seed,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# load data from x
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# randomly prune it
random = tl.rand(seed, offsets)
x_keep = random > p
# write-back
output = tl.where(x_keep, x / (1 - p), 0.0)
tl.store(output_ptr + offsets, output, mask=mask)
def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
return output
x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)
print(tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist()
]))
# %%
# Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
# If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
# to explore the `triton/language/random` folder!
# %%
# Exercises
# -------------
# 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.
# 2. Add support for striding.
# 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
# %%
# References
# --------------
#
# .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
# .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014

View File

@@ -0,0 +1,355 @@
"""
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
kernel that achieves performance on par with cuBLAS.
You will specifically learn about:
- Block-level matrix multiplications
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
"""
# %%
# Motivations
# -------------
# Matrix multiplications are a key building block of most modern high-performance computing systems.
# They are notoriously hard to optimize, hence their implementation is generally done by
# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
# Unfortunately, these libraries are often proprietary and cannot be easily customized
# to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
# In this tutorial, you will learn how to implement efficient matrix multiplications by
# yourself with Triton, in a way that is easy to customize and extend.
#
# Roughly speaking, the kernel that we will write will implement the following blocked
# algorithm to multiply a (M, K) by a (K, N) matrix:
#
# .. code-block:: python
#
# # do in parallel
# for m in range(0, M, BLOCK_SIZE_M):
# # do in parallel
# for n in range(0, N, BLOCK_SIZE_N):
# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
# for k in range(0, K, BLOCK_SIZE_K):
# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
# acc += dot(a, b)
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
#
# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
# %%
# Compute Kernel
# ----------------
#
# The above algorithm is, actually, fairly straightforward to implement in Triton.
# The main difficulty comes from the computation of the memory locations at which blocks
# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
# multi-dimensional pointer arithmetics.
#
# Pointer Arithmetics
# ~~~~~~~~~~~~~~~~~~~~
#
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
# y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
#
# .. code-block:: python
#
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
#
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
#
# .. code-block:: python
#
# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# offs_k = tl.arange(0, BLOCK_SIZE_K)
# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
#
# And then updated in the inner loop as follows:
#
# .. code-block:: python
#
# pa += BLOCK_SIZE_K * stride_ak;
# pb += BLOCK_SIZE_K * stride_bk;
#
#
# L2 Cache Optimizations
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
# block of :code:`C`.
# It is important to remember that the order in which these blocks are computed does
# matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
# a simple row-major ordering
#
# .. code-block:: Python
#
# pid = triton.program_id(0);
# grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
# grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
# pid_m = pid / grid_n;
# pid_n = pid % grid_n;
#
# is just not going to cut it.
#
# One possible solution is to launch blocks in an order that promotes data reuse.
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
# switching to the next column:
#
# .. code-block:: python
#
# # program ID
# pid = tl.program_id(axis=0)
# # number of program ids along the M axis
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# # number of programs ids along the N axis
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# # number of programs in group
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
# # id of the group this program is in
# group_id = pid // num_pid_in_group
# # row-id of the first program in the group
# first_pid_m = group_id * GROUP_SIZE_M
# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# # *within groups*, programs are ordered in a column-major order
# # row-id of the program in the *launch grid*
# pid_m = first_pid_m + (pid % group_size_m)
# # col-id of the program in the *launch grid*
# pid_n = (pid % num_pid_in_group) // group_size_m
#
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
# we can see that if we compute the output in row-major ordering, we need to load 90
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
# ordering, we only need to load 54 blocks.
# .. image:: grouped_vs_row_major_ordering.png
#
# In practice, this can improve the performance of our matrix multiplication kernel by
# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
#
# %%
# Final Result
# -------------
#
import torch
import triton
import triton.language as tl
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
# decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
# See above `L2 Cache Optimizations` section for details
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
# see above `Pointer Arithmetics` section for details
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K,
# this will access out-of-bounds memory and produce an
# error or (worse!) incorrect results.
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION:
accumulator = ACTIVATION(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
# %%
# We can now create a convenience wrapper function that only takes two input tensors
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
assert (
K % 32 == 0
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
# allocates output
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=activation,
)
return c
# %%
# Unit Test
# -----------
#
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b, activation=None)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
if triton.testing.allclose(triton_output, torch_output):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
# %%
# Benchmark
# --------------
#
# Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 33)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
# label name for the lines
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={},
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_relu(torch.matmul(a, b))
)
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, activation=leaky_relu)
)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True)

View File

@@ -0,0 +1,191 @@
"""
Fused Softmax
=================
In this tutorial, you will write a fused softmax operation that is significantly faster
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
the GPU's SRAM.
You will learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- Reduction operators in Triton.
"""
# %%
# Motivations
# ------------
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
@torch.jit.script
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
# %%
# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
# X once and does all the necessary computations on-chip.
# Doing so would require reading and writing back only :math:`MN` bytes, so we could
# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
# but, as we will see later, it is still far from ideal.
# %%
# Compute Kernel
# ----------------
# Our softmax kernel works as follows: each program loads a row of the input matrix X,
# normalizes it and writes back the result to the output Y.
# Note that one important limitation of Triton is that each block must have a
# power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Substract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
def softmax(x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
# %%
# Unit Test
# ----------
# %%
# We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
# This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
# %%
# As expected, the results are identical.
# %%
# Benchmark
# -------------
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 100)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=[
'triton',
'torch-native',
'torch-jit',
], # possible values for `line_arg``
line_names=[
"Triton",
"Torch (native)",
"Torch (jit)",
], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
if provider == 'torch-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'torch-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(show_plots=True, print_data=True)
# %%
# In the above plot, we can see that:
#
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.

View File

@@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Vector Addition\nIn this tutorial, you will write a simple vector addition using Triton and learn about:\n\n- The basic programming model of Triton\n- The `triton.jit` decorator, which is used to define Triton kernels.\n- The best practices for validating and benchmarking your custom ops against native reference implementations\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute Kernel\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process\n # NOTE: `constexpr` so it can be used as a shape value\n):\n # There are multiple 'program's processing different data. We identify which program\n # we are here\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n # This program will process inputs that are offset from the initial data.\n # for instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extra elements in case the input is not a\n # multiple of the block size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM\n tl.store(output_ptr + offsets, output, mask=mask)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's also declare a helper function to (1) allocate the `z` tensor\nand (2) enqueue the above kernel with appropriate grid/block sizes.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def add(x: torch.Tensor, y: torch.Tensor):\n # We need to preallocate the output\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n # In this case, we use a 1D grid where the size is the number of blocks\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n # NOTE:\n # - each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Seems like we're good to go!\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark\nWe can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.\nTo make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops\nfor different problem sizes.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['size'], # argument names to use as an x-axis for the plot\n x_vals=[\n 2 ** i for i in range(12, 28, 1)\n ], # different possible values for `x_name`\n x_log=True, # x axis is logarithmic\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['triton', 'torch'], # possible values for `line_arg`\n line_names=['Triton', 'Torch'], # label name for the lines\n styles=[('blue', '-'), ('green', '-')], # line styles\n ylabel='GB/s', # label name for the y-axis\n plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.\n args={}, # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(size, provider):\n x = torch.rand(size, device='cuda', dtype=torch.float32)\n y = torch.rand(size, device='cuda', dtype=torch.float32)\n if provider == 'torch':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))\n gbps = lambda ms: 12 * size / ms * 1e-6\n return gbps(ms), gbps(max_ms), gbps(min_ms)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or\n`save_path='/path/to/results/' to save them to disk along with raw CSV data\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"benchmark.run(print_data=True, show_plots=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 465 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@@ -0,0 +1,286 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/01-vector-add.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_01-vector-add.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_01-vector-add.py:
Vector Addition
=================
In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic programming model of Triton
- The `triton.jit` decorator, which is used to define Triton kernels.
- The best practices for validating and benchmarking your custom ops against native reference implementations
.. GENERATED FROM PYTHON SOURCE LINES 12-14
Compute Kernel
--------------------------
.. GENERATED FROM PYTHON SOURCE LINES 14-50
.. code-block:: default
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
.. GENERATED FROM PYTHON SOURCE LINES 51-53
Let's also declare a helper function to (1) allocate the `z` tensor
and (2) enqueue the above kernel with appropriate grid/block sizes.
.. GENERATED FROM PYTHON SOURCE LINES 53-74
.. code-block:: default
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
# In this case, we use a 1D grid where the size is the number of blocks
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
.. GENERATED FROM PYTHON SOURCE LINES 75-76
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
.. GENERATED FROM PYTHON SOURCE LINES 76-90
.. code-block:: default
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0
.. GENERATED FROM PYTHON SOURCE LINES 91-92
Seems like we're good to go!
.. GENERATED FROM PYTHON SOURCE LINES 94-99
Benchmark
-----------
We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
for different problem sizes.
.. GENERATED FROM PYTHON SOURCE LINES 99-128
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # argument names to use as an x-axis for the plot
x_vals=[
2 ** i for i in range(12, 28, 1)
], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=['Triton', 'Torch'], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel='GB/s', # label name for the y-axis
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)
.. GENERATED FROM PYTHON SOURCE LINES 129-131
We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
`save_path='/path/to/results/' to save them to disk along with raw CSV data
.. GENERATED FROM PYTHON SOURCE LINES 131-132
.. code-block:: default
benchmark.run(print_data=True, show_plots=True)
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
:alt: 01 vector add
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
vector-add-performance:
size Triton Torch
0 4096.0 9.600000 9.600000
1 8192.0 19.200000 19.200000
2 16384.0 38.400001 38.400001
3 32768.0 63.999998 63.999998
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 384.000001 384.000001
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
10 4194304.0 780.190482 780.190482
11 8388608.0 812.429770 812.429770
12 16777216.0 833.084721 833.084721
13 33554432.0 842.004273 842.004273
14 67108864.0 847.448255 848.362445
15 134217728.0 849.737435 850.656574
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 42.537 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 01-vector-add.py <01-vector-add.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 01-vector-add.ipynb <01-vector-add.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,337 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/02-fused-softmax.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_02-fused-softmax.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_02-fused-softmax.py:
Fused Softmax
=================
In this tutorial, you will write a fused softmax operation that is significantly faster
than PyTorch's native op for a particular class of matrices: those whose rows can fit in
the GPU's SRAM.
You will learn about:
- The benefits of kernel fusion for bandwidth-bound operations.
- Reduction operators in Triton.
.. GENERATED FROM PYTHON SOURCE LINES 14-18
Motivations
------------
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
.. GENERATED FROM PYTHON SOURCE LINES 18-46
.. code-block:: default
import torch
import triton
import triton.language as tl
@torch.jit.script
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
.. GENERATED FROM PYTHON SOURCE LINES 47-55
When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only :math:`MN` bytes, so we could
expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
but, as we will see later, it is still far from ideal.
.. GENERATED FROM PYTHON SOURCE LINES 57-64
Compute Kernel
----------------
Our softmax kernel works as follows: each program loads a row of the input matrix X,
normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a
power-of-two number of elements, so we need to internally "pad" each row and guard the
memory operations properly if we want to handle any possible input shapes:
.. GENERATED FROM PYTHON SOURCE LINES 64-93
.. code-block:: default
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Substract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
.. GENERATED FROM PYTHON SOURCE LINES 94-95
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
.. GENERATED FROM PYTHON SOURCE LINES 95-125
.. code-block:: default
def softmax(x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
.. GENERATED FROM PYTHON SOURCE LINES 126-128
Unit Test
----------
.. GENERATED FROM PYTHON SOURCE LINES 130-132
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.
.. GENERATED FROM PYTHON SOURCE LINES 132-139
.. code-block:: default
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
.. GENERATED FROM PYTHON SOURCE LINES 140-141
As expected, the results are identical.
.. GENERATED FROM PYTHON SOURCE LINES 143-147
Benchmark
-------------
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
.. GENERATED FROM PYTHON SOURCE LINES 147-186
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 100)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=[
'triton',
'torch-native',
'torch-jit',
], # possible values for `line_arg``
line_names=[
"Triton",
"Torch (native)",
"Torch (jit)",
], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
if provider == 'torch-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'torch-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(show_plots=True, print_data=True)
.. image:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png
:alt: 02 fused softmax
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 512.000001 546.133347 190.511628
1 384.0 614.400016 585.142862 153.600004
2 512.0 655.360017 606.814814 154.566038
3 640.0 706.206879 640.000002 160.000000
4 768.0 722.823517 664.216187 162.754967
.. ... ... ... ...
93 12160.0 814.058574 406.179533 198.733401
94 12288.0 814.111783 415.661740 198.995960
95 12416.0 814.163950 412.149375 198.556711
96 12544.0 814.214963 412.971190 198.815254
97 12672.0 814.265046 412.097543 198.873965
[98 rows x 4 columns]
.. GENERATED FROM PYTHON SOURCE LINES 187-192
In the above plot, we can see that:
- Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
- Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 3 minutes 22.066 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,529 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/03-matrix-multiplication.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_03-matrix-multiplication.py:
Matrix Multiplication
======================
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
kernel that achieves performance on par with cuBLAS.
You will specifically learn about:
- Block-level matrix multiplications
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
.. GENERATED FROM PYTHON SOURCE LINES 15-42
Motivations
-------------
Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is generally done by
hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be easily customized
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
In this tutorial, you will learn how to implement efficient matrix multiplications by
yourself with Triton, in a way that is easy to customize and extend.
Roughly speaking, the kernel that we will write will implement the following blocked
algorithm to multiply a (M, K) by a (K, N) matrix:
.. code-block:: python
# do in parallel
for m in range(0, M, BLOCK_SIZE_M):
# do in parallel
for n in range(0, N, BLOCK_SIZE_N):
acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
for k in range(0, K, BLOCK_SIZE_K):
a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
acc += dot(a, b)
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
.. GENERATED FROM PYTHON SOURCE LINES 44-137
Compute Kernel
----------------
The above algorithm is, actually, fairly straightforward to implement in Triton.
The main difficulty comes from the computation of the memory locations at which blocks
of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
multi-dimensional pointer arithmetics.
Pointer Arithmetics
~~~~~~~~~~~~~~~~~~~~
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
.. code-block:: python
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
.. code-block:: python
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
And then updated in the inner loop as follows:
.. code-block:: python
pa += BLOCK_SIZE_K * stride_ak;
pb += BLOCK_SIZE_K * stride_bk;
L2 Cache Optimizations
~~~~~~~~~~~~~~~~~~~~~~~~
As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
block of :code:`C`.
It is important to remember that the order in which these blocks are computed does
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
a simple row-major ordering
.. code-block:: Python
pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;
is just not going to cut it.
One possible solution is to launch blocks in an order that promotes data reuse.
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
switching to the next column:
.. code-block:: python
# program ID
pid = tl.program_id(axis=0)
# number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# id of the group this program is in
group_id = pid // num_pid_in_group
# row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *within groups*, programs are ordered in a column-major order
# row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m
For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
we can see that if we compute the output in row-major ordering, we need to load 90
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
ordering, we only need to load 54 blocks.
.. image:: grouped_vs_row_major_ordering.png
In practice, this can improve the performance of our matrix multiplication kernel by
more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
.. GENERATED FROM PYTHON SOURCE LINES 139-142
Final Result
-------------
.. GENERATED FROM PYTHON SOURCE LINES 142-258
.. code-block:: default
import torch
import triton
import triton.language as tl
# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
# decorator, which consumes:
# - A list of :code:`triton.Config` objects that define different configurations of
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
# See above `L2 Cache Optimizations` section for details
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
# see above `Pointer Arithmetics` section for details
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K,
# this will access out-of-bounds memory and produce an
# error or (worse!) incorrect results.
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION:
accumulator = ACTIVATION(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
.. GENERATED FROM PYTHON SOURCE LINES 259-261
We can now create a convenience wrapper function that only takes two input tensors
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
.. GENERATED FROM PYTHON SOURCE LINES 261-290
.. code-block:: default
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
assert (
K % 32 == 0
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
# allocates output
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=activation,
)
return c
.. GENERATED FROM PYTHON SOURCE LINES 291-295
Unit Test
-----------
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
.. GENERATED FROM PYTHON SOURCE LINES 295-308
.. code-block:: default
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b, activation=None)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
if triton.testing.allclose(triton_output, torch_output):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
...,
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
...,
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
.. GENERATED FROM PYTHON SOURCE LINES 309-315
Benchmark
--------------
Square Matrix Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
.. GENERATED FROM PYTHON SOURCE LINES 315-356
.. code-block:: default
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 33)
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
# label name for the lines
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={},
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_relu(torch.matmul(a, b))
)
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, activation=leaky_relu)
)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True)
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
:alt: 03 matrix multiplication
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
1 384.0 7.372800 ... 8.507077 7.899428
2 512.0 14.563555 ... 16.384000 16.384000
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 34.028308 34.028308
5 896.0 39.025776 ... 39.025776 39.025776
6 1024.0 49.932191 ... 52.428801 52.428801
7 1152.0 45.242181 ... 46.656000 46.656000
8 1280.0 51.200001 ... 56.888887 56.888887
9 1408.0 64.138541 ... 67.305878 66.485074
10 1536.0 80.430545 ... 79.526831 78.643199
11 1664.0 62.929456 ... 62.492442 62.061463
12 1792.0 72.512412 ... 71.588687 72.047592
13 1920.0 69.467336 ... 70.172588 70.172588
14 2048.0 73.262953 ... 76.608294 76.608294
15 2176.0 83.500614 ... 85.998493 85.998493
16 2304.0 68.446623 ... 76.563695 76.319081
17 2432.0 71.305746 ... 74.521127 85.134737
18 2560.0 78.019048 ... 80.908642 81.108913
19 2688.0 82.823267 ... 89.888756 89.464755
20 2816.0 83.233226 ... 82.916747 82.759409
21 2944.0 81.967162 ... 82.509987 82.237674
22 3072.0 82.181572 ... 88.197981 87.651868
23 3200.0 84.099871 ... 93.567248 94.674553
24 3328.0 79.901550 ... 84.945483 84.496824
25 3456.0 81.600781 ... 91.511426 84.775569
26 3584.0 83.798127 ... 86.542919 95.148565
27 3712.0 81.615477 ... 88.640059 82.017526
28 3840.0 84.744825 ... 91.777595 85.399230
29 3968.0 92.024087 ... 84.183469 89.558851
30 4096.0 86.424811 ... 87.438257 91.992956
[31 rows x 5 columns]
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 6 minutes 4.559 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 03-matrix-multiplication.py <03-matrix-multiplication.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 03-matrix-multiplication.ipynb <03-matrix-multiplication.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,271 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/04-low-memory-dropout.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_04-low-memory-dropout.py:
Low-Memory Dropout
=================
In this tutorial, you will write a memory-efficient implementation of dropout whose state
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:
- The limitations of naive implementations of Dropout with PyTorch
- Parallel pseudo-random number generation in Triton
.. GENERATED FROM PYTHON SOURCE LINES 14-29
Baseline
-------------
The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
of deep neural networks in low-data regime (i.e. regularization).
It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input.
This forces the network to perform well even when only :math:`1 - p` scalars from the input are available.
At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which
keeps the norm consistent regardless of the dropout probability.
Let's first take a look at the baseline implementation.
.. GENERATED FROM PYTHON SOURCE LINES 29-82
.. code-block:: default
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
# The line below is the crucial part, described in the paragraph above!
output = tl.where(x_keep, x / (1 - p), 0.0)
# Write-back output
tl.store(output_ptr + offsets, output, mask=mask)
def dropout(x, x_keep, p):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
["input"] + x.tolist(),
["keep mask"] + x_keep.tolist(),
["output"] + output.tolist()
]))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344
keep mask 1 1 0 1 0 1 1 0 0 0
output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
.. GENERATED FROM PYTHON SOURCE LINES 83-101
Seeded dropout
-------------
Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in
https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
of persisting randomness across multiple invocations of the kernel.
Pseudorandom number generation in Triton is simple! In this tutorial we will use the
:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
other :ref:`random number generation strategies <Random Number Generation>`.
.. note::
Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).
Let's put it all together.
.. GENERATED FROM PYTHON SOURCE LINES 101-149
.. code-block:: default
@triton.jit
def _seeded_dropout(
x_ptr,
output_ptr,
n_elements,
p,
seed,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# load data from x
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# randomly prune it
random = tl.rand(seed, offsets)
x_keep = random > p
# write-back
output = tl.where(x_keep, x / (1 - p), 0.0)
tl.store(output_ptr + offsets, output, mask=mask)
def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
return output
x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)
print(tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist()
]))
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
.. GENERATED FROM PYTHON SOURCE LINES 150-153
Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
to explore the `triton/language/random` folder!
.. GENERATED FROM PYTHON SOURCE LINES 155-160
Exercises
-------------
1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row.
2. Add support for striding.
3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
.. GENERATED FROM PYTHON SOURCE LINES 162-167
References
--------------
.. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011
.. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.480 seconds)
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 04-low-memory-dropout.py <04-low-memory-dropout.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 04-low-memory-dropout.ipynb <04-low-memory-dropout.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,370 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/05-layer-norm.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_05-layer-norm.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_05-layer-norm.py:
Layer Normalization
====================
.. GENERATED FROM PYTHON SOURCE LINES 5-262
.. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png
:alt: 05 layer norm
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
layer-norm-backward:
N Triton Torch Apex
0 1024.0 307.200008 98.303995 307.200008
1 1536.0 347.773587 134.050910 341.333333
2 2048.0 420.102553 161.684218 325.509933
3 2560.0 458.507457 181.238943 326.808501
4 3072.0 511.999982 191.999993 317.793096
5 3584.0 547.872604 208.271186 310.527060
6 4096.0 564.965515 220.412561 294.323343
7 4608.0 504.986315 232.825259 290.267724
8 5120.0 527.381977 242.845844 287.775181
9 5632.0 542.843364 243.545956 288.820505
10 6144.0 546.133354 248.661056 285.767458
11 6656.0 532.479975 256.000009 285.767438
12 7168.0 507.469040 260.260201 286.242939
13 7680.0 479.999983 262.564106 279.272719
14 8192.0 462.607053 267.130429 284.526763
15 8704.0 417.791980 267.815384 284.599455
16 9216.0 431.157889 272.394084 288.751954
17 9728.0 438.857162 280.278512 290.027323
18 10240.0 449.287041 286.433562 290.153487
19 10752.0 427.231788 246.935876 290.594591
20 11264.0 426.397479 245.760001 286.676558
21 11776.0 423.089806 249.888595 288.981596
22 12288.0 419.504980 254.673582 294.323369
23 12800.0 414.016170 253.674644 288.180121
24 13312.0 411.181478 252.959629 289.916513
25 13824.0 404.112047 257.190689 292.056329
26 14336.0 393.215988 254.485198 286.719986
27 14848.0 385.245405 257.665934 289.012175
28 15360.0 373.495460 257.970599 286.211174
29 15872.0 371.637071 261.626369 289.899545
|
.. code-block:: default
import torch
import triton
import triton.language as tl
try:
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
# should not be added to extras_require in setup.py.
import apex
HAS_APEX = True
except ModuleNotFoundError:
HAS_APEX = False
# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
BLOCK_SIZE: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
Y += row * stride
# load data and cast to float32
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
# compute mean
mean = tl.sum(x, axis=0) / N
# compute std
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean * rstd
# write-back mean/rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
# multiply by weight and add bias
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
y = xhat * w + b
# write-back
tl.store(Y + cols, y, mask=mask)
# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
DY += row * stride
DX += row * stride
# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW and DB into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# first store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# release lock
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
x, w, b, m, v = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
return dx, None, dw, db, None
layer_norm = LayerNorm.apply
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[512 * i for i in range(2, 32)],
line_arg='provider',
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 2 minutes 12.003 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,152 @@
:orphan:
.. _sphx_glr_getting-started_tutorials:
Tutorials
==================
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
To install the dependencies for the tutorials:
.. code-block:: bash
cd triton
pip install -e './python[tutorials]'
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model of Triton - The triton.jit decorator, which is used to define Tri...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_01-vector-add_thumb.png
:alt: Vector Addition
:ref:`sphx_glr_getting-started_tutorials_01-vector-add.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/01-vector-add
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - Reduction operators in Triton...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_02-fused-softmax_thumb.png
:alt: Fused Softmax
:ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/02-fused-softmax
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="- Block-level matrix multiplications - Multi-dimensional pointer arithmetic - Program re-orderi...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_03-matrix-multiplication_thumb.png
:alt: Matrix Multiplication
:ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/03-matrix-multiplication
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, you will write a memory-efficient implementation of dropout whose state will ...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_04-low-memory-dropout_thumb.png
:alt: Low-Memory Dropout
:ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/04-low-memory-dropout
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Layer Normalization">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_05-layer-norm_thumb.png
:alt: Layer Normalization
:ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/05-layer-norm
.. raw:: html
<div class="sphx-glr-clear"></div>
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-gallery
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download all examples in Python source code: tutorials_python.zip </getting-started/tutorials/tutorials_python.zip>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download all examples in Jupyter notebooks: tutorials_jupyter.zip </getting-started/tutorials/tutorials_jupyter.zip>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,20 @@
:orphan:
.. _sphx_glr_getting-started_tutorials_sg_execution_times:
Computation times
=================
**13:21.645** total execution time for **getting-started_tutorials** files:
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 06:04.559 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:22.066 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 02:12.003 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:42.537 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.480 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="109" height="20"><linearGradient id="b" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="a"><rect width="109" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#a)"><path fill="#555" d="M0 0h64v20H0z"/><path fill="#579aca" d="M64 0h45v20H64z"/><path fill="url(#b)" d="M0 0h109v20H0z"/></g><g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="110"><image x="5" y="3" width="14" height="14" xlink:href=""/> <text x="415" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="370">launch</text><text x="415" y="140" transform="scale(.1)" textLength="370">launch</text><text x="855" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="350">binder</text><text x="855" y="140" transform="scale(.1)" textLength="350">binder</text></g> </svg>

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

View File

@@ -0,0 +1,6 @@
/* CSS for binder integration */
div.binder-badge {
margin: 1em auto;
vertical-align: middle;
}

View File

@@ -0,0 +1,36 @@
/* Pandas dataframe css */
/* Taken from: https://github.com/spatialaudio/nbsphinx/blob/fb3ba670fc1ba5f54d4c487573dbc1b4ecf7e9ff/src/nbsphinx.py#L587-L619 */
table.dataframe {
border: none !important;
border-collapse: collapse;
border-spacing: 0;
border-color: transparent;
color: black;
font-size: 12px;
table-layout: fixed;
}
table.dataframe thead {
border-bottom: 1px solid black;
vertical-align: bottom;
}
table.dataframe tr,
table.dataframe th,
table.dataframe td {
text-align: right;
vertical-align: middle;
padding: 0.5em 0.5em;
line-height: normal;
white-space: normal;
max-width: none;
border: none;
}
table.dataframe th {
font-weight: bold;
}
table.dataframe tbody tr:nth-child(odd) {
background: #f5f5f5;
}
table.dataframe tbody tr:hover {
background: rgba(66, 165, 245, 0.2);
}

View File

@@ -0,0 +1,209 @@
/* Adapted from notebook/static/style/style.min.css */
.rendered_html {
color: #000;
/* any extras will just be numbers: */
}
.rendered_html em {
font-style: italic;
}
.rendered_html strong {
font-weight: bold;
}
.rendered_html u {
text-decoration: underline;
}
.rendered_html :link {
text-decoration: underline;
}
.rendered_html :visited {
text-decoration: underline;
}
.rendered_html h1 {
font-size: 185.7%;
margin: 1.08em 0 0 0;
font-weight: bold;
line-height: 1.0;
}
.rendered_html h2 {
font-size: 157.1%;
margin: 1.27em 0 0 0;
font-weight: bold;
line-height: 1.0;
}
.rendered_html h3 {
font-size: 128.6%;
margin: 1.55em 0 0 0;
font-weight: bold;
line-height: 1.0;
}
.rendered_html h4 {
font-size: 100%;
margin: 2em 0 0 0;
font-weight: bold;
line-height: 1.0;
}
.rendered_html h5 {
font-size: 100%;
margin: 2em 0 0 0;
font-weight: bold;
line-height: 1.0;
font-style: italic;
}
.rendered_html h6 {
font-size: 100%;
margin: 2em 0 0 0;
font-weight: bold;
line-height: 1.0;
font-style: italic;
}
.rendered_html h1:first-child {
margin-top: 0.538em;
}
.rendered_html h2:first-child {
margin-top: 0.636em;
}
.rendered_html h3:first-child {
margin-top: 0.777em;
}
.rendered_html h4:first-child {
margin-top: 1em;
}
.rendered_html h5:first-child {
margin-top: 1em;
}
.rendered_html h6:first-child {
margin-top: 1em;
}
.rendered_html ul:not(.list-inline),
.rendered_html ol:not(.list-inline) {
padding-left: 2em;
}
.rendered_html ul {
list-style: disc;
}
.rendered_html ul ul {
list-style: square;
margin-top: 0;
}
.rendered_html ul ul ul {
list-style: circle;
}
.rendered_html ol {
list-style: decimal;
}
.rendered_html ol ol {
list-style: upper-alpha;
margin-top: 0;
}
.rendered_html ol ol ol {
list-style: lower-alpha;
}
.rendered_html ol ol ol ol {
list-style: lower-roman;
}
.rendered_html ol ol ol ol ol {
list-style: decimal;
}
.rendered_html * + ul {
margin-top: 1em;
}
.rendered_html * + ol {
margin-top: 1em;
}
.rendered_html hr {
color: black;
background-color: black;
}
.rendered_html pre {
margin: 1em 2em;
padding: 0px;
background-color: #fff;
}
.rendered_html code {
background-color: #eff0f1;
}
.rendered_html p code {
padding: 1px 5px;
}
.rendered_html pre code {
background-color: #fff;
}
.rendered_html pre,
.rendered_html code {
border: 0;
color: #000;
font-size: 100%;
}
.rendered_html blockquote {
margin: 1em 2em;
}
.rendered_html table {
margin-left: auto;
margin-right: auto;
border: none;
border-collapse: collapse;
border-spacing: 0;
color: black;
font-size: 12px;
table-layout: fixed;
}
.rendered_html thead {
border-bottom: 1px solid black;
vertical-align: bottom;
}
.rendered_html tr,
.rendered_html th,
.rendered_html td {
text-align: right;
vertical-align: middle;
padding: 0.5em 0.5em;
line-height: normal;
white-space: normal;
max-width: none;
border: none;
}
.rendered_html th {
font-weight: bold;
}
.rendered_html tbody tr:nth-child(odd) {
background: #f5f5f5;
}
.rendered_html tbody tr:hover {
background: rgba(66, 165, 245, 0.2);
}
.rendered_html * + table {
margin-top: 1em;
}
.rendered_html p {
text-align: left;
}
.rendered_html * + p {
margin-top: 1em;
}
.rendered_html img {
display: block;
margin-left: auto;
margin-right: auto;
}
.rendered_html * + img {
margin-top: 1em;
}
.rendered_html img,
.rendered_html svg {
max-width: 100%;
height: auto;
}
.rendered_html img.unconfined,
.rendered_html svg.unconfined {
max-width: none;
}
.rendered_html .alert {
margin-bottom: initial;
}
.rendered_html * + .alert {
margin-top: 1em;
}
[dir="rtl"] .rendered_html p {
text-align: right;
}

204
master/_static/gallery.css Normal file
View File

@@ -0,0 +1,204 @@
/*
Sphinx-Gallery has compatible CSS to fix default sphinx themes
Tested for Sphinx 1.3.1 for all themes: default, alabaster, sphinxdoc,
scrolls, agogo, traditional, nature, haiku, pyramid
Tested for Read the Docs theme 0.1.7 */
.sphx-glr-thumbcontainer {
background: #fff;
border: solid #fff 1px;
-moz-border-radius: 5px;
-webkit-border-radius: 5px;
border-radius: 5px;
box-shadow: none;
float: left;
margin: 5px;
min-height: 230px;
padding-top: 5px;
position: relative;
}
.sphx-glr-thumbcontainer:hover {
border: solid #b4ddfc 1px;
box-shadow: 0 0 15px rgba(142, 176, 202, 0.5);
}
.sphx-glr-thumbcontainer a.internal {
bottom: 0;
display: block;
left: 0;
padding: 150px 10px 0;
position: absolute;
right: 0;
top: 0;
}
/* Next one is to avoid Sphinx traditional theme to cover all the
thumbnail with its default link Background color */
.sphx-glr-thumbcontainer a.internal:hover {
background-color: transparent;
}
.sphx-glr-thumbcontainer p {
margin: 0 0 .1em 0;
}
.sphx-glr-thumbcontainer .figure {
margin: 10px;
width: 160px;
}
.sphx-glr-thumbcontainer img {
display: inline;
max-height: 112px;
max-width: 160px;
}
.sphx-glr-thumbcontainer[tooltip]:hover:after {
background: rgba(0, 0, 0, 0.8);
-webkit-border-radius: 5px;
-moz-border-radius: 5px;
border-radius: 5px;
color: #fff;
content: attr(tooltip);
left: 95%;
padding: 5px 15px;
position: absolute;
z-index: 98;
width: 220px;
bottom: 52%;
}
.sphx-glr-thumbcontainer[tooltip]:hover:before {
border: solid;
border-color: #333 transparent;
border-width: 18px 0 0 20px;
bottom: 58%;
content: '';
left: 85%;
position: absolute;
z-index: 99;
}
.sphx-glr-script-out {
color: #888;
margin: 0;
}
p.sphx-glr-script-out {
padding-top: 0.7em;
}
.sphx-glr-script-out .highlight {
background-color: transparent;
margin-left: 2.5em;
margin-top: -2.1em;
}
.sphx-glr-script-out .highlight pre {
background-color: #fafae2;
border: 0;
max-height: 30em;
overflow: auto;
padding-left: 1ex;
margin: 0px;
word-break: break-word;
}
.sphx-glr-script-out + p {
margin-top: 1.8em;
}
blockquote.sphx-glr-script-out {
margin-left: 0pt;
}
.sphx-glr-script-out.highlight-pytb .highlight pre {
color: #000;
background-color: #ffe4e4;
border: 1px solid #f66;
margin-top: 10px;
padding: 7px;
}
div.sphx-glr-footer {
text-align: center;
}
div.sphx-glr-download {
margin: 1em auto;
vertical-align: middle;
}
div.sphx-glr-download a {
background-color: #ffc;
background-image: linear-gradient(to bottom, #FFC, #d5d57e);
border-radius: 4px;
border: 1px solid #c2c22d;
color: #000;
display: inline-block;
font-weight: bold;
padding: 1ex;
text-align: center;
}
div.sphx-glr-download code.download {
display: inline-block;
white-space: normal;
word-break: normal;
overflow-wrap: break-word;
/* border and background are given by the enclosing 'a' */
border: none;
background: none;
}
div.sphx-glr-download a:hover {
box-shadow: inset 0 1px 0 rgba(255,255,255,.1), 0 1px 5px rgba(0,0,0,.25);
text-decoration: none;
background-image: none;
background-color: #d5d57e;
}
.sphx-glr-example-title:target::before {
display: block;
content: "";
margin-top: -50px;
height: 50px;
visibility: hidden;
}
ul.sphx-glr-horizontal {
list-style: none;
padding: 0;
}
ul.sphx-glr-horizontal li {
display: inline;
}
ul.sphx-glr-horizontal img {
height: auto !important;
}
.sphx-glr-single-img {
margin: auto;
display: block;
max-width: 100%;
}
.sphx-glr-multi-img {
max-width: 42%;
height: auto;
}
div.sphx-glr-animation {
margin: auto;
display: block;
max-width: 100%;
}
div.sphx-glr-animation .animation{
display: block;
}
p.sphx-glr-signature a.reference.external {
-moz-border-radius: 5px;
-webkit-border-radius: 5px;
border-radius: 5px;
padding: 3px;
font-size: 75%;
text-align: right;
margin-left: auto;
display: table;
}
.sphx-glr-clear{
clear: both;
}
a.sphx-glr-backref-instance {
text-decoration: none;
}

BIN
master/_static/no_image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

View File

@@ -15,6 +15,10 @@
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="_static/css/custom.css" type="text/css" />
@@ -89,6 +93,7 @@
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>

View File

@@ -15,6 +15,10 @@
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/custom.css" type="text/css" />
@@ -42,7 +46,7 @@
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="triton" href="../python-api/triton.html" />
<link rel="next" title="Tutorials" href="tutorials/index.html" />
<link rel="prev" title="Welcome to Tritons documentation!" href="../index.html" />
</head>
@@ -98,6 +102,7 @@
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="tutorials/index.html">Tutorials</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
@@ -222,7 +227,7 @@ python -m run --with-plots --result-dir /tmp/triton-bench
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="tutorials/index.html" class="btn btn-neutral float-right" title="Tutorials" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="../index.html" class="btn btn-neutral float-left" title="Welcome to Tritons documentation!" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>

View File

@@ -0,0 +1,421 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vector Addition &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
<link rel="prev" title="Tutorials" href="index.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Vector Addition</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/01-vector-add.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="vector-addition">
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline"></a></h1>
<p>In this tutorial, you will write a simple vector addition using Triton and learn about:</p>
<ul class="simple">
<li><p>The basic programming model of Triton</p></li>
<li><p>The <cite>triton.jit</cite> decorator, which is used to define Triton kernels.</p></li>
<li><p>The best practices for validating and benchmarking your custom ops against native reference implementations</p></li>
</ul>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">add_kernel</span><span class="p">(</span>
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to first input vector</span>
<span class="n">y_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to second input vector</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to output vector</span>
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="c1"># Number of elements each program should process</span>
<span class="c1"># NOTE: `constexpr` so it can be used as a shape value</span>
<span class="p">):</span>
<span class="c1"># There are multiple &#39;program&#39;s processing different data. We identify which program</span>
<span class="c1"># we are here</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># We use a 1D launch grid so axis is 0</span>
<span class="c1"># This program will process inputs that are offset from the initial data.</span>
<span class="c1"># for instance, if you had a vector of length 256 and block_size of 64, the programs</span>
<span class="c1"># would each access the elements [0:64, 64:128, 128:192, 192:256].</span>
<span class="c1"># Note that offsets is a list of pointers</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="c1"># Create a mask to guard memory operations against out-of-bounds accesses</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
<span class="c1"># Load x and y from DRAM, masking out any extra elements in case the input is not a</span>
<span class="c1"># multiple of the block size</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="c1"># Write x + y back to DRAM</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
</pre></div>
</div>
<p>Lets also declare a helper function to (1) allocate the <cite>z</cite> tensor
and (2) enqueue the above kernel with appropriate grid/block sizes.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="c1"># We need to preallocate the output</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">y</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">output</span><span class="o">.</span><span class="n">is_cuda</span>
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
<span class="c1"># The SPMD launch grid denotes the number of kernel instances that run in parallel.</span>
<span class="c1"># It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -&gt; Tuple[int]</span>
<span class="c1"># In this case, we use a 1D grid where the size is the number of blocks</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]),)</span>
<span class="c1"># NOTE:</span>
<span class="c1"># - each torch.tensor object is implicitly converted into a pointer to its first element.</span>
<span class="c1"># - `triton.jit`&#39;ed functions can be index with a launch grid to obtain a callable GPU kernel</span>
<span class="c1"># - don&#39;t forget to pass meta-parameters as keywords arguments</span>
<span class="n">add_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="c1"># We return a handle to z but, since `torch.cuda.synchronize()` hasn&#39;t been called, the kernel is still</span>
<span class="c1"># running asynchronously at this point.</span>
<span class="k">return</span> <span class="n">output</span>
</pre></div>
</div>
<p>We can now use the above function to compute the element-wise sum of two <cite>torch.tensor</cite> objects and test its correctness:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">output_torch</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">output_triton</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output_torch</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output_triton</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 0.0
</pre></div>
</div>
<p>Seems like were good to go!</p>
</div>
<div class="section" id="benchmark">
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
<p>We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
for different problem sizes.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;size&#39;</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
<span class="mi">2</span> <span class="o">**</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
<span class="n">x_log</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># x axis is logarithmic</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="s1">&#39;torch&#39;</span><span class="p">],</span> <span class="c1"># possible values for `line_arg`</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;Triton&#39;</span><span class="p">,</span> <span class="s1">&#39;Torch&#39;</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">)],</span> <span class="c1"># line styles</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s1">&#39;GB/s&#39;</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s1">&#39;vector-add-performance&#39;</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
<span class="n">args</span><span class="o">=</span><span class="p">{},</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">12</span> <span class="o">*</span> <span class="n">size</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
</pre></div>
</div>
<p>We can now run the decorated function above. Pass <cite>print_data=True</cite> to see the performance number, <cite>show_plots=True</cite> to plot them, and/or
<a href="#id1"><span class="problematic" id="id2">`</span></a>save_path=/path/to/results/ to save them to disk along with raw CSV data</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<img alt="01 vector add" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>vector-add-performance:
size Triton Torch
0 4096.0 9.600000 9.600000
1 8192.0 19.200000 19.200000
2 16384.0 38.400001 38.400001
3 32768.0 63.999998 63.999998
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 384.000001 384.000001
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
10 4194304.0 780.190482 780.190482
11 8388608.0 812.429770 812.429770
12 16777216.0 833.084721 833.084721
13 33554432.0 842.004273 842.004273
14 67108864.0 847.448255 848.362445
15 134217728.0 849.737435 850.656574
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 42.537 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">01-vector-add.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="index.html" class="btn btn-neutral float-left" title="Tutorials" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="01-vector-add.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -0,0 +1,473 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Fused Softmax &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Matrix Multiplication" href="03-matrix-multiplication.html" />
<link rel="prev" title="Vector Addition" href="01-vector-add.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Fused Softmax</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/02-fused-softmax.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="fused-softmax">
<span id="sphx-glr-getting-started-tutorials-02-fused-softmax-py"></span><h1>Fused Softmax<a class="headerlink" href="#fused-softmax" title="Permalink to this headline"></a></h1>
<p>In this tutorial, you will write a fused softmax operation that is significantly faster
than PyTorchs native op for a particular class of matrices: those whose rows can fit in
the GPUs SRAM.
You will learn about:</p>
<ul class="simple">
<li><p>The benefits of kernel fusion for bandwidth-bound operations.</p></li>
<li><p>Reduction operators in Triton.</p></li>
</ul>
<div class="section" id="motivations">
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline"></a></h2>
<p>Custom GPU kernels for elementwise additions are educationally valuable but wont get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Compute row-wise softmax of X using native pytorch</span>
<span class="sd"> We subtract the maximum element in order to avoid overflows. Softmax is invariant to</span>
<span class="sd"> this shift.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># read MN elements ; write M elements</span>
<span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># read MN + M elements ; write MN elements</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># read MN elements ; write MN elements</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="c1"># read MN elements ; write M elements</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">numerator</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># read MN + M elements ; write MN elements</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># in total: read 5MN + 2M elements ; wrote 3MN + 2M elements</span>
<span class="k">return</span> <span class="n">ret</span>
</pre></div>
</div>
<p>When implemented naively in PyTorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span>
requires reading <span class="math notranslate nohighlight">\(5MN + 2M\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
This is obviously wasteful; wed prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could
expect a theoretical speed-up of ~4x (i.e., <span class="math notranslate nohighlight">\((8MN + 4M) / 2MN\)</span>).
The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.</p>
</div>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<p>Our softmax kernel works as follows: each program loads a row of the input matrix X,
normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a
power-of-two number of elements, so we need to internally “pad” each row and guard the
memory operations properly if we want to handle any possible input shapes:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">softmax_kernel</span><span class="p">(</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="n">input_ptr</span><span class="p">,</span> <span class="n">input_row_stride</span><span class="p">,</span> <span class="n">output_row_stride</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span>
<span class="p">):</span>
<span class="c1"># The rows of the softmax are independent, so we parallelize across those</span>
<span class="n">row_idx</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># The stride represents how much we need to increase the pointer to advance 1 row</span>
<span class="n">row_start_ptr</span> <span class="o">=</span> <span class="n">input_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">input_row_stride</span>
<span class="c1"># The block size is the next power of two greater than n_cols, so we can fit each</span>
<span class="c1"># row in a single block</span>
<span class="n">col_offsets</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="n">input_ptrs</span> <span class="o">=</span> <span class="n">row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
<span class="c1"># Load the row into SRAM, using a mask since BLOCK_SIZE may be &gt; than n_cols</span>
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">input_ptrs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o">&lt;</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">&#39;inf&#39;</span><span class="p">))</span>
<span class="c1"># Substract maximum for numerical stability</span>
<span class="n">row_minus_max</span> <span class="o">=</span> <span class="n">row</span> <span class="o">-</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">row</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">row_minus_max</span><span class="p">)</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">softmax_output</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span>
<span class="c1"># Write back output to DRAM</span>
<span class="n">output_row_start_ptr</span> <span class="o">=</span> <span class="n">output_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">output_row_stride</span>
<span class="n">output_ptrs</span> <span class="o">=</span> <span class="n">output_row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptrs</span><span class="p">,</span> <span class="n">softmax_output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o">&lt;</span> <span class="n">n_cols</span><span class="p">)</span>
</pre></div>
</div>
<p>We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="c1"># The block size is the smallest power of two greater than the number of columns in `x`</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">next_power_of_2</span><span class="p">(</span><span class="n">n_cols</span><span class="p">)</span>
<span class="c1"># Another trick we can use is to ask the compiler to use more threads per row by</span>
<span class="c1"># increasing the number of warps (`num_warps`) over which each row is distributed.</span>
<span class="c1"># You will see in the next tutorial how to auto-tune this value in a more natural</span>
<span class="c1"># way so you don&#39;t have to come up with manual heuristics yourself.</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
<span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">&gt;=</span> <span class="mi">2048</span><span class="p">:</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
<span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">&gt;=</span> <span class="mi">4096</span><span class="p">:</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
<span class="c1"># Allocate output</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o</span>
<span class="c1"># f the input matrix</span>
<span class="n">softmax_kernel</span><span class="p">[(</span><span class="n">n_rows</span><span class="p">,)](</span>
<span class="n">y</span><span class="p">,</span>
<span class="n">x</span><span class="p">,</span>
<span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">n_cols</span><span class="p">,</span>
<span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">y</span>
</pre></div>
</div>
</div>
<div class="section" id="unit-test">
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline"></a></h2>
<p>We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
This will allow us to verify that our padding mechanism works.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y_triton</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">),</span> <span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">)</span>
</pre></div>
</div>
<p>As expected, the results are identical.</p>
</div>
<div class="section" id="benchmark">
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
<p>Here we will benchmark our operation as a function of the number of columns in the input matrix assuming 4096 rows.
We will then compare its performance against (1) <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> and (2) the <code class="code docutils literal notranslate"><span class="pre">naive_softmax</span></code> defined above.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;N&#39;</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
<span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span>
<span class="s1">&#39;triton&#39;</span><span class="p">,</span>
<span class="s1">&#39;torch-native&#39;</span><span class="p">,</span>
<span class="s1">&#39;torch-jit&#39;</span><span class="p">,</span>
<span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span>
<span class="s2">&quot;Triton&quot;</span><span class="p">,</span>
<span class="s2">&quot;Torch (native)&quot;</span><span class="p">,</span>
<span class="s2">&quot;Torch (jit)&quot;</span><span class="p">,</span>
<span class="p">],</span> <span class="c1"># label name for the lines</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;--&#39;</span><span class="p">)],</span> <span class="c1"># line styles</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s2">&quot;GB/s&quot;</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s2">&quot;softmax-performance&quot;</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;M&#39;</span><span class="p">:</span> <span class="mi">4096</span><span class="p">},</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch-native&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch-jit&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<img alt="02 fused softmax" class="sphx-glr-single-img" src="../../_images/sphx_glr_02-fused-softmax_001.png" />
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 512.000001 546.133347 190.511628
1 384.0 614.400016 585.142862 153.600004
2 512.0 655.360017 606.814814 154.566038
3 640.0 706.206879 640.000002 160.000000
4 768.0 722.823517 664.216187 162.754967
.. ... ... ... ...
93 12160.0 814.058574 406.179533 198.733401
94 12288.0 814.111783 415.661740 198.995960
95 12416.0 814.163950 412.149375 198.556711
96 12544.0 814.214963 412.971190 198.815254
97 12672.0 814.265046 412.097543 198.873965
[98 rows x 4 columns]
</pre></div>
</div>
<p>In the above plot, we can see that:</p>
<blockquote>
<div><ul class="simple">
<li><p>Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.</p></li>
<li><p>Triton is noticeably faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> in addition to being <strong>easier to read, understand and maintain</strong>.
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
</ul>
</div></blockquote>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 22.066 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/034d953b6214fedce6ea03803c712b89/02-fused-softmax.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">02-fused-softmax.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="03-matrix-multiplication.html" class="btn btn-neutral float-right" title="Matrix Multiplication" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="02-fused-softmax.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -0,0 +1,681 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Matrix Multiplication &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Low-Memory Dropout" href="04-low-memory-dropout.html" />
<link rel="prev" title="Fused Softmax" href="02-fused-softmax.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Matrix Multiplication</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#pointer-arithmetics">Pointer Arithmetics</a></li>
<li class="toctree-l4"><a class="reference internal" href="#l2-cache-optimizations">L2 Cache Optimizations</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#final-result">Final Result</a></li>
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#square-matrix-performance">Square Matrix Performance</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Matrix Multiplication</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="matrix-multiplication">
<span id="sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"></span><h1>Matrix Multiplication<a class="headerlink" href="#matrix-multiplication" title="Permalink to this headline"></a></h1>
<p>In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication
kernel that achieves performance on par with cuBLAS.
You will specifically learn about:</p>
<ul class="simple">
<li><p>Block-level matrix multiplications</p></li>
<li><p>Multi-dimensional pointer arithmetic</p></li>
<li><p>Program re-ordering for improved L2 cache hit rate</p></li>
<li><p>Automatic performance tuning</p></li>
</ul>
<div class="section" id="motivations">
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline"></a></h2>
<p>Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is generally done by
hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be easily customized
to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
In this tutorial, you will learn how to implement efficient matrix multiplications by
yourself with Triton, in a way that is easy to customize and extend.</p>
<p>Roughly speaking, the kernel that we will write will implement the following blocked
algorithm to multiply a (M, K) by a (K, N) matrix:</p>
<blockquote>
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
<span class="c1"># do in parallel</span>
<span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">):</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span>
<span class="n">acc</span> <span class="o">+=</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="n">C</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">acc</span><span class="p">;</span>
</pre></div>
</div>
</div></blockquote>
<p>where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.</p>
</div>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<p>The above algorithm is, actually, fairly straightforward to implement in Triton.
The main difficulty comes from the computation of the memory locations at which blocks
of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need
multi-dimensional pointer arithmetics.</p>
<div class="section" id="pointer-arithmetics">
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline"></a></h3>
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given b
y <code class="code docutils literal notranslate"><span class="pre">&amp;X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_xi</span> <span class="pre">+</span> <span class="pre">j*stride_xj</span></code>.
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_SIZE_M,</span> <span class="pre">k:k+BLOCK_SIZE_K]</span></code> and
<code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_SIZE_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_SIZE_N]</span></code> can be defined in pseudo-code as:</p>
<blockquote>
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
<span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
</pre></div>
</div>
</div></blockquote>
<p>Which means that pointers for blocks of A and B can be initialized (i.e., <code class="code docutils literal notranslate"><span class="pre">k=0</span></code>) in Triton as:</p>
<blockquote>
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">offs_am</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">offs_bn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span> <span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_ak</span><span class="p">)</span>
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span> <span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_bn</span><span class="p">)</span>
</pre></div>
</div>
</div></blockquote>
<p>And then updated in the inner loop as follows:</p>
<blockquote>
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">;</span>
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span><span class="p">;</span>
</pre></div>
</div>
</div></blockquote>
</div>
<div class="section" id="l2-cache-optimizations">
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline"></a></h3>
<p>As mentioned above, each program instance computes a <code class="code docutils literal notranslate"><span class="pre">[BLOCK_SIZE_M,</span> <span class="pre">BLOCK_SIZE_N]</span></code>
block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
It is important to remember that the order in which these blocks are computed does
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
a simple row-major ordering</p>
<blockquote>
<div><div class="highlight-Python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_M</span><span class="p">;</span>
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_N</span><span class="p">;</span>
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">/</span> <span class="n">grid_n</span><span class="p">;</span>
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">%</span> <span class="n">grid_n</span><span class="p">;</span>
</pre></div>
</div>
</div></blockquote>
<p>is just not going to cut it.</p>
<p>One possible solution is to launch blocks in an order that promotes data reuse.
This can be done by super-grouping blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before
switching to the next column:</p>
<blockquote>
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># program ID</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># number of program ids along the M axis</span>
<span class="n">num_pid_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="c1"># number of programs ids along the N axis</span>
<span class="n">num_pid_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="c1"># number of programs in group</span>
<span class="n">num_pid_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">num_pid_n</span>
<span class="c1"># id of the group this program is in</span>
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">num_pid_in_group</span>
<span class="c1"># row-id of the first program in the group</span>
<span class="n">first_pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
<span class="c1"># if `num_pid_m` isn&#39;t divisible by `GROUP_SIZE_M`, the last group is smaller</span>
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_pid_m</span> <span class="o">-</span> <span class="n">first_pid_m</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
<span class="c1"># *within groups*, programs are ordered in a column-major order</span>
<span class="c1"># row-id of the program in the *launch grid*</span>
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">first_pid_m</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
<span class="c1"># col-id of the program in the *launch grid*</span>
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">num_pid_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
</pre></div>
</div>
</div></blockquote>
<p>For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
we can see that if we compute the output in row-major ordering, we need to load 90
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
ordering, we only need to load 54 blocks.</p>
<blockquote>
<div><img alt="../../_images/grouped_vs_row_major_ordering.png" src="../../_images/grouped_vs_row_major_ordering.png" />
</div></blockquote>
<p>In practice, this can improve the performance of our matrix multiplication kernel by
more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
</div>
</div>
<div class="section" id="final-result">
<h2>Final Result<a class="headerlink" href="#final-result" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="c1"># %</span>
<span class="c1"># :code:`triton.jit`&#39;ed functions can be auto-tuned by using the `triton.autotune`</span>
<span class="c1"># decorator, which consumes:</span>
<span class="c1"># - A list of :code:`triton.Config` objects that define different configurations of</span>
<span class="c1"># meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try</span>
<span class="c1"># - An autotuning *key* whose change in values will trigger evaluation of all the</span>
<span class="c1"># provided configs</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span>
<span class="n">configs</span><span class="o">=</span><span class="p">[</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="p">],</span>
<span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="s1">&#39;K&#39;</span><span class="p">],</span>
<span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">matmul_kernel</span><span class="p">(</span>
<span class="c1"># Pointers to matrices</span>
<span class="n">a_ptr</span><span class="p">,</span> <span class="n">b_ptr</span><span class="p">,</span> <span class="n">c_ptr</span><span class="p">,</span>
<span class="c1"># Matrix dimensions</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span>
<span class="c1"># The stride variables represent how much to increase the ptr by when moving by 1</span>
<span class="c1"># element in a particular dimension. E.g. stride_am is how much to increase a_ptr</span>
<span class="c1"># by to get the element one row down (A has M rows)</span>
<span class="n">stride_am</span><span class="p">,</span> <span class="n">stride_ak</span><span class="p">,</span>
<span class="n">stride_bk</span><span class="p">,</span> <span class="n">stride_bn</span><span class="p">,</span>
<span class="n">stride_cm</span><span class="p">,</span> <span class="n">stride_cn</span><span class="p">,</span>
<span class="c1"># Meta-parameters</span>
<span class="n">BLOCK_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">ACTIVATION</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Kernel for computing the matmul C = A x B.</span>
<span class="sd"> A has shape (M, K), B has shape (K, N) and C has shape (M, N)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># -----------------------------------------------------------</span>
<span class="c1"># Map program ids `pid` to the block of C it should compute.</span>
<span class="c1"># This is done in a grouped ordering to promote L2 data reuse</span>
<span class="c1"># See above `L2 Cache Optimizations` section for details</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">num_pid_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">num_pid_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">num_pid_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">num_pid_n</span>
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">num_pid_in_group</span>
<span class="n">first_pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_pid_m</span> <span class="o">-</span> <span class="n">first_pid_m</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">first_pid_m</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">num_pid_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
<span class="c1"># ----------------------------------------------------------</span>
<span class="c1"># Create pointers for the first blocks of A and B.</span>
<span class="c1"># We will advance this pointer as we move in the K direction</span>
<span class="c1"># and accumulate</span>
<span class="c1"># a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers</span>
<span class="c1"># b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers</span>
<span class="c1"># see above `Pointer Arithmetics` section for details</span>
<span class="n">offs_am</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">offs_bn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">)</span>
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_bn</span><span class="p">)</span>
<span class="c1"># -----------------------------------------------------------</span>
<span class="c1"># Iterate to compute a block of the C matrix</span>
<span class="c1"># We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block</span>
<span class="c1"># of fp32 values for higher accuracy.</span>
<span class="c1"># `accumulator` will be converted back to fp16 after the loop</span>
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
<span class="c1"># Note that for simplicity, we don&#39;t apply a mask here.</span>
<span class="c1"># This means that if K is not a multiple of BLOCK_SIZE_K,</span>
<span class="c1"># this will access out-of-bounds memory and produce an</span>
<span class="c1"># error or (worse!) incorrect results.</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">a_ptrs</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">b_ptrs</span><span class="p">)</span>
<span class="c1"># We accumulate along the K dimension</span>
<span class="n">accumulator</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="c1"># Advance the ptrs to the next K block</span>
<span class="n">a_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
<span class="n">b_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
<span class="c1"># you can fuse arbitrary activation functions here</span>
<span class="c1"># while the accumulator is still in FP32!</span>
<span class="k">if</span> <span class="n">ACTIVATION</span><span class="p">:</span>
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">ACTIVATION</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
<span class="n">c</span> <span class="o">=</span> <span class="n">accumulator</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="c1"># -----------------------------------------------------------</span>
<span class="c1"># Write back the block of the output matrix C</span>
<span class="n">offs_cm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">offs_cn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">c_ptrs</span> <span class="o">=</span> <span class="n">c_ptr</span> <span class="o">+</span> <span class="n">stride_cm</span> <span class="o">*</span> <span class="n">offs_cm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">stride_cn</span> <span class="o">*</span> <span class="n">offs_cn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">c_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">offs_cm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">offs_cn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">c_ptrs</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">c_mask</span><span class="p">)</span>
<span class="c1"># we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mf">0.01</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>We can now create a convenience wrapper function that only takes two input tensors
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="c1"># checks constraints</span>
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;incompatible dimensions&quot;</span>
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">&quot;matrix A must be contiguous&quot;</span>
<span class="k">assert</span> <span class="n">b</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">&quot;matrix B must be contiguous&quot;</span>
<span class="n">M</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span>
<span class="n">K</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">K</span> <span class="o">%</span> <span class="mi">32</span> <span class="o">==</span> <span class="mi">0</span>
<span class="p">),</span> <span class="s2">&quot;We don&#39;t check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K&quot;</span>
<span class="c1"># allocates output</span>
<span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># 1D launch kernel where each block gets its own program.</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">META</span><span class="p">:</span> <span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">]),</span>
<span class="p">)</span>
<span class="n">matmul_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
<span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span>
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">ACTIVATION</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">c</span>
</pre></div>
</div>
</div>
<div class="section" id="unit-test">
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline"></a></h2>
<p>We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">triton_output</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">torch_output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;triton_output=</span><span class="si">{</span><span class="n">triton_output</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;torch_output=</span><span class="si">{</span><span class="n">torch_output</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">triton_output</span><span class="p">,</span> <span class="n">torch_output</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;✅ Triton and Torch match&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;❌ Triton and Torch differ&quot;</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>triton_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
...,
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
device=&#39;cuda:0&#39;, dtype=torch.float16)
torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
...,
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
device=&#39;cuda:0&#39;, dtype=torch.float16)
✅ Triton and Torch match
</pre></div>
</div>
</div>
<div class="section" id="benchmark">
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
<div class="section" id="square-matrix-performance">
<h3>Square Matrix Performance<a class="headerlink" href="#square-matrix-performance" title="Permalink to this headline"></a></h3>
<p>We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="s1">&#39;K&#39;</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span>
<span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">33</span><span class="p">)</span>
<span class="p">],</span> <span class="c1"># different possible values for `x_name`</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
<span class="c1"># possible values for `line_arg``</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;cublas&#39;</span><span class="p">,</span> <span class="s1">&#39;cublas + relu&#39;</span><span class="p">,</span> <span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="s1">&#39;triton + relu&#39;</span><span class="p">],</span>
<span class="c1"># label name for the lines</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;cuBLAS&quot;</span><span class="p">,</span> <span class="s2">&quot;cuBLAS (+ torch.nn.LeakyReLU)&quot;</span><span class="p">,</span> <span class="s2">&quot;Triton&quot;</span><span class="p">,</span> <span class="s2">&quot;Triton (+ LeakyReLU)&quot;</span><span class="p">],</span>
<span class="c1"># line styles</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;--&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;--&#39;</span><span class="p">)],</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s2">&quot;TFLOPS&quot;</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s2">&quot;matmul-performance&quot;</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
<span class="n">args</span><span class="o">=</span><span class="p">{},</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;cublas&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;cublas + relu&#39;</span><span class="p">:</span>
<span class="n">torch_relu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span>
<span class="k">lambda</span><span class="p">:</span> <span class="n">torch_relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton + relu&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span>
<span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">leaky_relu</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<img alt="03 matrix multiplication" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
1 384.0 7.372800 ... 8.507077 7.899428
2 512.0 14.563555 ... 16.384000 16.384000
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 34.028308 34.028308
5 896.0 39.025776 ... 39.025776 39.025776
6 1024.0 49.932191 ... 52.428801 52.428801
7 1152.0 45.242181 ... 46.656000 46.656000
8 1280.0 51.200001 ... 56.888887 56.888887
9 1408.0 64.138541 ... 67.305878 66.485074
10 1536.0 80.430545 ... 79.526831 78.643199
11 1664.0 62.929456 ... 62.492442 62.061463
12 1792.0 72.512412 ... 71.588687 72.047592
13 1920.0 69.467336 ... 70.172588 70.172588
14 2048.0 73.262953 ... 76.608294 76.608294
15 2176.0 83.500614 ... 85.998493 85.998493
16 2304.0 68.446623 ... 76.563695 76.319081
17 2432.0 71.305746 ... 74.521127 85.134737
18 2560.0 78.019048 ... 80.908642 81.108913
19 2688.0 82.823267 ... 89.888756 89.464755
20 2816.0 83.233226 ... 82.916747 82.759409
21 2944.0 81.967162 ... 82.509987 82.237674
22 3072.0 82.181572 ... 88.197981 87.651868
23 3200.0 84.099871 ... 93.567248 94.674553
24 3328.0 79.901550 ... 84.945483 84.496824
25 3456.0 81.600781 ... 91.511426 84.775569
26 3584.0 83.798127 ... 86.542919 95.148565
27 3712.0 81.615477 ... 88.640059 82.017526
28 3840.0 84.744825 ... 91.777595 85.399230
29 3968.0 92.024087 ... 84.183469 89.558851
30 4096.0 86.424811 ... 87.438257 91.992956
[31 rows x 5 columns]
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 6 minutes 4.559 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/b51b68bc1c6b1a5e509f67800b6235af/03-matrix-multiplication.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">03-matrix-multiplication.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="04-low-memory-dropout.html" class="btn btn-neutral float-right" title="Low-Memory Dropout" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="02-fused-softmax.html" class="btn btn-neutral float-left" title="Fused Softmax" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="03-matrix-multiplication.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -0,0 +1,453 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Low-Memory Dropout &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Layer Normalization" href="05-layer-norm.html" />
<link rel="prev" title="Matrix Multiplication" href="03-matrix-multiplication.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Low-Memory Dropout</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#baseline">Baseline</a></li>
<li class="toctree-l3"><a class="reference internal" href="#seeded-dropout">Seeded dropout</a></li>
<li class="toctree-l3"><a class="reference internal" href="#exercises">Exercises</a></li>
<li class="toctree-l3"><a class="reference internal" href="#references">References</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Low-Memory Dropout</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="low-memory-dropout">
<span id="sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"></span><h1>Low-Memory Dropout<a class="headerlink" href="#low-memory-dropout" title="Permalink to this headline"></a></h1>
<p>In this tutorial, you will write a memory-efficient implementation of dropout whose state
will be composed of a single int32 seed. This differs from more traditional implementations of dropout,
whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about:</p>
<ul class="simple">
<li><p>The limitations of naive implementations of Dropout with PyTorch</p></li>
<li><p>Parallel pseudo-random number generation in Triton</p></li>
</ul>
<div class="section" id="baseline">
<h2>Baseline<a class="headerlink" href="#baseline" title="Permalink to this headline"></a></h2>
<p>The <em>dropout</em> operator was first introduced in <a class="reference internal" href="#srivastava2014" id="id1"><span>[SRIVASTAVA2014]</span></a> as a way to improve the performance
of deep neural networks in low-data regime (i.e. regularization).</p>
<p>It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
output has a probability <span class="math notranslate nohighlight">\(p\)</span> of being changed to zero and otherwise it is copied from the input.
This forces the network to perform well even when only <span class="math notranslate nohighlight">\(1 - p\)</span> scalars from the input are available.</p>
<p>At evaluation time we want to use the full power of the network so we set <span class="math notranslate nohighlight">\(p=0\)</span>. Naively this would
increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease
in the output softmax temperature). To prevent this we multiply the output by <span class="math notranslate nohighlight">\(\frac{1}{1 - p}\)</span>, which
keeps the norm consistent regardless of the dropout probability.</p>
<p>Lets first take a look at the baseline implementation.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tabulate</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_dropout</span><span class="p">(</span>
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># pointer to the input</span>
<span class="n">x_keep_ptr</span><span class="p">,</span> <span class="c1"># pointer to a mask of 0s and 1s</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># pointer to the output</span>
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># number of elements in the `x` tensor</span>
<span class="n">p</span><span class="p">,</span> <span class="c1"># probability that an element of `x` is changed to zero</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
<span class="c1"># Load data</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">x_keep</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_keep_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># The line below is the crucial part, described in the paragraph above!</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="c1"># Write-back output</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]),)</span>
<span class="n">_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="c1"># Input tensor</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1"># Dropout mask</span>
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="n">x_keep</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span> <span class="o">&gt;</span> <span class="n">p</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1">#</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="o">=</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">p</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
<span class="p">[</span><span class="s2">&quot;input&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;keep mask&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">x_keep</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
<span class="p">]))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344
keep mask 1 1 0 1 0 1 1 0 0 0
output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0
--------- ------- --------- -------- -------- -------- -------- -------- -------- --------- ---------
</pre></div>
</div>
</div>
<div class="section" id="seeded-dropout">
<h2>Seeded dropout<a class="headerlink" href="#seeded-dropout" title="Permalink to this headline"></a></h2>
<p>Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly
we need to store the dropout mask for backpropagation. Secondly, dropout state management can get
very tricky when using recompute/checkpointing (e.g. see all the notes about <cite>preserve_rng_state</cite> in
<a class="reference external" href="https://pytorch.org/docs/1.9.0/checkpoint.html">https://pytorch.org/docs/1.9.0/checkpoint.html</a>). In this tutorial well describe an alternative implementation
that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management
of persisting randomness across multiple invocations of the kernel.</p>
<p>Pseudorandom number generation in Triton is simple! In this tutorial we will use the
<code class="code docutils literal notranslate"><span class="pre">triton.language.rand</span></code> function which generates a block of uniformly distributed <code class="code docutils literal notranslate"><span class="pre">float32</span></code>
values in [0, 1), given a seed and a block of <code class="code docutils literal notranslate"><span class="pre">int32</span></code> offsets. But if you need it, Triton also provides
other <a class="reference internal" href="../../python-api/triton.language.html#random-number-generation"><span class="std std-ref">random number generation strategies</span></a>.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Tritons implementation of PRNG is based on the Philox algorithm (described on <a class="reference internal" href="#salmon2011" id="id2"><span>[SALMON2011]</span></a>).</p>
</div>
<p>Lets put it all together.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_seeded_dropout</span><span class="p">(</span>
<span class="n">x_ptr</span><span class="p">,</span>
<span class="n">output_ptr</span><span class="p">,</span>
<span class="n">n_elements</span><span class="p">,</span>
<span class="n">p</span><span class="p">,</span>
<span class="n">seed</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="c1"># compute memory offsets of elements handled by this instance</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="c1"># load data from x</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># randomly prune it</span>
<span class="n">random</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)</span>
<span class="n">x_keep</span> <span class="o">=</span> <span class="n">random</span> <span class="o">&gt;</span> <span class="n">p</span>
<span class="c1"># write-back</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x_keep</span><span class="p">,</span> <span class="n">x</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">()</span>
<span class="n">n_elements</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]),)</span>
<span class="n">_seeded_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1"># Compare this to the baseline - dropout mask is never instantiated!</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
<span class="n">output2</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">)</span>
<span class="n">output3</span> <span class="o">=</span> <span class="n">seeded_dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">tabulate</span><span class="o">.</span><span class="n">tabulate</span><span class="p">([</span>
<span class="p">[</span><span class="s2">&quot;input&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output (seed = 123)&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output (seed = 123)&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output2</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="p">[</span><span class="s2">&quot;output (seed = 512)&quot;</span><span class="p">]</span> <span class="o">+</span> <span class="n">output3</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
<span class="p">]))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
</pre></div>
</div>
<p>Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
If youd like explore further applications of pseudorandomness in GPU programming, we encourage you
to explore the <cite>triton/language/random</cite> folder!</p>
</div>
<div class="section" id="exercises">
<h2>Exercises<a class="headerlink" href="#exercises" title="Permalink to this headline"></a></h2>
<ol class="arabic simple">
<li><p>Extend the kernel to operate over a matrix and use a vector of seeds - one per row.</p></li>
<li><p>Add support for striding.</p></li>
<li><p>(challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.</p></li>
</ol>
</div>
<div class="section" id="references">
<h2>References<a class="headerlink" href="#references" title="Permalink to this headline"></a></h2>
<dl class="citation">
<dt class="label" id="salmon2011"><span class="brackets"><a class="fn-backref" href="#id2">SALMON2011</a></span></dt>
<dd><p>John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011</p>
</dd>
<dt class="label" id="srivastava2014"><span class="brackets"><a class="fn-backref" href="#id1">SRIVASTAVA2014</a></span></dt>
<dd><p>Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014</p>
</dd>
</dl>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.480 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">04-low-memory-dropout.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/bc847dec325798bdc436c4ef5ac8b78a/04-low-memory-dropout.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">04-low-memory-dropout.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="05-layer-norm.html" class="btn btn-neutral float-right" title="Layer Normalization" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="03-matrix-multiplication.html" class="btn btn-neutral float-left" title="Matrix Multiplication" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="04-low-memory-dropout.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -0,0 +1,567 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Layer Normalization &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="triton" href="../../python-api/triton.html" />
<link rel="prev" title="Low-Memory Dropout" href="04-low-memory-dropout.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Layer Normalization</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/05-layer-norm.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="layer-normalization">
<span id="sphx-glr-getting-started-tutorials-05-layer-norm-py"></span><h1>Layer Normalization<a class="headerlink" href="#layer-normalization" title="Permalink to this headline"></a></h1>
<img alt="05 layer norm" class="sphx-glr-single-img" src="../../_images/sphx_glr_05-layer-norm_001.png" />
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>layer-norm-backward:
N Triton Torch Apex
0 1024.0 307.200008 98.303995 307.200008
1 1536.0 347.773587 134.050910 341.333333
2 2048.0 420.102553 161.684218 325.509933
3 2560.0 458.507457 181.238943 326.808501
4 3072.0 511.999982 191.999993 317.793096
5 3584.0 547.872604 208.271186 310.527060
6 4096.0 564.965515 220.412561 294.323343
7 4608.0 504.986315 232.825259 290.267724
8 5120.0 527.381977 242.845844 287.775181
9 5632.0 542.843364 243.545956 288.820505
10 6144.0 546.133354 248.661056 285.767458
11 6656.0 532.479975 256.000009 285.767438
12 7168.0 507.469040 260.260201 286.242939
13 7680.0 479.999983 262.564106 279.272719
14 8192.0 462.607053 267.130429 284.526763
15 8704.0 417.791980 267.815384 284.599455
16 9216.0 431.157889 272.394084 288.751954
17 9728.0 438.857162 280.278512 290.027323
18 10240.0 449.287041 286.433562 290.153487
19 10752.0 427.231788 246.935876 290.594591
20 11264.0 426.397479 245.760001 286.676558
21 11776.0 423.089806 249.888595 288.981596
22 12288.0 419.504980 254.673582 294.323369
23 12800.0 414.016170 253.674644 288.180121
24 13312.0 411.181478 252.959629 289.916513
25 13824.0 404.112047 257.190689 292.056329
26 14336.0 393.215988 254.485198 286.719986
27 14848.0 385.245405 257.665934 289.012175
28 15360.0 373.495460 257.970599 286.211174
29 15872.0 371.637071 261.626369 289.899545
</pre></div>
</div>
<div class="line-block">
<div class="line"><br /></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="k">try</span><span class="p">:</span>
<span class="c1"># This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it</span>
<span class="c1"># should not be added to extras_require in setup.py.</span>
<span class="kn">import</span> <span class="nn">apex</span>
<span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">except</span> <span class="ne">ModuleNotFoundError</span><span class="p">:</span>
<span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># Forward Pass</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_fwd_fused</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="c1"># position of elements processed by this program</span>
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span>
<span class="c1"># offset data pointers to start at the row of interest</span>
<span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">Y</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="c1"># load data and cast to float32</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># compute mean</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="c1"># compute std</span>
<span class="n">xmean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">var</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xmean</span> <span class="o">*</span> <span class="n">xmean</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">tl</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="n">xmean</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="c1"># write-back mean/rstd</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">mean</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
<span class="c1"># multiply by weight and add bias</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">xhat</span> <span class="o">*</span> <span class="n">w</span> <span class="o">+</span> <span class="n">b</span>
<span class="c1"># write-back</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># Backward pass (DX + partial DW + partial DB)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dx_fused</span><span class="p">(</span><span class="n">DX</span><span class="p">,</span> <span class="n">DY</span><span class="p">,</span> <span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">Lock</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="c1"># position of elements processed by this program</span>
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span>
<span class="c1"># offset data pointers to start at the row of interest</span>
<span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">DY</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">DX</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="c1"># offset locks and weight/bias gradient pointer</span>
<span class="c1"># each kernel instance accumulates partial sums for</span>
<span class="c1"># DW and DB into one of GROUP_SIZE_M independent buffers</span>
<span class="c1"># these buffers stay in the L2, which allow this kernel</span>
<span class="c1"># to be fast</span>
<span class="n">lock_id</span> <span class="o">=</span> <span class="n">row</span> <span class="o">%</span> <span class="n">GROUP_SIZE_M</span>
<span class="n">Lock</span> <span class="o">+=</span> <span class="n">lock_id</span>
<span class="n">Count</span> <span class="o">=</span> <span class="n">Lock</span> <span class="o">+</span> <span class="n">GROUP_SIZE_M</span>
<span class="n">DW</span> <span class="o">=</span> <span class="n">DW</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
<span class="n">DB</span> <span class="o">=</span> <span class="n">DB</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
<span class="c1"># load data to SRAM</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DY</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
<span class="c1"># compute dx</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="n">wdy</span> <span class="o">=</span> <span class="n">w</span> <span class="o">*</span> <span class="n">dy</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">xhat</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">wdy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">wdy</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">mean1</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">mean2</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">dx</span> <span class="o">=</span> <span class="p">(</span><span class="n">wdy</span> <span class="o">-</span> <span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">mean1</span> <span class="o">+</span> <span class="n">mean2</span><span class="p">))</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="c1"># write-back dx</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DX</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">dx</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># accumulate partial sums for dw/db</span>
<span class="n">partial_dw</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span> <span class="o">*</span> <span class="n">xhat</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">partial_db</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">while</span> <span class="n">tl</span><span class="o">.</span><span class="n">atomic_cas</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">pass</span>
<span class="n">count</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Count</span><span class="p">)</span>
<span class="c1"># first store doesn&#39;t accumulate</span>
<span class="k">if</span> <span class="n">count</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Count</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">partial_dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">partial_db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">partial_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">partial_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># release lock</span>
<span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># Backward pass (total DW + total DB)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dwdb</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">FINAL_DW</span><span class="p">,</span> <span class="n">FINAL_DB</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
<span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">sum_dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dw</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">sum_db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DW</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DB</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">LayerNorm</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
<span class="c1"># allocate output</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># reshape input data into 2D tensor</span>
<span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="c1"># Less than 64KB per feature: enqueue fused kernel</span>
<span class="n">MAX_FUSED_SIZE</span> <span class="o">=</span> <span class="mi">65536</span> <span class="o">//</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">MAX_FUSED_SIZE</span><span class="p">,</span> <span class="n">triton</span><span class="o">.</span><span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">))</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&gt;</span> <span class="n">BLOCK_SIZE</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;This layer norm doesn&#39;t support feature dim &gt;= 64KB.&quot;</span><span class="p">)</span>
<span class="c1"># heuristics for number of warps</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="n">BLOCK_SIZE</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">8</span><span class="p">)</span>
<span class="c1"># enqueue kernel</span>
<span class="n">_layer_norm_fwd_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">x_arg</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">,</span>
<span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">BLOCK_SIZE</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span> <span class="o">=</span> <span class="n">num_warps</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="k">return</span> <span class="n">y</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dy</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
<span class="c1"># heuristics for amount of parallel reduction stream for DG/DB</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">64</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">8192</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">96</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">128</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">1024</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">256</span>
<span class="c1"># allocate output</span>
<span class="n">locks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">_dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">_db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">dy</span><span class="p">)</span>
<span class="c1"># enqueue kernel using forward pass heuristics</span>
<span class="c1"># also compute partial sums for DW and DB</span>
<span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
<span class="n">_layer_norm_bwd_dx_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">dx</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">locks</span><span class="p">,</span>
<span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">ctx</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="o">=</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span>
<span class="n">num_warps</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span><span class="p">)</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">])]</span>
<span class="c1"># accumulate partial sums in separate kernel</span>
<span class="n">_layer_norm_bwd_dwdb</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dx</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="kc">None</span>
<span class="n">layer_norm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="o">.</span><span class="n">apply</span>
<span class="k">def</span> <span class="nf">test_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
<span class="c1"># create data</span>
<span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># forward pass</span>
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># backward pass (triton)</span>
<span class="n">y_tri</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dx_tri</span><span class="p">,</span> <span class="n">dw_tri</span><span class="p">,</span> <span class="n">db_tri</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
<span class="n">x</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="c1"># backward pass (torch)</span>
<span class="n">y_ref</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dx_ref</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">db_ref</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
<span class="c1"># compare</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dx_tri</span><span class="p">,</span> <span class="n">dx_ref</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">db_tri</span><span class="p">,</span> <span class="n">db_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dw_tri</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;N&#39;</span><span class="p">],</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">512</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">32</span><span class="p">)],</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="s1">&#39;torch&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;apex&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;Triton&#39;</span><span class="p">,</span> <span class="s1">&#39;Torch&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;Apex&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;orange&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">)],</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s1">&#39;GB/s&#39;</span><span class="p">,</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s1">&#39;layer-norm-backward&#39;</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;M&#39;</span><span class="p">:</span> <span class="mi">4096</span><span class="p">,</span> <span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s1">&#39;mode&#39;</span><span class="p">:</span> <span class="s1">&#39;backward&#39;</span><span class="p">}</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">bench_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">provider</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;backward&#39;</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
<span class="c1"># create data</span>
<span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># utility functions</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch&#39;</span><span class="p">:</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;apex&#39;</span><span class="p">:</span>
<span class="n">apex_layer_norm</span> <span class="o">=</span> <span class="n">apex</span><span class="o">.</span><span class="n">normalization</span><span class="o">.</span><span class="n">FusedLayerNorm</span><span class="p">(</span><span class="n">w_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">apex_layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># forward pass</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;forward&#39;</span><span class="p">:</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">y_fwd</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="c1"># backward pass</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;backward&#39;</span><span class="p">:</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y_fwd</span><span class="p">()</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">y</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="n">grad_to_none</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">],</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
<span class="n">bench_layer_norm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">save_path</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 12.003 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-05-layer-norm-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">05-layer-norm.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">05-layer-norm.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="04-low-memory-dropout.html" class="btn btn-neutral float-left" title="Low-Memory Dropout" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="05-layer-norm.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -0,0 +1,298 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Tutorials &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Vector Addition" href="01-vector-add.html" />
<link rel="prev" title="Installation" href="../installation.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Tutorials</a><ul>
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li>Tutorials</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/index.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="tutorials">
<span id="sphx-glr-getting-started-tutorials"></span><h1>Tutorials<a class="headerlink" href="#tutorials" title="Permalink to this headline"></a></h1>
<p>Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.</p>
<p>To install the dependencies for the tutorials:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span> triton
pip install -e <span class="s1">&#39;./python[tutorials]&#39;</span>
</pre></div>
</div>
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model of Triton - The triton.jit decorator, which is used to define Tri..."><div class="figure align-default" id="id1">
<img alt="Vector Addition" src="../../_images/sphx_glr_01-vector-add_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a></span><a class="headerlink" href="#id1" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - Reduction operators in Triton..."><div class="figure align-default" id="id2">
<img alt="Fused Softmax" src="../../_images/sphx_glr_02-fused-softmax_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a></span><a class="headerlink" href="#id2" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-thumbcontainer" tooltip="- Block-level matrix multiplications - Multi-dimensional pointer arithmetic - Program re-orderi..."><div class="figure align-default" id="id3">
<img alt="Matrix Multiplication" src="../../_images/sphx_glr_03-matrix-multiplication_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a></span><a class="headerlink" href="#id3" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, you will write a memory-efficient implementation of dropout whose state will ..."><div class="figure align-default" id="id4">
<img alt="Low-Memory Dropout" src="../../_images/sphx_glr_04-low-memory-dropout_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="04-low-memory-dropout.html#sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">Low-Memory Dropout</span></a></span><a class="headerlink" href="#id4" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-thumbcontainer" tooltip="Layer Normalization"><div class="figure align-default" id="id5">
<img alt="Layer Normalization" src="../../_images/sphx_glr_05-layer-norm_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">Layer Normalization</span></a></span><a class="headerlink" href="#id5" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-clear"></div><div class="sphx-glr-footer class sphx-glr-footer-gallery docutils container">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tutorials_python.zip</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Jupyter</span> <span class="pre">notebooks:</span> <span class="pre">tutorials_jupyter.zip</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="01-vector-add.html" class="btn btn-neutral float-right" title="Vector Addition" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="../installation.html" class="btn btn-neutral float-left" title="Installation" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: master
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Tags</dt>
<dd><a href="../../../v1.1.2/index.html">v1.1.2</a></dd>
</dl>
<dl>
<dt>Branches</dt>
<dd><a href="index.html">master</a></dd>
</dl>
</div>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

Some files were not shown because too many files have changed in this diff Show More