.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/05-layer-norm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_05-layer-norm.py: Layer Normalization ==================== .. GENERATED FROM PYTHON SOURCE LINES 5-252 .. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png :alt: 05 layer norm :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none layer-norm-backward: N Triton Torch Apex 0 1024.0 311.088617 98.303995 307.200008 1 1536.0 354.461542 135.032961 341.333333 2 2048.0 423.724127 160.104230 323.368435 3 2560.0 461.954908 180.705883 326.808501 4 3072.0 519.211251 191.999993 319.168834 5 3584.0 551.384634 208.776702 309.410081 6 4096.0 568.231237 220.907859 300.623865 7 4608.0 500.416301 232.825259 292.571431 8 5120.0 527.381977 244.294240 289.129408 9 5632.0 545.032265 244.869560 291.310338 10 6144.0 550.208948 251.202731 288.563606 11 6656.0 539.675652 256.410903 286.793541 12 7168.0 512.000004 253.734520 277.470965 13 7680.0 491.519996 267.130429 284.884090 14 8192.0 463.698115 259.377306 280.467910 15 8704.0 416.958106 267.130429 285.377055 16 9216.0 430.319054 273.742580 289.887291 17 9728.0 438.033784 280.615388 288.950501 18 10240.0 446.836366 286.767793 289.811322 19 10752.0 431.518385 246.699797 289.941565 20 11264.0 430.471331 245.313973 285.465683 21 11776.0 421.826879 249.227509 289.277383 22 12288.0 418.909088 254.893699 294.617366 23 12800.0 414.574901 253.674644 289.811310 24 13312.0 412.242569 252.559690 290.179836 25 13824.0 404.604870 257.190689 292.313649 26 14336.0 398.683664 256.000002 289.129416 27 14848.0 382.351933 257.479779 289.012175 28 15360.0 377.704925 258.151252 288.000007 29 15872.0 368.046389 261.806182 289.899545 | .. code-block:: default import torch import triton.language as tl import triton # Forward Pass @triton.jit def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE) mask = cols < N # offset data pointers to start at the row of interest X += row * stride Y += row * stride # load data and cast to float32 x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) # compute mean mean = tl.sum(x, axis=0) / N # compute std xmean = tl.where(mask, x - mean, 0.) var = tl.sum(xmean * xmean, axis=0) / N rstd = 1 / tl.sqrt(var + eps) xhat = xmean*rstd # write-back mean/rstd tl.store(M + row, mean) tl.store(V + row, rstd) # multiply by weight and add bias w = tl.load(W + cols, mask=mask) b = tl.load(B + cols, mask=mask) y = xhat * w + b # write-back tl.store(Y + cols, y, mask=mask) # Backward pass (DX + partial DW + partial DB) @triton.jit def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps, **META): GROUP_SIZE_M = META['GROUP_SIZE_M'] BLOCK_SIZE_N = META['BLOCK_SIZE_N'] # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) mask = cols < N # offset data pointers to start at the row of interest X += row * stride DY += row * stride DX += row * stride # offset locks and weight/bias gradient pointer # each kernel instance accumulates partial sums for # DW and DB into one of GROUP_SIZE_M independent buffers # these buffers stay in the L2, which allow this kernel # to be fast lock_id = row % GROUP_SIZE_M Lock += lock_id Count = Lock + GROUP_SIZE_M DW = DW + lock_id*N + cols DB = DB + lock_id*N + cols # load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) w = tl.load(W + cols, mask=mask).to(tl.float32) mean = tl.load(M + row) rstd = tl.load(V + row) # compute dx xhat = (x - mean)*rstd wdy = w * dy xhat = tl.where(mask, xhat, 0.) wdy = tl.where(mask, wdy , 0.) mean1 = tl.sum(xhat * wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat*mean1 + mean2))*rstd # write-back dx tl.store(DX + cols, dx, mask=mask) # accumulate partial sums for dw/db partial_dw = (dy*xhat).to(w.dtype) partial_db = (dy).to(w.dtype) while tl.atomic_cas(Lock, 0, 1) == 1: pass count = tl.load(Count) # first store doesn't accumulate if count == 0: tl.atomic_xchg(Count, 1) else: partial_dw += tl.load(DW, mask=mask) partial_db += tl.load(DB, mask=mask) tl.store(DW, partial_dw, mask=mask) tl.store(DB, partial_db, mask=mask) # release lock tl.atomic_xchg(Lock, 0) # Backward pass (total DW + total DB) @triton.jit def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): pid = tl.program_id(0) BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) mask = (rows[:, None] < M) & (cols[None, :] < N) offs = rows[:, None]*N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw, mask=cols BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.eps = eps return y @staticmethod def backward(ctx, dy): x, w, b, m, v = ctx.saved_tensors # heuristics for amount of parallel reduction stream for DG/DB N = w.shape[0] GROUP_SIZE_M = 64 if N <= 8192: GROUP_SIZE_M = 96 if N <= 4096: GROUP_SIZE_M = 128 if N <= 1024: GROUP_SIZE_M = 256 # allocate output locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda') _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, x_arg.stride(0), N, ctx.eps, BLOCK_SIZE_N=ctx.BLOCK_SIZE, GROUP_SIZE_M=GROUP_SIZE_M, num_warps=ctx.num_warps) grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # accumulate partial sums in separate kernel _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, BLOCK_SIZE_M = 32, BLOCK_SIZE_N = 128) return dx, None, dw, db, None layer_norm = LayerNorm.apply def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') dy = .1*torch.randn_like(x) x.requires_grad_(True) # forward pass y_tri = layer_norm(x, w_shape, weight, bias, eps) y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) # backward pass (triton) y_tri.backward(dy, retain_graph=True) dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]] x.grad, weight.grad, bias.grad = None, None, None # backward pass (torch) y_ref.backward(dy, retain_graph=True) dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]] # compare triton.testing.assert_almost_equal(y_tri, y_ref) triton.testing.assert_almost_equal(dx_tri, dx_ref) triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], x_vals=[512 * i for i in range(2, 32)], line_arg='provider', line_vals=['triton', 'torch', 'apex'], line_names=['Triton', 'Torch', 'Apex'], styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} ) ) def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') dy = .1*torch.randn_like(x) x.requires_grad_(True) # utility functions if provider == 'triton': y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps) if provider == 'torch': y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) if provider == 'apex': import apex apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) y_fwd = lambda: apex_layer_norm(x) # forward pass if mode == 'forward': gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500) # backward pass if mode == 'backward': gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6 y = y_fwd() ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) bench_layer_norm.run(save_path='.', print_data=True) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 11.851 seconds) .. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_