[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2022-02-07 03:25:11 +00:00
parent 9e0b45ad2b
commit 0f03cfcfd3
48 changed files with 1601 additions and 324 deletions

View File

@@ -33,7 +33,7 @@
},
"outputs": [],
"source": [
"import torch\n\n\n@torch.jit.script\ndef naive_softmax(x):\n \"\"\"Compute row-wise softmax of X using native pytorch\n\n We subtract the maximum element in order to avoid overflows. Softmax is invariant to\n this shift.\n \"\"\"\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read MN + M elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read MN + M elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements\n return ret"
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@torch.jit.script\ndef naive_softmax(x):\n \"\"\"Compute row-wise softmax of X using native pytorch\n\n We subtract the maximum element in order to avoid overflows. Softmax is invariant to\n this shift.\n \"\"\"\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read MN + M elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(z)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read MN + M elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements\n return ret"
]
},
{
@@ -58,7 +58,7 @@
},
"outputs": [],
"source": [
"import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef softmax_kernel(\n output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta\n):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n BLOCK_SIZE = meta['BLOCK_SIZE']\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Substract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)"
"@triton.jit\ndef softmax_kernel(\n output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,\n BLOCK_SIZE: tl.constexpr\n):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Substract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)"
]
},
{
@@ -101,7 +101,7 @@
},
"outputs": [],
"source": [
"torch.manual_seed(0)\nx = torch.randn(1823, 781, device='cuda')\ny_triton = softmax(x)\ny_torch = torch.softmax(x, axis=1)\nprint(torch.allclose(y_triton, y_torch))"
"torch.manual_seed(0)\nx = torch.randn(1823, 781, device='cuda')\ny_triton = softmax(x)\ny_torch = torch.softmax(x, axis=1)\nassert torch.allclose(y_triton, y_torch), (y_triton, y_torch)"
]
},
{
@@ -133,7 +133,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In the above plot, we can see that:\n\n - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.\n - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. \n Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.\n\n"
"In the above plot, we can see that:\n\n - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.\n - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.\n Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.\n\n"
]
}
],

View File

@@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# --------------------------
import torch
import triton
import triton.language as tl
@@ -23,9 +24,9 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
@@ -37,7 +38,7 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
@@ -64,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output

View File

@@ -0,0 +1,261 @@
"""
Layer Normalization
====================
"""
import torch
import triton
import triton.language as tl
try:
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
# should not be added to extras_require in setup.py.
import apex
HAS_APEX = True
except ModuleNotFoundError:
HAS_APEX = False
# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
BLOCK_SIZE: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
Y += row * stride
# load data and cast to float32
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
# compute mean
mean = tl.sum(x, axis=0) / N
# compute std
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean * rstd
# write-back mean/rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
# multiply by weight and add bias
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
y = xhat * w + b
# write-back
tl.store(Y + cols, y, mask=mask)
# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
DY += row * stride
DX += row * stride
# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW and DB into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# first store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# release lock
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
x, w, b, m, v = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
return dx, None, dw, db, None
layer_norm = LayerNorm.apply
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[512 * i for i in range(2, 32)],
line_arg='provider',
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -22,7 +22,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Baseline\nThe *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance \nof deep neural networks in low-data regime (i.e. regularization).\n\nIt takes a vector as input and produces a vector of the same shape as output. Each scalar in the\noutput has a probability $p$ of being changed to zero and otherwise it is copied from the input.\nThis forces the network to perform well even when only $1 - p$ scalars from the input are available.\n\nAt evaluation time we want to use the full power of the network so we set $p=0$. Naively this would\nincrease the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease\nin the output softmax temperature). To prevent this we multiply the output by $\\frac{1}{1 - p}$, which\nkeeps the norm consistent regardless of the dropout probability.\n\nLet's first take a look at the baseline implementation.\n\n"
"## Baseline\nThe *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance\nof deep neural networks in low-data regime (i.e. regularization).\n\nIt takes a vector as input and produces a vector of the same shape as output. Each scalar in the\noutput has a probability $p$ of being changed to zero and otherwise it is copied from the input.\nThis forces the network to perform well even when only $1 - p$ scalars from the input are available.\n\nAt evaluation time we want to use the full power of the network so we set $p=0$. Naively this would\nincrease the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease\nin the output softmax temperature). To prevent this we multiply the output by $\\frac{1}{1 - p}$, which\nkeeps the norm consistent regardless of the dropout probability.\n\nLet's first take a look at the baseline implementation.\n\n"
]
},
{
@@ -33,14 +33,14 @@
},
"outputs": [],
"source": [
"import tabulate\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n **meta,\n):\n BLOCK_SIZE = meta['BLOCK_SIZE']\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n# Input tensor\nx = torch.randn(size=(10,)).cuda()\n# Dropout mask\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\n#\noutput = dropout(x, x_keep=x_keep, p=p)\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"keep mask\"] + x_keep.tolist(),\n [\"output\"] + output.tolist()\n]))"
"import tabulate\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n# Input tensor\nx = torch.randn(size=(10,)).cuda()\n# Dropout mask\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\n#\noutput = dropout(x, x_keep=x_keep, p=p)\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"keep mask\"] + x_keep.tolist(),\n [\"output\"] + output.tolist()\n]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Seeded dropout\nAbove implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly\nwe need to store the dropout mask for backpropagation. Secondly, dropout state management can get\nvery tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in\nhttps://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation\nthat (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management\nof persisting randomness across multiple invocations of the kernel.\n\nPseudorandom number generation in Triton is simple! In this tutorial we will use the\n:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` \nvalues in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides\nother `random number generation strategies <Random Number Generation>`.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).</p></div>\n\nLet's put it all together.\n\n"
"## Seeded dropout\nAbove implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly\nwe need to store the dropout mask for backpropagation. Secondly, dropout state management can get\nvery tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in\nhttps://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation\nthat (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management\nof persisting randomness across multiple invocations of the kernel.\n\nPseudorandom number generation in Triton is simple! In this tutorial we will use the\n:code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`\nvalues in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides\nother `random number generation strategies <Random Number Generation>`.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_).</p></div>\n\nLet's put it all together.\n\n"
]
},
{
@@ -51,7 +51,7 @@
},
"outputs": [],
"source": [
"@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n **meta,\n):\n # compute memory offsets of elements handled by this instance\n BLOCK_SIZE = meta['BLOCK_SIZE']\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\nx = torch.randn(size=(10,)).cuda()\n# Compare this to the baseline - dropout mask is never instantiated!\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\n\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"output (seed = 123)\"] + output.tolist(),\n [\"output (seed = 123)\"] + output2.tolist(),\n [\"output (seed = 512)\"] + output3.tolist()\n]))"
"@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\nx = torch.randn(size=(10,)).cuda()\n# Compare this to the baseline - dropout mask is never instantiated!\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\n\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"output (seed = 123)\"] + output.tolist(),\n [\"output (seed = 123)\"] + output2.tolist(),\n [\"output (seed = 512)\"] + output3.tolist()\n]))"
]
},
{

View File

@@ -30,19 +30,20 @@ whose state is generally composed of a bit mask tensor of the same shape as the
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
**meta,
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
BLOCK_SIZE: tl.constexpr,
):
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -64,6 +65,7 @@ def dropout(x, x_keep, p):
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
@@ -97,6 +99,7 @@ print(tabulate.tabulate([
#
# Let's put it all together.
@triton.jit
def _seeded_dropout(
x_ptr,
@@ -104,10 +107,9 @@ def _seeded_dropout(
n_elements,
p,
seed,
**meta,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

View File

@@ -141,6 +141,7 @@ You will specifically learn about:
#
import torch
import triton
import triton.language as tl
@@ -152,23 +153,22 @@ import triton.language as tl
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
# %
# We can now define our kernel as normal, using all the techniques presented above
@triton.jit
def matmul_kernel(
# Pointers to matrices
@@ -182,17 +182,13 @@ def matmul_kernel(
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
**meta,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# extract meta-parameters
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
GROUP_SIZE_M = 8
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
@@ -217,8 +213,8 @@ def matmul_kernel(
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
@@ -239,9 +235,9 @@ def matmul_kernel(
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here
# while the accumulator is still in FP32 !
if meta['ACTIVATION']:
accumulator = meta['ACTIVATION'](accumulator)
# while the accumulator is still in FP32!
if ACTIVATION:
accumulator = ACTIVATION(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------

View File

@@ -18,6 +18,9 @@ You will learn about:
import torch
import triton
import triton.language as tl
@torch.jit.script
def naive_softmax(x):
@@ -59,17 +62,14 @@ def naive_softmax(x):
# power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
@@ -134,9 +134,9 @@ torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
print(torch.allclose(y_triton, y_torch))
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
#%%
# %%
# As expected, the results are identical.
# %%

View File

@@ -33,7 +33,7 @@
},
"outputs": [],
"source": [
"import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n **meta, # Optional meta-parameters for the kernel\n):\n BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process\n # There are multiple 'program's processing different data. We identify which program\n # we are here\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n # This program will process inputs that are offset from the initial data.\n # for instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extar elements in case the input is not a\n # multiple of the block size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM\n tl.store(output_ptr + offsets, output, mask=mask)"
"import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process\n # NOTE: `constexpr` so it can be used as a shape value\n):\n # There are multiple 'program's processing different data. We identify which program\n # we are here\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n # This program will process inputs that are offset from the initial data.\n # for instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extra elements in case the input is not a\n # multiple of the block size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM\n tl.store(output_ptr + offsets, output, mask=mask)"
]
},
{
@@ -51,7 +51,7 @@
},
"outputs": [],
"source": [
"def add(x: torch.Tensor, y: torch.Tensor):\n # We need to preallocate the output\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n # In this case, we use a 1D grid where the size is the number of blocks\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n # NOTE:\n # - each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return output"
"def add(x: torch.Tensor, y: torch.Tensor):\n # We need to preallocate the output\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n # In this case, we use a 1D grid where the size is the number of blocks\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n # NOTE:\n # - each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return output"
]
},
{

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

After

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -44,7 +44,7 @@ You can then test your installation by running the unit tests:
.. code-block:: bash
pip install -r requirements-test.txt
pip install -e '.[tests]'
pytest -vs test/unit/
and the benchmarks

View File

@@ -31,12 +31,13 @@ In this tutorial, you will write a simple vector addition using Triton and learn
Compute Kernel
--------------------------
.. GENERATED FROM PYTHON SOURCE LINES 14-49
.. GENERATED FROM PYTHON SOURCE LINES 14-50
.. code-block:: default
import torch
import triton
import triton.language as tl
@@ -47,9 +48,9 @@ Compute Kernel
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
@@ -61,7 +62,7 @@ Compute Kernel
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
@@ -77,12 +78,12 @@ Compute Kernel
.. GENERATED FROM PYTHON SOURCE LINES 50-52
.. GENERATED FROM PYTHON SOURCE LINES 51-53
Let's also declare a helper function to (1) allocate the `z` tensor
and (2) enqueue the above kernel with appropriate grid/block sizes.
.. GENERATED FROM PYTHON SOURCE LINES 52-73
.. GENERATED FROM PYTHON SOURCE LINES 53-74
.. code-block:: default
@@ -101,7 +102,7 @@ and (2) enqueue the above kernel with appropriate grid/block sizes.
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
@@ -114,11 +115,11 @@ and (2) enqueue the above kernel with appropriate grid/block sizes.
.. GENERATED FROM PYTHON SOURCE LINES 74-75
.. GENERATED FROM PYTHON SOURCE LINES 75-76
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
.. GENERATED FROM PYTHON SOURCE LINES 75-89
.. GENERATED FROM PYTHON SOURCE LINES 76-90
.. code-block:: default
@@ -153,11 +154,11 @@ We can now use the above function to compute the element-wise sum of two `torch.
.. GENERATED FROM PYTHON SOURCE LINES 90-91
.. GENERATED FROM PYTHON SOURCE LINES 91-92
Seems like we're good to go!
.. GENERATED FROM PYTHON SOURCE LINES 93-98
.. GENERATED FROM PYTHON SOURCE LINES 94-99
Benchmark
-----------
@@ -165,7 +166,7 @@ We can now benchmark our custom op on vectors of increasing sizes to get a sense
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
for different problem sizes.
.. GENERATED FROM PYTHON SOURCE LINES 98-127
.. GENERATED FROM PYTHON SOURCE LINES 99-128
.. code-block:: default
@@ -205,12 +206,12 @@ for different problem sizes.
.. GENERATED FROM PYTHON SOURCE LINES 128-130
.. GENERATED FROM PYTHON SOURCE LINES 129-131
We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
`save_path='/path/to/results/' to save them to disk along with raw CSV data
.. GENERATED FROM PYTHON SOURCE LINES 130-131
.. GENERATED FROM PYTHON SOURCE LINES 131-132
.. code-block:: default
@@ -237,14 +238,14 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
3 32768.0 76.800002 76.800002
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 341.333321 384.000001
6 262144.0 341.333321 341.333321
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
10 4194304.0 780.190482 780.190482
11 8388608.0 812.429770 812.429770
12 16777216.0 833.084721 833.084721
13 33554432.0 842.004273 843.811163
13 33554432.0 842.004273 842.004273
14 67108864.0 847.448255 848.362445
15 134217728.0 849.737435 850.656574
@@ -254,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 52.411 seconds)
**Total running time of the script:** ( 1 minutes 50.312 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:

View File

@@ -35,13 +35,16 @@ Motivations
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
.. GENERATED FROM PYTHON SOURCE LINES 18-43
.. GENERATED FROM PYTHON SOURCE LINES 18-46
.. code-block:: default
import torch
import triton
import triton.language as tl
@torch.jit.script
def naive_softmax(x):
@@ -71,7 +74,7 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
.. GENERATED FROM PYTHON SOURCE LINES 44-52
.. GENERATED FROM PYTHON SOURCE LINES 47-55
When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}`
requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements.
@@ -82,7 +85,7 @@ expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`).
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically
but, as we will see later, it is still far from ideal.
.. GENERATED FROM PYTHON SOURCE LINES 54-61
.. GENERATED FROM PYTHON SOURCE LINES 57-64
Compute Kernel
----------------
@@ -92,22 +95,19 @@ Note that one important limitation of Triton is that each block must have a
power-of-two number of elements, so we need to internally "pad" each row and guard the
memory operations properly if we want to handle any possible input shapes:
.. GENERATED FROM PYTHON SOURCE LINES 61-93
.. GENERATED FROM PYTHON SOURCE LINES 64-93
.. code-block:: default
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
@@ -199,20 +199,12 @@ This will allow us to verify that our padding mechanism works.
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
print(torch.allclose(y_triton, y_torch))
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
True
@@ -286,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 512.000001 546.133347 188.321838
1 384.0 585.142862 585.142862 151.703707
2 512.0 655.360017 606.814814 154.566038
0 256.0 546.133347 512.000001 190.511628
1 384.0 614.400016 558.545450 153.600004
2 512.0 655.360017 585.142849 154.566038
3 640.0 682.666684 640.000002 160.000000
4 768.0 702.171410 646.736871 163.839992
4 768.0 722.823517 664.216187 162.754967
.. ... ... ... ...
93 12160.0 810.666687 405.755985 199.038365
94 12288.0 812.429770 415.661740 199.197579
95 12416.0 809.189387 412.149375 198.854847
96 12544.0 807.661970 412.971190 199.012395
97 12672.0 807.776923 412.097543 199.167004
93 12160.0 814.058574 406.179533 198.530610
94 12288.0 814.111783 415.661740 198.694297
95 12416.0 814.163950 412.149375 198.457532
96 12544.0 814.214963 412.546756 198.716830
97 12672.0 814.265046 412.097543 198.679085
[98 rows x 4 columns]
@@ -314,7 +306,7 @@ In the above plot, we can see that:
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 3 minutes 26.243 seconds)
**Total running time of the script:** ( 3 minutes 22.431 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:

View File

@@ -161,12 +161,13 @@ Final Result
-------------
.. GENERATED FROM PYTHON SOURCE LINES 142-262
.. GENERATED FROM PYTHON SOURCE LINES 142-258
.. code-block:: default
import torch
import triton
import triton.language as tl
@@ -178,23 +179,22 @@ Final Result
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
# %
# We can now define our kernel as normal, using all the techniques presented above
@triton.jit
def matmul_kernel(
# Pointers to matrices
@@ -208,17 +208,13 @@ Final Result
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
**meta,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# extract meta-parameters
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
GROUP_SIZE_M = 8
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
@@ -243,8 +239,8 @@ Final Result
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
@@ -265,9 +261,9 @@ Final Result
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here
# while the accumulator is still in FP32 !
if meta['ACTIVATION']:
accumulator = meta['ACTIVATION'](accumulator)
# while the accumulator is still in FP32!
if ACTIVATION:
accumulator = ACTIVATION(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
@@ -292,12 +288,12 @@ Final Result
.. GENERATED FROM PYTHON SOURCE LINES 263-265
.. GENERATED FROM PYTHON SOURCE LINES 259-261
We can now create a convenience wrapper function that only takes two input tensors
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
.. GENERATED FROM PYTHON SOURCE LINES 265-294
.. GENERATED FROM PYTHON SOURCE LINES 261-290
.. code-block:: default
@@ -337,14 +333,14 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
.. GENERATED FROM PYTHON SOURCE LINES 295-299
.. GENERATED FROM PYTHON SOURCE LINES 291-295
Unit Test
-----------
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
.. GENERATED FROM PYTHON SOURCE LINES 299-312
.. GENERATED FROM PYTHON SOURCE LINES 295-308
.. code-block:: default
@@ -392,7 +388,7 @@ We can test our custom matrix multiplication operation against a native torch im
.. GENERATED FROM PYTHON SOURCE LINES 313-319
.. GENERATED FROM PYTHON SOURCE LINES 309-315
Benchmark
--------------
@@ -401,7 +397,7 @@ Square Matrix Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
.. GENERATED FROM PYTHON SOURCE LINES 319-360
.. GENERATED FROM PYTHON SOURCE LINES 315-356
.. code-block:: default
@@ -463,36 +459,36 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
1 384.0 7.372800 ... 8.507077 8.507077
2 512.0 14.563555 ... 16.384000 16.384000
1 384.0 7.372800 ... 8.507077 7.899428
2 512.0 14.563555 ... 15.420235 16.384000
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 34.028308 34.028308
5 896.0 37.971025 ... 40.140799 39.025776
6 1024.0 49.932191 ... 53.773130 52.428801
7 1152.0 44.566925 ... 46.656000 46.656000
5 896.0 37.971025 ... 39.025776 39.025776
6 1024.0 49.932191 ... 52.428801 52.428801
7 1152.0 45.242181 ... 46.656000 46.656000
8 1280.0 51.200001 ... 56.888887 56.109587
9 1408.0 64.138541 ... 67.305878 67.305878
10 1536.0 80.430545 ... 79.526831 79.526831
11 1664.0 63.372618 ... 62.492442 62.492442
12 1792.0 72.983276 ... 72.047592 72.047592
13 1920.0 69.120002 ... 70.172588 70.172588
14 2048.0 73.908442 ... 76.959706 76.608294
15 2176.0 83.500614 ... 86.367588 85.632545
16 2304.0 68.446623 ... 76.809875 76.809875
17 2432.0 71.305746 ... 74.918570 85.393507
18 2560.0 78.019048 ... 80.908642 80.709358
19 2688.0 83.552988 ... 89.149366 89.464755
20 2816.0 82.602666 ... 83.074685 83.233226
21 2944.0 82.646820 ... 82.373605 82.784108
22 3072.0 81.943708 ... 88.612060 87.516392
23 3200.0 78.914919 ... 91.822093 93.567248
24 3328.0 81.530349 ... 84.003845 84.496824
25 3456.0 82.519518 ... 91.200871 90.943675
26 3584.0 85.633710 ... 94.847460 96.579370
27 3712.0 85.528545 ... 85.019017 87.706180
28 3840.0 81.980725 ... 86.130841 91.247522
29 3968.0 85.993854 ... 91.747320 86.053553
30 4096.0 93.727466 ... 88.417474 84.307617
9 1408.0 64.138541 ... 66.485074 65.684049
10 1536.0 79.526831 ... 79.526831 78.643199
11 1664.0 62.929456 ... 62.061463 62.061463
12 1792.0 72.983276 ... 72.047592 71.588687
13 1920.0 68.776119 ... 70.172588 69.818184
14 2048.0 73.262953 ... 76.959706 76.608294
15 2176.0 83.500614 ... 85.998493 85.269692
16 2304.0 68.446623 ... 76.319081 75.834511
17 2432.0 71.125224 ... 82.509438 84.877538
18 2560.0 77.833728 ... 80.709358 81.108913
19 2688.0 83.369354 ... 89.676257 89.464755
20 2816.0 83.233226 ... 82.446516 81.827785
21 2944.0 82.373605 ... 82.373605 81.298583
22 3072.0 82.062468 ... 88.060814 88.473602
23 3200.0 82.368085 ... 89.761569 94.955488
24 3328.0 80.889094 ... 80.527177 82.939284
25 3456.0 81.849303 ... 86.783176 91.304157
26 3584.0 87.042978 ... 98.375705 90.364394
27 3712.0 79.726532 ... 90.815768 85.820159
28 3840.0 82.592983 ... 88.191387 91.398346
29 3968.0 85.873762 ... 90.656713 83.867052
30 4096.0 92.563952 ... 82.441739 82.291681
[31 rows x 5 columns]
@@ -502,7 +498,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 5 minutes 41.070 seconds)
**Total running time of the script:** ( 6 minutes 5.923 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:

View File

@@ -46,7 +46,7 @@ keeps the norm consistent regardless of the dropout probability.
Let's first take a look at the baseline implementation.
.. GENERATED FROM PYTHON SOURCE LINES 29-80
.. GENERATED FROM PYTHON SOURCE LINES 29-82
.. code-block:: default
@@ -54,19 +54,20 @@ Let's first take a look at the baseline implementation.
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
**meta,
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
BLOCK_SIZE: tl.constexpr,
):
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -88,6 +89,7 @@ Let's first take a look at the baseline implementation.
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
@@ -120,7 +122,7 @@ Let's first take a look at the baseline implementation.
.. GENERATED FROM PYTHON SOURCE LINES 81-99
.. GENERATED FROM PYTHON SOURCE LINES 83-101
Seeded dropout
-------------
@@ -141,11 +143,12 @@ other :ref:`random number generation strategies <Random Number Generation>`.
Let's put it all together.
.. GENERATED FROM PYTHON SOURCE LINES 99-147
.. GENERATED FROM PYTHON SOURCE LINES 101-149
.. code-block:: default
@triton.jit
def _seeded_dropout(
x_ptr,
@@ -153,10 +156,9 @@ Let's put it all together.
n_elements,
p,
seed,
**meta,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -205,21 +207,21 @@ Let's put it all together.
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
output (seed = 123) 0 0.743443 0 2.84284 0.298794 -1.34172 0 0 0 0
output (seed = 123) 0 0.743443 0 2.84284 0.298794 -1.34172 0 0 0 0
output (seed = 512) -1.90567 0.743443 0 2.84284 0.298794 -1.34172 0 -0.863938 0 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
.. GENERATED FROM PYTHON SOURCE LINES 148-151
.. GENERATED FROM PYTHON SOURCE LINES 150-153
Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same!
If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you
to explore the `triton/language/random` folder!
.. GENERATED FROM PYTHON SOURCE LINES 153-158
.. GENERATED FROM PYTHON SOURCE LINES 155-160
Exercises
-------------
@@ -227,7 +229,7 @@ Exercises
2. Add support for striding.
3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed.
.. GENERATED FROM PYTHON SOURCE LINES 160-165
.. GENERATED FROM PYTHON SOURCE LINES 162-167
References
--------------
@@ -238,7 +240,7 @@ References
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.010 seconds)
**Total running time of the script:** ( 0 minutes 0.477 seconds)
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:

View File

@@ -0,0 +1,370 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/05-layer-norm.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_05-layer-norm.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_05-layer-norm.py:
Layer Normalization
====================
.. GENERATED FROM PYTHON SOURCE LINES 5-262
.. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png
:alt: 05 layer norm
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
layer-norm-backward:
N Triton Torch
0 1024.0 311.088617 99.497980
1 1536.0 347.773587 133.565214
2 2048.0 420.102553 162.217818
3 2560.0 455.111129 182.857144
4 3072.0 511.999982 191.501303
5 3584.0 551.384634 208.271186
6 4096.0 568.231237 220.907859
7 4608.0 502.690905 232.336141
8 5120.0 527.381977 243.326731
9 5632.0 540.671974 243.545956
10 6144.0 544.118087 249.081070
11 6656.0 528.953642 256.410903
12 7168.0 507.469040 262.243907
13 7680.0 481.253256 261.076480
14 8192.0 461.521112 269.326017
15 8704.0 417.791980 268.159180
16 9216.0 431.157889 273.404206
17 9728.0 442.181815 280.953074
18 10240.0 448.467168 286.767793
19 10752.0 427.231788 246.699797
20 11264.0 427.071098 245.313973
21 11776.0 420.571432 249.447482
22 12288.0 420.102570 254.673582
23 12800.0 414.016170 253.674644
24 13312.0 410.652963 252.759501
25 13824.0 403.620451 257.390218
26 14336.0 396.387109 254.862216
27 14848.0 382.351933 257.293872
28 15360.0 374.253788 257.790220
29 15872.0 368.046389 262.890274
|
.. code-block:: default
import torch
import triton
import triton.language as tl
try:
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
# should not be added to extras_require in setup.py.
import apex
HAS_APEX = True
except ModuleNotFoundError:
HAS_APEX = False
# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
BLOCK_SIZE: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
Y += row * stride
# load data and cast to float32
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
# compute mean
mean = tl.sum(x, axis=0) / N
# compute std
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean * rstd
# write-back mean/rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
# multiply by weight and add bias
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
y = xhat * w + b
# write-back
tl.store(Y + cols, y, mask=mask)
# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
DY += row * stride
DX += row * stride
# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW and DB into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# first store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# release lock
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
x, w, b, m, v = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
return dx, None, dw, db, None
layer_norm = LayerNorm.apply
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[512 * i for i in range(2, 32)],
line_arg='provider',
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 23.770 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -9,6 +9,14 @@ Tutorials
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
To install the dependencies for the tutorials:
.. code-block:: bash
cd triton
pip install -e './python[tutorials]'
.. raw:: html
@@ -93,6 +101,27 @@ Below is a gallery of tutorials for writing various basic operations with Triton
:hidden:
/getting-started/tutorials/04-low-memory-dropout
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Layer Normalization">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_05-layer-norm_thumb.png
:alt: Layer Normalization
:ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/05-layer-norm
.. raw:: html
<div class="sphx-glr-clear"></div>

View File

@@ -5,14 +5,16 @@
Computation times
=================
**10:59.734** total execution time for **getting-started_tutorials** files:
**12:42.913** total execution time for **getting-started_tutorials** files:
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 05:41.070 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 06:05.923 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:26.243 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:22.431 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:52.411 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:50.312 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.010 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 01:23.770 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.477 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+

View File

@@ -208,7 +208,7 @@ pip install -e .
</div>
<p>Note that, if llvm-11 is not present on your system, the setup.py script will download the official LLVM11 static libraries link against that.</p>
<p>You can then test your installation by running the unit tests:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip install -r requirements-test.txt
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip install -e <span class="s1">&#39;.[tests]&#39;</span>
pytest -vs test/unit/
</pre></div>
</div>

View File

@@ -104,6 +104,7 @@
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
@@ -202,6 +203,7 @@ to download the full example code</p>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
@@ -212,9 +214,9 @@ to download the full example code</p>
<span class="n">y_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to second input vector</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># *Pointer* to output vector</span>
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
<span class="o">**</span><span class="n">meta</span><span class="p">,</span> <span class="c1"># Optional meta-parameters for the kernel</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="c1"># Number of elements each program should process</span>
<span class="c1"># NOTE: `constexpr` so it can be used as a shape value</span>
<span class="p">):</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]</span> <span class="c1"># How many inputs each program should process</span>
<span class="c1"># There are multiple &#39;program&#39;s processing different data. We identify which program</span>
<span class="c1"># we are here</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># We use a 1D launch grid so axis is 0</span>
@@ -226,7 +228,7 @@ to download the full example code</p>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="c1"># Create a mask to guard memory operations against out-of-bounds accesses</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
<span class="c1"># Load x and y from DRAM, masking out any extar elements in case the input is not a</span>
<span class="c1"># Load x and y from DRAM, masking out any extra elements in case the input is not a</span>
<span class="c1"># multiple of the block size</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
@@ -250,7 +252,7 @@ and (2) enqueue the above kernel with appropriate grid/block sizes.</p>
<span class="c1"># - each torch.tensor object is implicitly converted into a pointer to its first element.</span>
<span class="c1"># - `triton.jit`&#39;ed functions can be index with a launch grid to obtain a callable GPU kernel</span>
<span class="c1"># - don&#39;t forget to pass meta-parameters as keywords arguments</span>
<span class="n">pgm</span> <span class="o">=</span> <span class="n">add_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="n">add_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="c1"># We return a handle to z but, since `torch.cuda.synchronize()` hasn&#39;t been called, the kernel is still</span>
<span class="c1"># running asynchronously at this point.</span>
<span class="k">return</span> <span class="n">output</span>
@@ -326,19 +328,19 @@ for different problem sizes.</p>
3 32768.0 76.800002 76.800002
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 341.333321 384.000001
6 262144.0 341.333321 341.333321
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
10 4194304.0 780.190482 780.190482
11 8388608.0 812.429770 812.429770
12 16777216.0 833.084721 833.084721
13 33554432.0 842.004273 843.811163
13 33554432.0 842.004273 842.004273
14 67108864.0 847.448255 848.362445
15 134217728.0 849.737435 850.656574
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 52.411 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 50.312 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>

View File

@@ -107,6 +107,7 @@
</li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
@@ -210,6 +211,9 @@ You will learn about:</p>
Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
@@ -248,17 +252,13 @@ normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a
power-of-two number of elements, so we need to internally “pad” each row and guard the
memory operations properly if we want to handle any possible input shapes:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">softmax_kernel</span><span class="p">(</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="n">input_ptr</span><span class="p">,</span> <span class="n">input_row_stride</span><span class="p">,</span> <span class="n">output_row_stride</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="o">**</span><span class="n">meta</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="n">input_ptr</span><span class="p">,</span> <span class="n">input_row_stride</span><span class="p">,</span> <span class="n">output_row_stride</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span>
<span class="p">):</span>
<span class="c1"># The rows of the softmax are independent, so we parallelize across those</span>
<span class="n">row_idx</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]</span>
<span class="c1"># The stride represents how much we need to increase the pointer to advance 1 row</span>
<span class="n">row_start_ptr</span> <span class="o">=</span> <span class="n">input_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">input_row_stride</span>
<span class="c1"># The block size is the next power of two greater than n_cols, so we can fit each</span>
@@ -318,11 +318,7 @@ This will allow us to verify that our padding mechanism works.</p>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y_triton</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>True
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">),</span> <span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">)</span>
</pre></div>
</div>
<p>As expected, the results are identical.</p>
@@ -373,17 +369,17 @@ We will then compare its performance against (1) <code class="code docutils lite
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 512.000001 546.133347 188.321838
1 384.0 585.142862 585.142862 151.703707
2 512.0 655.360017 606.814814 154.566038
0 256.0 546.133347 512.000001 190.511628
1 384.0 614.400016 558.545450 153.600004
2 512.0 655.360017 585.142849 154.566038
3 640.0 682.666684 640.000002 160.000000
4 768.0 702.171410 646.736871 163.839992
4 768.0 722.823517 664.216187 162.754967
.. ... ... ... ...
93 12160.0 810.666687 405.755985 199.038365
94 12288.0 812.429770 415.661740 199.197579
95 12416.0 809.189387 412.149375 198.854847
96 12544.0 807.661970 412.971190 199.012395
97 12672.0 807.776923 412.097543 199.167004
93 12160.0 814.058574 406.179533 198.530610
94 12288.0 814.111783 415.661740 198.694297
95 12416.0 814.163950 412.149375 198.457532
96 12544.0 814.214963 412.546756 198.716830
97 12672.0 814.265046 412.097543 198.679085
[98 rows x 4 columns]
</pre></div>
@@ -396,7 +392,7 @@ We will then compare its performance against (1) <code class="code docutils lite
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
</ul>
</div></blockquote>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 26.243 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 3 minutes 22.431 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>

View File

@@ -114,6 +114,7 @@
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
@@ -332,6 +333,7 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
<div class="section" id="final-result">
<h2>Final Result<a class="headerlink" href="#final-result" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
@@ -343,23 +345,22 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
<span class="c1"># - An autotuning *key* whose change in values will trigger evaluation of all the</span>
<span class="c1"># provided configs</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span>
<span class="n">configs</span><span class="o">=</span><span class="p">[</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">&#39;GROUP_SIZE_M&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="p">],</span>
<span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;M&#39;</span><span class="p">,</span> <span class="s1">&#39;N&#39;</span><span class="p">,</span> <span class="s1">&#39;K&#39;</span><span class="p">],</span>
<span class="p">)</span>
<span class="c1"># %</span>
<span class="c1"># We can now define our kernel as normal, using all the techniques presented above</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">matmul_kernel</span><span class="p">(</span>
<span class="c1"># Pointers to matrices</span>
@@ -373,17 +374,13 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
<span class="n">stride_bk</span><span class="p">,</span> <span class="n">stride_bn</span><span class="p">,</span>
<span class="n">stride_cm</span><span class="p">,</span> <span class="n">stride_cn</span><span class="p">,</span>
<span class="c1"># Meta-parameters</span>
<span class="o">**</span><span class="n">meta</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="n">ACTIVATION</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Kernel for computing the matmul C = A x B.</span>
<span class="sd"> A has shape (M, K), B has shape (K, N) and C has shape (M, N)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># extract meta-parameters</span>
<span class="n">BLOCK_SIZE_M</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_M&#39;</span><span class="p">]</span>
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">]</span>
<span class="n">BLOCK_SIZE_K</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_K&#39;</span><span class="p">]</span>
<span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">8</span>
<span class="c1"># -----------------------------------------------------------</span>
<span class="c1"># Map program ids `pid` to the block of C it should compute.</span>
<span class="c1"># This is done in a grouped ordering to promote L2 data reuse</span>
@@ -408,8 +405,8 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
<span class="n">offs_am</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">offs_bn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span> <span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_ak</span><span class="p">)</span>
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span> <span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_bn</span><span class="p">)</span>
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">)</span>
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_bn</span><span class="p">)</span>
<span class="c1"># -----------------------------------------------------------</span>
<span class="c1"># Iterate to compute a block of the C matrix</span>
@@ -430,9 +427,9 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
<span class="n">a_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
<span class="n">b_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
<span class="c1"># you can fuse arbitrary activation functions here</span>
<span class="c1"># while the accumulator is still in FP32 !</span>
<span class="k">if</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;ACTIVATION&#39;</span><span class="p">]:</span>
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;ACTIVATION&#39;</span><span class="p">](</span><span class="n">accumulator</span><span class="p">)</span>
<span class="c1"># while the accumulator is still in FP32!</span>
<span class="k">if</span> <span class="n">ACTIVATION</span><span class="p">:</span>
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">ACTIVATION</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
<span class="n">c</span> <span class="o">=</span> <span class="n">accumulator</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="c1"># -----------------------------------------------------------</span>
@@ -568,41 +565,41 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
1 384.0 7.372800 ... 8.507077 8.507077
2 512.0 14.563555 ... 16.384000 16.384000
1 384.0 7.372800 ... 8.507077 7.899428
2 512.0 14.563555 ... 15.420235 16.384000
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 34.028308 34.028308
5 896.0 37.971025 ... 40.140799 39.025776
6 1024.0 49.932191 ... 53.773130 52.428801
7 1152.0 44.566925 ... 46.656000 46.656000
5 896.0 37.971025 ... 39.025776 39.025776
6 1024.0 49.932191 ... 52.428801 52.428801
7 1152.0 45.242181 ... 46.656000 46.656000
8 1280.0 51.200001 ... 56.888887 56.109587
9 1408.0 64.138541 ... 67.305878 67.305878
10 1536.0 80.430545 ... 79.526831 79.526831
11 1664.0 63.372618 ... 62.492442 62.492442
12 1792.0 72.983276 ... 72.047592 72.047592
13 1920.0 69.120002 ... 70.172588 70.172588
14 2048.0 73.908442 ... 76.959706 76.608294
15 2176.0 83.500614 ... 86.367588 85.632545
16 2304.0 68.446623 ... 76.809875 76.809875
17 2432.0 71.305746 ... 74.918570 85.393507
18 2560.0 78.019048 ... 80.908642 80.709358
19 2688.0 83.552988 ... 89.149366 89.464755
20 2816.0 82.602666 ... 83.074685 83.233226
21 2944.0 82.646820 ... 82.373605 82.784108
22 3072.0 81.943708 ... 88.612060 87.516392
23 3200.0 78.914919 ... 91.822093 93.567248
24 3328.0 81.530349 ... 84.003845 84.496824
25 3456.0 82.519518 ... 91.200871 90.943675
26 3584.0 85.633710 ... 94.847460 96.579370
27 3712.0 85.528545 ... 85.019017 87.706180
28 3840.0 81.980725 ... 86.130841 91.247522
29 3968.0 85.993854 ... 91.747320 86.053553
30 4096.0 93.727466 ... 88.417474 84.307617
9 1408.0 64.138541 ... 66.485074 65.684049
10 1536.0 79.526831 ... 79.526831 78.643199
11 1664.0 62.929456 ... 62.061463 62.061463
12 1792.0 72.983276 ... 72.047592 71.588687
13 1920.0 68.776119 ... 70.172588 69.818184
14 2048.0 73.262953 ... 76.959706 76.608294
15 2176.0 83.500614 ... 85.998493 85.269692
16 2304.0 68.446623 ... 76.319081 75.834511
17 2432.0 71.125224 ... 82.509438 84.877538
18 2560.0 77.833728 ... 80.709358 81.108913
19 2688.0 83.369354 ... 89.676257 89.464755
20 2816.0 83.233226 ... 82.446516 81.827785
21 2944.0 82.373605 ... 82.373605 81.298583
22 3072.0 82.062468 ... 88.060814 88.473602
23 3200.0 82.368085 ... 89.761569 94.955488
24 3328.0 80.889094 ... 80.527177 82.939284
25 3456.0 81.849303 ... 86.783176 91.304157
26 3584.0 87.042978 ... 98.375705 90.364394
27 3712.0 79.726532 ... 90.815768 85.820159
28 3840.0 82.592983 ... 88.191387 91.398346
29 3968.0 85.873762 ... 90.656713 83.867052
30 4096.0 92.563952 ... 82.441739 82.291681
[31 rows x 5 columns]
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 41.070 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 6 minutes 5.923 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>

