[TUTORIALS] Removed #noformat in layer norm tutorial

This commit is contained in:
Philippe Tillet
2022-05-12 12:41:25 -07:00
parent c736ba7c3e
commit 0835a4fb05

View File

@@ -16,7 +16,6 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
HAS_APEX = False HAS_APEX = False
# fmt: off
@triton.jit @triton.jit
def _layer_norm_fwd_fused( def _layer_norm_fwd_fused(
@@ -37,17 +36,17 @@ def _layer_norm_fwd_fused(
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32) a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
_mean += a _mean += a
mean = tl.sum(_mean, axis = 0) / N mean = tl.sum(_mean, axis=0) / N
# compute variance # compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols<N, other=0., eviction_policy="evict_last").to(tl.float32) a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
a = tl.where(cols<N, a - mean, 0.) a = tl.where(cols < N, a - mean, 0.)
_var += a * a _var += a * a
var = tl.sum(_var, axis = 0) / N var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps) rstd = 1 / tl.sqrt(var + eps)
# write-back mean/rstd # write-back mean/rstd
tl.store(Mean + row, mean) tl.store(Mean + row, mean)
@@ -65,6 +64,8 @@ def _layer_norm_fwd_fused(
tl.store(Out + cols, out, mask=mask) tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB) # Backward pass (DA + partial DW + partial DB)
@triton.jit @triton.jit
def _layer_norm_bwd_dx_fused( def _layer_norm_bwd_dx_fused(
_DA, _DA,
@@ -78,9 +79,9 @@ def _layer_norm_bwd_dx_fused(
# position of elements processed by this program # position of elements processed by this program
pid = tl.program_id(0) pid = tl.program_id(0)
row = pid row = pid
A = _A + row*stride A = _A + row * stride
DOut = _DOut + row*stride DOut = _DOut + row * stride
DA = _DA + row*stride DA = _DA + row * stride
mean = tl.load(Mean + row) mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row) rstd = tl.load(Rstd + row)
# load data to SRAM # load data to SRAM
@@ -129,12 +130,12 @@ def _layer_norm_bwd_dwdb(
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M): for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M) rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None,:] < N) mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None,:] offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows<M, other=0.) mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows<M, other=0.) rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None] a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat dw += dout * a_hat
db += dout db += dout
@@ -235,6 +236,7 @@ class LayerNorm(torch.autograd.Function):
def layer_norm(a, normalized_shape, weight, bias, eps): def layer_norm(a, normalized_shape, weight, bias, eps):
return LayerNorm.apply(a, normalized_shape, weight, bias, eps) return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
torch.manual_seed(0) torch.manual_seed(0)
# create data # create data