[OPS] Faster and cleaner block-sparse implementation (#311)

This commit is contained in:
Philippe Tillet
2021-09-27 18:25:16 -07:00
committed by GitHub
parent c3756d1c33
commit 2c287544cb
8 changed files with 518 additions and 582 deletions

View File

@@ -2500,9 +2500,8 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
void generator::visit_basic_block(ir::basic_block * block) { void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = bbs_[block]; BasicBlock *parent = bbs_[block];
builder_->SetInsertPoint(parent); builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){ for(ir::instruction *i: block->get_inst_list())
visit_value(i); visit_value(i);
}
// Update ir bb -> llvm bb mapping // Update ir bb -> llvm bb mapping
bbs_[block] = builder_->GetInsertBlock(); bbs_[block] = builder_->GetInsertBlock();
} }

View File

@@ -117,7 +117,7 @@ void coalesce::run(ir::module &mod) {
queue.push_back(inst_op); queue.push_back(inst_op);
} }
} }
if(in_contig.empty() || out_contig==in_contig) if(in_contig.size() <= 1 || out_contig==in_contig)
continue; continue;
builder.set_insert_point_after(val_inst); builder.set_insert_point_after(val_inst);
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));

View File

@@ -188,7 +188,7 @@ void pipeline::run(ir::module &mod) {
builder.set_insert_point(block->get_first_non_phi()); builder.set_insert_point(block->get_first_non_phi());
std::map<ir::phi_node*, ir::value*> load_ivs; std::map<ir::phi_node*, ir::value*> load_ivs;
std::map<ir::phi_node*, ir::value*> next_load_ivs; std::map<ir::phi_node*, ir::value*> next_load_ivs;
for (ir::phi_node* iv : induction_vars) { for (auto& [iv, val] : prev_phi_vals) {
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2); ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
pn->add_incoming(prev_phi_vals[iv], header); pn->add_incoming(prev_phi_vals[iv], header);
load_ivs[iv] = pn; load_ivs[iv] = pn;
@@ -198,7 +198,8 @@ void pipeline::run(ir::module &mod) {
// pre-fetch next iteration // pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back()); builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block); // ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
ir::value* next_mask = builder.create_splat( ir::value* next_mask = builder.create_splat(
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes()); rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) { if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {

View File

@@ -11,20 +11,20 @@ square_confs = [
x_names = ['M', 'N', 'K'], x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block', line_arg = 'block',
line_vals = [16, 32, 64], line_vals = [16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64'], line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS', ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode, args = {'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\ )\
for AT in [False] for BT in [False] \ for AT in [False] for BT in [False] \
for op_mode in ['sdd', 'dsd', 'dds'] for layout_mode in ['tril', 'dense'] for op_mode in ['dsd'] for layout_mode in ['dense']
] ]
@triton.testing.perf_report(square_confs) @triton.testing.perf_report(square_confs)
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5): def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1 Z, H = 1, 1
make_layout = { make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
@@ -86,3 +86,6 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms) return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
bench_matmul.run(print_data=True, show_plots=True)

View File

@@ -7,7 +7,7 @@ import pytest
@pytest.mark.parametrize("TRANS_A", [False, True]) @pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True]) @pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64]) @pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) @pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
@@ -41,15 +41,20 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792]) @pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) @pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
def test_softmax(BLOCK, WIDTH, DTYPE): def test_softmax(BLOCK, WIDTH, DTYPE):
is_causal = True
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
Z, H, M, N = 2, 4, WIDTH, WIDTH Z, H, M, N = 1, 1, WIDTH, WIDTH
scale = 0.4 scale = 0.4
# create inputs # create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda") at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
# make sure each row has at least one non-zero element
torch.diagonal(layout)[:] = 1
torch.diagonal(at_mask)[:] = 1
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda") kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
kp_mask[:] = 0
kp_mask[kp_mask == 1.0] = float("-inf") kp_mask[kp_mask == 1.0] = float("-inf")
# triton result # triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK) op = triton.ops.blocksparse.softmax(layout, BLOCK)
@@ -61,16 +66,15 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
key_padding_mask_mode="add", key_padding_mask_mode="add",
attn_mask=at_mask.to(DTYPE), attn_mask=at_mask.to(DTYPE),
attn_mask_mode="mul", attn_mask_mode="mul",
is_causal=is_causal,
) )
# torch result # torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
if at_mask is not None: # broadcast at_mask to the same shape as rx
# broadcast at_mask to the same shape as rx if is_causal: at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx) M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf") rx[M == 0] = float("-inf")
if kp_mask is not None: # rx += kp_mask[:, None, None, :]
rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx * scale, -1)
ry = torch.softmax(rx * scale, -1) ry = torch.softmax(rx * scale, -1)
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
# compare # compare

