Files
triton/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb

54 lines
12 KiB
Plaintext
Raw Normal View History

2022-06-05 21:05:02 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Layer Normalization\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n\nimport triton\nimport triton.language as tl\n\ntry:\n # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it\n # should not be added to extras_require in setup.py.\n import apex\n HAS_APEX = True\nexcept ModuleNotFoundError:\n HAS_APEX = False\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n# Backward pass (DA + partial DW + partial DB)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n # position of elements processed by this program\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # load data to SRAM\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DA + cols, da, mask=mask)\n\n\n# Backward pass (total DW + total DB)\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i i
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}