[PYTHON] CUTLASS wrapper for fair benchmarks (#75)
Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
committed by
Philippe Tillet
parent
d6f18742b1
commit
eacbb73968
@@ -3,10 +3,14 @@
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_torch_utils(pybind11::module &m);
|
||||
void init_triton(pybind11::module &m);
|
||||
void init_cutlass(pybind11::module &m);
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
init_triton(m);
|
||||
init_torch_utils(m);
|
||||
init_superblocking(m);
|
||||
#ifdef WITH_CUTLASS_BINDINGS
|
||||
init_cutlass(m);
|
||||
#endif
|
||||
}
|
||||
|
Reference in New Issue
Block a user