View File

@@ -47,7 +47,7 @@
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="triton" href="../../python-api/triton.html" />
<link rel="next" title="Layer Normalization" href="05-layer-norm.html" />
<link rel="prev" title="Matrix Multiplication" href="03-matrix-multiplication.html" />
</head>
@@ -107,6 +107,7 @@
<li class="toctree-l3"><a class="reference internal" href="#references">References</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
@@ -217,19 +218,20 @@ keeps the norm consistent regardless of the dropout probability.</p>
<p>Lets first take a look at the baseline implementation.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tabulate</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_dropout</span><span class="p">(</span>
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># pointer to the input</span>
<span class="n">x_keep_ptr</span><span class="p">,</span> <span class="c1"># pointer to a mask of 0s and 1s</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># pointer to the output</span>
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># number of elements in the `x` tensor</span>
<span class="n">p</span><span class="p">,</span> <span class="c1"># probability that an element of `x` is changed to zero</span>
<span class="o">**</span><span class="n">meta</span><span class="p">,</span>
<span class="n">x_ptr</span><span class="p">,</span> <span class="c1"># pointer to the input</span>
<span class="n">x_keep_ptr</span><span class="p">,</span> <span class="c1"># pointer to a mask of 0s and 1s</span>
<span class="n">output_ptr</span><span class="p">,</span> <span class="c1"># pointer to the output</span>
<span class="n">n_elements</span><span class="p">,</span> <span class="c1"># number of elements in the `x` tensor</span>
<span class="n">p</span><span class="p">,</span> <span class="c1"># probability that an element of `x` is changed to zero</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
@@ -251,6 +253,7 @@ keeps the norm consistent regardless of the dropout probability.</p>
<span class="n">_dropout</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">x_keep</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="c1"># Input tensor</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1"># Dropout mask</span>
@@ -298,10 +301,9 @@ other <a class="reference internal" href="../../python-api/triton.language.html#
<span class="n">n_elements</span><span class="p">,</span>
<span class="n">p</span><span class="p">,</span>
<span class="n">seed</span><span class="p">,</span>
<span class="o">**</span><span class="n">meta</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span>
<span class="p">):</span>
<span class="c1"># compute memory offsets of elements handled by this instance</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">]</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
@@ -342,9 +344,9 @@ other <a class="reference internal" href="../../python-api/triton.language.html#
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434
output (seed = 123) 0 0.743443 0 2.84284 0.298794 -1.34172 0 0 0 0
output (seed = 123) 0 0.743443 0 2.84284 0.298794 -1.34172 0 0 0 0
output (seed = 512) -1.90567 0.743443 0 2.84284 0.298794 -1.34172 0 -0.863938 0 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868
output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0
------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- ---------
</pre></div>
</div>
@@ -370,7 +372,7 @@ to explore the <cite>triton/language/random</cite> folder!</p>
<dd><p>Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014</p>
</dd>
</dl>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.010 seconds)</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 0.477 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-04-low-memory-dropout-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/c9aed78977a4c05741d675a38dde3d7d/04-low-memory-dropout.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">04-low-memory-dropout.py</span></code></a></p>
@@ -389,7 +391,7 @@ to explore the <cite>triton/language/random</cite> folder!</p>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="05-layer-norm.html" class="btn btn-neutral float-right" title="Layer Normalization" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="03-matrix-multiplication.html" class="btn btn-neutral float-left" title="Matrix Multiplication" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>

