[PYTHON] Added automated benchmark script (#63)
This adds a bench functionality to the setup.py that can be used to run the benchmark suite and generates a bunch of csv files (and optionally plots) python setup.py bench python setup.py bench --with-plots python setup.py bench --filter=cross_entropy
This commit is contained in:
committed by
Philippe Tillet
parent
66c94f21d7
commit
5e3c7f5a60
87
python/bench/bench_blocksparse.py
Normal file
87
python/bench/bench_blocksparse.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# -------------------------------
|
||||
# Matrix Multiplication
|
||||
# -------------------------------
|
||||
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'block',
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
for AT in [False] for BT in [False] \
|
||||
for op_mode in ['sdd', 'dsd', 'dds'] for layout_mode in ['tril', 'dense']
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
# create layout
|
||||
shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode]
|
||||
layout = make_layout(H, shape[0] // block, shape[1] // block)
|
||||
# creat inputs
|
||||
a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda')
|
||||
# create op
|
||||
if provider == 'triton':
|
||||
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT)
|
||||
# inputs
|
||||
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
|
||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||
ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
|
||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||
}[op_mode]*1e-12
|
||||
triton_tflops = num_flops / ms * 1e3
|
||||
return triton_tflops
|
||||
|
||||
# -------------------------------
|
||||
# Softmax
|
||||
# -------------------------------
|
||||
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'block',
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{layout_mode}-square',
|
||||
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
for layout_mode in ['dense', 'tril']
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
layout = make_layout(H, M // block, N // block)
|
||||
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
|
||||
if provider == 'triton':
|
||||
a = triton.testing.sparsify_tensor(a, layout, block)
|
||||
op = triton.ops.blocksparse.softmax(layout, block)
|
||||
ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||
gbps = (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
return gbps
|
37
python/bench/bench_cross_entropy.py
Normal file
37
python/bench/bench_cross_entropy.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['N'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
||||
y_name = 'provider',
|
||||
y_vals = ['triton', 'torch'],
|
||||
y_lines = ['Triton', 'Torch'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{mode}-2048',
|
||||
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)\
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(confs)
|
||||
def bench_op(M, N, dtype, mode, provider):
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
||||
# forward pass
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
|
||||
'triton': triton.ops.cross_entropy}[provider]
|
||||
if mode == 'forward':
|
||||
ms = triton.testing.do_bench(lambda: op(x, idx))
|
||||
if mode == 'backward':
|
||||
y = op(x, idx)
|
||||
dy = torch.randn_like(y)
|
||||
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True))
|
||||
return num_gb / ms * 1e3
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run('tmp', False)
|
59
python/bench/bench_matmul.py
Normal file
59
python/bench/bench_matmul.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import triton
|
||||
import torch
|
||||
|
||||
# square benchmarks
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'provider',
|
||||
y_vals = ['torch', 'triton', 'cutlass'],
|
||||
y_lines = ['Torch', 'Triton', 'CUTLASS'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
|
||||
)\
|
||||
for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
|
||||
import os
|
||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
if provider == 'torch':
|
||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
return torch_tflops
|
||||
if provider == 'triton':
|
||||
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
return triton_tflops
|
||||
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = 'column' if AT else 'row'
|
||||
layout_b = 'column' if BT else 'row'
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
||||
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
|
||||
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
||||
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run()
|
Reference in New Issue
Block a user