[PYTHON] CUTLASS wrapper for fair benchmarks (#75)
Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
committed by
Philippe Tillet
parent
d6f18742b1
commit
eacbb73968
@@ -48,7 +48,7 @@ transformer_confs = [
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(transformer_confs)
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
@@ -62,37 +62,11 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = "column" if AT else "row"
|
||||
layout_b = "column" if BT else "row"
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [
|
||||
os.environ["CUTLASS_PROFILER"],
|
||||
f"--m={M}",
|
||||
f"--n={N}",
|
||||
f"--k={K}",
|
||||
f"--A=f16:{layout_a}",
|
||||
f"--B=f16:{layout_b}",
|
||||
"--C=f16:column",
|
||||
"--accum=f32",
|
||||
"--operation=gemm",
|
||||
"--verification-enabled=false",
|
||||
f"--warmup-iterations={warmup}",
|
||||
f"--profiling-iterations={rep}",
|
||||
f"--output={fname}",
|
||||
"--dist=uniform,min:0,max:1,scale:-1",
|
||||
"--verbose=false",
|
||||
]
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
||||
tflops = max(df_c["GFLOPs"]) / 1e3
|
||||
return tflops
|
||||
if provider == "cutlass":
|
||||
cutlass_matmul = triton.testing.cutlass_matmul
|
||||
try:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
except:
|
||||
return None
|
||||
return None
|
||||
|
Reference in New Issue
Block a user