View File

@@ -699,3 +699,35 @@ def ravel(x):
:type x: Block :type x: Block
""" """
return triton.language.reshape(x, [x.type.numel]) return triton.language.reshape(x, [x.type.numel])
@triton.jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
transformes indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
into
[[0, 2, 4 , 6 ],
[1, 3, 5 , 7 ],
[8, 10, 12, 14],
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i*size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j

View File

@@ -3,530 +3,439 @@ import triton.language as tl
import triton._C.libtriton as libtriton import triton._C.libtriton as libtriton
import torch import torch
# ********************************************************
# --------------------------------------------------------
# Sparse = Dense x Dense (SDD)
# This operation uses super-blocking to make sure that
# it's done efficiently when small blocks can be grouped
# together
# --------------------------------------------------------
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0,
})
@triton.jit @triton.jit
def _kernel( def _sdd_kernel(
A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, A, B, C,
stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut, **meta
): ):
TM = meta['TM'] TILE_M = meta['TILE_M']
TN = meta['TN'] TILE_N = meta['TILE_N']
TK = meta['TK'] TILE_K = meta['TILE_K']
TZ = meta['TZ'] BLOCK = meta['BLOCK']
BLOCK = meta['BLOCK']
#------------# #------------#
#- Prologue -# #- Prologue -#
#------------# #------------#
pid0 = tl.program_id(0) pid1 = tl.program_id(1) + grid_offset
pid1 = tl.program_id(1) blockidm = tl.arange(0, TILE_M) // BLOCK
pidz = tl.program_id(2) blockidn = tl.arange(0, TILE_N) // BLOCK
if meta['SDD']: offlutm = blockidm * (TILE_N // BLOCK) * 4
pid1 = pid1 + SDD_off_width offlutn = blockidn * 4
blockidm = tl.arange(0, TM) // BLOCK header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4
blockidn = tl.arange(0, TN) // BLOCK # batch offset
offlutm = blockidm * (TN // BLOCK) * 4 off_z = tl.program_id(2)
offlutn = blockidn * 4 # head offset
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 off_h = tl.load(header + 0)
z = tl.load(header + 0) # initialize pointers to A
i = tl.load(header + 1 + offlutm) start_am = tl.load(header + 1 + offlutm)
j = tl.load(header + 2 + offlutn) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
AS1 = SDD_K offs_ak = tl.arange(0, TILE_K)
lockid = tl.where(TZ > 1, 1, 0) a_ptrs = A + off_z * stride_za \
offka = pid0 * AS1 + off_h * stride_ha \
offkb = pid0 * AS1 + offs_am[:, None] * stride_ma \
offmc = 0 + offs_ak[None, :] * stride_ak
offnc = 0 # initialize pointers to B
offpa = 0 start_bn = tl.load(header + 2 + offlutn)
offpb = 0 offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
maxid = TZ offs_bk = tl.arange(0, TILE_K)
offhc = 0 b_ptrs = B + off_z * stride_zb \
offha = z + off_h * stride_hb \
offhb = z + offs_bn[None, :] * stride_nb \
ram = i * BLOCK + (tl.arange(0, TM) % BLOCK) + offs_bk[:, None] * stride_bk
rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
else:
header = lut + pid0 * 6
offset = tl.load(header + 0)
AS1 = tl.load(header + 1)
column = tl.load(header + 2)
depth = tl.load(header + 3)
lockid = tl.load(header + 4)
maxid = tl.load(header + 5)
pinc = lut + offset
offhc = depth
if meta['DSD']:
# output offset
offnc = pid1 * TN
offmc = column * TM
offpc = 0
# dense input offset
offnb = pid1 * TN
offkb = tl.load(pinc)
offkb = tl.multiple_of(offkb, 8) # compiler hint
offpb = 0
# sparse input offset
offma = 0
offka = 0
offpa = tl.load(pinc + 1)
offpa = tl.multiple_of(offpa, 8) # compiler hint
offpa = offpa * BLOCK * BLOCK
offha = 0
offhb = depth
else:
# output offset
offmc = pid1 * TM
offnc = column * TN
offpc = 0
# dense input offset
offma = pid1 * TM
offka = tl.load(pinc)
offka = tl.multiple_of(offka, 8) # compiler hint
offpa = 0
# sparse input offset
offnb = 0
offkb = 0
offpb = tl.load(pinc + 1)
offpb = tl.multiple_of(offpb, 8) # compiler hint
offpb = offpb * BLOCK * BLOCK
offha = depth
offhb = 0
ram = offma + tl.arange(0, TM)
rbn = offnb + tl.arange(0, TN)
# initialize a, b pointers
rka = offka + tl.arange(0, TK)
rkb = offkb + tl.arange(0, TK)
pa = A + pidz * TZ * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * TZ * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
if meta['DDS']:
checkam = ram[:, None] < DS0
else:
checkam = AS1 > 0
if meta['DSD']:
checkbn = rbn[None, :] < DS0
else:
checkbn = AS1 > 0
a = tl.load(pa, mask=checkam, other=0.)
b = tl.load(pb, mask=checkbn, other=0.)
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
acc = tl.zeros((TM, TN), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(AS1, 0, -TK*TZ): for k in range(K, 0, -TILE_K):
if meta['EVEN_K']:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b) acc += tl.dot(a, b)
if meta['SDD']: a_ptrs += TILE_K * stride_ak
inc_a = TK * TZ * stride_ka b_ptrs += TILE_K * stride_bk
inc_b = TK * TZ * stride_kb
else:
pinc += 2
if meta['DSD']:
inc_b = tl.load(pinc)
inc_a = tl.load(pinc + 1)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = inc_b * stride_kb
if meta['DDS']:
inc_a = tl.load(pinc)
inc_b = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
pa += inc_a
pb += inc_b
# pre-fetch
checkak = k > TK
checkbk = k > TK
checka = checkam & checkak
checkb = checkbn & checkbk
a = tl.load(pa, mask=checka)
b = tl.load(pb, mask=checkb)
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
## ---------------- ##
## Epilogue ##
## ---------------- ##
blockidm = tl.arange(0, TILE_M) // BLOCK
blockidn = tl.arange(0, TILE_N) // BLOCK
offlutm = blockidm * (TILE_N // BLOCK) * 4
offlutn = blockidn * 4
off_block_id = 3 + offlutm[:, None] + offlutn[None, :]
block_id = tl.load(header + off_block_id)
# initialize pointers to C
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True)
if meta['SDD']: def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
checkc = True # (A * B)^T = B^T * A^T
rr_blockidm = tl.arange(0, TM) // BLOCK if trans_c:
rr_blockidn = tl.arange(0, TN) // BLOCK a, b = b, a
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 trans_a, trans_b = not trans_b, not trans_a
rr_offlutn = rr_blockidn * 4 # shape constraints
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] a_dim = -2 if trans_a else -1
bkid = tl.load(header + off_bkid) b_dim = -1 if trans_b else -2
offpc = bkid * BLOCK * BLOCK Ka, Kb = a.shape[a_dim], b.shape[b_dim]
rcm = tl.arange(0, TM) % BLOCK if Ka != Kb:
rcn = tl.arange(0, TN) % BLOCK raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
if Ka % 16 != 0:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# allocate output
n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device)
# each iteration of the loop below
# computes the value for one group of super-blocks
# (e.g., all 4x4 super-blocks)
for lut, width, pack in zip(luts, widths, packs):
# maximum grid size in Triton/CUDA is 64k but we may have more
# super-blocks than that.
max_grid = 65535
for off_grid in range(0, width, max_grid):
grid = [1, min(max_grid, width - off_grid), c.shape[0]]
# fmt: off
pgm = _sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, off_grid, lut,
TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3,
num_warps=4,
)
# print(pgm.asm['ptx'])
# exit()
return c
def sdd_lut(layout, block, device):
start_width = 128 // block
layout = layout.type(torch.int32)
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
nnz = nnz.reshape(-1, 4)
width = nnz.shape[0] // (size * size)
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
widths.append(width)
packs.append(size)
return luts, None, widths, packs
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel.
# -----------------------------
@triton.jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut, **meta
):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
GROUP_SIZE_M = meta['GROUP_SIZE_M']
#------------#
#- Prologue -#
#------------#
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
pidz = tl.program_id(2)
header = lut + pid_m * 4
offset = tl.load(header + 0)
K = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \
+ block_id * stride_ha \
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B (dense)
offs_bn = pid_n*TILE_N + tl.arange(0, TILE_N)
start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
## ---------------- ##
## Inner Loop ##
## ---------------- ##
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K):
a = tl.load(pa, mask=True)
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
acc += tl.dot(a, b)
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
pa += inc_a
pb += inc_b*stride_bk
c = acc.to(C.dtype.element_ty)
# initialize pointers to C
offs_cm = column*TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0)
BS1 = b.size(1)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# allocate output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
# meta-parameter heuristics
TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block]
# compute output
grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0]
# fmt: off
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
num_warps=4, GROUP_SIZE_M=8,
)
# exit()
return c
def dsd_lut(layout, block, step, trans, device):
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True)
sizes = sizes.flatten()
segments = sizes*step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else: else:
rcm = offmc + tl.arange(0, TM) nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
rcn = offnc + tl.arange(0, TN) num_blocks = nnz.size(0)
if meta['DSD']: offsets = torch.zeros_like(sizes)
checkc = rcn[None, :] < DS0 offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
if meta['DDS']: offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
checkc = rcm[:, None] < DS0 # -------------------------------
# dense input pointer increments
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc # -------------------------------
# write-back directly # given a list of the indices for the first element of each non-zero block.
if lockid == 0: # For example, for the indices
tl.store(pc, c, mask=checkc) # [32, 80, 128, 256, 288]
# accumulate partial results using spin-locks # we would generate the increments
# [32, 48, 48, 128, 32]
# ^
# index of first element
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
# that is smaller than the block size, so we need to do a bit of extra work
# to handle this case
B_idx = nnz[:, 2] * block
B_incs = B_idx.clone()
B_incs[1:] -= B_idx[:-1]
div = block // step
B_incs = B_incs.view(-1, 1).repeat(1, div)
B_incs[:, 1:] = step
B_incs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
B_incs = B_incs.view(-1)
# -------------------------------
# sparse input pointer increments
# -------------------------------
# same as above, except that the increments are in the sparse memory layout
if trans:
A_idx = torch.arange(num_blocks)
else: else:
plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1 A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks current_offset = 0
while tl.atomic_cas(plock, 0, 1) == 1: for z in range(layout.size(0)):
pass layoutw = layout[z, :, :].clone()
count = tl.load(pcount) msum = layoutw.sum()
if count == 0: layoutw[layoutw > 0] = 1 + torch.arange(msum)
tl.store(pc, c, mask=checkc) A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
else: current_offset += msum
d = tl.load(pc, mask=checkc) A_incs = A_idx * block * block
tl.store(pc, d + c, mask=checkc) A_incs[1:] -= A_idx[:-1] * block * block
tl.atomic_xchg(pcount, (count + 1) % maxid) A_incs = A_incs.view(-1, 1).repeat(1, div)
tl.atomic_xchg(plock, 0) if trans:
A_incs[:, 1:] = step
A_incs[:, 0] -= (div - 1) * step
else:
A_incs[:, 1:] = step * block
A_incs[:, 0] -= (div - 1) * step * block
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
A_incs = A_incs.view(-1)
# create header
width = col_id.size(0)
offsets = offsets*2*div + 4*width
segments = segments*div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
return lut, None, width, None
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
@triton.jit
def _dds_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut, **meta
):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
GROUP_SIZE_M = meta['GROUP_SIZE_M']
#------------#
#- Prologue -#
#------------#
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
pid_z = tl.program_id(2)
header = lut + pid_m * 4
offset = tl.load(header + 0)
AS1 = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (dense)
offs_am = pid_n*TILE_M + tl.arange(0, TILE_M)
start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8)
offs_ak = start_ak + tl.arange(0, TILE_K)
ptrs_a = A + pid_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ka
# initialize pointers to B (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8)
offs_bn = tl.arange(0, TILE_N)
offs_bk = tl.arange(0, TILE_K)
ptrs_b = B + pid_z * stride_zb \
+ block_id * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
## ---------------- ##
## Inner Loop ##
## ---------------- ##
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(AS1, 0, -TILE_K):
a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0)
b = tl.load(ptrs_b, mask = True)
acc += tl.dot(a, b)
pinc += 2
inc_a = tl.load(pinc)
inc_b = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
ptrs_a += inc_a
ptrs_b += inc_b
## ---------------- ##
## Epilogue ##
## ---------------- ##
c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense)
offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M)
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
ptrs_c = C + off_h * stride_hc \
+ pid_z * stride_zc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
# write back
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0]
# fmt: off
_dds_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut,
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
num_warps=4, GROUP_SIZE_M=8,
)
return c
############## ##############
# MAIN API # # MAIN API #
############## ##############
class _matmul(torch.autograd.Function): class _matmul(torch.autograd.Function):
sdd_cache = dict() fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
dsd_cache = dict()
dds_cache = dict()
locks = dict()
# Given an array sizes representing reduction size for each
# column of a block-mode matrix multiplication,
# performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current + d] = seg_max
if r < seg_min and not isempty:
segments[current + d - 1] += r
if r >= seg_min or isempty:
segments[current + d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod
def get_locks(size, dev):
if dev not in _matmul.locks or \
size > _matmul.locks[dev].size(0):
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
return _matmul.locks[dev]
##########################
# SPARSE = DENSE x DENSE #
##########################
@staticmethod
def make_sdd_lut(layout, block, device):
start_width = 128 // block
layout = layout.type(torch.int32)
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
nnz = nnz.reshape(-1, 4)
width = nnz.shape[0] // (size * size)
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
# (A * B)^T = (B^T * A^T)
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
# Shape check
a_dim = -2 if trans_a else -1
b_dim = -1 if trans_b else -2
a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
if a_inner != b_inner:
raise ValueError(f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size "
f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})")
if a_inner % 16 != 0:
raise ValueError('Reduction size for SDD must be a multiple of 16')
batch_size = a.size(0)
a_outer = a.size(3 if trans_a else 2)
dtype = a.dtype
device = a.device
# create kernel
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1,
'SDD': True, 'DSD': False, 'DDS': False}
# create output
locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]
_kernel[grid](
a,
b,
c,
a.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(0),
c.stride(2),
c.stride(3),
a_outer,
a_outer,
a_inner,
off_width,
lut,
locks,
num_lock,
num_warps=4,
**meta
)
# save for backward pass
return c
##########################
# DENSE = DENSE x SPARSE #
# DENSE = SPARSE x DENSE #
##########################
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
column = _empty.clone()
depth = _empty.clone()
lockid = _empty.clone()
maxid = _empty.clone()
offsets = _empty.clone()
current_offset = 0
current_maxid = 0
for z in range(layout.size(0)):
if trans:
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
idx = transform(nnz[:, 2] * block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx * block * block
wincs[1:] -= widx[:-1] * block * block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div - 1) * step
else:
wincs[:, 1:] = step * block
wincs[:, 0] -= (div - 1) * step * block
wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2 * div
segments *= div
# create header
width = column.size(0)
offsets += 6 * width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
'SDD': False, 'DSD': False, 'DDS': True}
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(1),
c.stride(3 if trans_c else 2),
c.stride(2 if trans_c else 3),
AS2,
BS2,
0,
0,
lut,
locks,
num_locks,
num_warps=4,
**meta
)
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0)
BS1 = b.size(1)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
'SDD': False, 'DSD': True, 'DDS': False}
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
_kernel[grid](
a,
b,
c,
a.stride(0),
a.stride(1),
a.stride(3 if trans_a else 2),
a.stride(2 if trans_a else 3),
b.stride(0),
b.stride(1),
b.stride(3 if trans_b else 2),
b.stride(2 if trans_b else 3),
c.stride(0),
c.stride(1),
c.stride(3 if trans_c else 2),
c.stride(2 if trans_c else 3),
BS3,
AS1,
0,
0,
lut,
locks,
num_locks,
num_warps=4,
**meta
)
return c
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
@staticmethod @staticmethod
def forward( def forward(
@@ -586,27 +495,27 @@ class matmul:
return self.lut_cache[key] return self.lut_cache[key]
# C look-up table # C look-up table
layout, block = self.layout, self.block layout, block = self.layout, self.block
step = 16 step = min(block, 32)
if self.mode == 'sdd': if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device) c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds': elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device) c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device)
# DA look-up table # DA look-up table
if self.mode == 'sdd': if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device) da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device) da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device)
elif self.mode == 'dds': elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device)
# DB look-up table # DB look-up table
if self.mode == 'sdd': if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds': elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device) db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs, self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs, da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs) db_lut, db_num_locks, db_width, db_packs)
@@ -709,7 +618,3 @@ class matmul:
b = add_extra_dims(b) b = add_extra_dims(b)
return a, b return a, b
def _dim_to_name(x):
# assert x in (-1, -2)
return "last" if x == -1 else "second to last"