View File

@@ -0,0 +1,550 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Layer Normalization &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="triton" href="../../python-api/triton.html" />
<link rel="prev" title="Low-Memory Dropout" href="04-low-memory-dropout.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">Layer Normalization</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Layer Normalization</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/05-layer-norm.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="layer-normalization">
<span id="sphx-glr-getting-started-tutorials-05-layer-norm-py"></span><h1>Layer Normalization<a class="headerlink" href="#layer-normalization" title="Permalink to this headline"></a></h1>
<img alt="05 layer norm" class="sphx-glr-single-img" src="../../_images/sphx_glr_05-layer-norm_001.png" />
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>layer-norm-backward:
N Triton Torch
0 1024.0 311.088617 99.497980
1 1536.0 347.773587 133.565214
2 2048.0 420.102553 162.217818
3 2560.0 455.111129 182.857144
4 3072.0 511.999982 191.501303
5 3584.0 551.384634 208.271186
6 4096.0 568.231237 220.907859
7 4608.0 502.690905 232.336141
8 5120.0 527.381977 243.326731
9 5632.0 540.671974 243.545956
10 6144.0 544.118087 249.081070
11 6656.0 528.953642 256.410903
12 7168.0 507.469040 262.243907
13 7680.0 481.253256 261.076480
14 8192.0 461.521112 269.326017
15 8704.0 417.791980 268.159180
16 9216.0 431.157889 273.404206
17 9728.0 442.181815 280.953074
18 10240.0 448.467168 286.767793
19 10752.0 427.231788 246.699797
20 11264.0 427.071098 245.313973
21 11776.0 420.571432 249.447482
22 12288.0 420.102570 254.673582
23 12800.0 414.016170 253.674644
24 13312.0 410.652963 252.759501
25 13824.0 403.620451 257.390218
26 14336.0 396.387109 254.862216
27 14848.0 382.351933 257.293872
28 15360.0 374.253788 257.790220
29 15872.0 368.046389 262.890274
</pre></div>
</div>
<div class="line-block">
<div class="line"><br /></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
<span class="k">try</span><span class="p">:</span>
<span class="c1"># This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it</span>
<span class="c1"># should not be added to extras_require in setup.py.</span>
<span class="kn">import</span> <span class="nn">apex</span>
<span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">except</span> <span class="ne">ModuleNotFoundError</span><span class="p">:</span>
<span class="n">HAS_APEX</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># Forward Pass</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_fwd_fused</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="c1"># position of elements processed by this program</span>
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span>
<span class="c1"># offset data pointers to start at the row of interest</span>
<span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">Y</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="c1"># load data and cast to float32</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># compute mean</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="c1"># compute std</span>
<span class="n">xmean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">var</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xmean</span> <span class="o">*</span> <span class="n">xmean</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">tl</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="n">xmean</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="c1"># write-back mean/rstd</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">mean</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
<span class="c1"># multiply by weight and add bias</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">xhat</span> <span class="o">*</span> <span class="n">w</span> <span class="o">+</span> <span class="n">b</span>
<span class="c1"># write-back</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># Backward pass (DX + partial DW + partial DB)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dx_fused</span><span class="p">(</span><span class="n">DX</span><span class="p">,</span> <span class="n">DY</span><span class="p">,</span> <span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">Lock</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="c1"># position of elements processed by this program</span>
<span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span>
<span class="c1"># offset data pointers to start at the row of interest</span>
<span class="n">X</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">DY</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="n">DX</span> <span class="o">+=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">stride</span>
<span class="c1"># offset locks and weight/bias gradient pointer</span>
<span class="c1"># each kernel instance accumulates partial sums for</span>
<span class="c1"># DW and DB into one of GROUP_SIZE_M independent buffers</span>
<span class="c1"># these buffers stay in the L2, which allow this kernel</span>
<span class="c1"># to be fast</span>
<span class="n">lock_id</span> <span class="o">=</span> <span class="n">row</span> <span class="o">%</span> <span class="n">GROUP_SIZE_M</span>
<span class="n">Lock</span> <span class="o">+=</span> <span class="n">lock_id</span>
<span class="n">Count</span> <span class="o">=</span> <span class="n">Lock</span> <span class="o">+</span> <span class="n">GROUP_SIZE_M</span>
<span class="n">DW</span> <span class="o">=</span> <span class="n">DW</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
<span class="n">DB</span> <span class="o">=</span> <span class="n">DB</span> <span class="o">+</span> <span class="n">lock_id</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span>
<span class="c1"># load data to SRAM</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DY</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">W</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">V</span> <span class="o">+</span> <span class="n">row</span><span class="p">)</span>
<span class="c1"># compute dx</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="n">wdy</span> <span class="o">=</span> <span class="n">w</span> <span class="o">*</span> <span class="n">dy</span>
<span class="n">xhat</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">xhat</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">wdy</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">wdy</span><span class="p">,</span> <span class="mf">0.</span><span class="p">)</span>
<span class="n">mean1</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">mean2</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">wdy</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">N</span>
<span class="n">dx</span> <span class="o">=</span> <span class="p">(</span><span class="n">wdy</span> <span class="o">-</span> <span class="p">(</span><span class="n">xhat</span> <span class="o">*</span> <span class="n">mean1</span> <span class="o">+</span> <span class="n">mean2</span><span class="p">))</span> <span class="o">*</span> <span class="n">rstd</span>
<span class="c1"># write-back dx</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DX</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">dx</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># accumulate partial sums for dw/db</span>
<span class="n">partial_dw</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span> <span class="o">*</span> <span class="n">xhat</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">partial_db</span> <span class="o">=</span> <span class="p">(</span><span class="n">dy</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">while</span> <span class="n">tl</span><span class="o">.</span><span class="n">atomic_cas</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">pass</span>
<span class="n">count</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Count</span><span class="p">)</span>
<span class="c1"># first store doesn&#39;t accumulate</span>
<span class="k">if</span> <span class="n">count</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Count</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">partial_dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">partial_db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">partial_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">DB</span><span class="p">,</span> <span class="n">partial_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># release lock</span>
<span class="n">tl</span><span class="o">.</span><span class="n">atomic_xchg</span><span class="p">(</span><span class="n">Lock</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># Backward pass (total DW + total DB)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_layer_norm_bwd_dwdb</span><span class="p">(</span><span class="n">DW</span><span class="p">,</span> <span class="n">DB</span><span class="p">,</span> <span class="n">FINAL_DW</span><span class="p">,</span> <span class="n">FINAL_DB</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">:</span> <span class="n">tl</span><span class="o">.</span><span class="n">constexpr</span><span class="p">):</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
<span class="n">dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
<span class="n">rows</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">offs</span> <span class="o">=</span> <span class="n">rows</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">N</span> <span class="o">+</span> <span class="n">cols</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">dw</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DW</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">db</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">DB</span> <span class="o">+</span> <span class="n">offs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">sum_dw</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dw</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">sum_db</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DW</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_dw</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">FINAL_DB</span> <span class="o">+</span> <span class="n">cols</span><span class="p">,</span> <span class="n">sum_db</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">cols</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">LayerNorm</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
<span class="c1"># allocate output</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># reshape input data into 2D tensor</span>
<span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
<span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">rstd</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="c1"># Less than 64KB per feature: enqueue fused kernel</span>
<span class="n">MAX_FUSED_SIZE</span> <span class="o">=</span> <span class="mi">65536</span> <span class="o">//</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span>
<span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">MAX_FUSED_SIZE</span><span class="p">,</span> <span class="n">triton</span><span class="o">.</span><span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">))</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&gt;</span> <span class="n">BLOCK_SIZE</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;This layer norm doesn&#39;t support feature dim &gt;= 64KB.&quot;</span><span class="p">)</span>
<span class="c1"># heuristics for number of warps</span>
<span class="n">num_warps</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="n">BLOCK_SIZE</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">8</span><span class="p">)</span>
<span class="c1"># enqueue kernel</span>
<span class="n">_layer_norm_fwd_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">x_arg</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">,</span>
<span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">rstd</span><span class="p">)</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">BLOCK_SIZE</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span> <span class="o">=</span> <span class="n">num_warps</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="k">return</span> <span class="n">y</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dy</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
<span class="c1"># heuristics for amount of parallel reduction stream for DG/DB</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">64</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">8192</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">96</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">128</span>
<span class="k">if</span> <span class="n">N</span> <span class="o">&lt;=</span> <span class="mi">1024</span><span class="p">:</span> <span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">256</span>
<span class="c1"># allocate output</span>
<span class="n">locks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">_dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">_db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">dy</span><span class="p">)</span>
<span class="c1"># enqueue kernel using forward pass heuristics</span>
<span class="c1"># also compute partial sums for DW and DB</span>
<span class="n">x_arg</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x_arg</span><span class="o">.</span><span class="n">shape</span>
<span class="n">_layer_norm_bwd_dx_fused</span><span class="p">[(</span><span class="n">M</span><span class="p">,)](</span><span class="n">dx</span><span class="p">,</span> <span class="n">dy</span><span class="p">,</span> <span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">locks</span><span class="p">,</span>
<span class="n">x_arg</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">N</span><span class="p">,</span> <span class="n">ctx</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
<span class="n">GROUP_SIZE_M</span><span class="o">=</span><span class="n">GROUP_SIZE_M</span><span class="p">,</span>
<span class="n">num_warps</span><span class="o">=</span><span class="n">ctx</span><span class="o">.</span><span class="n">num_warps</span><span class="p">)</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">[</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK_SIZE_N&#39;</span><span class="p">])]</span>
<span class="c1"># accumulate partial sums in separate kernel</span>
<span class="n">_layer_norm_bwd_dwdb</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">_dw</span><span class="p">,</span> <span class="n">_db</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_M</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">BLOCK_SIZE_N</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dx</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dw</span><span class="p">,</span> <span class="n">db</span><span class="p">,</span> <span class="kc">None</span>
<span class="n">layer_norm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="o">.</span><span class="n">apply</span>
<span class="k">def</span> <span class="nf">test_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
<span class="c1"># create data</span>
<span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># forward pass</span>
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># backward pass (triton)</span>
<span class="n">y_tri</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dx_tri</span><span class="p">,</span> <span class="n">dw_tri</span><span class="p">,</span> <span class="n">db_tri</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
<span class="n">x</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
<span class="c1"># backward pass (torch)</span>
<span class="n">y_ref</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">dx_ref</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">db_ref</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">]]</span>
<span class="c1"># compare</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dx_tri</span><span class="p">,</span> <span class="n">dx_ref</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">db_tri</span><span class="p">,</span> <span class="n">db_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">dw_tri</span><span class="p">,</span> <span class="n">dw_ref</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;N&#39;</span><span class="p">],</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">512</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">32</span><span class="p">)],</span>
<span class="n">line_arg</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span>
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="s1">&#39;torch&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;apex&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;Triton&#39;</span><span class="p">,</span> <span class="s1">&#39;Torch&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="p">([</span><span class="s1">&#39;Apex&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="n">HAS_APEX</span> <span class="k">else</span> <span class="p">[]),</span>
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">&#39;blue&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;orange&#39;</span><span class="p">,</span> <span class="s1">&#39;-&#39;</span><span class="p">)],</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s1">&#39;GB/s&#39;</span><span class="p">,</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s1">&#39;layer-norm-backward&#39;</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;M&#39;</span><span class="p">:</span> <span class="mi">4096</span><span class="p">,</span> <span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s1">&#39;mode&#39;</span><span class="p">:</span> <span class="s1">&#39;backward&#39;</span><span class="p">}</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">bench_layer_norm</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">provider</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;backward&#39;</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
<span class="c1"># create data</span>
<span class="n">x_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.3</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">x_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">dy</span> <span class="o">=</span> <span class="mf">.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># utility functions</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch&#39;</span><span class="p">:</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;apex&#39;</span><span class="p">:</span>
<span class="n">apex_layer_norm</span> <span class="o">=</span> <span class="n">apex</span><span class="o">.</span><span class="n">normalization</span><span class="o">.</span><span class="n">FusedLayerNorm</span><span class="p">(</span><span class="n">w_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">y_fwd</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">apex_layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># forward pass</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;forward&#39;</span><span class="p">:</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">y_fwd</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="c1"># backward pass</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;backward&#39;</span><span class="p">:</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y_fwd</span><span class="p">()</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">y</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="n">grad_to_none</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">],</span> <span class="n">rep</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
<span class="n">bench_layer_norm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">save_path</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 23.770 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-05-layer-norm-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">05-layer-norm.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">05-layer-norm.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="04-low-memory-dropout.html" class="btn btn-neutral float-left" title="Low-Memory Dropout" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -100,6 +100,7 @@
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
<li class="toctree-l2"><a class="reference internal" href="04-low-memory-dropout.html">Low-Memory Dropout</a></li>
<li class="toctree-l2"><a class="reference internal" href="05-layer-norm.html">Layer Normalization</a></li>
</ul>
</li>
</ul>
@@ -183,6 +184,11 @@
<div class="section" id="tutorials">
<span id="sphx-glr-getting-started-tutorials"></span><h1>Tutorials<a class="headerlink" href="#tutorials" title="Permalink to this headline"></a></h1>
<p>Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.</p>
<p>To install the dependencies for the tutorials:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span> triton
pip install -e <span class="s1">&#39;./python[tutorials]&#39;</span>
</pre></div>
</div>
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model of Triton - The triton.jit decorator, which is used to define Tri..."><div class="figure align-default" id="id1">
<img alt="Vector Addition" src="../../_images/sphx_glr_01-vector-add_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a></span><a class="headerlink" href="#id1" title="Permalink to this image"></a></p>
@@ -207,6 +213,12 @@
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-thumbcontainer" tooltip="Layer Normalization"><div class="figure align-default" id="id5">
<img alt="Layer Normalization" src="../../_images/sphx_glr_05-layer-norm_thumb.png" />
<p class="caption"><span class="caption-text"><a class="reference internal" href="05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">Layer Normalization</span></a></span><a class="headerlink" href="#id5" title="Permalink to this image"></a></p>
</div>
</div><div class="toctree-wrapper compound">
</div>
<div class="sphx-glr-clear"></div><div class="sphx-glr-footer class sphx-glr-footer-gallery docutils container">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tutorials_python.zip</span></code></a></p>

View File

@@ -174,7 +174,7 @@
<div class="section" id="computation-times">
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline"></a></h1>
<p><strong>10:59.734</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
<p><strong>12:42.913</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
<table class="docutils align-default">
<colgroup>
<col style="width: 85%" />
@@ -183,19 +183,23 @@
</colgroup>
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
<td><p>05:41.070</p></td>
<td><p>06:05.923</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
<td><p>03:26.243</p></td>
<td><p>03:22.431</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
<td><p>01:52.411</p></td>
<td><p>01:50.312</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-even"><td><p><a class="reference internal" href="04-low-memory-dropout.html#sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">Low-Memory Dropout</span></a> (<code class="docutils literal notranslate"><span class="pre">04-low-memory-dropout.py</span></code>)</p></td>
<td><p>00:00.010</p></td>
<tr class="row-even"><td><p><a class="reference internal" href="05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py"><span class="std std-ref">Layer Normalization</span></a> (<code class="docutils literal notranslate"><span class="pre">05-layer-norm.py</span></code>)</p></td>
<td><p>01:23.770</p></td>
<td><p>0.0 MB</p></td>
</tr>
<tr class="row-odd"><td><p><a class="reference internal" href="04-low-memory-dropout.html#sphx-glr-getting-started-tutorials-04-low-memory-dropout-py"><span class="std std-ref">Low-Memory Dropout</span></a> (<code class="docutils literal notranslate"><span class="pre">04-low-memory-dropout.py</span></code>)</p></td>
<td><p>00:00.477</p></td>
<td><p>0.0 MB</p></td>
</tr>
</tbody>

Binary file not shown.

View File

@@ -186,7 +186,7 @@
<h1>triton.Config<a class="headerlink" href="#triton-config" title="Permalink to this headline"></a></h1>
<dl class="py class">
<dt class="sig sig-object py" id="triton.Config">
<em class="property"><span class="pre">class</span> </em><span class="sig-prename descclassname"><span class="pre">triton.</span></span><span class="sig-name descname"><span class="pre">Config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">meta</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_warps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">4</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_stages</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">2</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.Config" title="Permalink to this definition"></a></dt>
<em class="property"><span class="pre">class</span> </em><span class="sig-prename descclassname"><span class="pre">triton.</span></span><span class="sig-name descname"><span class="pre">Config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kwargs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_warps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">4</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_stages</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pre_hook</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.Config" title="Permalink to this definition"></a></dt>
<dd><p>An object that represents a possible kernel configuration for the auto-tuner to try.</p>
<dl class="field-list simple">
<dt class="field-odd">Variables</dt>
@@ -197,12 +197,14 @@
cooperatively execute using <cite>8 * 32 = 256</cite> threads.</p></li>
<li><p><strong>num_stages</strong> the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.</p></li>
<li><p><strong>pre_hook</strong> a function that will be called before the kernel is called. Parameters of this
function are args.</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt class="sig sig-object py" id="triton.Config.__init__">
<span class="sig-name descname"><span class="pre">__init__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">meta</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_warps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">4</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_stages</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">2</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.Config.__init__" title="Permalink to this definition"></a></dt>
<span class="sig-name descname"><span class="pre">__init__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kwargs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_warps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">4</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_stages</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pre_hook</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.Config.__init__" title="Permalink to this definition"></a></dt>
<dd></dd></dl>
<p class="rubric">Methods</p>
@@ -212,7 +214,7 @@ Mostly useful for matrix multiplication workloads on SM80+ GPUs.</p></li>
<col style="width: 90%" />
</colgroup>
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="#triton.Config.__init__" title="triton.Config.__init__"><code class="xref py py-obj docutils literal notranslate"><span class="pre">__init__</span></code></a>(self, meta[, num_warps, num_stages])</p></td>
<tr class="row-odd"><td><p><a class="reference internal" href="#triton.Config.__init__" title="triton.Config.__init__"><code class="xref py py-obj docutils literal notranslate"><span class="pre">__init__</span></code></a>(self, kwargs[, num_warps, ])</p></td>
<td><p></p></td>
</tr>
</tbody>

View File

@@ -186,7 +186,7 @@
<h1>triton.autotune<a class="headerlink" href="#triton-autotune" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.autotune">
<span class="sig-prename descclassname"><span class="pre">triton.</span></span><span class="sig-name descname"><span class="pre">autotune</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">configs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reset_to_zero</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.autotune" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.</span></span><span class="sig-name descname"><span class="pre">autotune</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">configs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prune_configs_by</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reset_to_zero</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.autotune" title="Permalink to this definition"></a></dt>
<dd><p>Decorator for auto-tuning a <code class="code docutils literal notranslate"><span class="pre">triton.jit</span></code>d function.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span><span class="n">configs</span><span class="o">=</span><span class="p">[</span>
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">(</span><span class="n">meta</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;BLOCK_SIZE&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
@@ -211,6 +211,10 @@ reset the value of the provided tensor to <cite>zero</cite> before running any c
<dd class="field-even"><ul class="simple">
<li><p><strong>configs</strong> (<em>list</em><em>[</em><a class="reference internal" href="triton.Config.html#triton.Config" title="triton.Config"><em>triton.Config</em></a><em>]</em>) a list of <code class="code docutils literal notranslate"><span class="pre">triton.Config</span></code> objects</p></li>
<li><p><strong>key</strong> (<em>list</em><em>[</em><em>str</em><em>]</em>) a list of argument names whose change in value will trigger the evaluation of all provided configs.</p></li>
<li><p><strong>prune_configs_by</strong> a dict of functions that are used to prune configs, fields:
perf_model: performance model used to predicate running time with different configs, returns running time
top_k: number of configs to bench
prune_num_stages_by(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.</p></li>
<li><p><strong>reset_to_zero</strong> (<em>list</em><em>[</em><em>str</em><em>]</em>) a list of argument names whose value will be reset to zero before evaluating any configs.</p></li>
</ul>
</dd>

View File

@@ -197,14 +197,14 @@
<h1>triton.language.dot<a class="headerlink" href="#triton-language-dot" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.dot">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">dot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">dot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_tf32</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition"></a></dt>
<dd><p>Returns the matrix product of two blocks.</p>
<p>The two blocks must be two dimensionals and have compatible inner dimensions.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>input</strong> (2D block of scalar-type in {<code class="code docutils literal notranslate"><span class="pre">float16</span></code>, <code class="code docutils literal notranslate"><span class="pre">float32</span></code>}) The first block to be multiplied.</p></li>
<li><p><strong>other</strong> (2D block of scalar-type in {<code class="code docutils literal notranslate"><span class="pre">float16</span></code>, <code class="code docutils literal notranslate"><span class="pre">float32</span></code>}) The second block to be multiplied.</p></li>
<li><p><strong>input</strong> (2D block of scalar-type in {<code class="code docutils literal notranslate"><span class="pre">float16</span></code>, <code class="code docutils literal notranslate"><span class="pre">bfloat16</span></code>, <code class="code docutils literal notranslate"><span class="pre">float32</span></code>}) The first block to be multiplied.</p></li>
<li><p><strong>other</strong> (2D block of scalar-type in {<code class="code docutils literal notranslate"><span class="pre">float16</span></code>, <code class="code docutils literal notranslate"><span class="pre">bfloat16</span></code>, <code class="code docutils literal notranslate"><span class="pre">float32</span></code>}) The second block to be multiplied.</p></li>
</ul>
</dd>
</dl>

View File

@@ -200,7 +200,7 @@
<h1>triton.language.load<a class="headerlink" href="#triton-language-load" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.load">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">load</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.load" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">load</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">cache_modifier</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">volatile</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.load" title="Permalink to this definition"></a></dt>
<dd><p>Return a block of data whose values are, elementwise, loaded from memory at location defined by <code class="code docutils literal notranslate"><span class="pre">pointer</span></code>.</p>
<p><code class="code docutils literal notranslate"><span class="pre">mask</span></code> and <code class="code docutils literal notranslate"><span class="pre">other</span></code> are implicitly broadcast to <code class="code docutils literal notranslate"><span class="pre">pointer.shape</span></code>.</p>
<p><code class="code docutils literal notranslate"><span class="pre">other</span></code> is implicitly typecast to <code class="code docutils literal notranslate"><span class="pre">pointer.dtype.element_ty</span></code>.</p>
@@ -210,9 +210,11 @@
<li><p><strong>pointer</strong> (<em>Block of dtype=triton.PointerDType</em>) Pointers to the data to be loaded.</p></li>
<li><p><strong>mask</strong> (<em>Block of triton.int1</em><em>, </em><em>optional</em>) if mask[idx] is false, do not load the data at address <code class="code docutils literal notranslate"><span class="pre">pointer[idx]</span></code>.</p></li>
<li><p><strong>other</strong> (<em>Block</em><em>, </em><em>optional</em>) if mask[idx] is false, return other[idx]</p></li>
<li><p><strong>cache_modifier</strong> changes cache option in nvidia ptx</p></li>
</ul>
</dd>
</dl>
<p>type cache_modifier: str, optional</p>
</dd></dl>
</div>

View File

@@ -201,7 +201,7 @@
<h1>triton.language.rand<a class="headerlink" href="#triton-language-rand" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.rand">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">rand</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.rand" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">rand</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_rounds</span></span><span class="p"><span class="pre">:</span></span> <span class="n"><span class="pre">triton.language.core.constexpr</span></span> <span class="o"><span class="pre">=</span></span> <span class="default_value"><span class="pre">10</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.rand" title="Permalink to this definition"></a></dt>
<dd><p>Given a <code class="code docutils literal notranslate"><span class="pre">seed</span></code> scalar and an <code class="code docutils literal notranslate"><span class="pre">offset</span></code> block,
returns a block of random <code class="code docutils literal notranslate"><span class="pre">float32</span></code> in <span class="math notranslate nohighlight">\(U(0, 1)\)</span></p>
<dl class="field-list simple">

View File

@@ -200,7 +200,7 @@
<h1>triton.language.randint<a class="headerlink" href="#triton-language-randint" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.randint">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">randint</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.randint" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">randint</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_rounds</span></span><span class="p"><span class="pre">:</span></span> <span class="n"><span class="pre">triton.language.core.constexpr</span></span> <span class="o"><span class="pre">=</span></span> <span class="default_value"><span class="pre">10</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.randint" title="Permalink to this definition"></a></dt>
<dd><p>Given a <code class="code docutils literal notranslate"><span class="pre">seed</span></code> scalar and an <code class="code docutils literal notranslate"><span class="pre">offset</span></code> block, returns a single
block of random <code class="code docutils literal notranslate"><span class="pre">int32</span></code>.</p>
<p>If you need multiple streams of random numbers,

View File

@@ -200,7 +200,7 @@
<h1>triton.language.randint4x<a class="headerlink" href="#triton-language-randint4x" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.randint4x">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">randint4x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.randint4x" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">randint4x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_rounds</span></span><span class="p"><span class="pre">:</span></span> <span class="n"><span class="pre">triton.language.core.constexpr</span></span> <span class="o"><span class="pre">=</span></span> <span class="default_value"><span class="pre">10</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.randint4x" title="Permalink to this definition"></a></dt>
<dd><p>Given a <code class="code docutils literal notranslate"><span class="pre">seed</span></code> scalar and an <code class="code docutils literal notranslate"><span class="pre">offset</span></code> block, returns four
blocks of random <code class="code docutils literal notranslate"><span class="pre">int32</span></code>.</p>
<p>This is the maximally efficient entry point

View File

@@ -201,7 +201,7 @@
<h1>triton.language.randn<a class="headerlink" href="#triton-language-randn" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.randn">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">randn</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.randn" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">randn</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">offset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_rounds</span></span><span class="p"><span class="pre">:</span></span> <span class="n"><span class="pre">triton.language.core.constexpr</span></span> <span class="o"><span class="pre">=</span></span> <span class="default_value"><span class="pre">10</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.randn" title="Permalink to this definition"></a></dt>
<dd><p>Given a <code class="code docutils literal notranslate"><span class="pre">seed</span></code> scalar and an <code class="code docutils literal notranslate"><span class="pre">offset</span></code> block,
returns a block of random <code class="code docutils literal notranslate"><span class="pre">float32</span></code> in <span class="math notranslate nohighlight">\(\mathcal{N}(0, 1)\)</span></p>
<dl class="field-list simple">

View File

@@ -203,7 +203,7 @@
<h1>triton.language.softmax<a class="headerlink" href="#triton-language-softmax" title="Permalink to this headline"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="triton.language.softmax">
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">softmax</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.softmax" title="Permalink to this definition"></a></dt>
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">softmax</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ieee_rounding</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.softmax" title="Permalink to this definition"></a></dt>
<dd><p>Computes the element-wise softmax of <code class="code docutils literal notranslate"><span class="pre">x</span></code></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>

View File

@@ -47,7 +47,7 @@
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="triton.jit" href="generated/triton.jit.html" />
<link rel="prev" title="Low-Memory Dropout" href="../getting-started/tutorials/04-low-memory-dropout.html" />
<link rel="prev" title="Layer Normalization" href="../getting-started/tutorials/05-layer-norm.html" />
</head>
<body class="wy-body-for-nav">
@@ -211,7 +211,7 @@
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="generated/triton.jit.html" class="btn btn-neutral float-right" title="triton.jit" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="../getting-started/tutorials/04-low-memory-dropout.html" class="btn btn-neutral float-left" title="Low-Memory Dropout" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="../getting-started/tutorials/05-layer-norm.html" class="btn btn-neutral float-left" title="Layer Normalization" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>

File diff suppressed because one or more lines are too long