[OPS] Faster and cleaner block-sparse implementation (#311)
This commit is contained in:
		@@ -2500,9 +2500,8 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
 | 
			
		||||
void generator::visit_basic_block(ir::basic_block * block) {
 | 
			
		||||
  BasicBlock *parent = bbs_[block];
 | 
			
		||||
  builder_->SetInsertPoint(parent);
 | 
			
		||||
  for(ir::instruction *i: block->get_inst_list()){
 | 
			
		||||
  for(ir::instruction *i: block->get_inst_list())
 | 
			
		||||
    visit_value(i);
 | 
			
		||||
  }
 | 
			
		||||
  // Update ir bb -> llvm bb mapping
 | 
			
		||||
  bbs_[block] = builder_->GetInsertBlock();
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -117,7 +117,7 @@ void coalesce::run(ir::module &mod) {
 | 
			
		||||
            queue.push_back(inst_op);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if(in_contig.empty() || out_contig==in_contig)
 | 
			
		||||
      if(in_contig.size() <= 1 || out_contig==in_contig)
 | 
			
		||||
        continue;
 | 
			
		||||
      builder.set_insert_point_after(val_inst);
 | 
			
		||||
      auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
 | 
			
		||||
 
 | 
			
		||||
@@ -188,7 +188,7 @@ void pipeline::run(ir::module &mod) {
 | 
			
		||||
      builder.set_insert_point(block->get_first_non_phi());
 | 
			
		||||
      std::map<ir::phi_node*, ir::value*> load_ivs;
 | 
			
		||||
      std::map<ir::phi_node*, ir::value*> next_load_ivs;
 | 
			
		||||
      for (ir::phi_node* iv : induction_vars) {
 | 
			
		||||
      for (auto& [iv, val] : prev_phi_vals) {
 | 
			
		||||
        ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
 | 
			
		||||
        pn->add_incoming(prev_phi_vals[iv], header);
 | 
			
		||||
        load_ivs[iv] = pn;
 | 
			
		||||
@@ -198,7 +198,8 @@ void pipeline::run(ir::module &mod) {
 | 
			
		||||
        
 | 
			
		||||
      // pre-fetch next iteration
 | 
			
		||||
      builder.set_insert_point(block->get_inst_list().back());
 | 
			
		||||
      ir::value* next_ptr = ptr->get_value_for_block(block);
 | 
			
		||||
//      ir::value* next_ptr = ptr->get_value_for_block(block);
 | 
			
		||||
      ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
 | 
			
		||||
      ir::value* next_mask = builder.create_splat(
 | 
			
		||||
          rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
 | 
			
		||||
      if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
 | 
			
		||||
 
 | 
			
		||||
@@ -11,20 +11,20 @@ square_confs = [
 | 
			
		||||
              x_names = ['M', 'N', 'K'],
 | 
			
		||||
              x_vals  = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
 | 
			
		||||
              line_arg  = 'block',
 | 
			
		||||
              line_vals  = [16, 32, 64],
 | 
			
		||||
              line_names = ['Block16', 'Block32', 'Block64'],
 | 
			
		||||
              line_vals  = [16, 32, 64, 128],
 | 
			
		||||
              line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
 | 
			
		||||
              ylabel  = 'TFLOPS',
 | 
			
		||||
              plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
 | 
			
		||||
              args = {'layout_mode': layout_mode, 'op_mode': op_mode,
 | 
			
		||||
                      'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
 | 
			
		||||
    )\
 | 
			
		||||
    for AT in [False] for BT in [False] \
 | 
			
		||||
    for op_mode in ['sdd', 'dsd', 'dds']  for layout_mode in ['tril', 'dense']
 | 
			
		||||
    for op_mode in ['dsd']  for layout_mode in ['dense']
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.testing.perf_report(square_confs)
 | 
			
		||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5):
 | 
			
		||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
 | 
			
		||||
    Z, H = 1, 1
 | 
			
		||||
    make_layout = {
 | 
			
		||||
        'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
 | 
			
		||||
@@ -86,3 +86,6 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
 | 
			
		||||
        gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
 | 
			
		||||
        mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
 | 
			
		||||
        return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bench_matmul.run(print_data=True, show_plots=True)
 | 
			
		||||
@@ -7,7 +7,7 @@ import pytest
 | 
			
		||||
@pytest.mark.parametrize("TRANS_A", [False, True])
 | 
			
		||||
@pytest.mark.parametrize("TRANS_B", [False, True])
 | 
			
		||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
 | 
			
		||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
 | 
			
		||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
 | 
			
		||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
 | 
			
		||||
    # set seed
 | 
			
		||||
    torch.random.manual_seed(0)
 | 
			
		||||
@@ -41,15 +41,20 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
 | 
			
		||||
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
 | 
			
		||||
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
 | 
			
		||||
def test_softmax(BLOCK, WIDTH, DTYPE):
 | 
			
		||||
    is_causal = True
 | 
			
		||||
    # set seed
 | 
			
		||||
    torch.random.manual_seed(0)
 | 
			
		||||
    Z, H, M, N = 2, 4, WIDTH, WIDTH
 | 
			
		||||
    Z, H, M, N = 1, 1, WIDTH, WIDTH
 | 
			
		||||
    scale = 0.4
 | 
			
		||||
    # create inputs
 | 
			
		||||
    layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
 | 
			
		||||
    x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
 | 
			
		||||
    at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
 | 
			
		||||
    # make sure each row has at least one non-zero element
 | 
			
		||||
    torch.diagonal(layout)[:] = 1
 | 
			
		||||
    torch.diagonal(at_mask)[:] = 1
 | 
			
		||||
    kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
 | 
			
		||||
    kp_mask[:] = 0
 | 
			
		||||
    kp_mask[kp_mask == 1.0] = float("-inf")
 | 
			
		||||
    # triton result
 | 
			
		||||
    op = triton.ops.blocksparse.softmax(layout, BLOCK)
 | 
			
		||||
@@ -61,16 +66,15 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
 | 
			
		||||
        key_padding_mask_mode="add",
 | 
			
		||||
        attn_mask=at_mask.to(DTYPE),
 | 
			
		||||
        attn_mask_mode="mul",
 | 
			
		||||
        is_causal=is_causal,
 | 
			
		||||
    )
 | 
			
		||||
    # torch result
 | 
			
		||||
    rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
 | 
			
		||||
    if at_mask is not None:
 | 
			
		||||
    # broadcast at_mask to the same shape as rx
 | 
			
		||||
    if is_causal: at_mask = torch.tril(at_mask)
 | 
			
		||||
    M = at_mask[None, None, :, :] + torch.zeros_like(rx)
 | 
			
		||||
    rx[M == 0] = float("-inf")
 | 
			
		||||
    if kp_mask is not None:
 | 
			
		||||
        rx += kp_mask[:, None, None, :]
 | 
			
		||||
    ry = torch.softmax(rx * scale, -1)
 | 
			
		||||
    # rx += kp_mask[:, None, None, :]
 | 
			
		||||
    ry = torch.softmax(rx * scale, -1)
 | 
			
		||||
    ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
 | 
			
		||||
    # compare
 | 
			
		||||
 
 | 
			
		||||
@@ -699,3 +699,35 @@ def ravel(x):
 | 
			
		||||
    :type x: Block
 | 
			
		||||
    """
 | 
			
		||||
    return triton.language.reshape(x, [x.type.numel])
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def swizzle2d(i, j, size_i, size_j, size_g):
 | 
			
		||||
    """
 | 
			
		||||
    transformes indices of a row-major size_i*size_j matrix into those
 | 
			
		||||
    of one where indices are row major for each group of size_j rows.
 | 
			
		||||
    For example, for size_i = size_j = 4 and size_g = 2, it will transform
 | 
			
		||||
    [[0 , 1 , 2 , 3 ], 
 | 
			
		||||
     [4 , 5 , 6 , 7 ],
 | 
			
		||||
     [8 , 9 , 10, 11],
 | 
			
		||||
     [12, 13, 14, 15]]
 | 
			
		||||
    into
 | 
			
		||||
    [[0, 2,  4 , 6 ],
 | 
			
		||||
     [1, 3,  5 , 7 ],
 | 
			
		||||
     [8, 10, 12, 14],
 | 
			
		||||
     [9, 11, 13, 15]]
 | 
			
		||||
    """
 | 
			
		||||
    # "unrolled index in array"
 | 
			
		||||
    ij = i*size_j + j
 | 
			
		||||
    # number of elements in `size_g` groups
 | 
			
		||||
    # of `size_j` columns
 | 
			
		||||
    size_gj = size_g * size_j
 | 
			
		||||
    # index of the group in which (i,j) is
 | 
			
		||||
    group_id = ij // size_gj 
 | 
			
		||||
    # row-index of the first element of this group
 | 
			
		||||
    off_i = group_id * size_g
 | 
			
		||||
    # last group may have fewer rows
 | 
			
		||||
    size_g = minimum(size_i - off_i, size_g) 
 | 
			
		||||
    # new row and column indices
 | 
			
		||||
    new_i = off_i + (ij % size_g)
 | 
			
		||||
    new_j = (ij % size_gj) // size_g
 | 
			
		||||
    return new_i, new_j
 | 
			
		||||
@@ -3,260 +3,134 @@ import triton.language as tl
 | 
			
		||||
import triton._C.libtriton as libtriton
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# ********************************************************
 | 
			
		||||
# --------------------------------------------------------
 | 
			
		||||
# Sparse = Dense x Dense (SDD)
 | 
			
		||||
# This operation uses super-blocking to make sure that
 | 
			
		||||
# it's done efficiently when small blocks can be grouped
 | 
			
		||||
# together
 | 
			
		||||
# --------------------------------------------------------
 | 
			
		||||
# ********************************************************
 | 
			
		||||
 | 
			
		||||
@triton.heuristics({
 | 
			
		||||
    'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0,
 | 
			
		||||
})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _kernel(
 | 
			
		||||
    A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc,
 | 
			
		||||
    stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta
 | 
			
		||||
def _sdd_kernel(
 | 
			
		||||
    A, B, C, 
 | 
			
		||||
    stride_za, stride_ha, stride_ma, stride_ak, 
 | 
			
		||||
    stride_zb, stride_hb, stride_bk, stride_nb, 
 | 
			
		||||
    stride_zc, stride_hc, stride_mc, stride_nc, 
 | 
			
		||||
    K, grid_offset, lut, **meta
 | 
			
		||||
):
 | 
			
		||||
    TM = meta['TM']
 | 
			
		||||
    TN = meta['TN']
 | 
			
		||||
    TK = meta['TK']
 | 
			
		||||
    TZ = meta['TZ']
 | 
			
		||||
    TILE_M = meta['TILE_M']
 | 
			
		||||
    TILE_N = meta['TILE_N']
 | 
			
		||||
    TILE_K = meta['TILE_K']
 | 
			
		||||
    BLOCK  = meta['BLOCK']
 | 
			
		||||
    #------------#
 | 
			
		||||
    #- Prologue -#
 | 
			
		||||
    #------------#
 | 
			
		||||
    pid0 = tl.program_id(0)
 | 
			
		||||
    pid1 = tl.program_id(1)
 | 
			
		||||
    pidz = tl.program_id(2)
 | 
			
		||||
    if meta['SDD']:
 | 
			
		||||
        pid1 = pid1 + SDD_off_width
 | 
			
		||||
        blockidm = tl.arange(0, TM) // BLOCK
 | 
			
		||||
        blockidn = tl.arange(0, TN) // BLOCK
 | 
			
		||||
        offlutm = blockidm * (TN // BLOCK) * 4
 | 
			
		||||
    pid1 = tl.program_id(1) + grid_offset
 | 
			
		||||
    blockidm = tl.arange(0, TILE_M) // BLOCK
 | 
			
		||||
    blockidn = tl.arange(0, TILE_N) // BLOCK
 | 
			
		||||
    offlutm = blockidm * (TILE_N // BLOCK) * 4
 | 
			
		||||
    offlutn = blockidn * 4
 | 
			
		||||
        header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
 | 
			
		||||
        z = tl.load(header + 0)
 | 
			
		||||
        i = tl.load(header + 1 + offlutm)
 | 
			
		||||
        j = tl.load(header + 2 + offlutn)
 | 
			
		||||
        AS1 = SDD_K
 | 
			
		||||
        lockid = tl.where(TZ > 1, 1, 0)
 | 
			
		||||
        offka = pid0 * AS1
 | 
			
		||||
        offkb = pid0 * AS1
 | 
			
		||||
        offmc = 0
 | 
			
		||||
        offnc = 0
 | 
			
		||||
        offpa = 0
 | 
			
		||||
        offpb = 0
 | 
			
		||||
        maxid = TZ
 | 
			
		||||
        offhc = 0
 | 
			
		||||
        offha = z
 | 
			
		||||
        offhb = z
 | 
			
		||||
        ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
 | 
			
		||||
        rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
 | 
			
		||||
    else:
 | 
			
		||||
        header = lut + pid0 * 6
 | 
			
		||||
        offset = tl.load(header + 0)
 | 
			
		||||
        AS1 = tl.load(header + 1)
 | 
			
		||||
        column = tl.load(header + 2)
 | 
			
		||||
        depth = tl.load(header + 3)
 | 
			
		||||
        lockid = tl.load(header + 4)
 | 
			
		||||
        maxid = tl.load(header + 5)
 | 
			
		||||
        pinc = lut + offset
 | 
			
		||||
        offhc = depth
 | 
			
		||||
        if meta['DSD']:
 | 
			
		||||
            # output offset
 | 
			
		||||
            offnc = pid1 * TN
 | 
			
		||||
            offmc = column * TM
 | 
			
		||||
            offpc = 0
 | 
			
		||||
            # dense input offset
 | 
			
		||||
            offnb = pid1 * TN
 | 
			
		||||
            offkb = tl.load(pinc)
 | 
			
		||||
            offkb = tl.multiple_of(offkb, 8)  # compiler hint
 | 
			
		||||
            offpb = 0
 | 
			
		||||
            # sparse input offset
 | 
			
		||||
            offma = 0
 | 
			
		||||
            offka = 0
 | 
			
		||||
            offpa = tl.load(pinc + 1)
 | 
			
		||||
            offpa = tl.multiple_of(offpa, 8)  # compiler hint
 | 
			
		||||
            offpa = offpa * BLOCK * BLOCK
 | 
			
		||||
            offha = 0
 | 
			
		||||
            offhb = depth
 | 
			
		||||
        else:
 | 
			
		||||
            # output offset
 | 
			
		||||
            offmc = pid1 * TM
 | 
			
		||||
            offnc = column * TN
 | 
			
		||||
            offpc = 0
 | 
			
		||||
            # dense input offset
 | 
			
		||||
            offma = pid1 * TM
 | 
			
		||||
            offka = tl.load(pinc)
 | 
			
		||||
            offka = tl.multiple_of(offka, 8)  # compiler hint
 | 
			
		||||
            offpa = 0
 | 
			
		||||
            # sparse input offset
 | 
			
		||||
            offnb = 0
 | 
			
		||||
            offkb = 0
 | 
			
		||||
            offpb = tl.load(pinc + 1)
 | 
			
		||||
            offpb = tl.multiple_of(offpb, 8)  # compiler hint
 | 
			
		||||
            offpb = offpb * BLOCK * BLOCK
 | 
			
		||||
            offha = depth
 | 
			
		||||
            offhb = 0
 | 
			
		||||
        ram = offma + tl.arange(0, TM)
 | 
			
		||||
        rbn = offnb + tl.arange(0, TN)
 | 
			
		||||
 | 
			
		||||
    # initialize a, b pointers
 | 
			
		||||
    rka = offka + tl.arange(0, TK)
 | 
			
		||||
    rkb = offkb + tl.arange(0, TK)
 | 
			
		||||
    pa = A + pidz * TZ * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
 | 
			
		||||
    pb = B + pidz * TZ * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
 | 
			
		||||
    if meta['DDS']:
 | 
			
		||||
        checkam = ram[:, None] < DS0
 | 
			
		||||
    else:
 | 
			
		||||
        checkam = AS1 > 0
 | 
			
		||||
    if meta['DSD']:
 | 
			
		||||
        checkbn = rbn[None, :] < DS0
 | 
			
		||||
    else:
 | 
			
		||||
        checkbn = AS1 > 0
 | 
			
		||||
    a = tl.load(pa, mask=checkam, other=0.)
 | 
			
		||||
    b = tl.load(pb, mask=checkbn, other=0.)
 | 
			
		||||
 | 
			
		||||
    header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4
 | 
			
		||||
    # batch offset
 | 
			
		||||
    off_z = tl.program_id(2)
 | 
			
		||||
    # head offset
 | 
			
		||||
    off_h = tl.load(header + 0)
 | 
			
		||||
    # initialize pointers to A
 | 
			
		||||
    start_am = tl.load(header + 1 + offlutm)
 | 
			
		||||
    offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
 | 
			
		||||
    offs_ak = tl.arange(0, TILE_K)
 | 
			
		||||
    a_ptrs = A + off_z * stride_za \
 | 
			
		||||
                + off_h * stride_ha \
 | 
			
		||||
                + offs_am[:, None] * stride_ma \
 | 
			
		||||
                + offs_ak[None, :] * stride_ak
 | 
			
		||||
    # initialize pointers to B
 | 
			
		||||
    start_bn = tl.load(header + 2 + offlutn)
 | 
			
		||||
    offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
 | 
			
		||||
    offs_bk = tl.arange(0, TILE_K)
 | 
			
		||||
    b_ptrs = B + off_z * stride_zb \
 | 
			
		||||
                + off_h * stride_hb \
 | 
			
		||||
                + offs_bn[None, :] * stride_nb \
 | 
			
		||||
                + offs_bk[:, None] * stride_bk
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    ##    Inner Loop    ##
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    acc = tl.zeros((TM, TN), dtype=tl.float32)
 | 
			
		||||
    for k in range(AS1, 0, -TK*TZ):
 | 
			
		||||
    acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
 | 
			
		||||
    for k in range(K, 0, -TILE_K):
 | 
			
		||||
        if meta['EVEN_K']:
 | 
			
		||||
            a = tl.load(a_ptrs)
 | 
			
		||||
            b = tl.load(b_ptrs)
 | 
			
		||||
        else:
 | 
			
		||||
            a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
 | 
			
		||||
            b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
 | 
			
		||||
        a = tl.load(a_ptrs)
 | 
			
		||||
        b = tl.load(b_ptrs)
 | 
			
		||||
        acc += tl.dot(a, b)
 | 
			
		||||
        if meta['SDD']:
 | 
			
		||||
            inc_a = TK * TZ * stride_ka
 | 
			
		||||
            inc_b = TK * TZ * stride_kb
 | 
			
		||||
        else:
 | 
			
		||||
            pinc += 2
 | 
			
		||||
        if meta['DSD']:
 | 
			
		||||
            inc_b = tl.load(pinc)
 | 
			
		||||
            inc_a = tl.load(pinc + 1)
 | 
			
		||||
            inc_b = tl.multiple_of(inc_b, 8)
 | 
			
		||||
            inc_a = tl.multiple_of(inc_a, 8)
 | 
			
		||||
            inc_b = inc_b * stride_kb
 | 
			
		||||
        if meta['DDS']:
 | 
			
		||||
            inc_a = tl.load(pinc)
 | 
			
		||||
            inc_b = tl.load(pinc + 1)
 | 
			
		||||
            inc_a = tl.multiple_of(inc_a, 8)
 | 
			
		||||
            inc_b = tl.multiple_of(inc_b, 8)
 | 
			
		||||
            inc_a = inc_a * stride_ka
 | 
			
		||||
        pa += inc_a
 | 
			
		||||
        pb += inc_b
 | 
			
		||||
        # pre-fetch
 | 
			
		||||
        checkak = k > TK
 | 
			
		||||
        checkbk = k > TK
 | 
			
		||||
        checka = checkam & checkak
 | 
			
		||||
        checkb = checkbn & checkbk
 | 
			
		||||
        a = tl.load(pa, mask=checka)
 | 
			
		||||
        b = tl.load(pb, mask=checkb)
 | 
			
		||||
        a_ptrs += TILE_K * stride_ak
 | 
			
		||||
        b_ptrs += TILE_K * stride_bk
 | 
			
		||||
    c = acc.to(C.dtype.element_ty)
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    ##    Epilogue      ##
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    blockidm = tl.arange(0, TILE_M) // BLOCK
 | 
			
		||||
    blockidn = tl.arange(0, TILE_N) // BLOCK
 | 
			
		||||
    offlutm = blockidm * (TILE_N // BLOCK) * 4
 | 
			
		||||
    offlutn = blockidn * 4
 | 
			
		||||
    off_block_id = 3 + offlutm[:, None] + offlutn[None, :]
 | 
			
		||||
    block_id = tl.load(header + off_block_id)
 | 
			
		||||
    # initialize pointers to C
 | 
			
		||||
    offs_cm = tl.arange(0, TILE_M) % BLOCK
 | 
			
		||||
    offs_cn = tl.arange(0, TILE_N) % BLOCK
 | 
			
		||||
    pc = C + off_z * stride_zc \
 | 
			
		||||
            + block_id * stride_hc \
 | 
			
		||||
            + offs_cm[:, None] * stride_mc \
 | 
			
		||||
            + offs_cn[None, :] * stride_nc
 | 
			
		||||
    tl.store(pc, c, mask=True)
 | 
			
		||||
 | 
			
		||||
    if meta['SDD']:
 | 
			
		||||
        checkc = True
 | 
			
		||||
        rr_blockidm = tl.arange(0, TM) // BLOCK
 | 
			
		||||
        rr_blockidn = tl.arange(0, TN) // BLOCK
 | 
			
		||||
        rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
 | 
			
		||||
        rr_offlutn = rr_blockidn * 4
 | 
			
		||||
        off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
 | 
			
		||||
        bkid = tl.load(header + off_bkid)
 | 
			
		||||
        offpc = bkid * BLOCK * BLOCK
 | 
			
		||||
        rcm = tl.arange(0, TM) % BLOCK
 | 
			
		||||
        rcn = tl.arange(0, TN) % BLOCK
 | 
			
		||||
    else:
 | 
			
		||||
        rcm = offmc + tl.arange(0, TM)
 | 
			
		||||
        rcn = offnc + tl.arange(0, TN)
 | 
			
		||||
    if meta['DSD']:
 | 
			
		||||
        checkc = rcn[None, :] < DS0
 | 
			
		||||
    if meta['DDS']:
 | 
			
		||||
        checkc = rcm[:, None] < DS0
 | 
			
		||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
 | 
			
		||||
    # (A * B)^T = B^T * A^T
 | 
			
		||||
    if trans_c:
 | 
			
		||||
        a, b = b, a
 | 
			
		||||
        trans_a, trans_b = not trans_b, not trans_a
 | 
			
		||||
    # shape constraints
 | 
			
		||||
    a_dim = -2 if trans_a else -1
 | 
			
		||||
    b_dim = -1 if trans_b else -2
 | 
			
		||||
    Ka, Kb = a.shape[a_dim], b.shape[b_dim]
 | 
			
		||||
    if Ka != Kb:
 | 
			
		||||
        raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
 | 
			
		||||
    if Ka % 16 != 0:
 | 
			
		||||
        raise ValueError('Reduction size for SDD must be a multiple of 16')
 | 
			
		||||
    # allocate output
 | 
			
		||||
    n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)])
 | 
			
		||||
    c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device)
 | 
			
		||||
    # each iteration of the loop below
 | 
			
		||||
    # computes the value for one group of super-blocks
 | 
			
		||||
    # (e.g., all 4x4 super-blocks)
 | 
			
		||||
    for lut, width, pack in zip(luts, widths, packs):
 | 
			
		||||
        # maximum grid size in Triton/CUDA is 64k but we may have more
 | 
			
		||||
        # super-blocks than that.
 | 
			
		||||
        max_grid = 65535
 | 
			
		||||
        for off_grid in range(0, width, max_grid):
 | 
			
		||||
            grid = [1, min(max_grid, width - off_grid), c.shape[0]]
 | 
			
		||||
            # fmt: off
 | 
			
		||||
            pgm = _sdd_kernel[grid](
 | 
			
		||||
                a, b, c,
 | 
			
		||||
                a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
 | 
			
		||||
                b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
 | 
			
		||||
                c.stride(0), c.stride(1), c.stride(2), c.stride(3),
 | 
			
		||||
                Ka, off_grid, lut,
 | 
			
		||||
                TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3,
 | 
			
		||||
                num_warps=4,
 | 
			
		||||
            )
 | 
			
		||||
            # print(pgm.asm['ptx'])
 | 
			
		||||
            # exit()
 | 
			
		||||
    return c
 | 
			
		||||
 | 
			
		||||
    pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
 | 
			
		||||
    # write-back directly
 | 
			
		||||
    if lockid == 0:
 | 
			
		||||
        tl.store(pc, c, mask=checkc)
 | 
			
		||||
    # accumulate partial results using spin-locks
 | 
			
		||||
    else:
 | 
			
		||||
        plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1
 | 
			
		||||
        pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
 | 
			
		||||
        while tl.atomic_cas(plock, 0, 1) == 1:
 | 
			
		||||
            pass
 | 
			
		||||
        count = tl.load(pcount)
 | 
			
		||||
        if count == 0:
 | 
			
		||||
            tl.store(pc, c, mask=checkc)
 | 
			
		||||
        else:
 | 
			
		||||
            d = tl.load(pc, mask=checkc)
 | 
			
		||||
            tl.store(pc, d + c, mask=checkc)
 | 
			
		||||
        tl.atomic_xchg(pcount, (count + 1) % maxid)
 | 
			
		||||
        tl.atomic_xchg(plock, 0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
##############
 | 
			
		||||
#  MAIN API  #
 | 
			
		||||
##############
 | 
			
		||||
class _matmul(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    sdd_cache = dict()
 | 
			
		||||
    dsd_cache = dict()
 | 
			
		||||
    dds_cache = dict()
 | 
			
		||||
    locks = dict()
 | 
			
		||||
 | 
			
		||||
    # Given an array sizes representing reduction size for each
 | 
			
		||||
    # column of a block-mode matrix multiplication,
 | 
			
		||||
    # performs load-balancing to achieve more smaller reductions
 | 
			
		||||
    # between `seg_size` elements
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load_balance(sizes):
 | 
			
		||||
        # segment size
 | 
			
		||||
        # heuristics taken from OpenAI blocksparse code
 | 
			
		||||
        # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
 | 
			
		||||
        max_size = sizes.max()
 | 
			
		||||
        min_size = sizes[sizes != 0].min()
 | 
			
		||||
        #if max_size > min_size * 2.0:
 | 
			
		||||
        #  seg_max = max(triton.cdiv(max_size, 4), min_size*2)
 | 
			
		||||
        #else:
 | 
			
		||||
        #  seg_max = max_size
 | 
			
		||||
        seg_max = max_size
 | 
			
		||||
        seg_min = max(triton.cdiv(seg_max, 4), 4)
 | 
			
		||||
        # split reduction into segments
 | 
			
		||||
        div = sizes // seg_max
 | 
			
		||||
        rem = sizes % seg_max
 | 
			
		||||
        packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
 | 
			
		||||
        width = packs.sum()
 | 
			
		||||
        segments = torch.empty(width, dtype=sizes.dtype)
 | 
			
		||||
        column = torch.empty_like(segments)
 | 
			
		||||
        lockid = torch.zeros_like(segments)
 | 
			
		||||
        maxid = torch.zeros_like(segments)
 | 
			
		||||
        nlocks = 0
 | 
			
		||||
        current = 0
 | 
			
		||||
        col_idx = 0
 | 
			
		||||
        for i in range(len(sizes)):
 | 
			
		||||
            d, r = div[i], rem[i]
 | 
			
		||||
            isempty = sizes[i] < seg_min
 | 
			
		||||
            last = current + d + (r >= seg_min) + isempty
 | 
			
		||||
            # column id
 | 
			
		||||
            column[current:last] = col_idx
 | 
			
		||||
            # lock id
 | 
			
		||||
            if d > 1 or (d == 1 and r >= seg_min):
 | 
			
		||||
                nlocks += 1
 | 
			
		||||
                lockid[current:last] = nlocks
 | 
			
		||||
                maxid[current:last] = last - current
 | 
			
		||||
            # segment size
 | 
			
		||||
            segments[current:current + d] = seg_max
 | 
			
		||||
            if r < seg_min and not isempty:
 | 
			
		||||
                segments[current + d - 1] += r
 | 
			
		||||
            if r >= seg_min or isempty:
 | 
			
		||||
                segments[current + d] = r
 | 
			
		||||
            current = last
 | 
			
		||||
            col_idx += 1
 | 
			
		||||
        offsets = torch.zeros_like(segments)
 | 
			
		||||
        offsets[1:] = torch.cumsum(segments[:-1], dim=0)
 | 
			
		||||
        return segments, column, lockid, maxid, offsets
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_locks(size, dev):
 | 
			
		||||
        if dev not in _matmul.locks or \
 | 
			
		||||
            size > _matmul.locks[dev].size(0):
 | 
			
		||||
            _matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
 | 
			
		||||
        return _matmul.locks[dev]
 | 
			
		||||
 | 
			
		||||
    ##########################
 | 
			
		||||
    # SPARSE = DENSE x DENSE #
 | 
			
		||||
    ##########################
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def make_sdd_lut(layout, block, device):
 | 
			
		||||
def sdd_lut(layout, block, device):
 | 
			
		||||
    start_width = 128 // block
 | 
			
		||||
    layout = layout.type(torch.int32)
 | 
			
		||||
    superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
 | 
			
		||||
@@ -267,266 +141,301 @@ class _matmul(torch.autograd.Function):
 | 
			
		||||
        luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
 | 
			
		||||
        widths.append(width)
 | 
			
		||||
        packs.append(size)
 | 
			
		||||
 | 
			
		||||
        # create locks
 | 
			
		||||
    return luts, None, widths, packs
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
 | 
			
		||||
        # (A * B)^T = (B^T * A^T)
 | 
			
		||||
        if trans_c:
 | 
			
		||||
            a, b = b, a
 | 
			
		||||
            trans_a, trans_b = not trans_b, not trans_a
 | 
			
		||||
# -----------------------------
 | 
			
		||||
# Dense = Sparse x Dense (DSD)
 | 
			
		||||
# This operation uses a look-up table that contains pre-computed pointer increments
 | 
			
		||||
# in order to minimize computations in the inner loop of the matmul kernel.
 | 
			
		||||
# -----------------------------
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _dsd_kernel(
 | 
			
		||||
    A, B, C, 
 | 
			
		||||
    stride_az, stride_ha, stride_am, stride_ak, 
 | 
			
		||||
    stride_zb, stride_hb, stride_bk, stride_bn, 
 | 
			
		||||
    stride_zc, stride_hc, stride_cm, stride_cn, 
 | 
			
		||||
    DS0, DS1, lut, **meta
 | 
			
		||||
):
 | 
			
		||||
    TILE_M = meta['TILE_M']
 | 
			
		||||
    TILE_N = meta['TILE_N']
 | 
			
		||||
    TILE_K = meta['TILE_K']
 | 
			
		||||
    GROUP_SIZE_M = meta['GROUP_SIZE_M']
 | 
			
		||||
    #------------#
 | 
			
		||||
    #- Prologue -#
 | 
			
		||||
    #------------#
 | 
			
		||||
    pid_m   = tl.program_id(0)
 | 
			
		||||
    pid_n   = tl.program_id(1)
 | 
			
		||||
    num_pid_m = tl.num_programs(0)
 | 
			
		||||
    num_pid_n = tl.num_programs(1)
 | 
			
		||||
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
 | 
			
		||||
    pidz   = tl.program_id(2)
 | 
			
		||||
    header = lut + pid_m * 4
 | 
			
		||||
    offset = tl.load(header + 0)
 | 
			
		||||
    K      = tl.load(header + 1)
 | 
			
		||||
    column = tl.load(header + 2)
 | 
			
		||||
    off_h  = tl.load(header + 3)
 | 
			
		||||
    pinc   = lut + offset
 | 
			
		||||
    # initialize pointers to A (sparse)
 | 
			
		||||
    block_id   = tl.load(pinc + 1)
 | 
			
		||||
    block_id   = tl.multiple_of(block_id, 8)  # compiler hint
 | 
			
		||||
    offs_am = tl.arange(0, TILE_M)
 | 
			
		||||
    offs_ak = tl.arange(0, TILE_K)
 | 
			
		||||
    pa = A + pidz * stride_az \
 | 
			
		||||
            + block_id * stride_ha \
 | 
			
		||||
            + offs_am[:, None] * stride_am \
 | 
			
		||||
            + offs_ak[None, :] * stride_ak
 | 
			
		||||
    # initialize pointers to B (dense)
 | 
			
		||||
    offs_bn  = pid_n*TILE_N + tl.arange(0, TILE_N)
 | 
			
		||||
    start_bk = tl.load(pinc)
 | 
			
		||||
    start_bk = tl.multiple_of(start_bk, 8)  # compiler hint
 | 
			
		||||
    offs_bk  = start_bk + tl.arange(0, TILE_K)
 | 
			
		||||
    pb = B + pidz * stride_zb \
 | 
			
		||||
            + off_h * stride_hb \
 | 
			
		||||
            + offs_bn[None, :] * stride_bn \
 | 
			
		||||
            + offs_bk[:, None] * stride_bk
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    ##    Inner Loop    ##
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
 | 
			
		||||
    for k in range(K, 0, -TILE_K):
 | 
			
		||||
        a = tl.load(pa, mask=True)
 | 
			
		||||
        b = tl.load(pb, mask=offs_bn[None, :] < DS0)
 | 
			
		||||
        acc += tl.dot(a, b)
 | 
			
		||||
        pinc += 2
 | 
			
		||||
        inc_a = tl.load(pinc + 1)
 | 
			
		||||
        inc_a = tl.multiple_of(inc_a, 8)
 | 
			
		||||
        inc_b = tl.load(pinc)
 | 
			
		||||
        inc_b = tl.multiple_of(inc_b, 8)
 | 
			
		||||
        pa += inc_a
 | 
			
		||||
        pb += inc_b*stride_bk
 | 
			
		||||
    c = acc.to(C.dtype.element_ty)
 | 
			
		||||
    # initialize pointers to C
 | 
			
		||||
    offs_cm = column*TILE_M + tl.arange(0, TILE_M)
 | 
			
		||||
    offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N)
 | 
			
		||||
    pc = C + off_h * stride_hc \
 | 
			
		||||
            + pidz * stride_zc \
 | 
			
		||||
            + offs_cm[:, None] * stride_cm \
 | 
			
		||||
            + offs_cn[None, :] * stride_cn
 | 
			
		||||
    tl.store(pc, c, mask = offs_cn[None, :] < DS0)
 | 
			
		||||
 | 
			
		||||
        # Shape check
 | 
			
		||||
        a_dim = -2 if trans_a else -1
 | 
			
		||||
        b_dim = -1 if trans_b else -2
 | 
			
		||||
        a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
 | 
			
		||||
        if a_inner != b_inner:
 | 
			
		||||
            raise ValueError(f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size "
 | 
			
		||||
                             f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})")
 | 
			
		||||
        if a_inner % 16 != 0:
 | 
			
		||||
            raise ValueError('Reduction size for SDD must be a multiple of 16')
 | 
			
		||||
 | 
			
		||||
        batch_size = a.size(0)
 | 
			
		||||
        a_outer = a.size(3 if trans_a else 2)
 | 
			
		||||
        dtype = a.dtype
 | 
			
		||||
        device = a.device
 | 
			
		||||
 | 
			
		||||
        # create kernel
 | 
			
		||||
        total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
 | 
			
		||||
        c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device)
 | 
			
		||||
        for lut, width, pack in zip(luts, widths, packs):
 | 
			
		||||
            num_lock = 1
 | 
			
		||||
            meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1,
 | 
			
		||||
                    'SDD': True, 'DSD': False, 'DDS': False}
 | 
			
		||||
            # create output
 | 
			
		||||
            locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device)
 | 
			
		||||
            # maximum grid size is 65535
 | 
			
		||||
            # so operation might be decomposed into multiple
 | 
			
		||||
            # kernel calls
 | 
			
		||||
            max_width = 49152
 | 
			
		||||
            for off_width in range(0, width, max_width):
 | 
			
		||||
                grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]
 | 
			
		||||
                _kernel[grid](
 | 
			
		||||
                    a,
 | 
			
		||||
                    b,
 | 
			
		||||
                    c,
 | 
			
		||||
                    a.stride(0),
 | 
			
		||||
                    a.stride(1),
 | 
			
		||||
                    a.stride(3 if trans_a else 2),
 | 
			
		||||
                    a.stride(2 if trans_a else 3),
 | 
			
		||||
                    b.stride(0),
 | 
			
		||||
                    b.stride(1),
 | 
			
		||||
                    b.stride(3 if trans_b else 2),
 | 
			
		||||
                    b.stride(2 if trans_b else 3),
 | 
			
		||||
                    c.stride(0),
 | 
			
		||||
                    c.stride(0),
 | 
			
		||||
                    c.stride(2),
 | 
			
		||||
                    c.stride(3),
 | 
			
		||||
                    a_outer,
 | 
			
		||||
                    a_outer,
 | 
			
		||||
                    a_inner,
 | 
			
		||||
                    off_width,
 | 
			
		||||
                    lut,
 | 
			
		||||
                    locks,
 | 
			
		||||
                    num_lock,
 | 
			
		||||
                    num_warps=4,
 | 
			
		||||
                    **meta
 | 
			
		||||
                )
 | 
			
		||||
        # save for backward pass
 | 
			
		||||
        return c
 | 
			
		||||
 | 
			
		||||
    ##########################
 | 
			
		||||
    # DENSE = DENSE x SPARSE #
 | 
			
		||||
    # DENSE = SPARSE x DENSE #
 | 
			
		||||
    ##########################
 | 
			
		||||
 | 
			
		||||
    # Given a binary layout of 0s and 1s,
 | 
			
		||||
    # Construct look-up table for efficient execution on GPUs
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
 | 
			
		||||
        # load-balancing
 | 
			
		||||
        _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
 | 
			
		||||
        segments = _empty.clone()
 | 
			
		||||
        column = _empty.clone()
 | 
			
		||||
        depth = _empty.clone()
 | 
			
		||||
        lockid = _empty.clone()
 | 
			
		||||
        maxid = _empty.clone()
 | 
			
		||||
        offsets = _empty.clone()
 | 
			
		||||
        current_offset = 0
 | 
			
		||||
        current_maxid = 0
 | 
			
		||||
        for z in range(layout.size(0)):
 | 
			
		||||
            if trans:
 | 
			
		||||
                sizes = torch.sum(layout[z, :, :], 1)
 | 
			
		||||
            else:
 | 
			
		||||
                sizes = torch.sum(layout[z, :, :], 0)
 | 
			
		||||
            z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes)
 | 
			
		||||
            z_depth = z * torch.ones_like(z_segments)
 | 
			
		||||
            z_lockid[z_lockid > 0] += current_maxid
 | 
			
		||||
            current_maxid = z_lockid.max()
 | 
			
		||||
            # concatenate depth
 | 
			
		||||
            segments = torch.cat((segments, z_segments))
 | 
			
		||||
            column = torch.cat((column, z_column))
 | 
			
		||||
            depth = torch.cat((depth, z_depth))
 | 
			
		||||
            maxid = torch.cat((maxid, z_maxid))
 | 
			
		||||
            offsets = torch.cat((offsets, current_offset + z_offsets))
 | 
			
		||||
            lockid = torch.cat((lockid, z_lockid))
 | 
			
		||||
            current_offset += layout[z, :, :].sum()
 | 
			
		||||
        segments *= step
 | 
			
		||||
        # pointer increments
 | 
			
		||||
        if trans:
 | 
			
		||||
            nnz = layout.nonzero(as_tuple=False)
 | 
			
		||||
        else:
 | 
			
		||||
            nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
 | 
			
		||||
        num_blocks = nnz.size(0)
 | 
			
		||||
        offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
 | 
			
		||||
        idx = transform(nnz[:, 2] * block)
 | 
			
		||||
        xincs = idx.clone()
 | 
			
		||||
        xincs[1:] -= idx[:-1]
 | 
			
		||||
        # divide block into multiple steps
 | 
			
		||||
        div = block // step
 | 
			
		||||
        xincs = xincs.view(-1, 1).repeat(1, div)
 | 
			
		||||
        xincs[:, 1:] = step
 | 
			
		||||
        xincs[:, 0] -= (div - 1) * step
 | 
			
		||||
        # first increment for each reduction is actually the offset
 | 
			
		||||
        xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
 | 
			
		||||
        xincs = xincs.view(-1)
 | 
			
		||||
        # block-mode input increments
 | 
			
		||||
        if trans:
 | 
			
		||||
            widx = torch.arange(num_blocks)
 | 
			
		||||
        else:
 | 
			
		||||
            widx = _empty.clone()
 | 
			
		||||
            current_offset = 0
 | 
			
		||||
            for z in range(layout.size(0)):
 | 
			
		||||
                layoutw = layout[z, :, :].clone()
 | 
			
		||||
                msum = layoutw.sum()
 | 
			
		||||
                layoutw[layoutw > 0] = 1 + torch.arange(msum)
 | 
			
		||||
                widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
 | 
			
		||||
                current_offset += msum
 | 
			
		||||
        widx = widx
 | 
			
		||||
        wincs = widx * block * block
 | 
			
		||||
        wincs[1:] -= widx[:-1] * block * block
 | 
			
		||||
        wincs = wincs.view(-1, 1).repeat(1, div)
 | 
			
		||||
        if trans:
 | 
			
		||||
            wincs[:, 1:] = step
 | 
			
		||||
            wincs[:, 0] -= (div - 1) * step
 | 
			
		||||
        else:
 | 
			
		||||
            wincs[:, 1:] = step * block
 | 
			
		||||
            wincs[:, 0] -= (div - 1) * step * block
 | 
			
		||||
        wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
 | 
			
		||||
        wincs = wincs.view(-1)
 | 
			
		||||
        # adjust offset and segment size
 | 
			
		||||
        offsets *= 2 * div
 | 
			
		||||
        segments *= div
 | 
			
		||||
        # create header
 | 
			
		||||
        width = column.size(0)
 | 
			
		||||
        offsets += 6 * width
 | 
			
		||||
        header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
 | 
			
		||||
        incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
 | 
			
		||||
        incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
 | 
			
		||||
        # create lut
 | 
			
		||||
        lut = torch.cat((header, incs))
 | 
			
		||||
        lut = lut.type(torch.int32).to(device)
 | 
			
		||||
        # create locks
 | 
			
		||||
        num_locks = max(1, lockid.max())
 | 
			
		||||
        return lut, num_locks, width, None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
 | 
			
		||||
        # shapes / dtypes
 | 
			
		||||
        AS0 = a.size(0)
 | 
			
		||||
        AS1 = a.size(1)
 | 
			
		||||
        AS2 = a.size(3 if trans_a else 2)
 | 
			
		||||
        BS2 = block * spdims[1 if trans_b else 2]
 | 
			
		||||
        dtype = a.dtype
 | 
			
		||||
        # kernel
 | 
			
		||||
        meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
 | 
			
		||||
                'SDD': False, 'DSD': False, 'DDS': True}
 | 
			
		||||
        # output
 | 
			
		||||
        CS0 = AS0
 | 
			
		||||
        CS1 = AS1
 | 
			
		||||
        CS2 = BS2 if trans_c else AS2
 | 
			
		||||
        CS3 = AS2 if trans_c else BS2
 | 
			
		||||
        locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
 | 
			
		||||
        c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
 | 
			
		||||
        grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
 | 
			
		||||
        _kernel[grid](
 | 
			
		||||
            a,
 | 
			
		||||
            b,
 | 
			
		||||
            c,
 | 
			
		||||
            a.stride(0),
 | 
			
		||||
            a.stride(1),
 | 
			
		||||
            a.stride(3 if trans_a else 2),
 | 
			
		||||
            a.stride(2 if trans_a else 3),
 | 
			
		||||
            b.stride(0),
 | 
			
		||||
            b.stride(1),
 | 
			
		||||
            b.stride(3 if trans_b else 2),
 | 
			
		||||
            b.stride(2 if trans_b else 3),
 | 
			
		||||
            c.stride(0),
 | 
			
		||||
            c.stride(1),
 | 
			
		||||
            c.stride(3 if trans_c else 2),
 | 
			
		||||
            c.stride(2 if trans_c else 3),
 | 
			
		||||
            AS2,
 | 
			
		||||
            BS2,
 | 
			
		||||
            0,
 | 
			
		||||
            0,
 | 
			
		||||
            lut,
 | 
			
		||||
            locks,
 | 
			
		||||
            num_locks,
 | 
			
		||||
            num_warps=4,
 | 
			
		||||
            **meta
 | 
			
		||||
        )
 | 
			
		||||
        return c
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
 | 
			
		||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
 | 
			
		||||
    # shapes / dtypes
 | 
			
		||||
    AS1 = block * spdims[2 if trans_a else 1]
 | 
			
		||||
    BS0 = b.size(0)
 | 
			
		||||
    BS1 = b.size(1)
 | 
			
		||||
    BS3 = b.size(2 if trans_b else 3)
 | 
			
		||||
    dtype = a.dtype
 | 
			
		||||
        # kernel
 | 
			
		||||
        meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
 | 
			
		||||
                'SDD': False, 'DSD': True, 'DDS': False}
 | 
			
		||||
        # output
 | 
			
		||||
    # allocate output
 | 
			
		||||
    CS0 = BS0
 | 
			
		||||
    CS1 = BS1
 | 
			
		||||
    CS2 = BS3 if trans_c else AS1
 | 
			
		||||
    CS3 = AS1 if trans_c else BS3
 | 
			
		||||
        locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
 | 
			
		||||
    c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
 | 
			
		||||
        grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
 | 
			
		||||
        _kernel[grid](
 | 
			
		||||
            a,
 | 
			
		||||
            b,
 | 
			
		||||
            c,
 | 
			
		||||
            a.stride(0),
 | 
			
		||||
            a.stride(1),
 | 
			
		||||
            a.stride(3 if trans_a else 2),
 | 
			
		||||
            a.stride(2 if trans_a else 3),
 | 
			
		||||
            b.stride(0),
 | 
			
		||||
            b.stride(1),
 | 
			
		||||
            b.stride(3 if trans_b else 2),
 | 
			
		||||
            b.stride(2 if trans_b else 3),
 | 
			
		||||
            c.stride(0),
 | 
			
		||||
            c.stride(1),
 | 
			
		||||
            c.stride(3 if trans_c else 2),
 | 
			
		||||
            c.stride(2 if trans_c else 3),
 | 
			
		||||
            BS3,
 | 
			
		||||
            AS1,
 | 
			
		||||
            0,
 | 
			
		||||
            0,
 | 
			
		||||
            lut,
 | 
			
		||||
            locks,
 | 
			
		||||
            num_locks,
 | 
			
		||||
            num_warps=4,
 | 
			
		||||
            **meta
 | 
			
		||||
    # meta-parameter heuristics
 | 
			
		||||
    TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block]
 | 
			
		||||
    # compute output
 | 
			
		||||
    grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0]
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    _dsd_kernel[grid](
 | 
			
		||||
        a, b, c,
 | 
			
		||||
        a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
 | 
			
		||||
        b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
 | 
			
		||||
        c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
 | 
			
		||||
        BS3, AS1, lut,
 | 
			
		||||
        TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
 | 
			
		||||
        num_warps=4, GROUP_SIZE_M=8,
 | 
			
		||||
    )
 | 
			
		||||
    # exit()
 | 
			
		||||
    return c
 | 
			
		||||
 | 
			
		||||
def dsd_lut(layout, block, step, trans, device):  
 | 
			
		||||
    sizes = torch.sum(layout, 2 if trans else 1)
 | 
			
		||||
    head_id, col_id = sizes.nonzero(as_tuple=True)
 | 
			
		||||
    sizes = sizes.flatten()
 | 
			
		||||
    segments = sizes*step
 | 
			
		||||
    # pointer increments
 | 
			
		||||
    if trans:
 | 
			
		||||
        nnz = layout.nonzero(as_tuple=False)
 | 
			
		||||
    else:
 | 
			
		||||
        nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
 | 
			
		||||
    num_blocks = nnz.size(0)
 | 
			
		||||
    offsets = torch.zeros_like(sizes)
 | 
			
		||||
    offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
 | 
			
		||||
    offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
 | 
			
		||||
    # -------------------------------
 | 
			
		||||
    # dense input pointer increments
 | 
			
		||||
    # -------------------------------
 | 
			
		||||
    # given a list of the indices for the first element of each non-zero block.
 | 
			
		||||
    # For example, for the indices
 | 
			
		||||
    # [32, 80, 128, 256, 288]
 | 
			
		||||
    # we would generate the increments
 | 
			
		||||
    # [32, 48, 48, 128, 32]
 | 
			
		||||
    #        ^
 | 
			
		||||
    #   index of first element
 | 
			
		||||
    # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
 | 
			
		||||
    # that is smaller than the block size, so we need to do a bit of extra work
 | 
			
		||||
    # to handle this case
 | 
			
		||||
    B_idx = nnz[:, 2] * block
 | 
			
		||||
    B_incs = B_idx.clone()
 | 
			
		||||
    B_incs[1:] -= B_idx[:-1]
 | 
			
		||||
    div = block // step
 | 
			
		||||
    B_incs = B_incs.view(-1, 1).repeat(1, div)
 | 
			
		||||
    B_incs[:, 1:] = step
 | 
			
		||||
    B_incs[:, 0] -= (div - 1) * step
 | 
			
		||||
    # first increment for each reduction is actually the offset
 | 
			
		||||
    B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
 | 
			
		||||
    B_incs = B_incs.view(-1)
 | 
			
		||||
    # -------------------------------
 | 
			
		||||
    # sparse input pointer increments
 | 
			
		||||
    # -------------------------------
 | 
			
		||||
    # same as above, except that the increments are in the sparse memory layout
 | 
			
		||||
    if trans:
 | 
			
		||||
        A_idx = torch.arange(num_blocks)
 | 
			
		||||
    else:
 | 
			
		||||
        A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
 | 
			
		||||
        current_offset = 0
 | 
			
		||||
        for z in range(layout.size(0)):
 | 
			
		||||
            layoutw = layout[z, :, :].clone()
 | 
			
		||||
            msum = layoutw.sum()
 | 
			
		||||
            layoutw[layoutw > 0] = 1 + torch.arange(msum)
 | 
			
		||||
            A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
 | 
			
		||||
            current_offset += msum
 | 
			
		||||
    A_incs = A_idx * block * block
 | 
			
		||||
    A_incs[1:] -= A_idx[:-1] * block * block
 | 
			
		||||
    A_incs = A_incs.view(-1, 1).repeat(1, div)
 | 
			
		||||
    if trans:
 | 
			
		||||
        A_incs[:, 1:] = step
 | 
			
		||||
        A_incs[:, 0] -= (div - 1) * step
 | 
			
		||||
    else:
 | 
			
		||||
        A_incs[:, 1:] = step * block
 | 
			
		||||
        A_incs[:, 0] -= (div - 1) * step * block
 | 
			
		||||
    A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
 | 
			
		||||
    A_incs = A_incs.view(-1)
 | 
			
		||||
    # create header
 | 
			
		||||
    width    = col_id.size(0)
 | 
			
		||||
    offsets  = offsets*2*div + 4*width
 | 
			
		||||
    segments = segments*div
 | 
			
		||||
    header   = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
 | 
			
		||||
    # create increments
 | 
			
		||||
    incs     = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
 | 
			
		||||
    incs     = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
 | 
			
		||||
    # create lut
 | 
			
		||||
    lut = torch.cat((header, incs))
 | 
			
		||||
    lut = lut.type(torch.int32).to(device)
 | 
			
		||||
    # create locks
 | 
			
		||||
    return lut, None, width, None
 | 
			
		||||
 | 
			
		||||
# -----------------------------
 | 
			
		||||
# Dense = Dense x Sparse (DDS)
 | 
			
		||||
# -----------------------------
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _dds_kernel(
 | 
			
		||||
    A, B, C, 
 | 
			
		||||
    stride_za, stride_ha, stride_ma, stride_ka, 
 | 
			
		||||
    stride_zb, stride_hb, stride_bk, stride_bn, 
 | 
			
		||||
    stride_zc, stride_hc, stride_mc, stride_nc, 
 | 
			
		||||
    DS0, DS1, lut, **meta
 | 
			
		||||
):
 | 
			
		||||
    TILE_M = meta['TILE_M']
 | 
			
		||||
    TILE_N = meta['TILE_N']
 | 
			
		||||
    TILE_K = meta['TILE_K']
 | 
			
		||||
    GROUP_SIZE_M = meta['GROUP_SIZE_M']
 | 
			
		||||
    #------------#
 | 
			
		||||
    #- Prologue -#
 | 
			
		||||
    #------------#
 | 
			
		||||
    pid_m = tl.program_id(0)
 | 
			
		||||
    pid_n = tl.program_id(1)
 | 
			
		||||
    num_pid_m = tl.num_programs(0)
 | 
			
		||||
    num_pid_n = tl.num_programs(1)
 | 
			
		||||
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
 | 
			
		||||
    pid_z = tl.program_id(2)
 | 
			
		||||
    header = lut + pid_m * 4
 | 
			
		||||
    offset = tl.load(header + 0)
 | 
			
		||||
    AS1 = tl.load(header + 1)
 | 
			
		||||
    column = tl.load(header + 2)
 | 
			
		||||
    off_h = tl.load(header + 3)
 | 
			
		||||
    pinc = lut + offset
 | 
			
		||||
    # initialize pointers to A (dense)
 | 
			
		||||
    offs_am = pid_n*TILE_M + tl.arange(0, TILE_M)
 | 
			
		||||
    start_ak = tl.load(pinc)
 | 
			
		||||
    start_ak = tl.multiple_of(start_ak, 8)
 | 
			
		||||
    offs_ak = start_ak + tl.arange(0, TILE_K)
 | 
			
		||||
    ptrs_a = A + pid_z * stride_za \
 | 
			
		||||
            + off_h * stride_ha \
 | 
			
		||||
            + offs_am[:, None] * stride_ma \
 | 
			
		||||
            + offs_ak[None, :] * stride_ka
 | 
			
		||||
    # initialize pointers to B (sparse)
 | 
			
		||||
    block_id = tl.load(pinc + 1)
 | 
			
		||||
    block_id = tl.multiple_of(block_id, 8) 
 | 
			
		||||
    offs_bn = tl.arange(0, TILE_N)
 | 
			
		||||
    offs_bk = tl.arange(0, TILE_K)
 | 
			
		||||
    ptrs_b = B + pid_z * stride_zb \
 | 
			
		||||
            + block_id * stride_hb \
 | 
			
		||||
            + offs_bn[None, :] * stride_bn \
 | 
			
		||||
            + offs_bk[:, None] * stride_bk
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    ##    Inner Loop    ##
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
 | 
			
		||||
    for k in range(AS1, 0, -TILE_K):
 | 
			
		||||
        a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0)
 | 
			
		||||
        b = tl.load(ptrs_b, mask = True)
 | 
			
		||||
        acc += tl.dot(a, b)
 | 
			
		||||
        pinc += 2
 | 
			
		||||
        inc_a = tl.load(pinc)
 | 
			
		||||
        inc_b = tl.load(pinc + 1)
 | 
			
		||||
        inc_a = tl.multiple_of(inc_a, 8)
 | 
			
		||||
        inc_b = tl.multiple_of(inc_b, 8)
 | 
			
		||||
        inc_a = inc_a * stride_ka
 | 
			
		||||
        ptrs_a += inc_a
 | 
			
		||||
        ptrs_b += inc_b
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    ##    Epilogue      ##
 | 
			
		||||
    ## ---------------- ##
 | 
			
		||||
    c = acc.to(C.dtype.element_ty)
 | 
			
		||||
    # initialize pointers to C (dense)
 | 
			
		||||
    offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M)
 | 
			
		||||
    offs_cn = column * TILE_N + tl.arange(0, TILE_N)
 | 
			
		||||
    ptrs_c = C + off_h * stride_hc \
 | 
			
		||||
            + pid_z * stride_zc \
 | 
			
		||||
            + offs_cm[:, None] * stride_mc \
 | 
			
		||||
            + offs_cn[None, :] * stride_nc
 | 
			
		||||
    # write back
 | 
			
		||||
    tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
 | 
			
		||||
 | 
			
		||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
 | 
			
		||||
    # shapes / dtypes
 | 
			
		||||
    AS0 = a.size(0)
 | 
			
		||||
    AS1 = a.size(1)
 | 
			
		||||
    AS2 = a.size(3 if trans_a else 2)
 | 
			
		||||
    BS2 = block * spdims[1 if trans_b else 2]
 | 
			
		||||
    dtype = a.dtype
 | 
			
		||||
    # output
 | 
			
		||||
    CS0 = AS0
 | 
			
		||||
    CS1 = AS1
 | 
			
		||||
    CS2 = BS2 if trans_c else AS2
 | 
			
		||||
    CS3 = AS2 if trans_c else BS2
 | 
			
		||||
    c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
 | 
			
		||||
    TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
 | 
			
		||||
    grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0]
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    _dds_kernel[grid](
 | 
			
		||||
        a, b, c,
 | 
			
		||||
        a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
 | 
			
		||||
        b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
 | 
			
		||||
        c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
 | 
			
		||||
        AS2, BS2, lut,
 | 
			
		||||
        TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
 | 
			
		||||
        num_warps=4, GROUP_SIZE_M=8,
 | 
			
		||||
    )
 | 
			
		||||
    return c
 | 
			
		||||
 | 
			
		||||
    fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
 | 
			
		||||
##############
 | 
			
		||||
#  MAIN API  #
 | 
			
		||||
##############
 | 
			
		||||
class _matmul(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
@@ -586,27 +495,27 @@ class matmul:
 | 
			
		||||
            return self.lut_cache[key]
 | 
			
		||||
        # C look-up table
 | 
			
		||||
        layout, block = self.layout, self.block
 | 
			
		||||
        step = 16
 | 
			
		||||
        step = min(block, 32)
 | 
			
		||||
        if self.mode == 'sdd':
 | 
			
		||||
            c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device)
 | 
			
		||||
            c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device)
 | 
			
		||||
        elif self.mode == 'dsd':
 | 
			
		||||
            c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
 | 
			
		||||
            c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device)
 | 
			
		||||
        elif self.mode == 'dds':
 | 
			
		||||
            c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
 | 
			
		||||
            c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device)
 | 
			
		||||
        # DA look-up table
 | 
			
		||||
        if self.mode == 'sdd':
 | 
			
		||||
            da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
 | 
			
		||||
            da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device)
 | 
			
		||||
        elif self.mode == 'dsd':
 | 
			
		||||
            da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device)
 | 
			
		||||
            da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device)
 | 
			
		||||
        elif self.mode == 'dds':
 | 
			
		||||
            da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
 | 
			
		||||
            da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device)
 | 
			
		||||
        # DB look-up table
 | 
			
		||||
        if self.mode == 'sdd':
 | 
			
		||||
            db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
 | 
			
		||||
            db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device)
 | 
			
		||||
        elif self.mode == 'dsd':
 | 
			
		||||
            db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
 | 
			
		||||
            db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device)
 | 
			
		||||
        elif self.mode == 'dds':
 | 
			
		||||
            db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
 | 
			
		||||
            db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device)
 | 
			
		||||
        self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
 | 
			
		||||
                               da_lut, da_num_locks, da_width, da_packs,
 | 
			
		||||
                               db_lut, db_num_locks, db_width, db_packs)
 | 
			
		||||
@@ -709,7 +618,3 @@ class matmul:
 | 
			
		||||
        b = add_extra_dims(b)
 | 
			
		||||
 | 
			
		||||
        return a, b
 | 
			
		||||
 | 
			
		||||
def _dim_to_name(x):
 | 
			
		||||
    # assert x in (-1, -2)
 | 
			
		||||
    return "last" if x == -1 else "second to last"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,18 +1,6 @@
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
import triton
 | 
			
		||||
import torch
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def next_power_of_2(n):
 | 
			
		||||
    n -= 1
 | 
			
		||||
    n |= n >> 1
 | 
			
		||||
    n |= n >> 2
 | 
			
		||||
    n |= n >> 4
 | 
			
		||||
    n |= n >> 8
 | 
			
		||||
    n |= n >> 16
 | 
			
		||||
    n += 1
 | 
			
		||||
    return n
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def num_warps(n):
 | 
			
		||||
@@ -23,11 +11,11 @@ def num_warps(n):
 | 
			
		||||
    return 16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])})
 | 
			
		||||
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])})
 | 
			
		||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])})
 | 
			
		||||
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _forward(
 | 
			
		||||
    X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
 | 
			
		||||
    X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
 | 
			
		||||
    **meta
 | 
			
		||||
):
 | 
			
		||||
    TN = meta['TN']
 | 
			
		||||
@@ -76,13 +64,16 @@ def _forward(
 | 
			
		||||
        if meta['ATTN_MASK_MUL']:
 | 
			
		||||
            attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
 | 
			
		||||
        x = x + attn_m
 | 
			
		||||
    # apply causal mask
 | 
			
		||||
    is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm 
 | 
			
		||||
    x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
 | 
			
		||||
    # computation
 | 
			
		||||
    x = tl.softmax(x)
 | 
			
		||||
    tl.store(px, x, mask=check)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
 | 
			
		||||
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
 | 
			
		||||
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
 | 
			
		||||
    pidhm = tl.program_id(0)
 | 
			
		||||
@@ -139,10 +130,13 @@ class _softmax(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, maxlut, bench, time
 | 
			
		||||
        ctx, x, scale, rpe, 
 | 
			
		||||
        key_padding_mask, attn_mask, 
 | 
			
		||||
        kp_mask_mode, attn_mask_mode, 
 | 
			
		||||
        is_causal,
 | 
			
		||||
        spdims, block, lut, maxlut
 | 
			
		||||
    ):
 | 
			
		||||
        apply_scale = False if scale == 1.0 else True
 | 
			
		||||
 | 
			
		||||
        # handle None rpe
 | 
			
		||||
        if rpe is None:
 | 
			
		||||
            apply_rpe = False
 | 
			
		||||
@@ -151,7 +145,6 @@ class _softmax(torch.autograd.Function):
 | 
			
		||||
        else:
 | 
			
		||||
            apply_rpe = True
 | 
			
		||||
            stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
 | 
			
		||||
 | 
			
		||||
        # handle None key_padding_mask
 | 
			
		||||
        if key_padding_mask is None:
 | 
			
		||||
            apply_kp_mask = False
 | 
			
		||||
@@ -160,7 +153,6 @@ class _softmax(torch.autograd.Function):
 | 
			
		||||
        else:
 | 
			
		||||
            apply_kp_mask = True
 | 
			
		||||
            stride_zkpm = key_padding_mask.stride(0)
 | 
			
		||||
 | 
			
		||||
        # handle None attention_mask
 | 
			
		||||
        if attn_mask is None:
 | 
			
		||||
            apply_attn_mask = False
 | 
			
		||||
@@ -169,22 +161,19 @@ class _softmax(torch.autograd.Function):
 | 
			
		||||
        else:
 | 
			
		||||
            apply_attn_mask = True
 | 
			
		||||
            stride_zattnm = attn_mask.stride(0)
 | 
			
		||||
 | 
			
		||||
        # run kernel
 | 
			
		||||
        M = x.shape[0]
 | 
			
		||||
        meta = {
 | 
			
		||||
            'BLOCK': block,
 | 
			
		||||
            'APPLY_SCALE': apply_scale,
 | 
			
		||||
            'APPLY_RPE': apply_rpe,
 | 
			
		||||
            'APPLY_KP_MASK': apply_kp_mask,
 | 
			
		||||
            'APPLY_ATTN_MASK': apply_attn_mask,
 | 
			
		||||
            'KP_MASK_MUL': kp_mask_mode == 'mul',
 | 
			
		||||
            'ATTN_MASK_MUL': attn_mask_mode == 'mul',
 | 
			
		||||
        }
 | 
			
		||||
        grid = lambda opt: [spdims[0] * spdims[1] * block, M]
 | 
			
		||||
        _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
 | 
			
		||||
                       stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta)
 | 
			
		||||
 | 
			
		||||
        grid = [spdims[0] * spdims[1] * block, M]
 | 
			
		||||
        _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\
 | 
			
		||||
                       stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, 
 | 
			
		||||
                       BLOCK           = block,
 | 
			
		||||
                       APPLY_SCALE     = apply_scale, 
 | 
			
		||||
                       APPLY_RPE       = apply_rpe,
 | 
			
		||||
                       APPLY_KP_MASK   = apply_kp_mask, 
 | 
			
		||||
                       APPLY_ATTN_MASK = apply_attn_mask, 
 | 
			
		||||
                       KP_MASK_MUL     = (kp_mask_mode == 'mul'),
 | 
			
		||||
                       ATTN_MASK_MUL   = (attn_mask_mode == 'mul'),
 | 
			
		||||
                       force_nc_cache  = True)
 | 
			
		||||
        # save to context
 | 
			
		||||
        ctx.mark_dirty(x)
 | 
			
		||||
        ctx.save_for_backward(x, lut)
 | 
			
		||||
@@ -213,25 +202,24 @@ class _softmax(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
class softmax:
 | 
			
		||||
 | 
			
		||||
    apply_softmax = _softmax.apply
 | 
			
		||||
 | 
			
		||||
    def make_lut(self, device):
 | 
			
		||||
        key = (device, )
 | 
			
		||||
        if key not in self.lut_cache:
 | 
			
		||||
            self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
 | 
			
		||||
        return self.lut_cache[key]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, layout, block, bench=False):
 | 
			
		||||
    def __init__(self, layout, block):
 | 
			
		||||
        self.spdims = layout.shape
 | 
			
		||||
        self.layout = layout
 | 
			
		||||
        self.block = block
 | 
			
		||||
        self.bench = bench
 | 
			
		||||
        self.lut_cache = dict()
 | 
			
		||||
 | 
			
		||||
    def __call__(
 | 
			
		||||
        self, x, scale=1., rpe=None, key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add'
 | 
			
		||||
        self, x, scale=1., rpe=None, 
 | 
			
		||||
        key_padding_mask=None, attn_mask=None, 
 | 
			
		||||
        key_padding_mask_mode='add', attn_mask_mode='add',
 | 
			
		||||
        is_causal = False
 | 
			
		||||
    ):
 | 
			
		||||
        time_y = [None]
 | 
			
		||||
        if rpe is not None and rpe.dtype != x.dtype:
 | 
			
		||||
            raise ValueError('relative position embedding must be %s' % x.dtype)
 | 
			
		||||
        if attn_mask is not None and attn_mask.dtype != x.dtype:
 | 
			
		||||
@@ -239,8 +227,12 @@ class softmax:
 | 
			
		||||
        if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
 | 
			
		||||
            raise ValueError('Key padding mask must be %s' % x.dtype)
 | 
			
		||||
        lut, maxlut = self.make_lut(x.device)
 | 
			
		||||
        x = softmax.apply_softmax(
 | 
			
		||||
            x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut,
 | 
			
		||||
            maxlut, self.bench, time_y
 | 
			
		||||
        x = _softmax.apply(
 | 
			
		||||
            x, scale, rpe, 
 | 
			
		||||
            key_padding_mask, attn_mask, 
 | 
			
		||||
            key_padding_mask_mode, attn_mask_mode, 
 | 
			
		||||
            is_causal,
 | 
			
		||||
            self.spdims, self.block, 
 | 
			
		||||
            lut, maxlut
 | 
			
		||||
        )
 | 
			
		||||
        return x
 | 
			
		||||
		Reference in New Issue
	
	Block a user