Files
triton/v1.1.2/_sources/getting-started/tutorials/05-layer-norm.rst.txt
2022-04-27 00:47:37 +00:00

361 lines
14 KiB
ReStructuredText

.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/05-layer-norm.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_05-layer-norm.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_05-layer-norm.py:
Layer Normalization
====================
.. GENERATED FROM PYTHON SOURCE LINES 5-252
.. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png
:alt: 05 layer norm
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
layer-norm-backward:
N Triton Torch Apex
0 1024.0 307.200008 99.497980 315.076934
1 1536.0 347.773587 133.083026 344.523365
2 2048.0 423.724127 158.045011 334.367350
3 2560.0 461.954908 182.314537 330.322572
4 3072.0 515.580429 191.501303 316.429186
5 3584.0 551.384634 207.768111 310.527060
6 4096.0 564.965515 219.919464 298.796351
7 4608.0 498.162157 231.849059 290.267724
8 5120.0 529.655159 243.809526 289.811322
9 5632.0 540.671974 244.426754 291.310338
10 6144.0 546.133354 251.202731 288.000001
11 6656.0 534.260858 256.000009 286.793541
12 7168.0 516.612607 253.734520 277.470965
13 7680.0 490.212752 266.358392 284.444450
14 8192.0 463.698115 258.694729 277.694924
15 8704.0 416.127506 267.130429 285.767450
16 9216.0 430.319054 273.066667 289.129410
17 9728.0 439.683593 280.278512 289.308559
18 10240.0 446.836366 286.767793 287.438599
19 10752.0 432.241202 246.699797 289.941565
20 11264.0 429.104745 246.432094 287.897767
21 11776.0 422.457417 249.667843 289.573776
22 12288.0 420.102570 254.673582 294.323369
23 12800.0 414.016170 253.884294 290.359162
24 13312.0 410.652963 253.360814 289.391298
25 13824.0 406.588243 257.190689 292.056329
26 14336.0 395.475867 255.619613 288.402346
27 14848.0 385.245405 257.015513 288.544136
28 15360.0 376.932517 258.513318 288.902809
29 15872.0 369.474279 261.626369 289.679087
|
.. code-block:: default
import torch
import triton.language as tl
import triton
# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
Y += row * stride
# load data and cast to float32
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
# compute mean
mean = tl.sum(x, axis=0) / N
# compute std
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean*rstd
# write-back mean/rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
# multiply by weight and add bias
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
y = xhat * w + b
# write-back
tl.store(Y + cols, y, mask=mask)
# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
stride, N, eps,
**META):
GROUP_SIZE_M = META['GROUP_SIZE_M']
BLOCK_SIZE_N = META['BLOCK_SIZE_N']
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
DY += row * stride
DX += row * stride
# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW and DB into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id*N + cols
DB = DB + lock_id*N + cols
# load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean)*rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy , 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat*mean1 + mean2))*rstd
# write-back dx
tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db
partial_dw = (dy*xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# first store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# release lock
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
pid = tl.program_id(0)
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None]*N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols<N)
tl.store(FINAL_DB + cols, sum_db, mask=cols<N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
x, w, b, m, v = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M = 32,
BLOCK_SIZE_N = 128)
return dx, None, dw, db, None
layer_norm = LayerNorm.apply
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
triton.testing.assert_almost_equal(y_tri, y_ref)
triton.testing.assert_almost_equal(dx_tri, dx_ref)
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[512 * i for i in range(2, 32)],
line_arg='provider',
line_vals=['triton', 'torch', 'apex'],
line_names=['Triton', 'Torch', 'Apex'],
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
if provider == 'apex':
import apex
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 2 minutes 10.422 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_