[OPS] Faster and cleaner block-sparse implementation (#311)

This commit is contained in:
Philippe Tillet
2021-09-27 18:25:16 -07:00
committed by GitHub
parent c3756d1c33
commit 2c287544cb
8 changed files with 518 additions and 582 deletions

View File

@@ -11,20 +11,20 @@ square_confs = [
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
line_vals = [16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\
for AT in [False] for BT in [False] \
for op_mode in ['sdd', 'dsd', 'dds'] for layout_mode in ['tril', 'dense']
for op_mode in ['dsd'] for layout_mode in ['dense']
]
@triton.testing.perf_report(square_confs)
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5):
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
@@ -85,4 +85,7 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
op = triton.ops.blocksparse.softmax(layout, block)
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
bench_matmul.run(print_data=True, show_plots=True)