[PYTHON] CUTLASS wrapper for fair benchmarks (#75)
Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
committed by
Philippe Tillet
parent
d6f18742b1
commit
eacbb73968
@@ -34,10 +34,19 @@ if(BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# PyBind11 wrapper source file
|
||||
file(GLOB_RECURSE TORCH_SRC torch/*.cc)
|
||||
# Build CUTLASS python wrapper if requested
|
||||
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||
set(TORCH_SRC ${TORCH_SRC} cutlass.cc)
|
||||
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||
set(CUTLASS_LIBRARIES "cutlass")
|
||||
endif()
|
||||
message(STATUS ${CUTLASS_INCLUDE_PATH})
|
||||
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
|
||||
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
|
||||
include_directories("." ${PYTHON_INCLUDE_DIRS})
|
||||
link_directories(${PYTHON_LINK_DIRS})
|
||||
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ${CUTLASS_OPT}")
|
||||
include_directories("." ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -47,5 +56,5 @@ add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS})
|
||||
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
target_link_libraries(triton ${TORCH_LIBRARIES})
|
||||
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES})
|
||||
endif()
|
||||
|
Reference in New Issue
Block a user