Merge triton-mlir
branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -19,8 +19,8 @@ except ModuleNotFoundError:
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
Out,
|
||||
A,
|
||||
Out,
|
||||
Weight,
|
||||
Bias,
|
||||
Mean, Rstd,
|
||||
@@ -36,14 +36,14 @@ def _layer_norm_fwd_fused(
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
a = tl.where(cols < N, a - mean, 0.)
|
||||
_var += a * a
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
@@ -57,192 +57,155 @@ def _layer_norm_fwd_fused(
|
||||
mask = cols < N
|
||||
weight = tl.load(Weight + cols, mask=mask)
|
||||
bias = tl.load(Bias + cols, mask=mask)
|
||||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
|
||||
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
out = a_hat * weight + bias
|
||||
# # write-back
|
||||
tl.store(Out + cols, out, mask=mask)
|
||||
|
||||
# Backward pass (DA + partial DW + partial DB)
|
||||
|
||||
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(
|
||||
_DA,
|
||||
_DOut,
|
||||
_A,
|
||||
Weight,
|
||||
Mean, Rstd,
|
||||
stride, NumRows, NumCols, eps,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# position of elements processed by this program
|
||||
pid = tl.program_id(0)
|
||||
row = pid
|
||||
A = _A + row * stride
|
||||
DOut = _DOut + row * stride
|
||||
DA = _DA + row * stride
|
||||
mean = tl.load(Mean + row)
|
||||
rstd = tl.load(Rstd + row)
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
X += row * stride
|
||||
DY += row * stride
|
||||
DX += row * stride
|
||||
# offset locks and weight/bias gradient pointer
|
||||
# each kernel instance accumulates partial sums for
|
||||
# DW and DB into one of GROUP_SIZE_M independent buffers
|
||||
# these buffers stay in the L2, which allow this kernel
|
||||
# to be fast
|
||||
lock_id = row % GROUP_SIZE_M
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
DW = DW + lock_id * N + cols
|
||||
DB = DB + lock_id * N + cols
|
||||
# load data to SRAM
|
||||
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
_mean1 += a_hat * wdout
|
||||
_mean2 += wdout
|
||||
mean1 = tl.sum(_mean1, axis=0) / NumCols
|
||||
mean2 = 0.
|
||||
mean2 = tl.sum(_mean2, axis=0) / NumCols
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DA + cols, da, mask=mask)
|
||||
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
# compute dx
|
||||
xhat = (x - mean) * rstd
|
||||
wdy = w * dy
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy, 0.)
|
||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
mean2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
# accumulate partial sums for dw/db
|
||||
partial_dw = (dy * xhat).to(w.dtype)
|
||||
partial_db = (dy).to(w.dtype)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
count = tl.load(Count)
|
||||
# first store doesn't accumulate
|
||||
if count == 0:
|
||||
tl.atomic_xchg(Count, 1)
|
||||
else:
|
||||
partial_dw += tl.load(DW, mask=mask)
|
||||
partial_db += tl.load(DB, mask=mask)
|
||||
tl.store(DW, partial_dw, mask=mask)
|
||||
tl.store(DB, partial_db, mask=mask)
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(
|
||||
A, DOut,
|
||||
Mean, Var,
|
||||
DW,
|
||||
DB,
|
||||
M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
UNROLL: tl.constexpr = 4
|
||||
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||
for j in range(UNROLL):
|
||||
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
||||
db += tl.load(DB + offs, mask=mask, other=0.)
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(DB + cols, sum_db, mask=cols < N)
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
||||
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, normalized_shape, weight, bias, eps):
|
||||
def forward(ctx, x, normalized_shape, weight, bias, eps):
|
||||
# allocate output
|
||||
out = torch.empty_like(a)
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor
|
||||
a_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = a_arg.shape
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // a.element_size()
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_layer_norm_fwd_fused[(M,)](
|
||||
out,
|
||||
a_arg,
|
||||
weight,
|
||||
bias,
|
||||
mean, rstd,
|
||||
a_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
a, weight, bias, mean, rstd,
|
||||
)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
if hasattr(bias, "config"):
|
||||
assert bias.config.grad_scale_name == weight.config.grad_scale_name
|
||||
grad_scale_name = bias.config.grad_scale_name
|
||||
else:
|
||||
grad_scale_name = None
|
||||
ctx.grad_scale_gain_bias_name = grad_scale_name
|
||||
return out
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
assert dout.is_contiguous()
|
||||
a, weight, bias, mean, var = ctx.saved_tensors
|
||||
def backward(ctx, dy):
|
||||
x, w, b, m, v = ctx.saved_tensors
|
||||
# heuristics for amount of parallel reduction stream for DG/DB
|
||||
N = weight.shape[0]
|
||||
N = w.shape[0]
|
||||
GROUP_SIZE_M = 64
|
||||
if N <= 8192: GROUP_SIZE_M = 96
|
||||
if N <= 4096: GROUP_SIZE_M = 128
|
||||
if N <= 1024: GROUP_SIZE_M = 256
|
||||
# allocate output
|
||||
da = torch.empty_like(dout)
|
||||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
||||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
dx = torch.empty_like(dy)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
x_arg = a.reshape(-1, a.shape[-1])
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
_layer_norm_bwd_dx_fused[(M,)](
|
||||
da,
|
||||
dout,
|
||||
a,
|
||||
weight,
|
||||
mean, var,
|
||||
x_arg.stride(0), M, N,
|
||||
ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
if N > 10240:
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_M = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_M = 16
|
||||
num_warps = 8
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
mean, var,
|
||||
dweight,
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
return (da, None, dweight, dbias, None)
|
||||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
||||
x_arg.stride(0), N, ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
||||
num_warps=ctx.num_warps)
|
||||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
||||
# accumulate partial sums in separate kernel
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128)
|
||||
return dx, None, dw, db, None
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||
layer_norm = LayerNorm.apply
|
||||
|
||||
|
||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
@@ -277,11 +240,11 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
plot_name='layer-norm',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
|
||||
plot_name='layer-norm-backward',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
|
||||
)
|
||||
)
|
||||
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
@@ -311,5 +274,5 @@ def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# test_layer_norm(1151, 8192, torch.float16)
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
test_layer_norm(1151, 8192, torch.float16)
|
||||
# bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
Reference in New Issue
Block a user