[TUTORIALS] Removed #noformat in layer norm tutorial
This commit is contained in:
@@ -16,14 +16,13 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
HAS_APEX = False
|
HAS_APEX = False
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_fwd_fused(
|
def _layer_norm_fwd_fused(
|
||||||
Out,
|
Out,
|
||||||
A,
|
A,
|
||||||
Weight,
|
Weight,
|
||||||
Bias,
|
Bias,
|
||||||
Mean, Rstd,
|
Mean, Rstd,
|
||||||
stride, N, eps,
|
stride, N, eps,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
@@ -37,17 +36,17 @@ def _layer_norm_fwd_fused(
|
|||||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32)
|
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||||
_mean += a
|
_mean += a
|
||||||
mean = tl.sum(_mean, axis = 0) / N
|
mean = tl.sum(_mean, axis=0) / N
|
||||||
# compute variance
|
# compute variance
|
||||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32)
|
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||||
a = tl.where(cols<N, a - mean, 0.)
|
a = tl.where(cols < N, a - mean, 0.)
|
||||||
_var += a * a
|
_var += a * a
|
||||||
var = tl.sum(_var, axis = 0) / N
|
var = tl.sum(_var, axis=0) / N
|
||||||
rstd = 1 / tl.sqrt(var + eps)
|
rstd = 1 / tl.sqrt(var + eps)
|
||||||
# write-back mean/rstd
|
# write-back mean/rstd
|
||||||
tl.store(Mean + row, mean)
|
tl.store(Mean + row, mean)
|
||||||
@@ -65,22 +64,24 @@ def _layer_norm_fwd_fused(
|
|||||||
tl.store(Out + cols, out, mask=mask)
|
tl.store(Out + cols, out, mask=mask)
|
||||||
|
|
||||||
# Backward pass (DA + partial DW + partial DB)
|
# Backward pass (DA + partial DW + partial DB)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_bwd_dx_fused(
|
def _layer_norm_bwd_dx_fused(
|
||||||
_DA,
|
_DA,
|
||||||
_DOut,
|
_DOut,
|
||||||
_A,
|
_A,
|
||||||
Weight,
|
Weight,
|
||||||
Mean, Rstd,
|
Mean, Rstd,
|
||||||
stride, NumRows, NumCols, eps,
|
stride, NumRows, NumCols, eps,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
):
|
):
|
||||||
# position of elements processed by this program
|
# position of elements processed by this program
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
row = pid
|
row = pid
|
||||||
A = _A + row*stride
|
A = _A + row * stride
|
||||||
DOut = _DOut + row*stride
|
DOut = _DOut + row * stride
|
||||||
DA = _DA + row*stride
|
DA = _DA + row * stride
|
||||||
mean = tl.load(Mean + row)
|
mean = tl.load(Mean + row)
|
||||||
rstd = tl.load(Rstd + row)
|
rstd = tl.load(Rstd + row)
|
||||||
# load data to SRAM
|
# load data to SRAM
|
||||||
@@ -117,8 +118,8 @@ def _layer_norm_bwd_dx_fused(
|
|||||||
def _layer_norm_bwd_dwdb(
|
def _layer_norm_bwd_dwdb(
|
||||||
A, DOut,
|
A, DOut,
|
||||||
Mean, Var,
|
Mean, Var,
|
||||||
DW,
|
DW,
|
||||||
DB,
|
DB,
|
||||||
M, N,
|
M, N,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
@@ -129,12 +130,12 @@ def _layer_norm_bwd_dwdb(
|
|||||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for i in range(0, M, BLOCK_SIZE_M):
|
for i in range(0, M, BLOCK_SIZE_M):
|
||||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (rows[:, None] < M) & (cols[None,:] < N)
|
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||||
offs = rows[:, None] * N + cols[None,:]
|
offs = rows[:, None] * N + cols[None, :]
|
||||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
mean = tl.load(Mean + rows, mask=rows<M, other=0.)
|
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||||
rstd = tl.load(Var + rows, mask=rows<M, other=0.)
|
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||||
dw += dout * a_hat
|
dw += dout * a_hat
|
||||||
db += dout
|
db += dout
|
||||||
@@ -162,10 +163,10 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||||
_layer_norm_fwd_fused[(M,)](
|
_layer_norm_fwd_fused[(M,)](
|
||||||
out,
|
out,
|
||||||
a_arg,
|
a_arg,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
mean, rstd,
|
mean, rstd,
|
||||||
a_arg.stride(0), N, eps,
|
a_arg.stride(0), N, eps,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
@@ -192,18 +193,18 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
# heuristics for amount of parallel reduction stream for DG/DB
|
# heuristics for amount of parallel reduction stream for DG/DB
|
||||||
N = weight.shape[0]
|
N = weight.shape[0]
|
||||||
# allocate output
|
# allocate output
|
||||||
da = torch.empty_like(dout)
|
da = torch.empty_like(dout)
|
||||||
# enqueue kernel using forward pass heuristics
|
# enqueue kernel using forward pass heuristics
|
||||||
# also compute partial sums for DW and DB
|
# also compute partial sums for DW and DB
|
||||||
x_arg = a.reshape(-1, a.shape[-1])
|
x_arg = a.reshape(-1, a.shape[-1])
|
||||||
M, N = x_arg.shape
|
M, N = x_arg.shape
|
||||||
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||||
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||||
_layer_norm_bwd_dx_fused[(M,)](
|
_layer_norm_bwd_dx_fused[(M,)](
|
||||||
da,
|
da,
|
||||||
dout,
|
dout,
|
||||||
a,
|
a,
|
||||||
weight,
|
weight,
|
||||||
mean, var,
|
mean, var,
|
||||||
x_arg.stride(0), M, N,
|
x_arg.stride(0), M, N,
|
||||||
ctx.eps,
|
ctx.eps,
|
||||||
@@ -216,7 +217,7 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
a, dout,
|
a, dout,
|
||||||
mean, var,
|
mean, var,
|
||||||
dweight,
|
dweight,
|
||||||
dbias,
|
dbias,
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
BLOCK_SIZE_M=32,
|
BLOCK_SIZE_M=32,
|
||||||
@@ -235,6 +236,7 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
|
|
||||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
# create data
|
# create data
|
||||||
|
Reference in New Issue
Block a user