diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 9880b428f..333cb80ec 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb( cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] - a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) - dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) - mean = tl.load(Mean + rows, mask=rows < M, other=0.) - rstd = tl.load(Var + rows, mask=rows < M, other=0.) - a_hat = (a - mean[:, None]) * rstd[:, None] - dw += dout * a_hat - db += dout + UNROLL: tl.constexpr = 4 + for i in range(0, M, BLOCK_SIZE_M * UNROLL): + for j in range(UNROLL): + rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) + dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) + mean = tl.load(Mean + rows, mask=rows < M, other=0.) + rstd = tl.load(Var + rows, mask=rows < M, other=0.) + a_hat = (a - mean[:, None]) * rstd[:, None] + dw += dout * a_hat + db += dout sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(DW + cols, sum_dw, mask=cols < N) @@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function): BLOCK_SIZE_N=ctx.BLOCK_SIZE, num_warps=ctx.num_warps, ) - # accumulate partial sums in separate kernel + if N > 10240: + BLOCK_SIZE_N = 128 + BLOCK_SIZE_M = 32 + num_warps = 4 + else: + # maximize occupancy for small N + BLOCK_SIZE_N = 16 + BLOCK_SIZE_M = 16 + num_warps = 8 grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] _layer_norm_bwd_dwdb[grid]( a, dout, @@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function): dbias, M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=num_warps ) - return (da, None, dweight, dbias, None, None, - None, None, None, None, - None, - None, None, None, - None, - None, None, None, - None, None, None, - None, None, None) + return (da, None, dweight, dbias, None) def layer_norm(a, normalized_shape, weight, bias, eps):