[TUTORIALS] Layer norm tutorial now uses residency control (#510)
This commit is contained in:
@@ -3,11 +3,9 @@ Layer Normalization
|
|||||||
====================
|
====================
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
import torch
|
||||||
try:
|
try:
|
||||||
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
|
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
|
||||||
# should not be added to extras_require in setup.py.
|
# should not be added to extras_require in setup.py.
|
||||||
@@ -16,99 +14,113 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
HAS_APEX = False
|
HAS_APEX = False
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
# Forward Pass
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
|
def _layer_norm_fwd_fused(
|
||||||
BLOCK_SIZE: tl.constexpr):
|
Out,
|
||||||
|
A,
|
||||||
|
Weight,
|
||||||
|
Bias,
|
||||||
|
Mean, Rstd,
|
||||||
|
stride, N, eps,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
# position of elements processed by this program
|
# position of elements processed by this program
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
cols = tl.arange(0, BLOCK_SIZE)
|
Out += row * stride
|
||||||
mask = cols < N
|
A += row * stride
|
||||||
# offset data pointers to start at the row of interest
|
|
||||||
X += row * stride
|
|
||||||
Y += row * stride
|
|
||||||
# load data and cast to float32
|
|
||||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
|
||||||
# compute mean
|
# compute mean
|
||||||
mean = tl.sum(x, axis=0) / N
|
mean = 0
|
||||||
# compute std
|
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||||
xmean = tl.where(mask, x - mean, 0.)
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
var = tl.sum(xmean * xmean, axis=0) / N
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
|
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||||
|
_mean += a
|
||||||
|
mean = tl.sum(_mean, axis = 0) / N
|
||||||
|
# compute variance
|
||||||
|
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||||
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
|
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||||
|
a = tl.where(cols<N, a - mean, 0.)
|
||||||
|
_var += a * a
|
||||||
|
var = tl.sum(_var, axis = 0) / N
|
||||||
rstd = 1 / tl.sqrt(var + eps)
|
rstd = 1 / tl.sqrt(var + eps)
|
||||||
xhat = xmean * rstd
|
|
||||||
# write-back mean/rstd
|
# write-back mean/rstd
|
||||||
tl.store(M + row, mean)
|
tl.store(Mean + row, mean)
|
||||||
tl.store(V + row, rstd)
|
tl.store(Rstd + row, rstd)
|
||||||
# multiply by weight and add bias
|
# multiply by weight and add bias
|
||||||
w = tl.load(W + cols, mask=mask)
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
b = tl.load(B + cols, mask=mask)
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
y = xhat * w + b
|
|
||||||
# write-back
|
|
||||||
tl.store(Y + cols, y, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
# Backward pass (DX + partial DW + partial DB)
|
|
||||||
@triton.jit
|
|
||||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
|
|
||||||
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
|
||||||
# position of elements processed by this program
|
|
||||||
row = tl.program_id(0)
|
|
||||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
mask = cols < N
|
mask = cols < N
|
||||||
# offset data pointers to start at the row of interest
|
weight = tl.load(Weight + cols, mask=mask)
|
||||||
X += row * stride
|
bias = tl.load(Bias + cols, mask=mask)
|
||||||
DY += row * stride
|
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
|
||||||
DX += row * stride
|
a_hat = (a - mean) * rstd
|
||||||
# offset locks and weight/bias gradient pointer
|
out = a_hat * weight + bias
|
||||||
# each kernel instance accumulates partial sums for
|
# # write-back
|
||||||
# DW and DB into one of GROUP_SIZE_M independent buffers
|
tl.store(Out + cols, out, mask=mask)
|
||||||
# these buffers stay in the L2, which allow this kernel
|
|
||||||
# to be fast
|
# Backward pass (DA + partial DW + partial DB)
|
||||||
lock_id = row % GROUP_SIZE_M
|
@triton.jit
|
||||||
Lock += lock_id
|
def _layer_norm_bwd_dx_fused(
|
||||||
Count = Lock + GROUP_SIZE_M
|
_DA,
|
||||||
DW = DW + lock_id * N + cols
|
_DOut,
|
||||||
DB = DB + lock_id * N + cols
|
_A,
|
||||||
|
Weight,
|
||||||
|
Mean, Rstd,
|
||||||
|
stride, NumRows, NumCols, eps,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
# position of elements processed by this program
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
row = pid
|
||||||
|
A = _A + row*stride
|
||||||
|
DOut = _DOut + row*stride
|
||||||
|
DA = _DA + row*stride
|
||||||
|
mean = tl.load(Mean + row)
|
||||||
|
rstd = tl.load(Rstd + row)
|
||||||
# load data to SRAM
|
# load data to SRAM
|
||||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||||
mean = tl.load(M + row)
|
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||||
rstd = tl.load(V + row)
|
mask = cols < NumCols
|
||||||
# compute dx
|
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||||
xhat = (x - mean) * rstd
|
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||||
wdy = w * dy
|
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||||
xhat = tl.where(mask, xhat, 0.)
|
a_hat = (a - mean) * rstd
|
||||||
wdy = tl.where(mask, wdy, 0.)
|
wdout = weight * dout
|
||||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
_mean1 += a_hat * wdout
|
||||||
mean2 = tl.sum(wdy, axis=0) / N
|
_mean2 += wdout
|
||||||
dx = (wdy - (xhat * mean1 + mean2)) * rstd
|
mean1 = tl.sum(_mean1, axis=0) / NumCols
|
||||||
|
mean2 = 0.
|
||||||
|
mean2 = tl.sum(_mean2, axis=0) / NumCols
|
||||||
|
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||||
|
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
mask = cols < NumCols
|
||||||
|
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||||
|
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||||
|
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||||
|
a_hat = (a - mean) * rstd
|
||||||
|
wdout = weight * dout
|
||||||
|
da = (wdout - (a_hat * mean1 + mean2)) * rstd
|
||||||
# write-back dx
|
# write-back dx
|
||||||
tl.store(DX + cols, dx, mask=mask)
|
tl.store(DA + cols, da, mask=mask)
|
||||||
# accumulate partial sums for dw/db
|
|
||||||
partial_dw = (dy * xhat).to(w.dtype)
|
|
||||||
partial_db = (dy).to(w.dtype)
|
|
||||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
|
||||||
pass
|
|
||||||
count = tl.load(Count)
|
|
||||||
# first store doesn't accumulate
|
|
||||||
if count == 0:
|
|
||||||
tl.atomic_xchg(Count, 1)
|
|
||||||
else:
|
|
||||||
partial_dw += tl.load(DW, mask=mask)
|
|
||||||
partial_db += tl.load(DB, mask=mask)
|
|
||||||
tl.store(DW, partial_dw, mask=mask)
|
|
||||||
tl.store(DB, partial_db, mask=mask)
|
|
||||||
# release lock
|
|
||||||
tl.atomic_xchg(Lock, 0)
|
|
||||||
|
|
||||||
# Backward pass (total DW + total DB)
|
# Backward pass (total DW + total DB)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
def _layer_norm_bwd_dwdb(
|
||||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
A, DOut,
|
||||||
|
Mean, Var,
|
||||||
|
DW,
|
||||||
|
DB,
|
||||||
|
M, N,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
):
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
@@ -117,79 +129,112 @@ def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
|||||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||||
offs = rows[:, None] * N + cols[None, :]
|
offs = rows[:, None] * N + cols[None, :]
|
||||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
db += tl.load(DB + offs, mask=mask, other=0.)
|
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
|
mean = tl.load(Mean + rows, mask=rows<M, other=0.)
|
||||||
|
rstd = tl.load(Var + rows, mask=rows<M, other=0.)
|
||||||
|
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||||
|
dw += dout * a_hat
|
||||||
|
db += dout
|
||||||
sum_dw = tl.sum(dw, axis=0)
|
sum_dw = tl.sum(dw, axis=0)
|
||||||
sum_db = tl.sum(db, axis=0)
|
sum_db = tl.sum(db, axis=0)
|
||||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||||
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
tl.store(DB + cols, sum_db, mask=cols < N)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.autograd.Function):
|
class LayerNorm(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, normalized_shape, weight, bias, eps):
|
def forward(ctx, a, normalized_shape, weight, bias, eps):
|
||||||
# allocate output
|
# allocate output
|
||||||
y = torch.empty_like(x)
|
out = torch.empty_like(a)
|
||||||
# reshape input data into 2D tensor
|
# reshape input data into 2D tensor
|
||||||
x_arg = x.reshape(-1, x.shape[-1])
|
a_arg = a.reshape(-1, a.shape[-1])
|
||||||
M, N = x_arg.shape
|
M, N = a_arg.shape
|
||||||
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||||
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||||
# Less than 64KB per feature: enqueue fused kernel
|
# Less than 64KB per feature: enqueue fused kernel
|
||||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
MAX_FUSED_SIZE = 65536 // a.element_size()
|
||||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||||
if N > BLOCK_SIZE:
|
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||||
# enqueue kernel
|
_layer_norm_fwd_fused[(M,)](
|
||||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
out,
|
||||||
x_arg.stride(0), N, eps,
|
a_arg,
|
||||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
weight,
|
||||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
bias,
|
||||||
|
mean, rstd,
|
||||||
|
a_arg.stride(0), N, eps,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
ctx.save_for_backward(
|
||||||
|
a, weight, bias, mean, rstd,
|
||||||
|
)
|
||||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
ctx.num_warps = num_warps
|
ctx.num_warps = num_warps
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
return y
|
if hasattr(bias, "config"):
|
||||||
|
assert bias.config.grad_scale_name == weight.config.grad_scale_name
|
||||||
|
grad_scale_name = bias.config.grad_scale_name
|
||||||
|
else:
|
||||||
|
grad_scale_name = None
|
||||||
|
ctx.grad_scale_gain_bias_name = grad_scale_name
|
||||||
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dy):
|
def backward(ctx, dout):
|
||||||
x, w, b, m, v = ctx.saved_tensors
|
assert dout.is_contiguous()
|
||||||
|
a, weight, bias, mean, var = ctx.saved_tensors
|
||||||
# heuristics for amount of parallel reduction stream for DG/DB
|
# heuristics for amount of parallel reduction stream for DG/DB
|
||||||
N = w.shape[0]
|
N = weight.shape[0]
|
||||||
GROUP_SIZE_M = 64
|
|
||||||
if N <= 8192: GROUP_SIZE_M = 96
|
|
||||||
if N <= 4096: GROUP_SIZE_M = 128
|
|
||||||
if N <= 1024: GROUP_SIZE_M = 256
|
|
||||||
# allocate output
|
# allocate output
|
||||||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
da = torch.empty_like(dout)
|
||||||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
|
||||||
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
|
||||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
|
||||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
|
||||||
dx = torch.empty_like(dy)
|
|
||||||
# enqueue kernel using forward pass heuristics
|
# enqueue kernel using forward pass heuristics
|
||||||
# also compute partial sums for DW and DB
|
# also compute partial sums for DW and DB
|
||||||
x_arg = x.reshape(-1, x.shape[-1])
|
x_arg = a.reshape(-1, a.shape[-1])
|
||||||
M, N = x_arg.shape
|
M, N = x_arg.shape
|
||||||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||||
x_arg.stride(0), N, ctx.eps,
|
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||||
|
_layer_norm_bwd_dx_fused[(M,)](
|
||||||
|
da,
|
||||||
|
dout,
|
||||||
|
a,
|
||||||
|
weight,
|
||||||
|
mean, var,
|
||||||
|
x_arg.stride(0), M, N,
|
||||||
|
ctx.eps,
|
||||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
num_warps=ctx.num_warps,
|
||||||
num_warps=ctx.num_warps)
|
)
|
||||||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
|
||||||
# accumulate partial sums in separate kernel
|
# accumulate partial sums in separate kernel
|
||||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||||
|
_layer_norm_bwd_dwdb[grid](
|
||||||
|
a, dout,
|
||||||
|
mean, var,
|
||||||
|
dweight,
|
||||||
|
dbias,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
BLOCK_SIZE_M=32,
|
BLOCK_SIZE_M=32,
|
||||||
BLOCK_SIZE_N=128)
|
BLOCK_SIZE_N=128,
|
||||||
return dx, None, dw, db, None
|
)
|
||||||
|
return (da, None, dweight, dbias, None, None,
|
||||||
|
None, None, None, None,
|
||||||
|
None,
|
||||||
|
None, None, None,
|
||||||
|
None,
|
||||||
|
None, None, None,
|
||||||
|
None, None, None,
|
||||||
|
None, None, None)
|
||||||
|
|
||||||
|
|
||||||
layer_norm = LayerNorm.apply
|
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||||
|
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||||
|
torch.manual_seed(0)
|
||||||
# create data
|
# create data
|
||||||
x_shape = (M, N)
|
x_shape = (M, N)
|
||||||
w_shape = (x_shape[-1], )
|
w_shape = (x_shape[-1], )
|
||||||
@@ -224,11 +269,11 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
|||||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||||
ylabel='GB/s',
|
ylabel='GB/s',
|
||||||
plot_name='layer-norm-backward',
|
plot_name='layer-norm',
|
||||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
|
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
|
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||||
# create data
|
# create data
|
||||||
x_shape = (M, N)
|
x_shape = (M, N)
|
||||||
w_shape = (x_shape[-1], )
|
w_shape = (x_shape[-1], )
|
||||||
@@ -258,4 +303,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
|||||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
# test_layer_norm(1151, 8192, torch.float16)
|
||||||
bench_layer_norm.run(save_path='.', print_data=True)
|
bench_layer_norm.run(save_path='.', print_data=True)
|
||||||
|
Reference in New Issue
Block a user