From eacbb7396897ab338b0d0081dad2e8efb8c491a6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 9 Mar 2021 16:32:44 -0500 Subject: [PATCH] [PYTHON] CUTLASS wrapper for fair benchmarks (#75) Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers. --- CMakeLists.txt | 17 ++- python/bench/bench_matmul.py | 42 ++----- python/setup.py | 15 ++- python/src/cutlass.cc | 206 +++++++++++++++++++++++++++++++++++ python/src/main.cc | 4 + python/triton/testing.py | 14 +++ 6 files changed, 257 insertions(+), 41 deletions(-) create mode 100644 python/src/cutlass.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index a65f3fc41..98a20d6bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,10 +34,19 @@ if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") # PyBind11 wrapper source file file(GLOB_RECURSE TORCH_SRC torch/*.cc) + # Build CUTLASS python wrapper if requested + set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}") + set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}") + if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL "")) + set(TORCH_SRC ${TORCH_SRC} cutlass.cc) + add_definitions(-DWITH_CUTLASS_BINDINGS) + set(CUTLASS_LIBRARIES "cutlass") + endif() + message(STATUS ${CUTLASS_INCLUDE_PATH}) set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC}) - set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") - include_directories("." ${PYTHON_INCLUDE_DIRS}) - link_directories(${PYTHON_LINK_DIRS}) + set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ${CUTLASS_OPT}") + include_directories("." ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR}) + link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR}) endif() @@ -47,5 +56,5 @@ add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS}) if(BUILD_PYTHON_MODULE) - target_link_libraries(triton ${TORCH_LIBRARIES}) + target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES}) endif() diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index d83c17c06..3648657c2 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -48,7 +48,7 @@ transformer_confs = [ ] -@triton.testing.perf_report(transformer_confs) +@triton.testing.perf_report(square_confs) def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50): a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) @@ -62,37 +62,11 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50): if provider == "triton": ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) return tflops(ms), tflops(max_ms), tflops(min_ms) - if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ: - import subprocess - import tempfile - import pandas as pd - # run program specified by CUTLASS_PROFILER env variable - layout_a = "column" if AT else "row" - layout_b = "column" if BT else "row" - # create temporary file name - fd, fname = tempfile.mkstemp() - # run program and gets its output - cmd = [ - os.environ["CUTLASS_PROFILER"], - f"--m={M}", - f"--n={N}", - f"--k={K}", - f"--A=f16:{layout_a}", - f"--B=f16:{layout_b}", - "--C=f16:column", - "--accum=f32", - "--operation=gemm", - "--verification-enabled=false", - f"--warmup-iterations={warmup}", - f"--profiling-iterations={rep}", - f"--output={fname}", - "--dist=uniform,min:0,max:1,scale:-1", - "--verbose=false", - ] - # run cmd - subprocess.run(cmd, stdout=subprocess.PIPE) - # read CSV output - df_c = pd.read_csv(f"{fname}.gemm.csv") - tflops = max(df_c["GFLOPs"]) / 1e3 - return tflops + if provider == "cutlass": + cutlass_matmul = triton.testing.cutlass_matmul + try: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) + return tflops(ms), tflops(max_ms), tflops(min_ms) + except: + return None return None diff --git a/python/setup.py b/python/setup.py index 4d4bfaee3..17cb1bcb4 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,6 +14,7 @@ from setuptools.command.test import test as TestCommand import distutils.spawn import torch + def find_llvm(): versions = ["-10", "-10.0", ""] supported = ["llvm-config{v}".format(v=v) for v in versions] @@ -28,19 +29,23 @@ def find_llvm(): version = os.popen("{config} --version".format(config=config)).read() raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions) + class CMakeExtension(Extension): def __init__(self, name, path, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) self.path = path + class CMakeBuild(build_ext): def run(self): try: out = subprocess.check_output(["cmake", "--version"]) except OSError: - raise RuntimeError("CMake must be installed to build the following extensions: " + - ", ".join(e.name for e in self.extensions)) + raise RuntimeError( + "CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions) + ) if platform.system() == "Windows": cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1)) @@ -92,6 +97,7 @@ class CMakeBuild(build_ext): subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp) + setup( name="triton", version="1.0.0", @@ -101,7 +107,10 @@ setup( long_description="", packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"], install_requires=["numpy", "torch"], - package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, + package_data={ + "triton/ops": ["*.c"], + "triton/ops/blocksparse": ["*.c"] + }, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild}, diff --git a/python/src/cutlass.cc b/python/src/cutlass.cc new file mode 100644 index 000000000..e680d83e2 --- /dev/null +++ b/python/src/cutlass.cc @@ -0,0 +1,206 @@ +#include "cutlass/library/handle.h" +#include "cutlass/library/library.h" +#include "cutlass/library/operation_table.h" +#include "cutlass/library/singleton.h" +#include "pybind11/pybind11.h" +#include "triton/tools/bench.hpp" +#include +#include + +using namespace cutlass; +using namespace cutlass::library; + +std::map, const Operation *> op_cache_; + +static int const kHostWorkspaceSize = (4 << 10); +static int const kDeviceWorkspaceSize = (4 << 20); + +void run(int M, int N, int K, + int lda, int ldb, int ldc, int ldd, + void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D, + void const *alpha, void const *beta, + ScalarPointerMode scalar_mode, + const Operation *operation, + cudaStream_t stream) { + + GemmUniversalConfiguration configuration{ + GemmUniversalMode::kGemm, + {M, N, K}, + 1, + lda, + ldb, + ldc, + ldd}; + + // host workspace size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) + throw std::runtime_error("Unable to find gemm operation"); + char host_workspace[kHostWorkspaceSize]; + + // device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed) + throw std::runtime_error("Unable to find gemm operation"); + static void *device_workspace; + + // Initialize host and device workspaces + Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream); + if (status != cutlass::Status::kSuccess) + throw std::runtime_error("Unable to initialize workspace"); + + // Run the operator + GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode}; + operation->run(&arguments, host_workspace, device_workspace, stream); +} + +const Operation *autotune(int M, int N, int K, + NumericTypeID element_compute, + NumericTypeID element_scalar, + void const *alpha, + NumericTypeID element_A, + LayoutTypeID layout_A, + ComplexTransform transform_A, + void const *ptr_A, + int lda, + NumericTypeID element_B, + LayoutTypeID layout_B, + ComplexTransform transform_B, + void const *ptr_B, + int ldb, + void const *beta, + NumericTypeID element_C, + void const *ptr_C, + int ldc, + void *ptr_D, + int ldd, + ScalarPointerMode scalar_mode, + int device_id, + cudaStream_t stream) { + + // index operation table with functional key + GemmFunctionalKey key( + Provider::kCUTLASS, + GemmKind::kUniversal, + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C); + + auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); + if (operators_it == Singleton::get().operation_table.gemm_operations.end()) + throw std::runtime_error("Unable to find gemm operation"); + if (operators_it->second.empty()) + throw std::runtime_error("Unable to find gemm operation"); + + cudaDeviceProp device_prop; + cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id); + if (error != cudaSuccess) + throw std::runtime_error("Unable to get device properties"); + int cc = device_prop.major * 10 + device_prop.minor; + + // index operation table with preference key + // assume 8-bytes aligned memory pointers + int alignment = 8; + GemmPreferenceKey preference_key(cc, alignment); + auto autotune_it = operators_it->second.find(preference_key); + if (autotune_it == operators_it->second.end()) + throw std::runtime_error("Unable to find gemm operation"); + const std::vector &operations = autotune_it->second; + if (operations.empty()) + throw std::runtime_error("Unable to find gemm operation"); + + // auto-tune + const Operation *best = nullptr; + double best_ms = std::numeric_limits::max(); + for (const Operation *op : operations) { + auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, + alpha, beta, scalar_mode, op, stream); }; + triton::driver::cu_stream tt_stream((CUstream)stream, false); + double ms = triton::tools::bench(fn, &tt_stream, 10, 25); + if (ms < best_ms) { + best_ms = ms; + best = op; + } + } + return best; +} + +// map of torch datatypes to cutlass datatypes +std::map type_map = { + {caffe2::TypeMeta::Id(), NumericTypeID::kF16}, + {caffe2::TypeMeta::Id(), NumericTypeID::kF32}, + {caffe2::TypeMeta::Id(), NumericTypeID::kF64}}; + +void cutlass_matmul(torch::Tensor A, torch::Tensor B, torch::Tensor C) { + size_t M = A.size(0); + size_t N = B.size(1); + size_t K = A.size(1); + size_t lda = A.stride(0); + size_t ldb = B.stride(0); + size_t ldc = C.stride(1); + size_t ldd = C.stride(1); + void *ptr_A = A.data_ptr(); + void *ptr_B = B.data_ptr(); + void *ptr_C = C.data_ptr(); + void *ptr_D = ptr_C; + float alpha = 1.0f; + float beta = 0.0f; + // layout for A + LayoutTypeID layout_A; + if (A.stride(0) == 1) + layout_A = LayoutTypeID::kColumnMajor; + else if (A.stride(1) == 1) + layout_A = LayoutTypeID::kRowMajor; + else { + A = A.contiguous(); + layout_A = LayoutTypeID::kRowMajor; + } + // layout for B + LayoutTypeID layout_B; + if (B.stride(0) == 1) + layout_B = LayoutTypeID::kColumnMajor; + else if (B.stride(1) == 1) + layout_B = LayoutTypeID::kRowMajor; + else { + B = B.contiguous(); + layout_B = LayoutTypeID::kRowMajor; + } + // data types + NumericTypeID element_compute = NumericTypeID::kF32; + NumericTypeID element_A = type_map[A.dtype().id()]; + NumericTypeID element_B = type_map[B.dtype().id()]; + NumericTypeID element_C = type_map[C.dtype().id()]; + // misc. flags + ScalarPointerMode scalar_mode = ScalarPointerMode::kHost; + NumericTypeID element_scalar = NumericTypeID::kF32; + ComplexTransform transform_A = ComplexTransform::kNone; + ComplexTransform transform_B = ComplexTransform::kNone; + // runtime flags + size_t dev_id = C.device().index(); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev_id).stream(); + // auto-tune + std::vector tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C, + dev_id, (size_t)element_compute, (size_t)scalar_mode}; + auto it = op_cache_.find(tune_key); + if (it == op_cache_.end()) { + const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha, + element_A, layout_A, transform_A, ptr_A, lda, + element_B, layout_B, transform_B, ptr_B, ldb, + &beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode, + dev_id, stream); + it = op_cache_.insert({tune_key, op}).first; + } + run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta, + scalar_mode, it->second, stream); +} + +void init_cutlass(pybind11::module &m) { + pybind11::module subm = m.def_submodule("cutlass"); + subm.def("matmul", &cutlass_matmul, "matrix multiplication"); +} \ No newline at end of file diff --git a/python/src/main.cc b/python/src/main.cc index 73394a30c..1d664f8f8 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -3,10 +3,14 @@ void init_superblocking(pybind11::module &m); void init_torch_utils(pybind11::module &m); void init_triton(pybind11::module &m); +void init_cutlass(pybind11::module &m); PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; init_triton(m); init_torch_utils(m); init_superblocking(m); +#ifdef WITH_CUTLASS_BINDINGS + init_cutlass(m); +#endif } diff --git a/python/triton/testing.py b/python/triton/testing.py index 22c97cc7f..b3b64a498 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,6 +1,11 @@ import torch import os +try: + import triton._C.libtriton.cutlass as _cutlass +except ImportError: + _cutlass = None + def sparsify_tensor(x, mask, block): ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) @@ -9,6 +14,15 @@ def sparsify_tensor(x, mask, block): return ret +def cutlass_matmul(a, b): + if _cutlass is None: + raise RuntimeError("Cannot find cutlass library") + M, N = a.shape[0], b.shape[1] + c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device) + _cutlass.matmul(a, b, c) + return c + + def mask_tensor(x, mask, block, value=0): ret = x.clone() for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):