[TUTORIALS] adjust heuristics for dwdb kernel (#565)

This commit is contained in:
Natalia Gimelshein
2022-06-29 17:00:22 -07:00
committed by GitHub
parent 1895ceaa2d
commit 1bbb2430d9

View File

@@ -128,8 +128,10 @@ def _layer_norm_bwd_dwdb(
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
# accumulate partial sums in separate kernel
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
dbias,
M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
)
return (da, None, dweight, dbias, None, None,
None, None, None, None,
None,
None, None, None,
None,
None, None, None,
None, None, None,
None, None, None)
return (da, None, dweight, dbias, None)
def layer_norm(a, normalized_shape, weight, bias, eps):