View File

@@ -1,18 +1,6 @@
import triton.language as tl import triton.language as tl
import triton import triton
import torch import torch
import os
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
def num_warps(n): def num_warps(n):
@@ -23,11 +11,11 @@ def num_warps(n):
return 16 return 16
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])}) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])}) @triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])})
@triton.jit @triton.jit
def _forward( def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
**meta **meta
): ):
TN = meta['TN'] TN = meta['TN']
@@ -46,10 +34,10 @@ def _forward(
check = rbn < size check = rbn < size
rbmn = tl.where(check, rbn, size - 1) rbmn = tl.where(check, rbn, size - 1)
# block id and column id # block id and column id
blockid = tl.load(LUT + offset + rbmn * 4 + 0) blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = tl.load(LUT + offset + rbmn * 4 + 1) columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = tl.load(LUT + offset + rbmn * 4 + 2) rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3) headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X # pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = tl.load(px, mask=check, other=-float('inf')) x = tl.load(px, mask=check, other=-float('inf'))
@@ -76,13 +64,16 @@ def _forward(
if meta['ATTN_MASK_MUL']: if meta['ATTN_MASK_MUL']:
attn_m = tl.where(attn_m == 0, -float('inf'), 0.) attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m x = x + attn_m
# apply causal mask
is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
# computation # computation
x = tl.softmax(x) x = tl.softmax(x)
tl.store(px, x, mask=check) tl.store(px, x, mask=check)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']}) @triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
@triton.jit @triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
pidhm = tl.program_id(0) pidhm = tl.program_id(0)
@@ -139,10 +130,13 @@ class _softmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, maxlut, bench, time ctx, x, scale, rpe,
key_padding_mask, attn_mask,
kp_mask_mode, attn_mask_mode,
is_causal,
spdims, block, lut, maxlut
): ):
apply_scale = False if scale == 1.0 else True apply_scale = False if scale == 1.0 else True
# handle None rpe # handle None rpe
if rpe is None: if rpe is None:
apply_rpe = False apply_rpe = False
@@ -151,7 +145,6 @@ class _softmax(torch.autograd.Function):
else: else:
apply_rpe = True apply_rpe = True
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2) stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
# handle None key_padding_mask # handle None key_padding_mask
if key_padding_mask is None: if key_padding_mask is None:
apply_kp_mask = False apply_kp_mask = False
@@ -160,7 +153,6 @@ class _softmax(torch.autograd.Function):
else: else:
apply_kp_mask = True apply_kp_mask = True
stride_zkpm = key_padding_mask.stride(0) stride_zkpm = key_padding_mask.stride(0)
# handle None attention_mask # handle None attention_mask
if attn_mask is None: if attn_mask is None:
apply_attn_mask = False apply_attn_mask = False
@@ -169,22 +161,19 @@ class _softmax(torch.autograd.Function):
else: else:
apply_attn_mask = True apply_attn_mask = True
stride_zattnm = attn_mask.stride(0) stride_zattnm = attn_mask.stride(0)
# run kernel # run kernel
M = x.shape[0] M = x.shape[0]
meta = { grid = [spdims[0] * spdims[1] * block, M]
'BLOCK': block, _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\
'APPLY_SCALE': apply_scale, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
'APPLY_RPE': apply_rpe, BLOCK = block,
'APPLY_KP_MASK': apply_kp_mask, APPLY_SCALE = apply_scale,
'APPLY_ATTN_MASK': apply_attn_mask, APPLY_RPE = apply_rpe,
'KP_MASK_MUL': kp_mask_mode == 'mul', APPLY_KP_MASK = apply_kp_mask,
'ATTN_MASK_MUL': attn_mask_mode == 'mul', APPLY_ATTN_MASK = apply_attn_mask,
} KP_MASK_MUL = (kp_mask_mode == 'mul'),
grid = lambda opt: [spdims[0] * spdims[1] * block, M] ATTN_MASK_MUL = (attn_mask_mode == 'mul'),
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ force_nc_cache = True)
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta)
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
ctx.save_for_backward(x, lut) ctx.save_for_backward(x, lut)
@@ -213,25 +202,24 @@ class _softmax(torch.autograd.Function):
class softmax: class softmax:
apply_softmax = _softmax.apply
def make_lut(self, device): def make_lut(self, device):
key = (device, ) key = (device, )
if key not in self.lut_cache: if key not in self.lut_cache:
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
return self.lut_cache[key] return self.lut_cache[key]
def __init__(self, layout, block, bench=False): def __init__(self, layout, block):
self.spdims = layout.shape self.spdims = layout.shape
self.layout = layout self.layout = layout
self.block = block self.block = block
self.bench = bench
self.lut_cache = dict() self.lut_cache = dict()
def __call__( def __call__(
self, x, scale=1., rpe=None, key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add' self, x, scale=1., rpe=None,
key_padding_mask=None, attn_mask=None,
key_padding_mask_mode='add', attn_mask_mode='add',
is_causal = False
): ):
time_y = [None]
if rpe is not None and rpe.dtype != x.dtype: if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype) raise ValueError('relative position embedding must be %s' % x.dtype)
if attn_mask is not None and attn_mask.dtype != x.dtype: if attn_mask is not None and attn_mask.dtype != x.dtype:
@@ -239,8 +227,12 @@ class softmax:
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype) raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device) lut, maxlut = self.make_lut(x.device)
x = softmax.apply_softmax( x = _softmax.apply(
x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut, x, scale, rpe,
maxlut, self.bench, time_y key_padding_mask, attn_mask,
key_padding_mask_mode, attn_mask_mode,
is_causal,
self.spdims, self.block,
lut, maxlut
) )
return x return x