[PYTHON] CUTLASS wrapper for fair benchmarks (#75)

Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
Philippe Tillet
2021-03-09 16:32:44 -05:00
committed by Philippe Tillet
parent d6f18742b1
commit eacbb73968
6 changed files with 257 additions and 41 deletions

View File

@@ -34,10 +34,19 @@ if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
# PyBind11 wrapper source file
file(GLOB_RECURSE TORCH_SRC torch/*.cc)
# Build CUTLASS python wrapper if requested
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
set(TORCH_SRC ${TORCH_SRC} cutlass.cc)
add_definitions(-DWITH_CUTLASS_BINDINGS)
set(CUTLASS_LIBRARIES "cutlass")
endif()
message(STATUS ${CUTLASS_INCLUDE_PATH})
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
include_directories("." ${PYTHON_INCLUDE_DIRS})
link_directories(${PYTHON_LINK_DIRS})
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ${CUTLASS_OPT}")
include_directories("." ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
endif()
@@ -47,5 +56,5 @@ add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS})
if(BUILD_PYTHON_MODULE)
target_link_libraries(triton ${TORCH_LIBRARIES})
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES})
endif()