diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 1cefc60b9..6581c809e 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -3,11 +3,9 @@ Layer Normalization ==================== """ -import torch - import triton import triton.language as tl - +import torch try: # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it # should not be added to extras_require in setup.py. @@ -16,99 +14,113 @@ try: except ModuleNotFoundError: HAS_APEX = False +# fmt: off -# Forward Pass @triton.jit -def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, - BLOCK_SIZE: tl.constexpr): +def _layer_norm_fwd_fused( + Out, + A, + Weight, + Bias, + Mean, Rstd, + stride, N, eps, + BLOCK_SIZE: tl.constexpr, +): # position of elements processed by this program row = tl.program_id(0) - cols = tl.arange(0, BLOCK_SIZE) - mask = cols < N - # offset data pointers to start at the row of interest - X += row * stride - Y += row * stride - # load data and cast to float32 - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + Out += row * stride + A += row * stride # compute mean - mean = tl.sum(x, axis=0) / N - # compute std - xmean = tl.where(mask, x - mean, 0.) - var = tl.sum(xmean * xmean, axis=0) / N + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, - x_arg.stride(0), N, eps, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - ctx.save_for_backward(x, weight, bias, mean, rstd) + _layer_norm_fwd_fused[(M,)]( + out, + a_arg, + weight, + bias, + mean, rstd, + a_arg.stride(0), N, eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + ctx.save_for_backward( + a, weight, bias, mean, rstd, + ) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.eps = eps - return y + if hasattr(bias, "config"): + assert bias.config.grad_scale_name == weight.config.grad_scale_name + grad_scale_name = bias.config.grad_scale_name + else: + grad_scale_name = None + ctx.grad_scale_gain_bias_name = grad_scale_name + return out @staticmethod - def backward(ctx, dy): - x, w, b, m, v = ctx.saved_tensors + def backward(ctx, dout): + assert dout.is_contiguous() + a, weight, bias, mean, var = ctx.saved_tensors # heuristics for amount of parallel reduction stream for DG/DB - N = w.shape[0] - GROUP_SIZE_M = 64 - if N <= 8192: GROUP_SIZE_M = 96 - if N <= 4096: GROUP_SIZE_M = 128 - if N <= 1024: GROUP_SIZE_M = 256 + N = weight.shape[0] # allocate output - locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') - _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - dx = torch.empty_like(dy) + da = torch.empty_like(dout) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB - x_arg = x.reshape(-1, x.shape[-1]) + x_arg = a.reshape(-1, a.shape[-1]) M, N = x_arg.shape - _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, - x_arg.stride(0), N, ctx.eps, - BLOCK_SIZE_N=ctx.BLOCK_SIZE, - GROUP_SIZE_M=GROUP_SIZE_M, - num_warps=ctx.num_warps) - grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device) + dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device) + _layer_norm_bwd_dx_fused[(M,)]( + da, + dout, + a, + weight, + mean, var, + x_arg.stride(0), M, N, + ctx.eps, + BLOCK_SIZE_N=ctx.BLOCK_SIZE, + num_warps=ctx.num_warps, + ) # accumulate partial sums in separate kernel - _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128) - return dx, None, dw, db, None + grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] + _layer_norm_bwd_dwdb[grid]( + a, dout, + mean, var, + dweight, + dbias, + M, + N, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=128, + ) + return (da, None, dweight, dbias, None, None, + None, None, None, None, + None, + None, None, None, + None, + None, None, None, + None, None, None, + None, None, None) -layer_norm = LayerNorm.apply - +def layer_norm(a, normalized_shape, weight, bias, eps): + return LayerNorm.apply(a, normalized_shape, weight, bias, eps) def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): + torch.manual_seed(0) # create data x_shape = (M, N) w_shape = (x_shape[-1], ) @@ -224,11 +269,11 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', - plot_name='layer-norm-backward', - args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} + plot_name='layer-norm', + args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'} ) ) -def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): +def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) @@ -258,4 +303,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c return gbps(ms), gbps(max_ms), gbps(min_ms) +# test_layer_norm(1151, 8192, torch.float16) bench_layer_norm.run(save_path='.', print_data=True)