{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Layer Normalization\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport triton.language as tl\nimport triton\n\n# Forward Pass\n@triton.jit\ndef _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # position of elements processed by this program\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n # offset data pointers to start at the row of interest\n X += row * stride\n Y += row * stride\n # load data and cast to float32\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n # compute mean\n mean = tl.sum(x, axis=0) / N\n # compute std\n xmean = tl.where(mask, x - mean, 0.)\n var = tl.sum(xmean * xmean, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n xhat = xmean*rstd\n # write-back mean/rstd\n tl.store(M + row, mean)\n tl.store(V + row, rstd)\n # multiply by weight and add bias\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n y = xhat * w + b\n # write-back\n tl.store(Y + cols, y, mask=mask)\n\n\n# Backward pass (DX + partial DW + partial DB)\n@triton.jit\ndef _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,\n stride, N, eps, \n **META):\n GROUP_SIZE_M = META['GROUP_SIZE_M']\n BLOCK_SIZE_N = META['BLOCK_SIZE_N']\n # position of elements processed by this program\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n # offset data pointers to start at the row of interest\n X += row * stride\n DY += row * stride\n DX += row * stride\n # offset locks and weight/bias gradient pointer\n # each kernel instance accumulates partial sums for\n # DW and DB into one of GROUP_SIZE_M independent buffers\n # these buffers stay in the L2, which allow this kernel\n # to be fast\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id*N + cols\n DB = DB + lock_id*N + cols\n # load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n mean = tl.load(M + row)\n rstd = tl.load(V + row)\n # compute dx\n xhat = (x - mean)*rstd\n wdy = w * dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy , 0.)\n mean1 = tl.sum(xhat * wdy, axis=0) / N\n mean2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat*mean1 + mean2))*rstd\n # write-back dx\n tl.store(DX + cols, dx, mask=mask)\n # accumulate partial sums for dw/db\n partial_dw = (dy*xhat).to(w.dtype)\n partial_db = (dy).to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n # first store doesn't accumulate\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n partial_db += tl.load(DB, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n tl.store(DB, partial_db, mask=mask)\n # release lock\n tl.atomic_xchg(Lock, 0)\n\n# Backward pass (total DW + total DB)\n@triton.jit\ndef _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):\n pid = tl.program_id(0)\n BLOCK_SIZE_M = meta['BLOCK_SIZE_M']\n BLOCK_SIZE_N = meta['BLOCK_SIZE_N']\n cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None]*N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n db += tl.load(DB + offs, mask=mask, other=0.)\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(FINAL_DW + cols, sum_dw, mask=cols BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, \n x_arg.stride(0), N, eps, \n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, b, m, v = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n # allocate output\n locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda')\n _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,\n x_arg.stride(0), N, ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE, \n GROUP_SIZE_M=GROUP_SIZE_M,\n num_warps=ctx.num_warps)\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n # accumulate partial sums in separate kernel\n _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, \n BLOCK_SIZE_M = 32, \n BLOCK_SIZE_N = 128)\n return dx, None, dw, db, None\n\n\nlayer_norm = LayerNorm.apply\n\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1*torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'],\n x_vals=[512 * i for i in range(2, 32)],\n line_arg='provider',\n line_vals=['triton', 'torch', 'apex'],\n line_names=['Triton', 'Torch', 'Apex'],\n styles=[('blue', '-'), ('green', '-'), ('orange', '-')],\n ylabel='GB/s',\n plot_name='layer-norm-backward',\n args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}\n )\n)\ndef bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1*torch.randn_like(x)\n x.requires_grad_(True)\n # utility functions\n if provider == 'triton':\n y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'torch':\n y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'apex':\n import apex\n apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)\n y_fwd = lambda: apex_layer_norm(x)\n # forward pass\n if mode == 'forward':\n gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6\n ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)\n # backward pass\n if mode == 'backward':\n gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6\n y = y_fwd()\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), \n grad_to_none=[x], rep=500)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\nbench_layer_norm.run(save_path='.', print_data=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 }