[OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)
* dds layout now internally re-uses dsd code path for increased code * at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks. * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros * blocksparse softmax now no longer modifies any data in-place * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention. * unit tests now test backward pass
This commit is contained in:
		@@ -10,77 +10,108 @@ import triton
 | 
			
		||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
 | 
			
		||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
 | 
			
		||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
 | 
			
		||||
    # set seed
 | 
			
		||||
    torch.random.manual_seed(0)
 | 
			
		||||
    seed = 0
 | 
			
		||||
    torch.manual_seed(seed)
 | 
			
		||||
    is_sdd = MODE == "sdd"
 | 
			
		||||
    is_dsd = MODE == "dsd"
 | 
			
		||||
    is_dds = MODE == "dds"
 | 
			
		||||
    do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
 | 
			
		||||
    do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
 | 
			
		||||
    # create inputs
 | 
			
		||||
    a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
 | 
			
		||||
    b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
 | 
			
		||||
    # create op
 | 
			
		||||
    a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
 | 
			
		||||
    b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
 | 
			
		||||
    c_shape = (Z, H, M, N)
 | 
			
		||||
    shape = {
 | 
			
		||||
        "sdd": (M, N),
 | 
			
		||||
        "dsd": (a.shape[2], a.shape[3]),
 | 
			
		||||
        "dds": (b.shape[2], b.shape[3]),
 | 
			
		||||
        "dsd": (a_shape[2], a_shape[3]),
 | 
			
		||||
        "dds": (b_shape[2], b_shape[3]),
 | 
			
		||||
    }[MODE]
 | 
			
		||||
    layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
 | 
			
		||||
    layout[1, 2, :] = 0
 | 
			
		||||
    layout[1, :, 1] = 0
 | 
			
		||||
    # create data
 | 
			
		||||
    a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
 | 
			
		||||
    b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
 | 
			
		||||
    dc_ref, dc_tri = triton.testing.make_pair(c_shape)
 | 
			
		||||
    # compute [torch]
 | 
			
		||||
    dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
 | 
			
		||||
    a_ref = do_mask(a_ref) if is_dsd else a_ref
 | 
			
		||||
    b_ref = do_mask(b_ref) if is_dds else b_ref
 | 
			
		||||
    a_ref.retain_grad()
 | 
			
		||||
    b_ref.retain_grad()
 | 
			
		||||
    c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
 | 
			
		||||
                         b_ref.transpose(2, 3) if TRANS_B else b_ref)
 | 
			
		||||
    c_ref.backward(dc_ref)
 | 
			
		||||
    c_ref = do_sparsify(c_ref) if is_sdd else c_ref
 | 
			
		||||
    da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
 | 
			
		||||
    db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
 | 
			
		||||
    # triton result
 | 
			
		||||
    dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
 | 
			
		||||
    a_tri = do_sparsify(a_tri) if is_dsd else a_tri
 | 
			
		||||
    b_tri = do_sparsify(b_tri) if is_dds else b_tri
 | 
			
		||||
    a_tri.retain_grad()
 | 
			
		||||
    b_tri.retain_grad()
 | 
			
		||||
    op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
 | 
			
		||||
    ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
 | 
			
		||||
    rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
 | 
			
		||||
    rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
 | 
			
		||||
    # torch result
 | 
			
		||||
    ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
 | 
			
		||||
    tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
 | 
			
		||||
    ta = ta.transpose(2, 3) if TRANS_A else ta
 | 
			
		||||
    tb = tb.transpose(2, 3) if TRANS_B else tb
 | 
			
		||||
    tc = torch.matmul(ta, tb)
 | 
			
		||||
    tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
 | 
			
		||||
    tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
 | 
			
		||||
    c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
 | 
			
		||||
    triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
 | 
			
		||||
    da_tri = a_tri.grad
 | 
			
		||||
    db_tri = b_tri.grad
 | 
			
		||||
    # compare
 | 
			
		||||
    triton.testing.assert_almost_equal(rc, tc)
 | 
			
		||||
    triton.testing.assert_almost_equal(c_ref, c_tri)
 | 
			
		||||
    triton.testing.assert_almost_equal(da_ref, da_tri)
 | 
			
		||||
    triton.testing.assert_almost_equal(db_ref, db_tri)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
 | 
			
		||||
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
 | 
			
		||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
 | 
			
		||||
def test_softmax(BLOCK, WIDTH, DTYPE):
 | 
			
		||||
    is_causal = True
 | 
			
		||||
