Layer NormalizationΒΆ

05 layer norm

Out:

layer-norm-backward:
          N      Triton       Torch        Apex
0    1024.0  311.088617   98.698793  299.707322
1    1536.0  351.085717  135.032961  335.127270
2    2048.0  427.408686  160.627450  321.254900
3    2560.0  461.954908  181.775141  328.556154
4    3072.0  519.211251  192.501302  319.168834
5    3584.0  554.941930  208.271186  312.785456
6    4096.0  568.231237  220.412561  300.623865
7    4608.0  500.416301  232.825259  287.251954
8    5120.0  527.381977  241.414550  281.834861
9    5632.0  547.238891  242.671458  288.820505
10   6144.0  550.208948  251.631408  287.438593
11   6656.0  534.260858  256.410903  286.279570
12   7168.0  510.480705  257.919035  282.482768
13   7680.0  482.513091  265.208635  279.272719
14   8192.0  463.698115  267.130429  276.523207
15   8704.0  417.791980  263.424974  279.646592
16   9216.0  428.651187  273.066667  288.751954
17   9728.0  439.683593  280.615388  288.950501
18  10240.0  446.836366  286.767793  289.469963
19  10752.0  432.241202  246.229020  288.644296
20  11264.0  429.786952  242.671458  283.966395
21  11776.0  423.089806  248.351483  287.219500
22  12288.0  418.314886  253.796902  294.323369
23  12800.0  416.260178  253.047766  289.265522
24  13312.0  413.844560  253.965017  291.769875
25  13824.0  403.130022  256.792581  291.543045
26  14336.0  397.302533  255.619613  288.402346
27  14848.0  381.533186  255.816222  287.844912
28  15360.0  371.987899  261.818180  288.450715
29  15872.0  366.276922  261.626369  289.679087

import torch
import triton.language as tl
import triton

# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
    BLOCK_SIZE = META['BLOCK_SIZE']
    # position of elements processed by this program
    row =  tl.program_id(0)
    cols = tl.arange(0, BLOCK_SIZE)
    mask = cols < N
    # offset data pointers to start at the row of interest
    X += row * stride
    Y += row * stride
    # load data and cast to float32
    x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
    # compute mean
    mean = tl.sum(x, axis=0) / N
    # compute std
    xmean = tl.where(mask, x - mean, 0.)
    var   = tl.sum(xmean * xmean, axis=0) / N
    rstd  = 1 / tl.sqrt(var + eps)
    xhat  = xmean*rstd
    # write-back mean/rstd
    tl.store(M + row, mean)
    tl.store(V + row, rstd)
    # multiply by weight and add bias
    w = tl.load(W + cols, mask=mask)
    b = tl.load(B + cols, mask=mask)
    y = xhat * w + b
    # write-back
    tl.store(Y + cols, y, mask=mask)


# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
                       stride, N, eps,
                       **META):
    GROUP_SIZE_M = META['GROUP_SIZE_M']
    BLOCK_SIZE_N = META['BLOCK_SIZE_N']
    # position of elements processed by this program
    row =  tl.program_id(0)
    cols = tl.arange(0, BLOCK_SIZE_N)
    mask = cols < N
    # offset data pointers to start at the row of interest
    X  += row * stride
    DY += row * stride
    DX += row * stride
    # offset locks and weight/bias gradient pointer
    # each kernel instance accumulates partial sums for
    # DW and DB into one of GROUP_SIZE_M independent buffers
    # these buffers stay in the L2, which allow this kernel
    # to be fast
    lock_id = row % GROUP_SIZE_M
    Lock   += lock_id
    Count   = Lock + GROUP_SIZE_M
    DW      = DW + lock_id*N + cols
    DB      = DB + lock_id*N + cols
    # load data to SRAM
    x     = tl.load(X  + cols, mask=mask, other=0).to(tl.float32)
    dy    = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
    w     = tl.load(W  + cols, mask=mask).to(tl.float32)
    mean  = tl.load(M + row)
    rstd  = tl.load(V + row)
    # compute dx
    xhat  = (x - mean)*rstd
    wdy   = w * dy
    xhat  = tl.where(mask, xhat, 0.)
    wdy   = tl.where(mask, wdy , 0.)
    mean1 = tl.sum(xhat * wdy, axis=0) / N
    mean2 = tl.sum(wdy, axis=0) / N
    dx    = (wdy - (xhat*mean1 + mean2))*rstd
    # write-back dx
    tl.store(DX + cols, dx, mask=mask)
    # accumulate partial sums for dw/db
    partial_dw = (dy*xhat).to(w.dtype)
    partial_db = (dy).to(w.dtype)
    while tl.atomic_cas(Lock, 0, 1) == 1:
        pass
    count = tl.load(Count)
    # first store doesn't accumulate
    if count == 0:
        tl.atomic_xchg(Count, 1)
    else:
        partial_dw += tl.load(DW, mask=mask)
        partial_db += tl.load(DB, mask=mask)
    tl.store(DW, partial_dw, mask=mask)
    tl.store(DB, partial_db, mask=mask)
    # release lock
    tl.atomic_xchg(Lock, 0)

# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
    pid = tl.program_id(0)
    BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
    BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
    cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    dw   = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    db   = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for i in range(0, M, BLOCK_SIZE_M):
        rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
        mask = (rows[:, None] < M) & (cols[None, :] < N)
        offs = rows[:, None]*N + cols[None, :]
        dw += tl.load(DW + offs, mask=mask, other=0.)
        db += tl.load(DB + offs, mask=mask, other=0.)
    sum_dw = tl.sum(dw, axis=0)
    sum_db = tl.sum(db, axis=0)
    tl.store(FINAL_DW + cols, sum_dw, mask=cols<N)
    tl.store(FINAL_DB + cols, sum_db, mask=cols<N)

class LayerNorm(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, normalized_shape, weight, bias, eps):
        # allocate output
        y = torch.empty_like(x)
        # reshape input data into 2D tensor
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
        rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
        # Less than 64KB per feature: enqueue fused kernel
        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE     = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if N > BLOCK_SIZE:
            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
        # heuristics for number of warps
        num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
        # enqueue kernel
        _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
                                    x_arg.stride(0), N, eps,
                                    BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
        ctx.save_for_backward(x, weight, bias, mean, rstd)
        ctx.BLOCK_SIZE = BLOCK_SIZE
        ctx.num_warps  = num_warps
        ctx.eps        = eps
        return y

    @staticmethod
    def backward(ctx, dy):
        x, w, b, m, v = ctx.saved_tensors
        # heuristics for amount of parallel reduction stream for DG/DB
        N = w.shape[0]
        GROUP_SIZE_M = 64
        if N <= 8192: GROUP_SIZE_M = 96
        if N <= 4096: GROUP_SIZE_M = 128
        if N <= 1024: GROUP_SIZE_M = 256
        # allocate output
        locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda')
        _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
        _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
        dw  = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
        db  = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
        dx = torch.empty_like(dy)
        # enqueue kernel using forward pass heuristics
        # also compute partial sums for DW and DB
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
                                       x_arg.stride(0), N, ctx.eps,
                                       BLOCK_SIZE_N=ctx.BLOCK_SIZE,
                                       GROUP_SIZE_M=GROUP_SIZE_M,
                                       num_warps=ctx.num_warps)
        grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
        # accumulate partial sums in separate kernel
        _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
                                   BLOCK_SIZE_M = 32,
                                   BLOCK_SIZE_N = 128)
        return dx, None, dw, db, None


layer_norm = LayerNorm.apply


def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
    # create data
    x_shape = (M, N)
    w_shape = (x_shape[-1], )
    weight  = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
    bias    = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
    x       = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
    dy      = .1*torch.randn_like(x)
    x.requires_grad_(True)
    # forward pass
    y_tri = layer_norm(x, w_shape, weight, bias, eps)
    y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
    # backward pass (triton)
    y_tri.backward(dy, retain_graph=True)
    dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
    x.grad, weight.grad, bias.grad = None, None, None
    # backward pass (torch)
    y_ref.backward(dy, retain_graph=True)
    dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
    # compare
    triton.testing.assert_almost_equal(y_tri, y_ref)
    triton.testing.assert_almost_equal(dx_tri, dx_ref)
    triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
    triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],
        x_vals=[512 * i for i in range(2, 32)],
        line_arg='provider',
        line_vals=['triton', 'torch', 'apex'],
        line_names=['Triton', 'Torch', 'Apex'],
        styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
        ylabel='GB/s',
        plot_name='layer-norm-backward',
        args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
    )
)
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'):
    # create data
    x_shape = (M, N)
    w_shape = (x_shape[-1], )
    weight  = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
    bias    = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
    x       = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
    dy      = .1*torch.randn_like(x)
    x.requires_grad_(True)
    # utility functions
    if provider == 'triton':
        y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
    if provider == 'torch':
        y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
    if provider == 'apex':
        import apex
        apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
        y_fwd = lambda: apex_layer_norm(x)
    # forward pass
    if mode == 'forward':
        gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6
        ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
    # backward pass
    if mode == 'backward':
        gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6
        y = y_fwd()
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
                                                     grad_to_none=[x], rep=500)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

bench_layer_norm.run(save_path='.', print_data=True)

Total running time of the script: ( 2 minutes 11.205 seconds)

Gallery generated by Sphinx-Gallery