54 lines
11 KiB
Plaintext
54 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n# Layer Normalization\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n\nimport triton\nimport triton.language as tl\n\ntry:\n # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it\n # should not be added to extras_require in setup.py.\n import apex\n HAS_APEX = True\nexcept ModuleNotFoundError:\n HAS_APEX = False\n\n\n# Forward Pass\n@triton.jit\ndef _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,\n BLOCK_SIZE: tl.constexpr):\n # position of elements processed by this program\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n # offset data pointers to start at the row of interest\n X += row * stride\n Y += row * stride\n # load data and cast to float32\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n # compute mean\n mean = tl.sum(x, axis=0) / N\n # compute std\n xmean = tl.where(mask, x - mean, 0.)\n var = tl.sum(xmean * xmean, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n xhat = xmean * rstd\n # write-back mean/rstd\n tl.store(M + row, mean)\n tl.store(V + row, rstd)\n # multiply by weight and add bias\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n y = xhat * w + b\n # write-back\n tl.store(Y + cols, y, mask=mask)\n\n\n# Backward pass (DX + partial DW + partial DB)\n@triton.jit\ndef _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,\n GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n # position of elements processed by this program\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n # offset data pointers to start at the row of interest\n X += row * stride\n DY += row * stride\n DX += row * stride\n # offset locks and weight/bias gradient pointer\n # each kernel instance accumulates partial sums for\n # DW and DB into one of GROUP_SIZE_M independent buffers\n # these buffers stay in the L2, which allow this kernel\n # to be fast\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id * N + cols\n DB = DB + lock_id * N + cols\n # load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n mean = tl.load(M + row)\n rstd = tl.load(V + row)\n # compute dx\n xhat = (x - mean) * rstd\n wdy = w * dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n mean1 = tl.sum(xhat * wdy, axis=0) / N\n mean2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DX + cols, dx, mask=mask)\n # accumulate partial sums for dw/db\n partial_dw = (dy * xhat).to(w.dtype)\n partial_db = (dy).to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n # first store doesn't accumulate\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n partial_db += tl.load(DB, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n tl.store(DB, partial_db, mask=mask)\n # release lock\n tl.atomic_xchg(Lock, 0)\n\n# Backward pass (total DW + total DB)\n\n\n@triton.jit\ndef _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n db += tl.load(DB + offs, mask=mask, other=0.)\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)\n tl.store(FINAL_DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device='cuda')\n rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, b, m, v = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n # allocate output\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')\n _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,\n x_arg.stride(0), N, ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n GROUP_SIZE_M=GROUP_SIZE_M,\n num_warps=ctx.num_warps)\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n # accumulate partial sums in separate kernel\n _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128)\n return dx, None, dw, db, None\n\n\nlayer_norm = LayerNorm.apply\n\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'],\n x_vals=[512 * i for i in range(2, 32)],\n line_arg='provider',\n line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),\n line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),\n styles=[('blue', '-'), ('green', '-'), ('orange', '-')],\n ylabel='GB/s',\n plot_name='layer-norm-backward',\n args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}\n )\n)\ndef bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # utility functions\n if provider == 'triton':\n y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'torch':\n y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'apex':\n apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)\n y_fwd = lambda: apex_layer_norm(x)\n # forward pass\n if mode == 'forward':\n gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6\n ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)\n # backward pass\n if mode == 'backward':\n gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6\n y = y_fwd()\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),\n grad_to_none=[x], rep=500)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbench_layer_norm.run(save_path='.', print_data=True)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |