2022-02-09 07:15:50 +00:00
|
|
|
|
|
|
|
.. DO NOT EDIT.
|
|
|
|
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
|
|
|
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
|
|
|
.. "getting-started/tutorials/05-layer-norm.py"
|
|
|
|
.. LINE NUMBERS ARE GIVEN BELOW.
|
|
|
|
|
|
|
|
.. only:: html
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
:class: sphx-glr-download-link-note
|
|
|
|
|
|
|
|
Click :ref:`here <sphx_glr_download_getting-started_tutorials_05-layer-norm.py>`
|
|
|
|
to download the full example code
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-example-title
|
|
|
|
|
|
|
|
.. _sphx_glr_getting-started_tutorials_05-layer-norm.py:
|
|
|
|
|
|
|
|
|
|
|
|
Layer Normalization
|
|
|
|
====================
|
|
|
|
|
|
|
|
.. GENERATED FROM PYTHON SOURCE LINES 5-252
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png
|
|
|
|
:alt: 05 layer norm
|
|
|
|
:class: sphx-glr-single-img
|
|
|
|
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-script-out
|
|
|
|
|
|
|
|
Out:
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
layer-norm-backward:
|
|
|
|
N Triton Torch Apex
|
2022-06-04 00:51:13 +00:00
|
|
|
0 1024.0 311.088617 98.303995 303.407414
|
|
|
|
1 1536.0 351.085717 133.083026 341.333333
|
|
|
|
2 2048.0 420.102553 161.684218 334.367350
|
|
|
|
3 2560.0 461.954908 182.314537 326.808501
|
|
|
|
4 3072.0 511.999982 191.501303 317.793096
|
|
|
|
5 3584.0 551.384634 207.768111 309.410081
|
|
|
|
6 4096.0 568.231237 219.919464 296.096389
|
|
|
|
7 4608.0 498.162157 232.336141 286.507772
|
|
|
|
8 5120.0 525.128191 242.366855 284.444444
|
|
|
|
9 5632.0 538.517949 243.107920 290.060087
|
|
|
|
10 6144.0 542.117638 248.242431 286.322318
|
|
|
|
11 6656.0 527.207907 256.000009 285.767438
|
|
|
|
12 7168.0 507.469040 261.446807 287.678923
|
|
|
|
13 7680.0 482.513091 261.076480 278.429013
|
|
|
|
14 8192.0 461.521112 268.223740 287.438585
|
|
|
|
15 8704.0 417.791980 267.130429 284.987724
|
|
|
|
16 9216.0 430.319054 272.059034 288.751954
|
2022-06-03 00:42:29 +00:00
|
|
|
17 9728.0 438.857162 280.278512 289.308559
|
2022-06-04 00:51:13 +00:00
|
|
|
18 10240.0 445.217381 286.433562 290.496460
|
|
|
|
19 10752.0 427.940303 246.464170 290.267711
|
|
|
|
20 11264.0 428.424741 245.091565 286.676558
|
|
|
|
21 11776.0 421.826879 249.447482 288.391833
|
|
|
|
22 12288.0 419.504980 254.453844 294.617366
|
|
|
|
23 12800.0 414.016170 253.674644 289.538159
|
|
|
|
24 13312.0 412.242569 252.360194 289.653667
|
|
|
|
25 13824.0 405.594132 256.991469 291.543045
|
|
|
|
26 14336.0 397.761846 254.297107 286.481278
|
|
|
|
27 14848.0 386.498925 257.479779 289.717061
|
|
|
|
28 15360.0 376.932517 257.970599 288.000007
|
|
|
|
29 15872.0 366.629453 262.708969 291.229369
|
2022-02-09 07:15:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import triton.language as tl
|
|
|
|
import triton
|
|
|
|
|
|
|
|
# Forward Pass
|
|
|
|
@triton.jit
|
|
|
|
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
|
|
|
BLOCK_SIZE = META['BLOCK_SIZE']
|
|
|
|
# position of elements processed by this program
|
|
|
|
row = tl.program_id(0)
|
|
|
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
|
|
mask = cols < N
|
|
|
|
# offset data pointers to start at the row of interest
|
|
|
|
X += row * stride
|
|
|
|
Y += row * stride
|
|
|
|
# load data and cast to float32
|
|
|
|
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
|
|
|
# compute mean
|
|
|
|
mean = tl.sum(x, axis=0) / N
|
|
|
|
# compute std
|
|
|
|
xmean = tl.where(mask, x - mean, 0.)
|
|
|
|
var = tl.sum(xmean * xmean, axis=0) / N
|
|
|
|
rstd = 1 / tl.sqrt(var + eps)
|
|
|
|
xhat = xmean*rstd
|
|
|
|
# write-back mean/rstd
|
|
|
|
tl.store(M + row, mean)
|
|
|
|
tl.store(V + row, rstd)
|
|
|
|
# multiply by weight and add bias
|
|
|
|
w = tl.load(W + cols, mask=mask)
|
|
|
|
b = tl.load(B + cols, mask=mask)
|
|
|
|
y = xhat * w + b
|
|
|
|
# write-back
|
|
|
|
tl.store(Y + cols, y, mask=mask)
|
|
|
|
|
|
|
|
|
|
|
|
# Backward pass (DX + partial DW + partial DB)
|
|
|
|
@triton.jit
|
|
|
|
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
|
|
|
stride, N, eps,
|
|
|
|
**META):
|
|
|
|
GROUP_SIZE_M = META['GROUP_SIZE_M']
|
|
|
|
BLOCK_SIZE_N = META['BLOCK_SIZE_N']
|
|
|
|
# position of elements processed by this program
|
|
|
|
row = tl.program_id(0)
|
|
|
|
cols = tl.arange(0, BLOCK_SIZE_N)
|
|
|
|
mask = cols < N
|
|
|
|
# offset data pointers to start at the row of interest
|
|
|
|
X += row * stride
|
|
|
|
DY += row * stride
|
|
|
|
DX += row * stride
|
|
|
|
# offset locks and weight/bias gradient pointer
|
|
|
|
# each kernel instance accumulates partial sums for
|
|
|
|
# DW and DB into one of GROUP_SIZE_M independent buffers
|
|
|
|
# these buffers stay in the L2, which allow this kernel
|
|
|
|
# to be fast
|
|
|
|
lock_id = row % GROUP_SIZE_M
|
|
|
|
Lock += lock_id
|
|
|
|
Count = Lock + GROUP_SIZE_M
|
|
|
|
DW = DW + lock_id*N + cols
|
|
|
|
DB = DB + lock_id*N + cols
|
|
|
|
# load data to SRAM
|
|
|
|
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
|
|
|
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
|
|
|
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
|
|
|
mean = tl.load(M + row)
|
|
|
|
rstd = tl.load(V + row)
|
|
|
|
# compute dx
|
|
|
|
xhat = (x - mean)*rstd
|
|
|
|
wdy = w * dy
|
|
|
|
xhat = tl.where(mask, xhat, 0.)
|
|
|
|
wdy = tl.where(mask, wdy , 0.)
|
|
|
|
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
|
|
|
mean2 = tl.sum(wdy, axis=0) / N
|
|
|
|
dx = (wdy - (xhat*mean1 + mean2))*rstd
|
|
|
|
# write-back dx
|
|
|
|
tl.store(DX + cols, dx, mask=mask)
|
|
|
|
# accumulate partial sums for dw/db
|
|
|
|
partial_dw = (dy*xhat).to(w.dtype)
|
|
|
|
partial_db = (dy).to(w.dtype)
|
|
|
|
while tl.atomic_cas(Lock, 0, 1) == 1:
|
|
|
|
pass
|
|
|
|
count = tl.load(Count)
|
|
|
|
# first store doesn't accumulate
|
|
|
|
if count == 0:
|
|
|
|
tl.atomic_xchg(Count, 1)
|
|
|
|
else:
|
|
|
|
partial_dw += tl.load(DW, mask=mask)
|
|
|
|
partial_db += tl.load(DB, mask=mask)
|
|
|
|
tl.store(DW, partial_dw, mask=mask)
|
|
|
|
tl.store(DB, partial_db, mask=mask)
|
|
|
|
# release lock
|
|
|
|
tl.atomic_xchg(Lock, 0)
|
|
|
|
|
|
|
|
# Backward pass (total DW + total DB)
|
|
|
|
@triton.jit
|
|
|
|
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
|
|
|
pid = tl.program_id(0)
|
|
|
|
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
|
|
|
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
|
|
|
cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
|
|
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
for i in range(0, M, BLOCK_SIZE_M):
|
|
|
|
rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
|
|
|
|
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
|
|
|
offs = rows[:, None]*N + cols[None, :]
|
|
|
|
dw += tl.load(DW + offs, mask=mask, other=0.)
|
|
|
|
db += tl.load(DB + offs, mask=mask, other=0.)
|
|
|
|
sum_dw = tl.sum(dw, axis=0)
|
|
|
|
sum_db = tl.sum(db, axis=0)
|
|
|
|
tl.store(FINAL_DW + cols, sum_dw, mask=cols<N)
|
|
|
|
tl.store(FINAL_DB + cols, sum_db, mask=cols<N)
|
|
|
|
|
|
|
|
class LayerNorm(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, x, normalized_shape, weight, bias, eps):
|
|
|
|
# allocate output
|
|
|
|
y = torch.empty_like(x)
|
|
|
|
# reshape input data into 2D tensor
|
|
|
|
x_arg = x.reshape(-1, x.shape[-1])
|
|
|
|
M, N = x_arg.shape
|
|
|
|
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
|
|
|
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
|
|
|
# Less than 64KB per feature: enqueue fused kernel
|
|
|
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
|
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
|
|
|
if N > BLOCK_SIZE:
|
|
|
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
|
|
# heuristics for number of warps
|
|
|
|
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
|
|
|
# enqueue kernel
|
|
|
|
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
|
|
|
x_arg.stride(0), N, eps,
|
|
|
|
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
|
|
|
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
|
|
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
|
|
ctx.num_warps = num_warps
|
|
|
|
ctx.eps = eps
|
|
|
|
return y
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, dy):
|
|
|
|
x, w, b, m, v = ctx.saved_tensors
|
|
|
|
# heuristics for amount of parallel reduction stream for DG/DB
|
|
|
|
N = w.shape[0]
|
|
|
|
GROUP_SIZE_M = 64
|
|
|
|
if N <= 8192: GROUP_SIZE_M = 96
|
|
|
|
if N <= 4096: GROUP_SIZE_M = 128
|
|
|
|
if N <= 1024: GROUP_SIZE_M = 256
|
|
|
|
# allocate output
|
|
|
|
locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
|
|
|
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
|
|
|
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
|
|
|
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
|
|
|
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
|
|
|
dx = torch.empty_like(dy)
|
|
|
|
# enqueue kernel using forward pass heuristics
|
|
|
|
# also compute partial sums for DW and DB
|
|
|
|
x_arg = x.reshape(-1, x.shape[-1])
|
|
|
|
M, N = x_arg.shape
|
|
|
|
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
|
|
|
x_arg.stride(0), N, ctx.eps,
|
|
|
|
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
|
|
|
GROUP_SIZE_M=GROUP_SIZE_M,
|
|
|
|
num_warps=ctx.num_warps)
|
|
|
|
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
|
|
|
# accumulate partial sums in separate kernel
|
|
|
|
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
|
|
|
BLOCK_SIZE_M = 32,
|
|
|
|
BLOCK_SIZE_N = 128)
|
|
|
|
return dx, None, dw, db, None
|
|
|
|
|
|
|
|
|
|
|
|
layer_norm = LayerNorm.apply
|
|
|
|
|
|
|
|
|
|
|
|
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
|
|
|
# create data
|
|
|
|
x_shape = (M, N)
|
|
|
|
w_shape = (x_shape[-1], )
|
|
|
|
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
|
|
|
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
|
|
|
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
|
|
|
|
dy = .1*torch.randn_like(x)
|
|
|
|
x.requires_grad_(True)
|
|
|
|
# forward pass
|
|
|
|
y_tri = layer_norm(x, w_shape, weight, bias, eps)
|
|
|
|
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
|
|
|
# backward pass (triton)
|
|
|
|
y_tri.backward(dy, retain_graph=True)
|
|
|
|
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
|
|
|
|
x.grad, weight.grad, bias.grad = None, None, None
|
|
|
|
# backward pass (torch)
|
|
|
|
y_ref.backward(dy, retain_graph=True)
|
|
|
|
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
|
|
|
|
# compare
|
|
|
|
triton.testing.assert_almost_equal(y_tri, y_ref)
|
|
|
|
triton.testing.assert_almost_equal(dx_tri, dx_ref)
|
|
|
|
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
|
|
|
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
|
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
|
|
triton.testing.Benchmark(
|
|
|
|
x_names=['N'],
|
|
|
|
x_vals=[512 * i for i in range(2, 32)],
|
|
|
|
line_arg='provider',
|
|
|
|
line_vals=['triton', 'torch', 'apex'],
|
|
|
|
line_names=['Triton', 'Torch', 'Apex'],
|
|
|
|
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
|
|
|
ylabel='GB/s',
|
|
|
|
plot_name='layer-norm-backward',
|
|
|
|
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
|
|
|
|
)
|
|
|
|
)
|
|
|
|
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'):
|
|
|
|
# create data
|
|
|
|
x_shape = (M, N)
|
|
|
|
w_shape = (x_shape[-1], )
|
|
|
|
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
|
|
|
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
|
|
|
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
|
|
|
|
dy = .1*torch.randn_like(x)
|
|
|
|
x.requires_grad_(True)
|
|
|
|
# utility functions
|
|
|
|
if provider == 'triton':
|
|
|
|
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
|
|
|
|
if provider == 'torch':
|
|
|
|
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
|
|
|
if provider == 'apex':
|
|
|
|
import apex
|
|
|
|
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
|
|
|
y_fwd = lambda: apex_layer_norm(x)
|
|
|
|
# forward pass
|
|
|
|
if mode == 'forward':
|
|
|
|
gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
|
|
|
|
# backward pass
|
|
|
|
if mode == 'backward':
|
|
|
|
gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6
|
|
|
|
y = y_fwd()
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
|
|
|
grad_to_none=[x], rep=500)
|
|
|
|
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
|
|
|
|
|
|
|
bench_layer_norm.run(save_path='.', print_data=True)
|
|
|
|
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-timing
|
|
|
|
|
2022-06-04 00:51:13 +00:00
|
|
|
**Total running time of the script:** ( 2 minutes 12.961 seconds)
|
2022-02-09 07:15:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
|
|
|
|
|
|
|
|
|
|
|
|
.. only :: html
|
|
|
|
|
|
|
|
.. container:: sphx-glr-footer
|
|
|
|
:class: sphx-glr-footer-example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. container:: sphx-glr-download sphx-glr-download-python
|
|
|
|
|
|
|
|
:download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
|
|
|
|
|
|
|
:download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>`
|
|
|
|
|
|
|
|
|
|
|
|
.. only:: html
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-signature
|
|
|
|
|
|
|
|
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|