{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Layer Normalization\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n\nimport triton\nimport triton.language as tl\n\ntry:\n # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it\n # should not be added to extras_require in setup.py.\n import apex\n HAS_APEX = True\nexcept ModuleNotFoundError:\n HAS_APEX = False\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n# Backward pass (DA + partial DW + partial DB)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n # position of elements processed by this program\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # load data to SRAM\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DA + cols, da, mask=mask)\n\n\n# Backward pass (total DW + total DB)\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)\n dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M, other=0.)\n rstd = tl.load(Var + rows, mask=rows < M, other=0.)\n a_hat = (a - mean[:, None]) * rstd[:, None]\n dw += dout * a_hat\n db += dout\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(DW + cols, sum_dw, mask=cols < N)\n tl.store(DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, a, normalized_shape, weight, bias, eps):\n # allocate output\n out = torch.empty_like(a)\n # reshape input data into 2D tensor\n a_arg = a.reshape(-1, a.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M,)](\n out,\n a_arg,\n weight,\n bias,\n mean, rstd,\n a_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(\n a, weight, bias, mean, rstd,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n if hasattr(bias, \"config\"):\n assert bias.config.grad_scale_name == weight.config.grad_scale_name\n grad_scale_name = bias.config.grad_scale_name\n else:\n grad_scale_name = None\n ctx.grad_scale_gain_bias_name = grad_scale_name\n return out\n\n @staticmethod\n def backward(ctx, dout):\n assert dout.is_contiguous()\n a, weight, bias, mean, var = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = weight.shape[0]\n # allocate output\n da = torch.empty_like(dout)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = a.reshape(-1, a.shape[-1])\n M, N = x_arg.shape\n dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n _layer_norm_bwd_dx_fused[(M,)](\n da,\n dout,\n a,\n weight,\n mean, var,\n x_arg.stride(0), M, N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n # accumulate partial sums in separate kernel\n grid = lambda meta: [triton.cdiv(N, meta[\"BLOCK_SIZE_N\"])]\n _layer_norm_bwd_dwdb[grid](\n a, dout,\n mean, var,\n dweight,\n dbias,\n M,\n N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128,\n )\n return (da, None, dweight, dbias, None, None,\n None, None, None, None,\n None,\n None, None, None,\n None,\n None, None, None,\n None, None, None,\n None, None, None)\n\n\ndef layer_norm(a, normalized_shape, weight, bias, eps):\n return LayerNorm.apply(a, normalized_shape, weight, bias, eps)\n\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n torch.manual_seed(0)\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'],\n x_vals=[512 * i for i in range(2, 32)],\n line_arg='provider',\n line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),\n line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),\n styles=[('blue', '-'), ('green', '-'), ('orange', '-')],\n ylabel='GB/s',\n plot_name='layer-norm',\n args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}\n )\n)\ndef bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # utility functions\n if provider == 'triton':\n y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'torch':\n y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'apex':\n apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)\n y_fwd = lambda: apex_layer_norm(x)\n # forward pass\n if mode == 'forward':\n gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6\n ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)\n # backward pass\n if mode == 'backward':\n gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6\n y = y_fwd()\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),\n grad_to_none=[x], rep=500)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\n# test_layer_norm(1151, 8192, torch.float16)\nbench_layer_norm.run(save_path='.', print_data=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 }