[TUTORIALS] Layer norm tutorial now uses residency control (#510)
This commit is contained in:
@@ -3,11 +3,9 @@ Layer Normalization
|
||||
====================
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import torch
|
||||
try:
|
||||
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
|
||||
# should not be added to extras_require in setup.py.
|
||||
@@ -16,99 +14,113 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
HAS_APEX = False
|
||||
|
||||
# fmt: off
|
||||
|
||||
# Forward Pass
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _layer_norm_fwd_fused(
|
||||
Out,
|
||||
A,
|
||||
Weight,
|
||||
Bias,
|
||||
Mean, Rstd,
|
||||
stride, N, eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
X += row * stride
|
||||
Y += row * stride
|
||||
# load data and cast to float32
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
Out += row * stride
|
||||
A += row * stride
|
||||
# compute mean
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
# compute std
|
||||
xmean = tl.where(mask, x - mean, 0.)
|
||||
var = tl.sum(xmean * xmean, axis=0) / N
|
||||
mean = 0
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis = 0) / N
|
||||
# compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
a = tl.where(cols<N, a - mean, 0.)
|
||||
_var += a * a
|
||||
var = tl.sum(_var, axis = 0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
xhat = xmean * rstd
|
||||
# write-back mean/rstd
|
||||
tl.store(M + row, mean)
|
||||
tl.store(V + row, rstd)
|
||||
tl.store(Mean + row, mean)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# multiply by weight and add bias
|
||||
w = tl.load(W + cols, mask=mask)
|
||||
b = tl.load(B + cols, mask=mask)
|
||||
y = xhat * w + b
|
||||
# write-back
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
X += row * stride
|
||||
DY += row * stride
|
||||
DX += row * stride
|
||||
# offset locks and weight/bias gradient pointer
|
||||
# each kernel instance accumulates partial sums for
|
||||
# DW and DB into one of GROUP_SIZE_M independent buffers
|
||||
# these buffers stay in the L2, which allow this kernel
|
||||
# to be fast
|
||||
lock_id = row % GROUP_SIZE_M
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
DW = DW + lock_id * N + cols
|
||||
DB = DB + lock_id * N + cols
|
||||
weight = tl.load(Weight + cols, mask=mask)
|
||||
bias = tl.load(Bias + cols, mask=mask)
|
||||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
out = a_hat * weight + bias
|
||||
# # write-back
|
||||
tl.store(Out + cols, out, mask=mask)
|
||||
|
||||
# Backward pass (DA + partial DW + partial DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(
|
||||
_DA,
|
||||
_DOut,
|
||||
_A,
|
||||
Weight,
|
||||
Mean, Rstd,
|
||||
stride, NumRows, NumCols, eps,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
# position of elements processed by this program
|
||||
pid = tl.program_id(0)
|
||||
row = pid
|
||||
A = _A + row*stride
|
||||
DOut = _DOut + row*stride
|
||||
DA = _DA + row*stride
|
||||
mean = tl.load(Mean + row)
|
||||
rstd = tl.load(Rstd + row)
|
||||
# load data to SRAM
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
# compute dx
|
||||
xhat = (x - mean) * rstd
|
||||
wdy = w * dy
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy, 0.)
|
||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
mean2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * mean1 + mean2)) * rstd
|
||||
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
_mean1 += a_hat * wdout
|
||||
_mean2 += wdout
|
||||
mean1 = tl.sum(_mean1, axis=0) / NumCols
|
||||
mean2 = 0.
|
||||
mean2 = tl.sum(_mean2, axis=0) / NumCols
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
# accumulate partial sums for dw/db
|
||||
partial_dw = (dy * xhat).to(w.dtype)
|
||||
partial_db = (dy).to(w.dtype)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
count = tl.load(Count)
|
||||
# first store doesn't accumulate
|
||||
if count == 0:
|
||||
tl.atomic_xchg(Count, 1)
|
||||
else:
|
||||
partial_dw += tl.load(DW, mask=mask)
|
||||
partial_db += tl.load(DB, mask=mask)
|
||||
tl.store(DW, partial_dw, mask=mask)
|
||||
tl.store(DB, partial_db, mask=mask)
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
tl.store(DA + cols, da, mask=mask)
|
||||
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
def _layer_norm_bwd_dwdb(
|
||||
A, DOut,
|
||||
Mean, Var,
|
||||
DW,
|
||||
DB,
|
||||
M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
@@ -117,79 +129,112 @@ def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
||||
db += tl.load(DB + offs, mask=mask, other=0.)
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows<M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows<M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(DB + cols, sum_db, mask=cols < N)
|
||||
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, normalized_shape, weight, bias, eps):
|
||||
def forward(ctx, a, normalized_shape, weight, bias, eps):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
out = torch.empty_like(a)
|
||||
# reshape input data into 2D tensor
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
a_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = a_arg.shape
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
MAX_FUSED_SIZE = 65536 // a.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
||||
_layer_norm_fwd_fused[(M,)](
|
||||
out,
|
||||
a_arg,
|
||||
weight,
|
||||
bias,
|
||||
mean, rstd,
|
||||
a_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
a, weight, bias, mean, rstd,
|
||||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
return y
|
||||
if hasattr(bias, "config"):
|
||||
assert bias.config.grad_scale_name == weight.config.grad_scale_name
|
||||
grad_scale_name = bias.config.grad_scale_name
|
||||
else:
|
||||
grad_scale_name = None
|
||||
ctx.grad_scale_gain_bias_name = grad_scale_name
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
x, w, b, m, v = ctx.saved_tensors
|
||||
def backward(ctx, dout):
|
||||
assert dout.is_contiguous()
|
||||
a, weight, bias, mean, var = ctx.saved_tensors
|
||||
# heuristics for amount of parallel reduction stream for DG/DB
|
||||
N = w.shape[0]
|
||||
GROUP_SIZE_M = 64
|
||||
if N <= 8192: GROUP_SIZE_M = 96
|
||||
if N <= 4096: GROUP_SIZE_M = 128
|
||||
if N <= 1024: GROUP_SIZE_M = 256
|
||||
N = weight.shape[0]
|
||||
# allocate output
|
||||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
||||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
dx = torch.empty_like(dy)
|
||||
da = torch.empty_like(dout)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
x_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
||||
x_arg.stride(0), N, ctx.eps,
|
||||
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
_layer_norm_bwd_dx_fused[(M,)](
|
||||
da,
|
||||
dout,
|
||||
a,
|
||||
weight,
|
||||
mean, var,
|
||||
x_arg.stride(0), M, N,
|
||||
ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
||||
num_warps=ctx.num_warps)
|
||||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
mean, var,
|
||||
dweight,
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128)
|
||||
return dx, None, dw, db, None
|
||||
BLOCK_SIZE_N=128,
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
|
||||
|
||||
layer_norm = LayerNorm.apply
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||
|
||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
@@ -224,11 +269,11 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
plot_name='layer-norm-backward',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
|
||||
plot_name='layer-norm',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
|
||||
)
|
||||
)
|
||||
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
|
||||
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
@@ -258,4 +303,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# test_layer_norm(1151, 8192, torch.float16)
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
Reference in New Issue
Block a user