[TUTORIALS] adjust heuristics for dwdb kernel (#565)
This commit is contained in:
committed by
GitHub
parent
1895ceaa2d
commit
1bbb2430d9
@@ -128,8 +128,10 @@ def _layer_norm_bwd_dwdb(
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
UNROLL: tl.constexpr = 4
|
||||
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||
for j in range(UNROLL):
|
||||
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
if N > 10240:
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_M = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_M = 16
|
||||
num_warps = 8
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
return (da, None, dweight, dbias, None)
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
|
Reference in New Issue
Block a user