.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/05-layer-norm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_05-layer-norm.py: Layer Normalization ==================== .. GENERATED FROM PYTHON SOURCE LINES 5-252 .. image:: /getting-started/tutorials/images/sphx_glr_05-layer-norm_001.png :alt: 05 layer norm :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none layer-norm-backward: N Triton Torch Apex 0 1024.0 311.088617 99.902435 311.088617 1 1536.0 354.461542 133.083026 341.333333 2 2048.0 427.408686 158.554837 321.254900 3 2560.0 461.954908 182.857144 323.368411 4 3072.0 515.580429 191.999993 319.168834 5 3584.0 551.384634 208.271186 309.410081 6 4096.0 568.231237 219.919464 299.707322 7 4608.0 500.416301 232.825259 287.251954 8 5120.0 529.655159 243.809526 289.811322 9 5632.0 540.671974 244.869560 291.310338 10 6144.0 548.163546 251.631408 288.000001 11 6656.0 534.260858 256.000009 286.279570 12 7168.0 516.612607 254.485198 278.820105 13 7680.0 487.619051 266.743841 284.884090 14 8192.0 468.114289 257.003920 276.912679 15 8704.0 416.958106 267.815384 285.767450 16 9216.0 430.319054 274.081793 289.887291 17 9728.0 439.683593 280.278512 289.308559 18 10240.0 446.025405 287.102804 290.153487 19 10752.0 430.797982 246.699797 289.291486 20 11264.0 429.104745 246.656943 286.980888 21 11776.0 422.457417 250.109737 288.981596 22 12288.0 419.504980 254.893699 294.323369 23 12800.0 414.574901 254.094291 288.993430 24 13312.0 413.309181 252.759501 289.653667 25 13824.0 407.587209 257.390218 292.056329 26 14336.0 395.930964 255.429842 288.644296 27 14848.0 386.918555 257.293872 287.380642 28 15360.0 375.015246 258.513318 286.656296 29 15872.0 368.046389 261.267482 289.679087 | .. code-block:: default import torch import triton.language as tl import triton # Forward Pass @triton.jit def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE) mask = cols < N # offset data pointers to start at the row of interest X += row * stride Y += row * stride # load data and cast to float32 x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) # compute mean mean = tl.sum(x, axis=0) / N # compute std xmean = tl.where(mask, x - mean, 0.) var = tl.sum(xmean * xmean, axis=0) / N rstd = 1 / tl.sqrt(var + eps) xhat = xmean*rstd # write-back mean/rstd tl.store(M + row, mean) tl.store(V + row, rstd) # multiply by weight and add bias w = tl.load(W + cols, mask=mask) b = tl.load(B + cols, mask=mask) y = xhat * w + b # write-back tl.store(Y + cols, y, mask=mask) # Backward pass (DX + partial DW + partial DB) @triton.jit def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps, **META): GROUP_SIZE_M = META['GROUP_SIZE_M'] BLOCK_SIZE_N = META['BLOCK_SIZE_N'] # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) mask = cols < N # offset data pointers to start at the row of interest X += row * stride DY += row * stride DX += row * stride # offset locks and weight/bias gradient pointer # each kernel instance accumulates partial sums for # DW and DB into one of GROUP_SIZE_M independent buffers # these buffers stay in the L2, which allow this kernel # to be fast lock_id = row % GROUP_SIZE_M Lock += lock_id Count = Lock + GROUP_SIZE_M DW = DW + lock_id*N + cols DB = DB + lock_id*N + cols # load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) w = tl.load(W + cols, mask=mask).to(tl.float32) mean = tl.load(M + row) rstd = tl.load(V + row) # compute dx xhat = (x - mean)*rstd wdy = w * dy xhat = tl.where(mask, xhat, 0.) wdy = tl.where(mask, wdy , 0.) mean1 = tl.sum(xhat * wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat*mean1 + mean2))*rstd # write-back dx tl.store(DX + cols, dx, mask=mask) # accumulate partial sums for dw/db partial_dw = (dy*xhat).to(w.dtype) partial_db = (dy).to(w.dtype) while tl.atomic_cas(Lock, 0, 1) == 1: pass count = tl.load(Count) # first store doesn't accumulate if count == 0: tl.atomic_xchg(Count, 1) else: partial_dw += tl.load(DW, mask=mask) partial_db += tl.load(DB, mask=mask) tl.store(DW, partial_dw, mask=mask) tl.store(DB, partial_db, mask=mask) # release lock tl.atomic_xchg(Lock, 0) # Backward pass (total DW + total DB) @triton.jit def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): pid = tl.program_id(0) BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) mask = (rows[:, None] < M) & (cols[None, :] < N) offs = rows[:, None]*N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw, mask=cols BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.eps = eps return y @staticmethod def backward(ctx, dy): x, w, b, m, v = ctx.saved_tensors # heuristics for amount of parallel reduction stream for DG/DB N = w.shape[0] GROUP_SIZE_M = 64 if N <= 8192: GROUP_SIZE_M = 96 if N <= 4096: GROUP_SIZE_M = 128 if N <= 1024: GROUP_SIZE_M = 256 # allocate output locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda') _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, x_arg.stride(0), N, ctx.eps, BLOCK_SIZE_N=ctx.BLOCK_SIZE, GROUP_SIZE_M=GROUP_SIZE_M, num_warps=ctx.num_warps) grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # accumulate partial sums in separate kernel _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, BLOCK_SIZE_M = 32, BLOCK_SIZE_N = 128) return dx, None, dw, db, None layer_norm = LayerNorm.apply def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') dy = .1*torch.randn_like(x) x.requires_grad_(True) # forward pass y_tri = layer_norm(x, w_shape, weight, bias, eps) y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) # backward pass (triton) y_tri.backward(dy, retain_graph=True) dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]] x.grad, weight.grad, bias.grad = None, None, None # backward pass (torch) y_ref.backward(dy, retain_graph=True) dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]] # compare triton.testing.assert_almost_equal(y_tri, y_ref) triton.testing.assert_almost_equal(dx_tri, dx_ref) triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], x_vals=[512 * i for i in range(2, 32)], line_arg='provider', line_vals=['triton', 'torch', 'apex'], line_names=['Triton', 'Torch', 'Apex'], styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} ) ) def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') dy = .1*torch.randn_like(x) x.requires_grad_(True) # utility functions if provider == 'triton': y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps) if provider == 'torch': y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) if provider == 'apex': import apex apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) y_fwd = lambda: apex_layer_norm(x) # forward pass if mode == 'forward': gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500) # backward pass if mode == 'backward': gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6 y = y_fwd() ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) bench_layer_norm.run(save_path='.', print_data=True) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 12.791 seconds) .. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 05-layer-norm.py <05-layer-norm.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 05-layer-norm.ipynb <05-layer-norm.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_