Files
triton/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt

425 lines
15 KiB
Plaintext
Raw Normal View History

2022-06-05 21:05:02 +00:00
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/05-layer-norm.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_05-layer-norm.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_05-layer-norm.py:
Layer Normalization
====================
2022-07-14 07:22:19 +00:00
.. GENERATED FROM PYTHON SOURCE LINES 5-316
2022-06-05 21:05:02 +00:00
.. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png
:alt: 05 layer norm
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
layer-norm:
N Triton Torch Apex
2022-08-07 00:51:30 +00:00
0 1024.0 585.142849 277.694907 481.882344
2022-07-17 00:49:40 +00:00
1 1536.0 630.153868 323.368435 511.999982
2022-08-07 00:51:30 +00:00
2 2048.0 668.734716 337.814445 520.126988
3 2560.0 694.237267 362.477870 518.481028
4 3072.0 702.171410 375.206126 501.551037
2022-08-06 00:49:10 +00:00
5 3584.0 725.873439 384.859062 455.111115
2022-08-07 00:51:30 +00:00
6 4096.0 728.177767 383.251446 451.972420
7 4608.0 670.254540 396.387087 421.302872
8 5120.0 688.403381 397.669909 422.268057
9 5632.0 704.000002 398.725657 413.357796
10 6144.0 702.171410 402.885254 411.313806
11 6656.0 700.631610 400.360920 400.360920
12 7168.0 690.891575 383.571898 381.023265
2022-08-06 00:49:10 +00:00
13 7680.0 678.895043 392.587863 386.415087
2022-08-07 00:51:30 +00:00
14 8192.0 636.271854 390.095241 375.564460
15 8704.0 627.315309 392.292962 380.502740
16 9216.0 609.322328 403.989025 381.023249
17 9728.0 587.350922 408.524944 382.427505
18 10240.0 566.920437 408.578556 382.803739
19 10752.0 547.872604 412.546760 379.761601
20 11264.0 531.634232 396.096702 369.311483
21 11776.0 521.927959 408.711507 378.345375
22 12288.0 514.007840 413.911572 383.251457
23 12800.0 504.433489 410.420828 377.163903
24 13312.0 494.180982 404.159395 376.310952
25 13824.0 481.882350 409.600016 378.739711
26 14336.0 471.967074 400.307157 369.961287
27 14848.0 461.297068 404.027214 375.304904
28 15360.0 454.269882 406.214870 378.092307
29 15872.0 447.098578 408.940410 376.783377
2022-06-05 21:05:02 +00:00
|
.. code-block:: default
import torch
import triton
import triton.language as tl
try:
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
# should not be added to extras_require in setup.py.
import apex
HAS_APEX = True
except ModuleNotFoundError:
HAS_APEX = False
@triton.jit
def _layer_norm_fwd_fused(
Out,
A,
Weight,
Bias,
Mean, Rstd,
stride, N, eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# write-back mean/rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(
_DA,
_DOut,
_A,
Weight,
Mean, Rstd,
stride, NumRows, NumCols, eps,
BLOCK_SIZE_N: tl.constexpr,
):
# position of elements processed by this program
pid = tl.program_id(0)
row = pid
A = _A + row * stride
DOut = _DOut + row * stride
DA = _DA + row * stride
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# load data to SRAM
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
_mean1 += a_hat * wdout
_mean2 += wdout
mean1 = tl.sum(_mean1, axis=0) / NumCols
mean2 = 0.
mean2 = tl.sum(_mean2, axis=0) / NumCols
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DA + cols, da, mask=mask)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(
A, DOut,
Mean, Var,
DW,
DB,
M, N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
2022-07-14 07:22:19 +00:00
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
2022-06-05 21:05:02 +00:00
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
tl.store(DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, a, normalized_shape, weight, bias, eps):
# allocate output
out = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.reshape(-1, a.shape[-1])
M, N = a_arg.shape
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_layer_norm_fwd_fused[(M,)](
out,
a_arg,
weight,
bias,
mean, rstd,
a_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.save_for_backward(
a, weight, bias, mean, rstd,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
if hasattr(bias, "config"):
assert bias.config.grad_scale_name == weight.config.grad_scale_name
grad_scale_name = bias.config.grad_scale_name
else:
grad_scale_name = None
ctx.grad_scale_gain_bias_name = grad_scale_name
return out
@staticmethod
def backward(ctx, dout):
assert dout.is_contiguous()
a, weight, bias, mean, var = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = weight.shape[0]
# allocate output
da = torch.empty_like(dout)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = a.reshape(-1, a.shape[-1])
M, N = x_arg.shape
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
_layer_norm_bwd_dx_fused[(M,)](
da,
dout,
a,
weight,
mean, var,
x_arg.stride(0), M, N,
ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
2022-07-14 07:22:19 +00:00
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
2022-06-05 21:05:02 +00:00
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
mean, var,
dweight,
dbias,
M,
N,
2022-07-14 07:22:19 +00:00
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
2022-06-05 21:05:02 +00:00
)
2022-07-14 07:22:19 +00:00
return (da, None, dweight, dbias, None)
2022-06-05 21:05:02 +00:00
def layer_norm(a, normalized_shape, weight, bias, eps):
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
torch.manual_seed(0)
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[512 * i for i in range(2, 32)],
line_arg='provider',
line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
# test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True)
.. rst-class:: sphx-glr-timing
2022-08-07 00:51:30 +00:00
**Total running time of the script:** ( 5 minutes 29.729 seconds)
2022-06-05 21:05:02 +00:00
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_