[TUTORIALS] adjust heuristics for dwdb kernel (#565)
This commit is contained in:
committed by
GitHub
parent
1895ceaa2d
commit
1bbb2430d9
@@ -128,8 +128,10 @@ def _layer_norm_bwd_dwdb(
|
|||||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for i in range(0, M, BLOCK_SIZE_M):
|
UNROLL: tl.constexpr = 4
|
||||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||||
|
for j in range(UNROLL):
|
||||||
|
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||||
offs = rows[:, None] * N + cols[None, :]
|
offs = rows[:, None] * N + cols[None, :]
|
||||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||||
num_warps=ctx.num_warps,
|
num_warps=ctx.num_warps,
|
||||||
)
|
)
|
||||||
# accumulate partial sums in separate kernel
|
if N > 10240:
|
||||||
|
BLOCK_SIZE_N = 128
|
||||||
|
BLOCK_SIZE_M = 32
|
||||||
|
num_warps = 4
|
||||||
|
else:
|
||||||
|
# maximize occupancy for small N
|
||||||
|
BLOCK_SIZE_N = 16
|
||||||
|
BLOCK_SIZE_M = 16
|
||||||
|
num_warps = 8
|
||||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||||
_layer_norm_bwd_dwdb[grid](
|
_layer_norm_bwd_dwdb[grid](
|
||||||
a, dout,
|
a, dout,
|
||||||
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
dbias,
|
dbias,
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
BLOCK_SIZE_M=32,
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
BLOCK_SIZE_N=128,
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||||
|
num_warps=num_warps
|
||||||
)
|
)
|
||||||
return (da, None, dweight, dbias, None, None,
|
return (da, None, dweight, dbias, None)
|
||||||
None, None, None, None,
|
|
||||||
None,
|
|
||||||
None, None, None,
|
|
||||||
None,
|
|
||||||
None, None, None,
|
|
||||||
None, None, None,
|
|
||||||
None, None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||||
|
Reference in New Issue
Block a user