[TUTORIALS] adjust heuristics for dwdb kernel (#565)

This commit is contained in:
Natalia Gimelshein
2022-06-29 17:00:22 -07:00
committed by GitHub
parent 1895ceaa2d
commit 1bbb2430d9

View File

@@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb(
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M): UNROLL: tl.constexpr = 4
rows = i + tl.arange(0, BLOCK_SIZE_M) for i in range(0, M, BLOCK_SIZE_M * UNROLL):
mask = (rows[:, None] < M) & (cols[None, :] < N) for j in range(UNROLL):
offs = rows[:, None] * N + cols[None, :] rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) mask = (rows[:, None] < M) & (cols[None, :] < N)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) offs = rows[:, None] * N + cols[None, :]
mean = tl.load(Mean + rows, mask=rows < M, other=0.) a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
rstd = tl.load(Var + rows, mask=rows < M, other=0.) dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
a_hat = (a - mean[:, None]) * rstd[:, None] mean = tl.load(Mean + rows, mask=rows < M, other=0.)
dw += dout * a_hat rstd = tl.load(Var + rows, mask=rows < M, other=0.)
db += dout a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0) sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0) sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N) tl.store(DW + cols, sum_dw, mask=cols < N)
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
BLOCK_SIZE_N=ctx.BLOCK_SIZE, BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps, num_warps=ctx.num_warps,
) )
# accumulate partial sums in separate kernel if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid]( _layer_norm_bwd_dwdb[grid](
a, dout, a, dout,
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
dbias, dbias,
M, M,
N, N,
BLOCK_SIZE_M=32, BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=128, BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
) )
return (da, None, dweight, dbias, None, None, return (da, None, dweight, dbias, None)
None, None, None, None,
None,
None, None, None,
None,
None, None, None,
None, None, None,
None, None, None)
def layer_norm(a, normalized_shape, weight, bias, eps): def layer_norm(a, normalized_shape, weight, bias, eps):