From 2c287544cb6fdd63ad3e6927e467c9d0660489a1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 27 Sep 2021 18:25:16 -0700 Subject: [PATCH] [OPS] Faster and cleaner block-sparse implementation (#311) --- lib/codegen/selection/generator.cc | 3 +- lib/codegen/transform/coalesce.cc | 2 +- lib/codegen/transform/pipeline.cc | 5 +- python/bench/bench_blocksparse.py | 13 +- .../test/unit/operators/test_blocksparse.py | 22 +- python/triton/language/core.py | 32 + python/triton/ops/blocksparse/matmul.py | 939 ++++++++---------- python/triton/ops/blocksparse/softmax.py | 84 +- 8 files changed, 518 insertions(+), 582 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e4e2631d6..f0068be11 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2500,9 +2500,8 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) { void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); - for(ir::instruction *i: block->get_inst_list()){ + for(ir::instruction *i: block->get_inst_list()) visit_value(i); - } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); } diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 6b8cd7e70..448517408 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -117,7 +117,7 @@ void coalesce::run(ir::module &mod) { queue.push_back(inst_op); } } - if(in_contig.empty() || out_contig==in_contig) + if(in_contig.size() <= 1 || out_contig==in_contig) continue; builder.set_insert_point_after(val_inst); auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index 24e79e47d..096b7d7b6 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -188,7 +188,7 @@ void pipeline::run(ir::module &mod) { builder.set_insert_point(block->get_first_non_phi()); std::map load_ivs; std::map next_load_ivs; - for (ir::phi_node* iv : induction_vars) { + for (auto& [iv, val] : prev_phi_vals) { ir::phi_node* pn = builder.create_phi(iv->get_type(), 2); pn->add_incoming(prev_phi_vals[iv], header); load_ivs[iv] = pn; @@ -198,7 +198,8 @@ void pipeline::run(ir::module &mod) { // pre-fetch next iteration builder.set_insert_point(block->get_inst_list().back()); - ir::value* next_ptr = ptr->get_value_for_block(block); +// ir::value* next_ptr = ptr->get_value_for_block(block); + ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs); ir::value* next_mask = builder.create_splat( rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes()); if (auto* masked_load = dynamic_cast(load)) { diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py index 6954aa315..b6eacd884 100644 --- a/python/bench/bench_blocksparse.py +++ b/python/bench/bench_blocksparse.py @@ -11,20 +11,20 @@ square_confs = [ x_names = ['M', 'N', 'K'], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], line_arg = 'block', - line_vals = [16, 32, 64], - line_names = ['Block16', 'Block32', 'Block64'], + line_vals = [16, 32, 64, 128], + line_names = ['Block16', 'Block32', 'Block64', 'Block128'], ylabel = 'TFLOPS', plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', args = {'layout_mode': layout_mode, 'op_mode': op_mode, 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} )\ for AT in [False] for BT in [False] \ - for op_mode in ['sdd', 'dsd', 'dds'] for layout_mode in ['tril', 'dense'] + for op_mode in ['dsd'] for layout_mode in ['dense'] ] @triton.testing.perf_report(square_confs) -def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5): +def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): Z, H = 1, 1 make_layout = { 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ @@ -85,4 +85,7 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50): op = triton.ops.blocksparse.softmax(layout, block) gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep) - return gbps(mean_ms), gbps(min_ms), gbps(max_ms) \ No newline at end of file + return gbps(mean_ms), gbps(min_ms), gbps(max_ms) + + +bench_matmul.run(print_data=True, show_plots=True) \ No newline at end of file diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 86006671f..3b9a1c17f 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -7,7 +7,7 @@ import pytest @pytest.mark.parametrize("TRANS_A", [False, True]) @pytest.mark.parametrize("TRANS_B", [False, True]) @pytest.mark.parametrize("BLOCK", [16, 32, 64]) -@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) +@pytest.mark.parametrize("DTYPE", [torch.float16]) def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): # set seed torch.random.manual_seed(0) @@ -41,15 +41,20 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= @pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792]) @pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) def test_softmax(BLOCK, WIDTH, DTYPE): + is_causal = True # set seed torch.random.manual_seed(0) - Z, H, M, N = 2, 4, WIDTH, WIDTH + Z, H, M, N = 1, 1, WIDTH, WIDTH scale = 0.4 # create inputs layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda") + # make sure each row has at least one non-zero element + torch.diagonal(layout)[:] = 1 + torch.diagonal(at_mask)[:] = 1 kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda") + kp_mask[:] = 0 kp_mask[kp_mask == 1.0] = float("-inf") # triton result op = triton.ops.blocksparse.softmax(layout, BLOCK) @@ -61,16 +66,15 @@ def test_softmax(BLOCK, WIDTH, DTYPE): key_padding_mask_mode="add", attn_mask=at_mask.to(DTYPE), attn_mask_mode="mul", + is_causal=is_causal, ) # torch result rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) - if at_mask is not None: - # broadcast at_mask to the same shape as rx - M = at_mask[None, None, :, :] + torch.zeros_like(rx) - rx[M == 0] = float("-inf") - if kp_mask is not None: - rx += kp_mask[:, None, None, :] - ry = torch.softmax(rx * scale, -1) + # broadcast at_mask to the same shape as rx + if is_causal: at_mask = torch.tril(at_mask) + M = at_mask[None, None, :, :] + torch.zeros_like(rx) + rx[M == 0] = float("-inf") + # rx += kp_mask[:, None, None, :] ry = torch.softmax(rx * scale, -1) ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) # compare diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 22cd717e7..6f192b2c6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -699,3 +699,35 @@ def ravel(x): :type x: Block """ return triton.language.reshape(x, [x.type.numel]) + +@triton.jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + transformes indices of a row-major size_i*size_j matrix into those + of one where indices are row major for each group of size_j rows. + For example, for size_i = size_j = 4 and size_g = 2, it will transform + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + into + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i*size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = minimum(size_i - off_i, size_g) + # new row and column indices + new_i = off_i + (ij % size_g) + new_j = (ij % size_gj) // size_g + return new_i, new_j \ No newline at end of file diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index d9528133c..bcba2e505 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -3,530 +3,439 @@ import triton.language as tl import triton._C.libtriton as libtriton import torch +# ******************************************************** +# -------------------------------------------------------- +# Sparse = Dense x Dense (SDD) +# This operation uses super-blocking to make sure that +# it's done efficiently when small blocks can be grouped +# together +# -------------------------------------------------------- +# ******************************************************** +@triton.heuristics({ + 'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0, +}) @triton.jit -def _kernel( - A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, - stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta +def _sdd_kernel( + A, B, C, + stride_za, stride_ha, stride_ma, stride_ak, + stride_zb, stride_hb, stride_bk, stride_nb, + stride_zc, stride_hc, stride_mc, stride_nc, + K, grid_offset, lut, **meta ): - TM = meta['TM'] - TN = meta['TN'] - TK = meta['TK'] - TZ = meta['TZ'] - BLOCK = meta['BLOCK'] + TILE_M = meta['TILE_M'] + TILE_N = meta['TILE_N'] + TILE_K = meta['TILE_K'] + BLOCK = meta['BLOCK'] #------------# #- Prologue -# #------------# - pid0 = tl.program_id(0) - pid1 = tl.program_id(1) - pidz = tl.program_id(2) - if meta['SDD']: - pid1 = pid1 + SDD_off_width - blockidm = tl.arange(0, TM) // BLOCK - blockidn = tl.arange(0, TN) // BLOCK - offlutm = blockidm * (TN // BLOCK) * 4 - offlutn = blockidn * 4 - header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 - z = tl.load(header + 0) - i = tl.load(header + 1 + offlutm) - j = tl.load(header + 2 + offlutn) - AS1 = SDD_K - lockid = tl.where(TZ > 1, 1, 0) - offka = pid0 * AS1 - offkb = pid0 * AS1 - offmc = 0 - offnc = 0 - offpa = 0 - offpb = 0 - maxid = TZ - offhc = 0 - offha = z - offhb = z - ram = i * BLOCK + (tl.arange(0, TM) % BLOCK) - rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK) - else: - header = lut + pid0 * 6 - offset = tl.load(header + 0) - AS1 = tl.load(header + 1) - column = tl.load(header + 2) - depth = tl.load(header + 3) - lockid = tl.load(header + 4) - maxid = tl.load(header + 5) - pinc = lut + offset - offhc = depth - if meta['DSD']: - # output offset - offnc = pid1 * TN - offmc = column * TM - offpc = 0 - # dense input offset - offnb = pid1 * TN - offkb = tl.load(pinc) - offkb = tl.multiple_of(offkb, 8) # compiler hint - offpb = 0 - # sparse input offset - offma = 0 - offka = 0 - offpa = tl.load(pinc + 1) - offpa = tl.multiple_of(offpa, 8) # compiler hint - offpa = offpa * BLOCK * BLOCK - offha = 0 - offhb = depth - else: - # output offset - offmc = pid1 * TM - offnc = column * TN - offpc = 0 - # dense input offset - offma = pid1 * TM - offka = tl.load(pinc) - offka = tl.multiple_of(offka, 8) # compiler hint - offpa = 0 - # sparse input offset - offnb = 0 - offkb = 0 - offpb = tl.load(pinc + 1) - offpb = tl.multiple_of(offpb, 8) # compiler hint - offpb = offpb * BLOCK * BLOCK - offha = depth - offhb = 0 - ram = offma + tl.arange(0, TM) - rbn = offnb + tl.arange(0, TN) - - # initialize a, b pointers - rka = offka + tl.arange(0, TK) - rkb = offkb + tl.arange(0, TK) - pa = A + pidz * TZ * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka - pb = B + pidz * TZ * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb - if meta['DDS']: - checkam = ram[:, None] < DS0 - else: - checkam = AS1 > 0 - if meta['DSD']: - checkbn = rbn[None, :] < DS0 - else: - checkbn = AS1 > 0 - a = tl.load(pa, mask=checkam, other=0.) - b = tl.load(pb, mask=checkbn, other=0.) - + pid1 = tl.program_id(1) + grid_offset + blockidm = tl.arange(0, TILE_M) // BLOCK + blockidn = tl.arange(0, TILE_N) // BLOCK + offlutm = blockidm * (TILE_N // BLOCK) * 4 + offlutn = blockidn * 4 + header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4 + # batch offset + off_z = tl.program_id(2) + # head offset + off_h = tl.load(header + 0) + # initialize pointers to A + start_am = tl.load(header + 1 + offlutm) + offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) + offs_ak = tl.arange(0, TILE_K) + a_ptrs = A + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B + start_bn = tl.load(header + 2 + offlutn) + offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) + offs_bk = tl.arange(0, TILE_K) + b_ptrs = B + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk ## ---------------- ## ## Inner Loop ## ## ---------------- ## - acc = tl.zeros((TM, TN), dtype=tl.float32) - for k in range(AS1, 0, -TK*TZ): + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(K, 0, -TILE_K): + if meta['EVEN_K']: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) + b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) acc += tl.dot(a, b) - if meta['SDD']: - inc_a = TK * TZ * stride_ka - inc_b = TK * TZ * stride_kb - else: - pinc += 2 - if meta['DSD']: - inc_b = tl.load(pinc) - inc_a = tl.load(pinc + 1) - inc_b = tl.multiple_of(inc_b, 8) - inc_a = tl.multiple_of(inc_a, 8) - inc_b = inc_b * stride_kb - if meta['DDS']: - inc_a = tl.load(pinc) - inc_b = tl.load(pinc + 1) - inc_a = tl.multiple_of(inc_a, 8) - inc_b = tl.multiple_of(inc_b, 8) - inc_a = inc_a * stride_ka - pa += inc_a - pb += inc_b - # pre-fetch - checkak = k > TK - checkbk = k > TK - checka = checkam & checkak - checkb = checkbn & checkbk - a = tl.load(pa, mask=checka) - b = tl.load(pb, mask=checkb) + a_ptrs += TILE_K * stride_ak + b_ptrs += TILE_K * stride_bk c = acc.to(C.dtype.element_ty) + ## ---------------- ## + ## Epilogue ## + ## ---------------- ## + blockidm = tl.arange(0, TILE_M) // BLOCK + blockidn = tl.arange(0, TILE_N) // BLOCK + offlutm = blockidm * (TILE_N // BLOCK) * 4 + offlutn = blockidn * 4 + off_block_id = 3 + offlutm[:, None] + offlutn[None, :] + block_id = tl.load(header + off_block_id) + # initialize pointers to C + offs_cm = tl.arange(0, TILE_M) % BLOCK + offs_cn = tl.arange(0, TILE_N) % BLOCK + pc = C + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc + tl.store(pc, c, mask=True) - if meta['SDD']: - checkc = True - rr_blockidm = tl.arange(0, TM) // BLOCK - rr_blockidn = tl.arange(0, TN) // BLOCK - rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 - rr_offlutn = rr_blockidn * 4 - off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] - bkid = tl.load(header + off_bkid) - offpc = bkid * BLOCK * BLOCK - rcm = tl.arange(0, TM) % BLOCK - rcn = tl.arange(0, TN) % BLOCK +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): + # (A * B)^T = B^T * A^T + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + # shape constraints + a_dim = -2 if trans_a else -1 + b_dim = -1 if trans_b else -2 + Ka, Kb = a.shape[a_dim], b.shape[b_dim] + if Ka != Kb: + raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") + if Ka % 16 != 0: + raise ValueError('Reduction size for SDD must be a multiple of 16') + # allocate output + n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)]) + c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device) + # each iteration of the loop below + # computes the value for one group of super-blocks + # (e.g., all 4x4 super-blocks) + for lut, width, pack in zip(luts, widths, packs): + # maximum grid size in Triton/CUDA is 64k but we may have more + # super-blocks than that. + max_grid = 65535 + for off_grid in range(0, width, max_grid): + grid = [1, min(max_grid, width - off_grid), c.shape[0]] + # fmt: off + pgm = _sdd_kernel[grid]( + a, b, c, + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), c.stride(2), c.stride(3), + Ka, off_grid, lut, + TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3, + num_warps=4, + ) + # print(pgm.asm['ptx']) + # exit() + return c + +def sdd_lut(layout, block, device): + start_width = 128 // block + layout = layout.type(torch.int32) + superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width) + luts, widths, packs = [], [], [] + for size, nnz in superblocks: + nnz = nnz.reshape(-1, 4) + width = nnz.shape[0] // (size * size) + luts.append(torch.from_numpy(nnz).type(torch.int32).to(device)) + widths.append(width) + packs.append(size) + return luts, None, widths, packs + +# ----------------------------- +# Dense = Sparse x Dense (DSD) +# This operation uses a look-up table that contains pre-computed pointer increments +# in order to minimize computations in the inner loop of the matmul kernel. +# ----------------------------- +@triton.jit +def _dsd_kernel( + A, B, C, + stride_az, stride_ha, stride_am, stride_ak, + stride_zb, stride_hb, stride_bk, stride_bn, + stride_zc, stride_hc, stride_cm, stride_cn, + DS0, DS1, lut, **meta +): + TILE_M = meta['TILE_M'] + TILE_N = meta['TILE_N'] + TILE_K = meta['TILE_K'] + GROUP_SIZE_M = meta['GROUP_SIZE_M'] + #------------# + #- Prologue -# + #------------# + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + pidz = tl.program_id(2) + header = lut + pid_m * 4 + offset = tl.load(header + 0) + K = tl.load(header + 1) + column = tl.load(header + 2) + off_h = tl.load(header + 3) + pinc = lut + offset + # initialize pointers to A (sparse) + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint + offs_am = tl.arange(0, TILE_M) + offs_ak = tl.arange(0, TILE_K) + pa = A + pidz * stride_az \ + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B (dense) + offs_bn = pid_n*TILE_N + tl.arange(0, TILE_N) + start_bk = tl.load(pinc) + start_bk = tl.multiple_of(start_bk, 8) # compiler hint + offs_bk = start_bk + tl.arange(0, TILE_K) + pb = B + pidz * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + ## ---------------- ## + ## Inner Loop ## + ## ---------------- ## + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(K, 0, -TILE_K): + a = tl.load(pa, mask=True) + b = tl.load(pb, mask=offs_bn[None, :] < DS0) + acc += tl.dot(a, b) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + pa += inc_a + pb += inc_b*stride_bk + c = acc.to(C.dtype.element_ty) + # initialize pointers to C + offs_cm = column*TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N) + pc = C + off_h * stride_hc \ + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask = offs_cn[None, :] < DS0) + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): + # shapes / dtypes + AS1 = block * spdims[2 if trans_a else 1] + BS0 = b.size(0) + BS1 = b.size(1) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # allocate output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + # meta-parameter heuristics + TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block] + # compute output + grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0] + # fmt: off + _dsd_kernel[grid]( + a, b, c, + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), + BS3, AS1, lut, + TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3, + num_warps=4, GROUP_SIZE_M=8, + ) + # exit() + return c + +def dsd_lut(layout, block, step, trans, device): + sizes = torch.sum(layout, 2 if trans else 1) + head_id, col_id = sizes.nonzero(as_tuple=True) + sizes = sizes.flatten() + segments = sizes*step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) else: - rcm = offmc + tl.arange(0, TM) - rcn = offnc + tl.arange(0, TN) - if meta['DSD']: - checkc = rcn[None, :] < DS0 - if meta['DDS']: - checkc = rcm[:, None] < DS0 - - pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc - # write-back directly - if lockid == 0: - tl.store(pc, c, mask=checkc) - # accumulate partial results using spin-locks + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) + # ------------------------------- + # dense input pointer increments + # ------------------------------- + # given a list of the indices for the first element of each non-zero block. + # For example, for the indices + # [32, 80, 128, 256, 288] + # we would generate the increments + # [32, 48, 48, 128, 32] + # ^ + # index of first element + # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) + # that is smaller than the block size, so we need to do a bit of extra work + # to handle this case + B_idx = nnz[:, 2] * block + B_incs = B_idx.clone() + B_incs[1:] -= B_idx[:-1] + div = block // step + B_incs = B_incs.view(-1, 1).repeat(1, div) + B_incs[:, 1:] = step + B_incs[:, 0] -= (div - 1) * step + # first increment for each reduction is actually the offset + B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]] + B_incs = B_incs.view(-1) + # ------------------------------- + # sparse input pointer increments + # ------------------------------- + # same as above, except that the increments are in the sparse memory layout + if trans: + A_idx = torch.arange(num_blocks) else: - plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1 - pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks - while tl.atomic_cas(plock, 0, 1) == 1: - pass - count = tl.load(pcount) - if count == 0: - tl.store(pc, c, mask=checkc) - else: - d = tl.load(pc, mask=checkc) - tl.store(pc, d + c, mask=checkc) - tl.atomic_xchg(pcount, (count + 1) % maxid) - tl.atomic_xchg(plock, 0) + A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum) + A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + A_incs = A_idx * block * block + A_incs[1:] -= A_idx[:-1] * block * block + A_incs = A_incs.view(-1, 1).repeat(1, div) + if trans: + A_incs[:, 1:] = step + A_incs[:, 0] -= (div - 1) * step + else: + A_incs[:, 1:] = step * block + A_incs[:, 0] -= (div - 1) * step * block + A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] + A_incs = A_incs.view(-1) + # create header + width = col_id.size(0) + offsets = offsets*2*div + 4*width + segments = segments*div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + # create increments + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + return lut, None, width, None +# ----------------------------- +# Dense = Dense x Sparse (DDS) +# ----------------------------- +@triton.jit +def _dds_kernel( + A, B, C, + stride_za, stride_ha, stride_ma, stride_ka, + stride_zb, stride_hb, stride_bk, stride_bn, + stride_zc, stride_hc, stride_mc, stride_nc, + DS0, DS1, lut, **meta +): + TILE_M = meta['TILE_M'] + TILE_N = meta['TILE_N'] + TILE_K = meta['TILE_K'] + GROUP_SIZE_M = meta['GROUP_SIZE_M'] + #------------# + #- Prologue -# + #------------# + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_z = tl.program_id(2) + header = lut + pid_m * 4 + offset = tl.load(header + 0) + AS1 = tl.load(header + 1) + column = tl.load(header + 2) + off_h = tl.load(header + 3) + pinc = lut + offset + # initialize pointers to A (dense) + offs_am = pid_n*TILE_M + tl.arange(0, TILE_M) + start_ak = tl.load(pinc) + start_ak = tl.multiple_of(start_ak, 8) + offs_ak = start_ak + tl.arange(0, TILE_K) + ptrs_a = A + pid_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ka + # initialize pointers to B (sparse) + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) + offs_bn = tl.arange(0, TILE_N) + offs_bk = tl.arange(0, TILE_K) + ptrs_b = B + pid_z * stride_zb \ + + block_id * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + ## ---------------- ## + ## Inner Loop ## + ## ---------------- ## + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(AS1, 0, -TILE_K): + a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0) + b = tl.load(ptrs_b, mask = True) + acc += tl.dot(a, b) + pinc += 2 + inc_a = tl.load(pinc) + inc_b = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.multiple_of(inc_b, 8) + inc_a = inc_a * stride_ka + ptrs_a += inc_a + ptrs_b += inc_b + ## ---------------- ## + ## Epilogue ## + ## ---------------- ## + c = acc.to(C.dtype.element_ty) + # initialize pointers to C (dense) + offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M) + offs_cn = column * TILE_N + tl.arange(0, TILE_N) + ptrs_c = C + off_h * stride_hc \ + + pid_z * stride_zc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc + # write back + tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): + # shapes / dtypes + AS0 = a.size(0) + AS1 = a.size(1) + AS2 = a.size(3 if trans_a else 2) + BS2 = block * spdims[1 if trans_b else 2] + dtype = a.dtype + # output + CS0 = AS0 + CS1 = AS1 + CS2 = BS2 if trans_c else AS2 + CS3 = AS2 if trans_c else BS2 + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] + grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0] + # fmt: off + _dds_kernel[grid]( + a, b, c, + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), + AS2, BS2, lut, + TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3, + num_warps=4, GROUP_SIZE_M=8, + ) + return c ############## # MAIN API # ############## class _matmul(torch.autograd.Function): - sdd_cache = dict() - dsd_cache = dict() - dds_cache = dict() - locks = dict() - - # Given an array sizes representing reduction size for each - # column of a block-mode matrix multiplication, - # performs load-balancing to achieve more smaller reductions - # between `seg_size` elements - @staticmethod - def load_balance(sizes): - # segment size - # heuristics taken from OpenAI blocksparse code - # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95 - max_size = sizes.max() - min_size = sizes[sizes != 0].min() - #if max_size > min_size * 2.0: - # seg_max = max(triton.cdiv(max_size, 4), min_size*2) - #else: - # seg_max = max_size - seg_max = max_size - seg_min = max(triton.cdiv(seg_max, 4), 4) - # split reduction into segments - div = sizes // seg_max - rem = sizes % seg_max - packs = div + (sizes < seg_min).long() + (rem >= seg_min).long() - width = packs.sum() - segments = torch.empty(width, dtype=sizes.dtype) - column = torch.empty_like(segments) - lockid = torch.zeros_like(segments) - maxid = torch.zeros_like(segments) - nlocks = 0 - current = 0 - col_idx = 0 - for i in range(len(sizes)): - d, r = div[i], rem[i] - isempty = sizes[i] < seg_min - last = current + d + (r >= seg_min) + isempty - # column id - column[current:last] = col_idx - # lock id - if d > 1 or (d == 1 and r >= seg_min): - nlocks += 1 - lockid[current:last] = nlocks - maxid[current:last] = last - current - # segment size - segments[current:current + d] = seg_max - if r < seg_min and not isempty: - segments[current + d - 1] += r - if r >= seg_min or isempty: - segments[current + d] = r - current = last - col_idx += 1 - offsets = torch.zeros_like(segments) - offsets[1:] = torch.cumsum(segments[:-1], dim=0) - return segments, column, lockid, maxid, offsets - - @staticmethod - def get_locks(size, dev): - if dev not in _matmul.locks or \ - size > _matmul.locks[dev].size(0): - _matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev) - return _matmul.locks[dev] - - ########################## - # SPARSE = DENSE x DENSE # - ########################## - - @staticmethod - def make_sdd_lut(layout, block, device): - start_width = 128 // block - layout = layout.type(torch.int32) - superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width) - luts, widths, packs = [], [], [] - for size, nnz in superblocks: - nnz = nnz.reshape(-1, 4) - width = nnz.shape[0] // (size * size) - luts.append(torch.from_numpy(nnz).type(torch.int32).to(device)) - widths.append(width) - packs.append(size) - - # create locks - return luts, None, widths, packs - - @staticmethod - def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): - # (A * B)^T = (B^T * A^T) - if trans_c: - a, b = b, a - trans_a, trans_b = not trans_b, not trans_a - - # Shape check - a_dim = -2 if trans_a else -1 - b_dim = -1 if trans_b else -2 - a_inner, b_inner = a.shape[a_dim], b.shape[b_dim] - if a_inner != b_inner: - raise ValueError(f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size " - f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})") - if a_inner % 16 != 0: - raise ValueError('Reduction size for SDD must be a multiple of 16') - - batch_size = a.size(0) - a_outer = a.size(3 if trans_a else 2) - dtype = a.dtype - device = a.device - - # create kernel - total_width = sum([width * pack * pack for width, pack in zip(widths, packs)]) - c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device) - for lut, width, pack in zip(luts, widths, packs): - num_lock = 1 - meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, - 'SDD': True, 'DSD': False, 'DDS': False} - # create output - locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device) - # maximum grid size is 65535 - # so operation might be decomposed into multiple - # kernel calls - max_width = 49152 - for off_width in range(0, width, max_width): - grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size] - _kernel[grid]( - a, - b, - c, - a.stride(0), - a.stride(1), - a.stride(3 if trans_a else 2), - a.stride(2 if trans_a else 3), - b.stride(0), - b.stride(1), - b.stride(3 if trans_b else 2), - b.stride(2 if trans_b else 3), - c.stride(0), - c.stride(0), - c.stride(2), - c.stride(3), - a_outer, - a_outer, - a_inner, - off_width, - lut, - locks, - num_lock, - num_warps=4, - **meta - ) - # save for backward pass - return c - - ########################## - # DENSE = DENSE x SPARSE # - # DENSE = SPARSE x DENSE # - ########################## - - # Given a binary layout of 0s and 1s, - # Construct look-up table for efficient execution on GPUs - @staticmethod - def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx): - # load-balancing - _empty = torch.tensor([], dtype=torch.int64, device=layout.device) - segments = _empty.clone() - column = _empty.clone() - depth = _empty.clone() - lockid = _empty.clone() - maxid = _empty.clone() - offsets = _empty.clone() - current_offset = 0 - current_maxid = 0 - for z in range(layout.size(0)): - if trans: - sizes = torch.sum(layout[z, :, :], 1) - else: - sizes = torch.sum(layout[z, :, :], 0) - z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes) - z_depth = z * torch.ones_like(z_segments) - z_lockid[z_lockid > 0] += current_maxid - current_maxid = z_lockid.max() - # concatenate depth - segments = torch.cat((segments, z_segments)) - column = torch.cat((column, z_column)) - depth = torch.cat((depth, z_depth)) - maxid = torch.cat((maxid, z_maxid)) - offsets = torch.cat((offsets, current_offset + z_offsets)) - lockid = torch.cat((lockid, z_lockid)) - current_offset += layout[z, :, :].sum() - segments *= step - # pointer increments - if trans: - nnz = layout.nonzero(as_tuple=False) - else: - nnz = layout.transpose(1, 2).nonzero(as_tuple=False) - num_blocks = nnz.size(0) - offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) - idx = transform(nnz[:, 2] * block) - xincs = idx.clone() - xincs[1:] -= idx[:-1] - # divide block into multiple steps - div = block // step - xincs = xincs.view(-1, 1).repeat(1, div) - xincs[:, 1:] = step - xincs[:, 0] -= (div - 1) * step - # first increment for each reduction is actually the offset - xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]] - xincs = xincs.view(-1) - # block-mode input increments - if trans: - widx = torch.arange(num_blocks) - else: - widx = _empty.clone() - current_offset = 0 - for z in range(layout.size(0)): - layoutw = layout[z, :, :].clone() - msum = layoutw.sum() - layoutw[layoutw > 0] = 1 + torch.arange(msum) - widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1)) - current_offset += msum - widx = widx - wincs = widx * block * block - wincs[1:] -= widx[:-1] * block * block - wincs = wincs.view(-1, 1).repeat(1, div) - if trans: - wincs[:, 1:] = step - wincs[:, 0] -= (div - 1) * step - else: - wincs[:, 1:] = step * block - wincs[:, 0] -= (div - 1) * step * block - wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]] - wincs = wincs.view(-1) - # adjust offset and segment size - offsets *= 2 * div - segments *= div - # create header - width = column.size(0) - offsets += 6 * width - header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous() - incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous() - incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) - # create lut - lut = torch.cat((header, incs)) - lut = lut.type(torch.int32).to(device) - # create locks - num_locks = max(1, lockid.max()) - return lut, num_locks, width, None - - @staticmethod - def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): - # shapes / dtypes - AS0 = a.size(0) - AS1 = a.size(1) - AS2 = a.size(3 if trans_a else 2) - BS2 = block * spdims[1 if trans_b else 2] - dtype = a.dtype - # kernel - meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, - 'SDD': False, 'DSD': False, 'DDS': True} - # output - CS0 = AS0 - CS1 = AS1 - CS2 = BS2 if trans_c else AS2 - CS3 = AS2 if trans_c else BS2 - locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device) - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0] - _kernel[grid]( - a, - b, - c, - a.stride(0), - a.stride(1), - a.stride(3 if trans_a else 2), - a.stride(2 if trans_a else 3), - b.stride(0), - b.stride(1), - b.stride(3 if trans_b else 2), - b.stride(2 if trans_b else 3), - c.stride(0), - c.stride(1), - c.stride(3 if trans_c else 2), - c.stride(2 if trans_c else 3), - AS2, - BS2, - 0, - 0, - lut, - locks, - num_locks, - num_warps=4, - **meta - ) - return c - - @staticmethod - def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): - # shapes / dtypes - AS1 = block * spdims[2 if trans_a else 1] - BS0 = b.size(0) - BS1 = b.size(1) - BS3 = b.size(2 if trans_b else 3) - dtype = a.dtype - # kernel - meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, - 'SDD': False, 'DSD': True, 'DDS': False} - # output - CS0 = BS0 - CS1 = BS1 - CS2 = BS3 if trans_c else AS1 - CS3 = AS1 if trans_c else BS3 - locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0] - _kernel[grid]( - a, - b, - c, - a.stride(0), - a.stride(1), - a.stride(3 if trans_a else 2), - a.stride(2 if trans_a else 3), - b.stride(0), - b.stride(1), - b.stride(3 if trans_b else 2), - b.stride(2 if trans_b else 3), - c.stride(0), - c.stride(1), - c.stride(3 if trans_c else 2), - c.stride(2 if trans_c else 3), - BS3, - AS1, - 0, - 0, - lut, - locks, - num_locks, - num_warps=4, - **meta - ) - return c - - fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)} + fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} @staticmethod def forward( @@ -586,27 +495,27 @@ class matmul: return self.lut_cache[key] # C look-up table layout, block = self.layout, self.block - step = 16 + step = min(block, 32) if self.mode == 'sdd': - c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device) + c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device) elif self.mode == 'dsd': - c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) + c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device) elif self.mode == 'dds': - c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device) + c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device) # DA look-up table if self.mode == 'sdd': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device) + da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device) elif self.mode == 'dsd': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device) + da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device) elif self.mode == 'dds': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) + da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device) # DB look-up table if self.mode == 'sdd': - db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) + db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device) elif self.mode == 'dsd': - db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) + db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device) elif self.mode == 'dds': - db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device) + db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device) self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs) @@ -709,7 +618,3 @@ class matmul: b = add_extra_dims(b) return a, b - -def _dim_to_name(x): - # assert x in (-1, -2) - return "last" if x == -1 else "second to last" diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 412530722..10f806af2 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -1,18 +1,6 @@ import triton.language as tl import triton import torch -import os - - -def next_power_of_2(n): - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n def num_warps(n): @@ -23,11 +11,11 @@ def num_warps(n): return 16 -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])}) -@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])}) +@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])}) +@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])}) @triton.jit def _forward( - X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, + X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta ): TN = meta['TN'] @@ -46,10 +34,10 @@ def _forward( check = rbn < size rbmn = tl.where(check, rbn, size - 1) # block id and column id - blockid = tl.load(LUT + offset + rbmn * 4 + 0) + blockid = tl.load(LUT + offset + rbmn * 4 + 0) columnid = tl.load(LUT + offset + rbmn * 4 + 1) - rowid = tl.load(LUT + offset + rbmn * 4 + 2) - headid = tl.load(LUT + offset + rbmn * 4 + 3) + rowid = tl.load(LUT + offset + rbmn * 4 + 2) + headid = tl.load(LUT + offset + rbmn * 4 + 3) # pointers to X px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn x = tl.load(px, mask=check, other=-float('inf')) @@ -76,13 +64,16 @@ def _forward( if meta['ATTN_MASK_MUL']: attn_m = tl.where(attn_m == 0, -float('inf'), 0.) x = x + attn_m + # apply causal mask + is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm + x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.) # computation x = tl.softmax(x) tl.store(px, x, mask=check) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) -@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']}) +@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']}) @triton.jit def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): pidhm = tl.program_id(0) @@ -139,10 +130,13 @@ class _softmax(torch.autograd.Function): @staticmethod def forward( - ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, maxlut, bench, time + ctx, x, scale, rpe, + key_padding_mask, attn_mask, + kp_mask_mode, attn_mask_mode, + is_causal, + spdims, block, lut, maxlut ): apply_scale = False if scale == 1.0 else True - # handle None rpe if rpe is None: apply_rpe = False @@ -151,7 +145,6 @@ class _softmax(torch.autograd.Function): else: apply_rpe = True stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2) - # handle None key_padding_mask if key_padding_mask is None: apply_kp_mask = False @@ -160,7 +153,6 @@ class _softmax(torch.autograd.Function): else: apply_kp_mask = True stride_zkpm = key_padding_mask.stride(0) - # handle None attention_mask if attn_mask is None: apply_attn_mask = False @@ -169,22 +161,19 @@ class _softmax(torch.autograd.Function): else: apply_attn_mask = True stride_zattnm = attn_mask.stride(0) - # run kernel M = x.shape[0] - meta = { - 'BLOCK': block, - 'APPLY_SCALE': apply_scale, - 'APPLY_RPE': apply_rpe, - 'APPLY_KP_MASK': apply_kp_mask, - 'APPLY_ATTN_MASK': apply_attn_mask, - 'KP_MASK_MUL': kp_mask_mode == 'mul', - 'ATTN_MASK_MUL': attn_mask_mode == 'mul', - } - grid = lambda opt: [spdims[0] * spdims[1] * block, M] - _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ - stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta) - + grid = [spdims[0] * spdims[1] * block, M] + _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\ + stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, + BLOCK = block, + APPLY_SCALE = apply_scale, + APPLY_RPE = apply_rpe, + APPLY_KP_MASK = apply_kp_mask, + APPLY_ATTN_MASK = apply_attn_mask, + KP_MASK_MUL = (kp_mask_mode == 'mul'), + ATTN_MASK_MUL = (attn_mask_mode == 'mul'), + force_nc_cache = True) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut) @@ -213,25 +202,24 @@ class _softmax(torch.autograd.Function): class softmax: - apply_softmax = _softmax.apply - def make_lut(self, device): key = (device, ) if key not in self.lut_cache: self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) return self.lut_cache[key] - def __init__(self, layout, block, bench=False): + def __init__(self, layout, block): self.spdims = layout.shape self.layout = layout self.block = block - self.bench = bench self.lut_cache = dict() def __call__( - self, x, scale=1., rpe=None, key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add' + self, x, scale=1., rpe=None, + key_padding_mask=None, attn_mask=None, + key_padding_mask_mode='add', attn_mask_mode='add', + is_causal = False ): - time_y = [None] if rpe is not None and rpe.dtype != x.dtype: raise ValueError('relative position embedding must be %s' % x.dtype) if attn_mask is not None and attn_mask.dtype != x.dtype: @@ -239,8 +227,12 @@ class softmax: if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: raise ValueError('Key padding mask must be %s' % x.dtype) lut, maxlut = self.make_lut(x.device) - x = softmax.apply_softmax( - x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut, - maxlut, self.bench, time_y + x = _softmax.apply( + x, scale, rpe, + key_padding_mask, attn_mask, + key_padding_mask_mode, attn_mask_mode, + is_causal, + self.spdims, self.block, + lut, maxlut ) return x \ No newline at end of file