configs = [
 | 
			
		||||
    (16, 256),
 | 
			
		||||
    (32, 576),
 | 
			
		||||
    (64, 1871),
 | 
			
		||||
    (128, 2511),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("is_dense", [False, True])
 | 
			
		||||
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
 | 
			
		||||
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
 | 
			
		||||
    # set seed
 | 
			
		||||
    torch.random.manual_seed(0)
 | 
			
		||||
    Z, H, M, N = 1, 1, WIDTH, WIDTH
 | 
			
		||||
    scale = 0.4
 | 
			
		||||
    # create inputs
 | 
			
		||||
    layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
 | 
			
		||||
    x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
 | 
			
		||||
    at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
 | 
			
		||||
    Z, H, M, N = 2, 3, WIDTH, WIDTH
 | 
			
		||||
    # initialize layout
 | 
			
		||||
    # make sure each row has at least one non-zero element
 | 
			
		||||
    torch.diagonal(layout)[:] = 1
 | 
			
		||||
    torch.diagonal(at_mask)[:] = 1
 | 
			
		||||
    kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
 | 
			
		||||
    kp_mask[:] = 0
 | 
			
		||||
    kp_mask[kp_mask == 1.0] = float("-inf")
 | 
			
		||||
    # triton result
 | 
			
		||||
    op = triton.ops.blocksparse.softmax(layout, BLOCK)
 | 
			
		||||
    tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
 | 
			
		||||
    ty = op(
 | 
			
		||||
        tx,
 | 
			
		||||
        scale=scale,
 | 
			
		||||
        key_padding_mask=kp_mask,
 | 
			
		||||
        key_padding_mask_mode="add",
 | 
			
		||||
        attn_mask=at_mask.to(DTYPE),
 | 
			
		||||
        attn_mask_mode="mul",
 | 
			
		||||
        is_causal=is_causal,
 | 
			
		||||
    )
 | 
			
		||||
    # torch result
 | 
			
		||||
    rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
 | 
			
		||||
    # broadcast at_mask to the same shape as rx
 | 
			
		||||
    layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
 | 
			
		||||
    if is_dense:
 | 
			
		||||
        layout[:] = 1
 | 
			
		||||
    else:
 | 
			
		||||
        layout[1, 2, :] = 0
 | 
			
		||||
        layout[1, :, 1] = 0
 | 
			
		||||
    # initialize data
 | 
			
		||||
    a_shape = (Z, H, M, N)
 | 
			
		||||
    a_ref, a_tri = triton.testing.make_pair(a_shape)
 | 
			
		||||
    dout_ref, dout_tri = triton.testing.make_pair(a_shape)
 | 
			
		||||
    # compute [torch]
 | 
			
		||||
    a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
 | 
			
		||||
    a_ref.retain_grad()
 | 
			
		||||
    at_mask = torch.ones((M, N), device="cuda")
 | 
			
		||||
    if is_causal:
 | 
			
		||||
        at_mask = torch.tril(at_mask)
 | 
			
		||||
    M = at_mask[None, None, :, :] + torch.zeros_like(rx)
 | 
			
		||||
    rx[M == 0] = float("-inf")
 | 
			
		||||
    # rx += kp_mask[:, None, None, :]
 | 
			
		||||
    ry = torch.softmax(rx * scale, -1)
 | 
			
		||||
    ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
 | 
			
		||||
    M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
 | 
			
		||||
    a_ref[M == 0] = float("-inf")
 | 
			
		||||
    out_ref = torch.softmax(a_ref * scale, -1)
 | 
			
		||||
    out_ref.backward(dout_ref)
 | 
			
		||||
    out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
 | 
			
		||||
    da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
 | 
			
		||||
    # compute [triton]
 | 
			
		||||
    a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
 | 
			
		||||
    a_tri.retain_grad()
 | 
			
		||||
    dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
 | 
			
		||||
    op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
 | 
			
		||||
    out_tri = op(a_tri, scale=scale, is_causal=is_causal)
 | 
			
		||||
    out_tri.backward(dout_tri)
 | 
			
		||||
    da_tri = a_tri.grad
 | 
			
		||||
    # compare
 | 
			
		||||
    triton.testing.assert_almost_equal(ry, ty)
 | 
			
		||||
    triton.testing.assert_almost_equal(out_tri, out_ref)
 | 
			
		||||
    triton.testing.assert_almost_equal(da_tri, da_ref)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("block", [16, 32, 64])
 | 
			
		||||
@@ -99,14 +130,6 @@ def test_attention_fwd_bwd(
 | 
			
		||||
    qkvs = [
 | 
			
		||||
        torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
 | 
			
		||||
    ]
 | 
			
		||||
    attn_mask = torch.tril(
 | 
			
		||||
        torch.ones(
 | 
			
		||||
            [n_ctx, n_ctx],
 | 
			
		||||
            device="cuda",
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
        ),
 | 
			
		||||
        diagonal=0,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Triton:
 | 
			
		||||
    n_blocks = n_ctx // block
 | 
			
		||||
@@ -115,7 +138,7 @@ def test_attention_fwd_bwd(
 | 
			
		||||
    query.retain_grad()
 | 
			
		||||
    key.retain_grad()
 | 
			
		||||
    value.retain_grad()
 | 
			
		||||
    attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
 | 
			
		||||
    attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
 | 
			
		||||
    # ad hoc loss
 | 
			
		||||
    loss = (attn_out ** 2).mean()
 | 
			
		||||
    loss.backward()
 | 
			
		||||
@@ -123,6 +146,8 @@ def test_attention_fwd_bwd(
 | 
			
		||||
 | 
			
		||||
    # Torch version:
 | 
			
		||||
    torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
 | 
			
		||||
    attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
 | 
			
		||||
    attn_mask = torch.tril(attn_mask, diagonal=0)
 | 
			
		||||
    attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
 | 
			
		||||
    torch_q.retain_grad()
 | 
			
		||||
    torch_k.retain_grad()
 | 
			
		||||
@@ -147,7 +172,6 @@ def test_attention_fwd_bwd(
 | 
			
		||||
def triton_attention(
 | 
			
		||||
    layout,
 | 
			
		||||
    block: int,
 | 
			
		||||
    attn_mask: torch.Tensor,
 | 
			
		||||
    query: torch.Tensor,
 | 
			
		||||
    key: torch.Tensor,
 | 
			
		||||
    value: torch.Tensor,
 | 
			
		||||
@@ -155,12 +179,9 @@ def triton_attention(
 | 
			
		||||
):
 | 
			
		||||
    sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
 | 
			
		||||
    sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
 | 
			
		||||
    sparse_softmax = triton.ops.blocksparse.softmax(
 | 
			
		||||
        layout,
 | 
			
		||||
        block,
 | 
			
		||||
    )
 | 
			
		||||
    sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
 | 
			
		||||
 | 
			
		||||
    w = sparse_dot_sdd_nt(query, key)
 | 
			
		||||
    w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul")
 | 
			
		||||
    w = sparse_softmax(w, scale=scale, is_causal=True)
 | 
			
		||||
    a = sparse_dot_dsd_nn(w, value)
 | 
			
		||||
    return a
 | 
			
		||||
 
 | 
			
		||||
@@ -242,8 +242,31 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dsd_lut(layout, block, step, trans, device):
 | 
			
		||||
    """
 | 
			
		||||
    Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
 | 
			
		||||
    Example (BLOCK=32, STEP=16)
 | 
			
		||||
    [[1, 0, 0, 1, 0],
 | 
			
		||||
     [0, 1, 1, 0, 1],
 | 
			
		||||
     [1, 0, 1, 0, 0]]
 | 
			
		||||
 | 
			
		||||
    Then the offsets for A are
 | 
			
		||||
     [0 , 16, 32, 48] <- row 0
 | 
			
		||||
      \\----/  \\----/
 | 
			
		||||
      col=0   col=3
 | 
			
		||||
     [64, 80, 96, 112, 128, 144] <- row 1
 | 
			
		||||
      \\----/   \\----/  \\------/
 | 
			
		||||
       col=1    col=2    col=3
 | 
			
		||||
     [160, 176, 192, 208]
 | 
			
		||||
    which leads to increments table
 | 
			
		||||
    [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
 | 
			
		||||
 | 
			
		||||
    Because B is dense, the offsets are
 | 
			
		||||
    [0, 16, 96, 112] <- row 0
 | 
			
		||||
    [32, 48, 64, 80]  <- row 1
 | 
			
		||||
    [0, 16, 64, 80]   <- row 2
 | 
			
		||||
    """
 | 
			
		||||
    sizes = torch.sum(layout, 2 if trans else 1)
 | 
			
		||||
    head_id, col_id = sizes.nonzero(as_tuple=True)
 | 
			
		||||
    head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
 | 
			
		||||
    sizes = sizes.flatten()
 | 
			
		||||
    segments = sizes * step
 | 
			
		||||
    # pointer increments
 | 
			
		||||
@@ -258,13 +281,6 @@ def dsd_lut(layout, block, step, trans, device):
 | 
			
		||||
    # -------------------------------
 | 
			
		||||
    # dense input pointer increments
 | 
			
		||||
    # -------------------------------
 | 
			
		||||
    # given a list of the indices for the first element of each non-zero block.
 | 
			
		||||
    # For example, for the indices
 | 
			
		||||
    # [32, 80, 128, 256, 288]
 | 
			
		||||
    # we would generate the increments
 | 
			
		||||
    # [32, 48, 48, 128, 32]
 | 
			
		||||
    #        ^
 | 
			
		||||
    #   index of first element
 | 
			
		||||
    # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
 | 
			
		||||
    # that is smaller than the block size, so we need to do a bit of extra work
 | 
			
		||||
    # to handle this case
 | 
			
		||||
@@ -324,116 +340,11 @@ def dsd_lut(layout, block, step, trans, device):
 | 
			
		||||
# -----------------------------
 | 
			
		||||
# Dense = Dense x Sparse (DDS)
 | 
			
		||||
# -----------------------------
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _dds_kernel(
 | 
			
		||||
    A, B, C,
 | 
			
		||||
    stride_za, stride_ha, stride_ma, stride_ka,
 | 
			
		||||
    stride_zb, stride_hb, stride_bk, stride_bn,
 | 
			
		||||
    stride_zc, stride_hc, stride_mc, stride_nc,
 | 
			
		||||
    DS0, DS1, lut,
 | 
			
		||||
    TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
 | 
			
		||||
    GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    # ------------ #
 | 
			
		||||
    # - Prologue - #
 | 
			
		||||
    # ------------ #
 | 
			
		||||
    pid_m = tl.program_id(0)
 | 
			
		||||
    pid_n = tl.program_id(1)
 | 
			
		||||
    num_pid_m = tl.num_programs(0)
 | 
			
		||||
    num_pid_n = tl.num_programs(1)
 | 
			
		||||
    pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
 | 
			
		||||
    pid_z = tl.program_id(2)
 | 
			
		||||
    header = lut + pid_n * 4
 | 
			
		||||
    offset = tl.load(header + 0)
 | 
			
		||||
    AS1 = tl.load(header + 1)
 | 
			
		||||
    column = tl.load(header + 2)
 | 
			
		||||
    off_h = tl.load(header + 3)
 | 
			
		||||
    pinc = lut + offset
 | 
			
		||||
    # initialize pointers to A (dense)
 | 
			
		||||
    offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
 | 
			
		||||
    offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
 | 
			
		||||
    start_ak = tl.load(pinc)
 | 
			
		||||
    start_ak = tl.multiple_of(start_ak, 8)
 | 
			
		||||
    offs_ak = start_ak + tl.arange(0, TILE_K)
 | 
			
		||||
    ptrs_a = A + pid_z * stride_za \
 | 
			
		||||
        + off_h * stride_ha \
 | 
			
		||||
        + offs_am[:, None] * stride_ma \
 | 
			
		||||
        + offs_ak[None, :] * stride_ka
 | 
			
		||||
    # initialize pointers to B (sparse)
 | 
			
		||||
    block_id = tl.load(pinc + 1)
 | 
			
		||||
    block_id = tl.multiple_of(block_id, 8)
 | 
			
		||||
    offs_bn = tl.arange(0, TILE_N)
 | 
			
		||||
    offs_bk = tl.arange(0, TILE_K)
 | 
			
		||||
    ptrs_b = B + pid_z * stride_zb \
 | 
			
		||||
        + block_id * stride_hb \
 | 
			
		||||
        + offs_bn[None, :] * stride_bn \
 | 
			
		||||
        + offs_bk[:, None] * stride_bk
 | 
			
		||||
    # ---------------- #
 | 
			
		||||
    #    Inner Loop    #
 | 
			
		||||
    # ---------------- #
 | 
			
		||||
    acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
 | 
			
		||||
    for k in range(AS1, 0, -TILE_K):
 | 
			
		||||
        a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
 | 
			
		||||
        b = tl.load(ptrs_b, mask=True)
 | 
			
		||||
        acc += tl.dot(a, b)
 | 
			
		||||
        pinc += 2
 | 
			
		||||
        inc_a = tl.load(pinc)
 | 
			
		||||
        inc_b = tl.load(pinc + 1)
 | 
			
		||||
        inc_a = tl.multiple_of(inc_a, 8)
 | 
			
		||||
        inc_b = tl.multiple_of(inc_b, 8)
 | 
			
		||||
        inc_a = inc_a * stride_ka
 | 
			
		||||
        ptrs_a += inc_a
 | 
			
		||||
        ptrs_b += inc_b
 | 
			
		||||
    # ---------------- #
 | 
			
		||||
    #    Epilogue      #
 | 
			
		||||
    # ---------------- #
 | 
			
		||||
    c = acc.to(C.dtype.element_ty)
 | 
			
		||||
    # initialize pointers to C (dense)
 | 
			
		||||
    offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
 | 
			
		||||
    offs_cn = column * TILE_N + tl.arange(0, TILE_N)
 | 
			
		||||
    ptrs_c = C + off_h * stride_hc \
 | 
			
		||||
        + pid_z * stride_zc \
 | 
			
		||||
        + offs_cm[:, None] * stride_mc \
 | 
			
		||||
        + offs_cn[None, :] * stride_nc
 | 
			
		||||
    # write back
 | 
			
		||||
    tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
 | 
			
		||||
# AB = (B^T A^T)^T
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
 | 
			
		||||
    if a.stride(2) != 1 and a.stride(3) != 1:
 | 
			
		||||
        a = a.contiguous()
 | 
			
		||||
    if b.stride(2) != 1 and b.stride(3) != 1:
 | 
			
		||||
        b = b.contiguous()
 | 
			
		||||
    # shapes / dtypes
 | 
			
		||||
    AS0 = a.size(0)
 | 
			
		||||
    AS1 = a.size(1)
 | 
			
		||||
    AS2 = a.size(3 if trans_a else 2)
 | 
			
		||||
    BS2 = block * spdims[1 if trans_b else 2]
 | 
			
		||||
    dtype = a.dtype
 | 
			
		||||
    # output
 | 
			
		||||
    CS0 = AS0
 | 
			
		||||
    CS1 = AS1
 | 
			
		||||
    CS2 = BS2 if trans_c else AS2
 | 
			
		||||
    CS3 = AS2 if trans_c else BS2
 | 
			
		||||
    if out is None:
 | 
			
		||||
        c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
 | 
			
		||||
    else:
 | 
			
		||||
        assert out.shape == (CS0, CS1, CS2, CS3)
 | 
			
		||||
        c = out
 | 
			
		||||
    TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
 | 
			
		||||
    grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
 | 
			
		||||
    _dds_kernel[grid](
 | 
			
		||||
        a, b, c,
 | 
			
		||||
        a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
 | 
			
		||||
        b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
 | 
			
		||||
        c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
 | 
			
		||||
        AS2, BS2, lut,
 | 
			
		||||
        TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
 | 
			
		||||
        num_warps=4, GROUP_SIZE_M=4,
 | 
			
		||||
    )
 | 
			
		||||
    return c
 | 
			
		||||
    return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
 | 
			
		||||
 | 
			
		||||
##############
 | 
			
		||||
#  MAIN API  #
 | 
			
		||||
 
 | 
			
		||||
@@ -5,230 +5,235 @@ import triton.language as tl
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def num_warps(n):
 | 
			
		||||
    if n < 512:
 | 
			
		||||
    if n <= 128:
 | 
			
		||||
        return 1
 | 
			
		||||
    if n <= 256:
 | 
			
		||||
        return 2
 | 
			
		||||
    if n <= 512:
 | 
			
		||||
        return 4
 | 
			
		||||
    if n < 2048:
 | 
			
		||||
    if n <= 4096:
 | 
			
		||||
        return 8
 | 
			
		||||
    return 16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
 | 
			
		||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _forward(
 | 
			
		||||
    X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
 | 
			
		||||
    TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
 | 
			
		||||
    KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
 | 
			
		||||
def _blocksparse_softmax_fwd(
 | 
			
		||||
    Out, A, stride_xz, LUT,
 | 
			
		||||
    R, extent, stride_zr, stride_hr,  # relative attention
 | 
			
		||||
    scale, is_causal,
 | 
			
		||||
    ROW_SIZE: tl.constexpr,
 | 
			
		||||
    BLOCK_SIZE: tl.constexpr,
 | 
			
		||||
    IS_DENSE: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    pidhm = tl.program_id(0)
 | 
			
		||||
    pidz = tl.program_id(1)
 | 
			
		||||
    h = tl.program_id(0)
 | 
			
		||||
    m = tl.program_id(1)
 | 
			
		||||
    z = tl.program_id(2)
 | 
			
		||||
    # create index ranges
 | 
			
		||||
    rxm = pidhm % BLOCK
 | 
			
		||||
    rbm = pidhm // BLOCK
 | 
			
		||||
    rxn = tl.arange(0, TN) % BLOCK
 | 
			
		||||
    rbn = tl.arange(0, TN) // BLOCK
 | 
			
		||||
    hm = h * tl.num_programs(1) + m
 | 
			
		||||
    lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
 | 
			
		||||
    block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
 | 
			
		||||
    # extract information from LUT
 | 
			
		||||
    header = LUT + rbm * 2
 | 
			
		||||
    header = LUT + (hm // BLOCK_SIZE) * 2
 | 
			
		||||
    size = tl.load(header + 0)
 | 
			
		||||
    offset = tl.load(header + 1)
 | 
			
		||||
    check = rbn < size
 | 
			
		||||
    rbmn = tl.where(check, rbn, size - 1)
 | 
			
		||||
    # block id and column id
 | 
			
		||||
    blockid = tl.load(LUT + offset + rbmn * 4 + 0)
 | 
			
		||||
    columnid = tl.load(LUT + offset + rbmn * 4 + 1)
 | 
			
		||||
    rowid = tl.load(LUT + offset + rbmn * 4 + 2)
 | 
			
		||||
    headid = tl.load(LUT + offset + rbmn * 4 + 3)
 | 
			
		||||
    # pointers to X
 | 
			
		||||
    px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
 | 
			
		||||
    x = tl.load(px, mask=check, other=-float('inf'))
 | 
			
		||||
    x = x.to(tl.float32)
 | 
			
		||||
    # apply scale
 | 
			
		||||
    if APPLY_SCALE:
 | 
			
		||||
        x = x * scale
 | 
			
		||||
    # apply RPE
 | 
			
		||||
    if APPLY_RPE:
 | 
			
		||||
        prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
 | 
			
		||||
        rpe = tl.load(prpe, mask=check, other=0)
 | 
			
		||||
        x = x + rpe
 | 
			
		||||
    # apply key-padding mask
 | 
			
		||||
    if APPLY_KP_MASK:
 | 
			
		||||
        pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
 | 
			
		||||
        kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
 | 
			
		||||
        if KP_MASK_MUL:
 | 
			
		||||
            kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
 | 
			
		||||
        x = x + kp_m
 | 
			
		||||
    # apply attention mask
 | 
			
		||||
    if APPLY_ATTN_MASK:
 | 
			
		||||
        pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
 | 
			
		||||
        attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
 | 
			
		||||
        if ATTN_MASK_MUL:
 | 
			
		||||
            attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
 | 
			
		||||
        x = x + attn_m
 | 
			
		||||
    # pointer offset
 | 
			
		||||
    off_a = z * stride_xz
 | 
			
		||||
    off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE  # block indx
 | 
			
		||||
    off_a += (m % BLOCK_SIZE) * BLOCK_SIZE  # row indx
 | 
			
		||||
    # do not need to read column indices in the dense case
 | 
			
		||||
    if IS_DENSE:
 | 
			
		||||
        ns = tl.arange(0, ROW_SIZE)
 | 
			
		||||
    else:
 | 
			
		||||
        off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
 | 
			
		||||
        start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
 | 
			
		||||
        ns = start_n * BLOCK_SIZE + lane_n
 | 
			
		||||
    # load X
 | 
			
		||||
    mask = block_n < size
 | 
			
		||||
    a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
 | 
			
		||||
    a = a.to(tl.float32)
 | 
			
		||||
    # compute
 | 
			
		||||
    out = a
 | 
			
		||||
    out *= scale
 | 
			
		||||
    # apply relative attention
 | 
			
		||||
    if R is not None:
 | 
			
		||||
        R += z * stride_zr
 | 
			
		||||
        R += h * stride_hr
 | 
			
		||||
        off_lo = (extent - m - 1) + ns
 | 
			
		||||
        mask_lo = (off_lo >= 0) & (off_lo < extent)
 | 
			
		||||
        rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
 | 
			
		||||
        out += rel_logits
 | 
			
		||||
    out = out.to(tl.float32)
 | 
			
		||||
    # apply causal mask
 | 
			
		||||
    is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
 | 
			
		||||
    x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
 | 
			
		||||
    out = tl.where((ns > m) & is_causal, -float("inf"), out)
 | 
			
		||||
    # computation
 | 
			
		||||
    x = tl.softmax(x)
 | 
			
		||||
    tl.store(px, x, mask=check)
 | 
			
		||||
    out = tl.softmax(out)
 | 
			
		||||
    # write-back
 | 
			
		||||
    tl.store(Out + off_a + lane_n, out, mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
 | 
			
		||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
 | 
			
		||||
    pidhm = tl.program_id(0)
 | 
			
		||||
    pidz = tl.program_id(1)
 | 
			
		||||
def _blocksparse_softmax_bwd(
 | 
			
		||||
    DA, stride_zdx,
 | 
			
		||||
    DOut, stride_zdout,
 | 
			
		||||
    Out, stride_zout,
 | 
			
		||||
    scale,
 | 
			
		||||
    LUT,
 | 
			
		||||
    DR, extent, stride_zr, stride_hr, stride_er,
 | 
			
		||||
    is_causal,
 | 
			
		||||
    ROW_SIZE: tl.constexpr,
 | 
			
		||||
    BLOCK_SIZE: tl.constexpr,
 | 
			
		||||
    IS_DENSE: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    h = tl.program_id(0)
 | 
			
		||||
    m = tl.program_id(1)
 | 
			
		||||
    z = tl.program_id(2)
 | 
			
		||||
    # create index ranges
 | 
			
		||||
    rxm = pidhm % BLOCK
 | 
			
		||||
    rbm = pidhm // BLOCK
 | 
			
		||||
    rxn = tl.arange(0, TN) % BLOCK
 | 
			
		||||
    rbn = tl.arange(0, TN) // BLOCK
 | 
			
		||||
    # extract information from look-up table
 | 
			
		||||
    header = LUT + rbm * 2
 | 
			
		||||
    hm = h * tl.num_programs(1) + m
 | 
			
		||||
    lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
 | 
			
		||||
    block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
 | 
			
		||||
    # extract information from LUT
 | 
			
		||||
    header = LUT + (hm // BLOCK_SIZE) * 2
 | 
			
		||||
    size = tl.load(header + 0)
 | 
			
		||||
    offset = tl.load(header + 1)
 | 
			
		||||
    # bounds checking on lut
 | 
			
		||||
    check = rbn < size
 | 
			
		||||
    rbmn = tl.where(check, rbn, size - 1)
 | 
			
		||||
    # initialize pointers to block-sparse input
 | 
			
		||||
    blockid = tl.load(LUT + offset + rbmn * 4)
 | 
			
		||||
    X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
 | 
			
		||||
    DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
 | 
			
		||||
    # compute fused softmax backward
 | 
			
		||||
    x = tl.load(X, mask=check, other=0)
 | 
			
		||||
    dx = tl.load(DX, mask=check, other=0)
 | 
			
		||||
    x = x.to(tl.float32)
 | 
			
		||||
    dx = dx.to(tl.float32)
 | 
			
		||||
    y = x * (dx - tl.sum(x * dx, 0)) * scale
 | 
			
		||||
    tl.store(DX, y, mask=check)
 | 
			
		||||
    # row-col offset
 | 
			
		||||
    off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
 | 
			
		||||
    off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
 | 
			
		||||
    mask = block_n < size
 | 
			
		||||
    # pointers
 | 
			
		||||
    As = Out + z * stride_zout + off_mn
 | 
			
		||||
    DOuts = DOut + z * stride_zdout + off_mn
 | 
			
		||||
    # do not need to read column indices in the dense case
 | 
			
		||||
    if IS_DENSE:
 | 
			
		||||
        ns = tl.arange(0, ROW_SIZE)
 | 
			
		||||
    else:
 | 
			
		||||
        off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
 | 
			
		||||
        start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
 | 
			
		||||
        ns = start_n * BLOCK_SIZE + lane_n
 | 
			
		||||
    # load data
 | 
			
		||||
    a = tl.load(As + lane_n, mask=mask, other=0.0)
 | 
			
		||||
    a = a.to(tl.float32)
 | 
			
		||||
    dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
 | 
			
		||||
    dout = dout.to(tl.float32)
 | 
			
		||||
    # compute
 | 
			
		||||
    da = a * (dout - tl.sum(a * dout, 0))
 | 
			
		||||
    da = tl.where((ns > m) & is_causal, 0., da)
 | 
			
		||||
    # apply relative attention
 | 
			
		||||
    if DR is not None:
 | 
			
		||||
        DR += z * stride_zr
 | 
			
		||||
        DR += h * stride_hr
 | 
			
		||||
        off_lo = (extent - m - 1) + ns
 | 
			
		||||
        mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
 | 
			
		||||
        tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
 | 
			
		||||
    da = da * scale
 | 
			
		||||
    # convert da
 | 
			
		||||
    # write-back
 | 
			
		||||
    DAs = DA + z * stride_zdx + off_mn
 | 
			
		||||
    tl.store(DAs + lane_n, da, mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _softmax(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def make_lut(layout, block, device):
 | 
			
		||||
        _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
 | 
			
		||||
        sizes = _empty.clone()
 | 
			
		||||
        # sizes along rows
 | 
			
		||||
        sizes = layout.sum(-1).view(-1)
 | 
			
		||||
        for h in range(layout.shape[0]):
 | 
			
		||||
            sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
 | 
			
		||||
        total_sizes = sizes * block
 | 
			
		||||
        # offsets in block format
 | 
			
		||||
        offsets = torch.zeros_like(sizes)
 | 
			
		||||
        offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
 | 
			
		||||
        # block indices
 | 
			
		||||
        layout_sum = sizes.sum()
 | 
			
		||||
        idx = torch.arange(layout_sum, device=layout.device)
 | 
			
		||||
        layout_nonzero = layout.nonzero(as_tuple=False)
 | 
			
		||||
        head = layout_nonzero[:, 0]
 | 
			
		||||
        rows = layout_nonzero[:, 1]
 | 
			
		||||
        columns = layout_nonzero[:, 2]
 | 
			
		||||
        core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
 | 
			
		||||
        # construct look-up table
 | 
			
		||||
        offsets = offsets * 4 + 2 * sizes.numel()
 | 
			
		||||
        columns = layout.nonzero(as_tuple=False)[:, 2]
 | 
			
		||||
        header = torch.stack((sizes, offsets), dim=1).view(-1)
 | 
			
		||||
        lut = torch.cat((header, core)).type(torch.int32).to(device)
 | 
			
		||||
        return lut, int(sizes.max())
 | 
			
		||||
        lut = torch.cat((header, columns)).type(torch.int32).to(device)
 | 
			
		||||
        return lut, int(total_sizes.max())
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx, x, scale, rpe,
 | 
			
		||||
        key_padding_mask, attn_mask,
 | 
			
		||||
        kp_mask_mode, attn_mask_mode,
 | 
			
		||||
        is_causal,
 | 
			
		||||
        spdims, block, lut, maxlut
 | 
			
		||||
        ctx, a, scale, rel_logits, is_causal,
 | 
			
		||||
        spdims, block, lut, maxlut, is_dense
 | 
			
		||||
    ):
 | 
			
		||||
        apply_scale = False if scale == 1.0 else True
 | 
			
		||||
        # handle None rpe
 | 
			
		||||
        if rpe is None:
 | 
			
		||||
            apply_rpe = False
 | 
			
		||||
            stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
 | 
			
		||||
            rpe = torch.empty(0, dtype=x.dtype, device=x.device)
 | 
			
		||||
        else:
 | 
			
		||||
            apply_rpe = True
 | 
			
		||||
            stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
 | 
			
		||||
        # handle None key_padding_mask
 | 
			
		||||
        if key_padding_mask is None:
 | 
			
		||||
            apply_kp_mask = False
 | 
			
		||||
            stride_zkpm = 0
 | 
			
		||||
            key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
 | 
			
		||||
        else:
 | 
			
		||||
            apply_kp_mask = True
 | 
			
		||||
            stride_zkpm = key_padding_mask.stride(0)
 | 
			
		||||
        # handle None attention_mask
 | 
			
		||||
        if attn_mask is None:
 | 
			
		||||
            apply_attn_mask = False
 | 
			
		||||
            stride_zattnm = 0
 | 
			
		||||
            attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
 | 
			
		||||
        else:
 | 
			
		||||
            apply_attn_mask = True
 | 
			
		||||
            stride_zattnm = attn_mask.stride(0)
 | 
			
		||||
        # run kernel
 | 
			
		||||
        M = x.shape[0]
 | 
			
		||||
        grid = [spdims[0] * spdims[1] * block, M]
 | 
			
		||||
        _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
 | 
			
		||||
                       stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
 | 
			
		||||
                       BLOCK=block,
 | 
			
		||||
                       APPLY_SCALE=apply_scale,
 | 
			
		||||
                       APPLY_RPE=apply_rpe,
 | 
			
		||||
                       APPLY_KP_MASK=apply_kp_mask,
 | 
			
		||||
                       APPLY_ATTN_MASK=apply_attn_mask,
 | 
			
		||||
                       KP_MASK_MUL=(kp_mask_mode == 'mul'),
 | 
			
		||||
                       ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
 | 
			
		||||
        if scale is not None and isinstance(scale, torch.Tensor):
 | 
			
		||||
            assert scale.device.type == "cpu"
 | 
			
		||||
            scale = scale.item()
 | 
			
		||||
        M = a.shape[0]
 | 
			
		||||
        grid = [spdims[0], spdims[1] * block, M]
 | 
			
		||||
        rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
 | 
			
		||||
        rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
 | 
			
		||||
        # enqueue kernel
 | 
			
		||||
        out = torch.empty_like(a)
 | 
			
		||||
        _blocksparse_softmax_fwd[grid](
 | 
			
		||||
            out, a, a.stride(0), lut,
 | 
			
		||||
            rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1],  # relative attn
 | 
			
		||||
            scale,
 | 
			
		||||
            is_causal,
 | 
			
		||||
            BLOCK_SIZE=block,
 | 
			
		||||
            ROW_SIZE=triton.next_power_of_2(maxlut),
 | 
			
		||||
            IS_DENSE=is_dense,
 | 
			
		||||
            num_warps=num_warps(maxlut)
 | 
			
		||||
        )
 | 
			
		||||
        # save to context
 | 
			
		||||
        ctx.mark_dirty(x)
 | 
			
		||||
        ctx.save_for_backward(x, lut)
 | 
			
		||||
        # ctx.mark_dirty(x)
 | 
			
		||||
        ctx.save_for_backward(out, lut)
 | 
			
		||||
        ctx.spdims = spdims
 | 
			
		||||
        ctx.block = block
 | 
			
		||||
        ctx.maxlut = maxlut
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
        ctx.apply_scale = apply_scale
 | 
			
		||||
        ctx.apply_rpe = apply_rpe
 | 
			
		||||
        ctx.apply_kp_mask = apply_kp_mask
 | 
			
		||||
        ctx.apply_attn_mask = apply_attn_mask
 | 
			
		||||
        ctx.kp_mask_mode = kp_mask_mode
 | 
			
		||||
        ctx.attn_mask_mode = attn_mask_mode
 | 
			
		||||
        return x
 | 
			
		||||
        ctx.rel_shape = rel_shape
 | 
			
		||||
        ctx.rel_strides = rel_strides
 | 
			
		||||
        ctx.rel_dtype = a.dtype
 | 
			
		||||
        ctx.is_dense = is_dense
 | 
			
		||||
        ctx.is_causal = is_causal
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, dx):
 | 
			
		||||
    def backward(ctx, dout):
 | 
			
		||||
        # retrieve from context
 | 
			
		||||
        x, lut = ctx.saved_tensors
 | 
			
		||||
        out, lut = ctx.saved_tensors
 | 
			
		||||
        # relative logits gradients
 | 
			
		||||
        dr = None
 | 
			
		||||
        if ctx.needs_input_grad[3]:
 | 
			
		||||
            dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
 | 
			
		||||
        # run kernel
 | 
			
		||||
        M = x.shape[0]
 | 
			
		||||
        grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
 | 
			
		||||
        _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
 | 
			
		||||
        return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
 | 
			
		||||
        M = out.shape[0]
 | 
			
		||||
        grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
 | 
			
		||||
        da = torch.empty_like(dout)
 | 
			
		||||
        _blocksparse_softmax_bwd[grid](
 | 
			
		||||
            da, da.stride(0),
 | 
			
		||||
            dout, dout.stride(0),
 | 
			
		||||
            out, out.stride(0),
 | 
			
		||||
            ctx.scale,
 | 
			
		||||
            lut,
 | 
			
		||||
            dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
 | 
			
		||||
            ctx.is_causal,
 | 
			
		||||
            BLOCK_SIZE=ctx.block,
 | 
			
		||||
            ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
 | 
			
		||||
            IS_DENSE=ctx.is_dense,
 | 
			
		||||
            num_warps=num_warps(ctx.maxlut)
 | 
			
		||||
        )
 | 
			
		||||
        return (da, None, None, dr, None,
 | 
			
		||||
                None, None, None, None, None,
 | 
			
		||||
                None,
 | 
			
		||||
                None, None, None,
 | 
			
		||||
                None,
 | 
			
		||||
                None, None, None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class softmax:
 | 
			
		||||
 | 
			
		||||
    def make_lut(self, device):
 | 
			
		||||
        key = (device, )
 | 
			
		||||
        if key not in self.lut_cache:
 | 
			
		||||
            self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
 | 
			
		||||
        return self.lut_cache[key]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, layout, block):
 | 
			
		||||
    def __init__(self, layout, block, device, is_dense=False):
 | 
			
		||||
        self.spdims = layout.shape
 | 
			
		||||
        self.layout = layout
 | 
			
		||||
        self.block = block
 | 
			
		||||
        self.lut_cache = dict()
 | 
			
		||||
        self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
 | 
			
		||||
        self.is_dense = is_dense
 | 
			
		||||
 | 
			
		||||
    def __call__(
 | 
			
		||||
        self, x, scale=1., rpe=None,
 | 
			
		||||
        key_padding_mask=None, attn_mask=None,
 | 
			
		||||
        key_padding_mask_mode='add', attn_mask_mode='add',
 | 
			
		||||
        is_causal=False
 | 
			
		||||
    ):
 | 
			
		||||
        if rpe is not None and rpe.dtype != x.dtype:
 | 
			
		||||
            raise ValueError('relative position embedding must be %s' % x.dtype)
 | 
			
		||||
        if attn_mask is not None and attn_mask.dtype != x.dtype:
 | 
			
		||||
            raise ValueError('Attention mask must be %s' % x.dtype)
 | 
			
		||||
        if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
 | 
			
		||||
            raise ValueError('Key padding mask must be %s' % x.dtype)
 | 
			
		||||
        lut, maxlut = self.make_lut(x.device)
 | 
			
		||||
        x = _softmax.apply(
 | 
			
		||||
            x, scale, rpe,
 | 
			
		||||
            key_padding_mask, attn_mask,
 | 
			
		||||
            key_padding_mask_mode, attn_mask_mode,
 | 
			
		||||
            is_causal,
 | 
			
		||||
            self.spdims, self.block,
 | 
			
		||||
            lut, maxlut
 | 
			
		||||
    def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
 | 
			
		||||
        if rel_logits is not None and rel_logits.dtype != a.dtype:
 | 
			
		||||
            raise ValueError("relative position embedding must be %s" % a.dtype)
 | 
			
		||||
        a = _softmax.apply(
 | 
			
		||||
            a, scale, rel_logits, is_causal,
 | 
			
		||||
            self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
 | 
			
		||||
        )
 | 
			
		||||
        return x
 | 
			
		||||
        return a
 | 
			
		||||
 
 | 
			
		||||
@@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block):
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
 | 
			
		||||
    if data is None:
 | 
			
		||||
        data = torch.randn(shape, dtype=torch.float32, device=device)
 | 
			
		||||
    ref_ret = data
 | 
			
		||||
    ref_ret = ref_ret * alpha + beta
 | 
			
		||||
    ref_ret = ref_ret.half().float()
 | 
			
		||||
    if trans:
 | 
			
		||||
        ref_ret = ref_ret.t().requires_grad_()
 | 
			
		||||
    ref_ret = ref_ret.detach().requires_grad_()
 | 
			
		||||
    tri_ret = ref_ret.clone().detach().requires_grad_()
 | 
			
		||||
    return ref_ret, tri_ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cutlass_matmul(a, b):
 | 
			
		||||
    if _cutlass is None:
 | 
			
		||||
        raise RuntimeError("Cannot find cutlass library")
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user