[PYTHON] CUTLASS wrapper for fair benchmarks (#75)

Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
Philippe Tillet
2021-03-09 16:32:44 -05:00
committed by Philippe Tillet
parent d6f18742b1
commit eacbb73968
6 changed files with 257 additions and 41 deletions

View File

@@ -48,7 +48,7 @@ transformer_confs = [
]
@triton.testing.perf_report(transformer_confs)
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
@@ -62,37 +62,11 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess
import tempfile
import pandas as pd
# run program specified by CUTLASS_PROFILER env variable
layout_a = "column" if AT else "row"
layout_b = "column" if BT else "row"
# create temporary file name
fd, fname = tempfile.mkstemp()
# run program and gets its output
cmd = [
os.environ["CUTLASS_PROFILER"],
f"--m={M}",
f"--n={N}",
f"--k={K}",
f"--A=f16:{layout_a}",
f"--B=f16:{layout_b}",
"--C=f16:column",
"--accum=f32",
"--operation=gemm",
"--verification-enabled=false",
f"--warmup-iterations={warmup}",
f"--profiling-iterations={rep}",
f"--output={fname}",
"--dist=uniform,min:0,max:1,scale:-1",
"--verbose=false",
]
# run cmd
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f"{fname}.gemm.csv")
tflops = max(df_c["GFLOPs"]) / 1e3
return tflops
if provider == "cutlass":
cutlass_matmul = triton.testing.cutlass_matmul
try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
except:
return None
return None