.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/05-layer-norm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_05-layer-norm.py: Layer Normalization ==================== .. GENERATED FROM PYTHON SOURCE LINES 5-252 .. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png :alt: 05 layer norm :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none layer-norm-backward: N Triton Torch Apex 0 1024.0 307.200008 98.303995 303.407414 1 1536.0 351.085717 133.083026 341.333333 2 2048.0 420.102553 161.684218 334.367350 3 2560.0 461.954908 182.314537 326.808501 4 3072.0 511.999982 191.501303 317.793096 5 3584.0 551.384634 207.768111 309.410081 6 4096.0 568.231237 219.919464 296.096389 7 4608.0 498.162157 232.336141 286.507772 8 5120.0 525.128191 242.366855 284.444444 9 5632.0 538.517949 243.107920 289.438969 10 6144.0 540.131844 248.242431 285.767458 11 6656.0 527.207907 256.000009 285.767438 12 7168.0 507.469040 261.446807 287.678923 13 7680.0 482.513091 261.446804 278.850215 14 8192.0 461.521112 267.858310 287.018988 15 8704.0 417.791980 267.130429 284.987724 16 9216.0 430.319054 272.059034 288.563595 17 9728.0 439.683593 280.278512 289.308559 18 10240.0 445.217381 286.433562 291.184826 19 10752.0 427.940303 246.699797 290.267711 20 11264.0 428.424741 245.091565 286.372873 21 11776.0 421.826879 249.447482 288.391833 22 12288.0 419.504980 254.673582 294.911986 23 12800.0 414.016170 253.674644 288.180121 24 13312.0 412.242569 252.360194 289.653667 25 13824.0 405.594132 257.190689 291.543045 26 14336.0 397.761846 254.109315 286.481278 27 14848.0 386.498925 257.293872 289.952797 28 15360.0 376.932517 257.970599 287.550706 29 15872.0 366.982663 262.708969 291.229369 | .. code-block:: default import torch import triton.language as tl import triton # Forward Pass @triton.jit def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE) mask = cols < N # offset data pointers to start at the row of interest X += row * stride Y += row * stride # load data and cast to float32 x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) # compute mean mean = tl.sum(x, axis=0) / N # compute std xmean = tl.where(mask, x - mean, 0.) var = tl.sum(xmean * xmean, axis=0) / N rstd = 1 / tl.sqrt(var + eps) xhat = xmean*rstd # write-back mean/rstd tl.store(M + row, mean) tl.store(V + row, rstd) # multiply by weight and add bias w = tl.load(W + cols, mask=mask) b = tl.load(B + cols, mask=mask) y = xhat * w + b # write-back tl.store(Y + cols, y, mask=mask) # Backward pass (DX + partial DW + partial DB) @triton.jit def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps, **META): GROUP_SIZE_M = META['GROUP_SIZE_M'] BLOCK_SIZE_N = META['BLOCK_SIZE_N'] # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) mask = cols < N # offset data pointers to start at the row of interest X += row * stride DY += row * stride DX += row * stride # offset locks and weight/bias gradient pointer # each kernel instance accumulates partial sums for # DW and DB into one of GROUP_SIZE_M independent buffers # these buffers stay in the L2, which allow this kernel # to be fast lock_id = row % GROUP_SIZE_M Lock += lock_id Count = Lock + GROUP_SIZE_M DW = DW + lock_id*N + cols DB = DB + lock_id*N + cols # load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) w = tl.load(W + cols, mask=mask).to(tl.float32) mean = tl.load(M + row) rstd = tl.load(V + row) # compute dx xhat = (x - mean)*rstd wdy = w * dy xhat = tl.where(mask, xhat, 0.) wdy = tl.where(mask, wdy , 0.) mean1 = tl.sum(xhat * wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat*mean1 + mean2))*rstd # write-back dx tl.store(DX + cols, dx, mask=mask) # accumulate partial sums for dw/db partial_dw = (dy*xhat).to(w.dtype) partial_db = (dy).to(w.dtype) while tl.atomic_cas(Lock, 0, 1) == 1: pass count = tl.load(Count) # first store doesn't accumulate if count == 0: tl.atomic_xchg(Count, 1) else: partial_dw += tl.load(DW, mask=mask) partial_db += tl.load(DB, mask=mask) tl.store(DW, partial_dw, mask=mask) tl.store(DB, partial_db, mask=mask) # release lock tl.atomic_xchg(Lock, 0) # Backward pass (total DW + total DB) @triton.jit def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): pid = tl.program_id(0) BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) mask = (rows[:, None] < M) & (cols[None, :] < N) offs = rows[:, None]*N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw, mask=cols BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.eps = eps return y @staticmethod def backward(ctx, dy): x, w, b, m, v = ctx.saved_tensors # heuristics for amount of parallel reduction stream for DG/DB N = w.shape[0] GROUP_SIZE_M = 64 if N <= 8192: GROUP_SIZE_M = 96 if N <= 4096: GROUP_SIZE_M = 128 if N <= 1024: GROUP_SIZE_M = 256 # allocate output locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda') _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, x_arg.stride(0), N, ctx.eps, BLOCK_SIZE_N=ctx.BLOCK_SIZE, GROUP_SIZE_M=GROUP_SIZE_M, num_warps=ctx.num_warps) grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # accumulate partial sums in separate kernel _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, BLOCK_SIZE_M = 32, BLOCK_SIZE_N = 128) return dx, None, dw, db, None layer_norm = LayerNorm.apply def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') dy = .1*torch.randn_like(x) x.requires_grad_(True) # forward pass y_tri = layer_norm(x, w_shape, weight, bias, eps) y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) # backward pass (triton) y_tri.backward(dy, retain_graph=True) dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]] x.grad, weight.grad, bias.grad = None, None, None # backward pass (torch) y_ref.backward(dy, retain_graph=True) dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]] # compare triton.testing.assert_almost_equal(y_tri, y_ref) triton.testing.assert_almost_equal(dx_tri, dx_ref) triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], x_vals=[512 * i for i in range(2, 32)], line_arg='provider', line_vals=['triton', 'torch', 'apex'], line_names=['Triton', 'Torch', 'Apex'], styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} ) ) def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') dy = .1*torch.randn_like(x) x.requires_grad_(True) # utility functions if provider == 'triton': y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps) if provider == 'torch': y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) if provider == 'apex': import apex apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) y_fwd = lambda: apex_layer_norm(x) # forward pass if mode == 'forward': gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500) # backward pass if mode == 'backward': gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6 y = y_fwd() ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) bench_layer_norm.run(save_path='.', print_data=True) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 13.206 seconds) .. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_