[TUTORIALS] adjust heuristics for dwdb kernel (#565)
This commit is contained in:
committed by
GitHub
parent
1895ceaa2d
commit
1bbb2430d9
@@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb(
|
|||||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for i in range(0, M, BLOCK_SIZE_M):
|
UNROLL: tl.constexpr = 4
|
||||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
for j in range(UNROLL):
|
||||||
offs = rows[:, None] * N + cols[None, :]
|
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
offs = rows[:, None] * N + cols[None, :]
|
||||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||||
dw += dout * a_hat
|
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||||
db += dout
|
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||||
|
dw += dout * a_hat
|
||||||
|
db += dout
|
||||||
sum_dw = tl.sum(dw, axis=0)
|
sum_dw = tl.sum(dw, axis=0)
|
||||||
sum_db = tl.sum(db, axis=0)
|
sum_db = tl.sum(db, axis=0)
|
||||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||||
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||||
num_warps=ctx.num_warps,
|
num_warps=ctx.num_warps,
|
||||||
)
|
)
|
||||||
# accumulate partial sums in separate kernel
|
if N > 10240:
|
||||||
|
BLOCK_SIZE_N = 128
|
||||||
|
BLOCK_SIZE_M = 32
|
||||||
|
num_warps = 4
|
||||||
|
else:
|
||||||
|
# maximize occupancy for small N
|
||||||
|
BLOCK_SIZE_N = 16
|
||||||
|
BLOCK_SIZE_M = 16
|
||||||
|
num_warps = 8
|
||||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||||
_layer_norm_bwd_dwdb[grid](
|
_layer_norm_bwd_dwdb[grid](
|
||||||
a, dout,
|
a, dout,
|
||||||
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
dbias,
|
dbias,
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
BLOCK_SIZE_M=32,
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
BLOCK_SIZE_N=128,
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||||
|
num_warps=num_warps
|
||||||
)
|
)
|
||||||
return (da, None, dweight, dbias, None, None,
|
return (da, None, dweight, dbias, None)
|
||||||
None, None, None, None,
|
|
||||||
None,
|
|
||||||
None, None, None,
|
|
||||||
None,
|
|
||||||
None, None, None,
|
|
||||||
None, None, None,
|
|
||||||
None, None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||||
|
Reference in New Issue
Block a user