[PYTHON] CUTLASS wrapper for fair benchmarks (#75)
Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
committed by
Philippe Tillet
parent
d6f18742b1
commit
eacbb73968
@@ -34,10 +34,19 @@ if(BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# PyBind11 wrapper source file
|
||||
file(GLOB_RECURSE TORCH_SRC torch/*.cc)
|
||||
# Build CUTLASS python wrapper if requested
|
||||
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||
set(TORCH_SRC ${TORCH_SRC} cutlass.cc)
|
||||
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||
set(CUTLASS_LIBRARIES "cutlass")
|
||||
endif()
|
||||
message(STATUS ${CUTLASS_INCLUDE_PATH})
|
||||
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
|
||||
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
|
||||
include_directories("." ${PYTHON_INCLUDE_DIRS})
|
||||
link_directories(${PYTHON_LINK_DIRS})
|
||||
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ${CUTLASS_OPT}")
|
||||
include_directories("." ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -47,5 +56,5 @@ add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS})
|
||||
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
target_link_libraries(triton ${TORCH_LIBRARIES})
|
||||
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES})
|
||||
endif()
|
||||
|
@@ -48,7 +48,7 @@ transformer_confs = [
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(transformer_confs)
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
@@ -62,37 +62,11 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = "column" if AT else "row"
|
||||
layout_b = "column" if BT else "row"
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [
|
||||
os.environ["CUTLASS_PROFILER"],
|
||||
f"--m={M}",
|
||||
f"--n={N}",
|
||||
f"--k={K}",
|
||||
f"--A=f16:{layout_a}",
|
||||
f"--B=f16:{layout_b}",
|
||||
"--C=f16:column",
|
||||
"--accum=f32",
|
||||
"--operation=gemm",
|
||||
"--verification-enabled=false",
|
||||
f"--warmup-iterations={warmup}",
|
||||
f"--profiling-iterations={rep}",
|
||||
f"--output={fname}",
|
||||
"--dist=uniform,min:0,max:1,scale:-1",
|
||||
"--verbose=false",
|
||||
]
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
||||
tflops = max(df_c["GFLOPs"]) / 1e3
|
||||
return tflops
|
||||
if provider == "cutlass":
|
||||
cutlass_matmul = triton.testing.cutlass_matmul
|
||||
try:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
except:
|
||||
return None
|
||||
return None
|
||||
|
@@ -14,6 +14,7 @@ from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import torch
|
||||
|
||||
|
||||
def find_llvm():
|
||||
versions = ["-10", "-10.0", ""]
|
||||
supported = ["llvm-config{v}".format(v=v) for v in versions]
|
||||
@@ -28,19 +29,23 @@ def find_llvm():
|
||||
version = os.popen("{config} --version".format(config=config)).read()
|
||||
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name, path, sourcedir=""):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
self.path = path
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
def run(self):
|
||||
try:
|
||||
out = subprocess.check_output(["cmake", "--version"])
|
||||
except OSError:
|
||||
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
||||
", ".join(e.name for e in self.extensions))
|
||||
raise RuntimeError(
|
||||
"CMake must be installed to build the following extensions: " +
|
||||
", ".join(e.name for e in self.extensions)
|
||||
)
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
|
||||
@@ -92,6 +97,7 @@ class CMakeBuild(build_ext):
|
||||
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
version="1.0.0",
|
||||
@@ -101,7 +107,10 @@ setup(
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
||||
install_requires=["numpy", "torch"],
|
||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||
package_data={
|
||||
"triton/ops": ["*.c"],
|
||||
"triton/ops/blocksparse": ["*.c"]
|
||||
},
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||
cmdclass={"build_ext": CMakeBuild},
|
||||
|
206
python/src/cutlass.cc
Normal file
206
python/src/cutlass.cc
Normal file
@@ -0,0 +1,206 @@
|
||||
#include "cutlass/library/handle.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/operation_table.h"
|
||||
#include "cutlass/library/singleton.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::library;
|
||||
|
||||
std::map<std::vector<size_t>, const Operation *> op_cache_;
|
||||
|
||||
static int const kHostWorkspaceSize = (4 << 10);
|
||||
static int const kDeviceWorkspaceSize = (4 << 20);
|
||||
|
||||
void run(int M, int N, int K,
|
||||
int lda, int ldb, int ldc, int ldd,
|
||||
void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D,
|
||||
void const *alpha, void const *beta,
|
||||
ScalarPointerMode scalar_mode,
|
||||
const Operation *operation,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GemmUniversalConfiguration configuration{
|
||||
GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
1,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd};
|
||||
|
||||
// host workspace size
|
||||
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
||||
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed)
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
char host_workspace[kHostWorkspaceSize];
|
||||
|
||||
// device workspace size
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||
if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed)
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
static void *device_workspace;
|
||||
|
||||
// Initialize host and device workspaces
|
||||
Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream);
|
||||
if (status != cutlass::Status::kSuccess)
|
||||
throw std::runtime_error("Unable to initialize workspace");
|
||||
|
||||
// Run the operator
|
||||
GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode};
|
||||
operation->run(&arguments, host_workspace, device_workspace, stream);
|
||||
}
|
||||
|
||||
const Operation *autotune(int M, int N, int K,
|
||||
NumericTypeID element_compute,
|
||||
NumericTypeID element_scalar,
|
||||
void const *alpha,
|
||||
NumericTypeID element_A,
|
||||
LayoutTypeID layout_A,
|
||||
ComplexTransform transform_A,
|
||||
void const *ptr_A,
|
||||
int lda,
|
||||
NumericTypeID element_B,
|
||||
LayoutTypeID layout_B,
|
||||
ComplexTransform transform_B,
|
||||
void const *ptr_B,
|
||||
int ldb,
|
||||
void const *beta,
|
||||
NumericTypeID element_C,
|
||||
void const *ptr_C,
|
||||
int ldc,
|
||||
void *ptr_D,
|
||||
int ldd,
|
||||
ScalarPointerMode scalar_mode,
|
||||
int device_id,
|
||||
cudaStream_t stream) {
|
||||
|
||||
// index operation table with functional key
|
||||
GemmFunctionalKey key(
|
||||
Provider::kCUTLASS,
|
||||
GemmKind::kUniversal,
|
||||
element_compute,
|
||||
element_scalar,
|
||||
element_A,
|
||||
layout_A,
|
||||
transform_A,
|
||||
element_B,
|
||||
layout_B,
|
||||
transform_B,
|
||||
element_C);
|
||||
|
||||
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
||||
if (operators_it == Singleton::get().operation_table.gemm_operations.end())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
if (operators_it->second.empty())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
|
||||
cudaDeviceProp device_prop;
|
||||
cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id);
|
||||
if (error != cudaSuccess)
|
||||
throw std::runtime_error("Unable to get device properties");
|
||||
int cc = device_prop.major * 10 + device_prop.minor;
|
||||
|
||||
// index operation table with preference key
|
||||
// assume 8-bytes aligned memory pointers
|
||||
int alignment = 8;
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
auto autotune_it = operators_it->second.find(preference_key);
|
||||
if (autotune_it == operators_it->second.end())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
const std::vector<const Operation *> &operations = autotune_it->second;
|
||||
if (operations.empty())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
|
||||
// auto-tune
|
||||
const Operation *best = nullptr;
|
||||
double best_ms = std::numeric_limits<double>::max();
|
||||
for (const Operation *op : operations) {
|
||||
auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D,
|
||||
alpha, beta, scalar_mode, op, stream); };
|
||||
triton::driver::cu_stream tt_stream((CUstream)stream, false);
|
||||
double ms = triton::tools::bench(fn, &tt_stream, 10, 25);
|
||||
if (ms < best_ms) {
|
||||
best_ms = ms;
|
||||
best = op;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
// map of torch datatypes to cutlass datatypes
|
||||
std::map<caffe2::TypeIdentifier, NumericTypeID> type_map = {
|
||||
{caffe2::TypeMeta::Id<at::Half>(), NumericTypeID::kF16},
|
||||
{caffe2::TypeMeta::Id<float>(), NumericTypeID::kF32},
|
||||
{caffe2::TypeMeta::Id<double>(), NumericTypeID::kF64}};
|
||||
|
||||
void cutlass_matmul(torch::Tensor A, torch::Tensor B, torch::Tensor C) {
|
||||
size_t M = A.size(0);
|
||||
size_t N = B.size(1);
|
||||
size_t K = A.size(1);
|
||||
size_t lda = A.stride(0);
|
||||
size_t ldb = B.stride(0);
|
||||
size_t ldc = C.stride(1);
|
||||
size_t ldd = C.stride(1);
|
||||
void *ptr_A = A.data_ptr();
|
||||
void *ptr_B = B.data_ptr();
|
||||
void *ptr_C = C.data_ptr();
|
||||
void *ptr_D = ptr_C;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
// layout for A
|
||||
LayoutTypeID layout_A;
|
||||
if (A.stride(0) == 1)
|
||||
layout_A = LayoutTypeID::kColumnMajor;
|
||||
else if (A.stride(1) == 1)
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
else {
|
||||
A = A.contiguous();
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
}
|
||||
// layout for B
|
||||
LayoutTypeID layout_B;
|
||||
if (B.stride(0) == 1)
|
||||
layout_B = LayoutTypeID::kColumnMajor;
|
||||
else if (B.stride(1) == 1)
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
else {
|
||||
B = B.contiguous();
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
}
|
||||
// data types
|
||||
NumericTypeID element_compute = NumericTypeID::kF32;
|
||||
NumericTypeID element_A = type_map[A.dtype().id()];
|
||||
NumericTypeID element_B = type_map[B.dtype().id()];
|
||||
NumericTypeID element_C = type_map[C.dtype().id()];
|
||||
// misc. flags
|
||||
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32;
|
||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||
// runtime flags
|
||||
size_t dev_id = C.device().index();
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
// auto-tune
|
||||
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
||||
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
||||
auto it = op_cache_.find(tune_key);
|
||||
if (it == op_cache_.end()) {
|
||||
const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha,
|
||||
element_A, layout_A, transform_A, ptr_A, lda,
|
||||
element_B, layout_B, transform_B, ptr_B, ldb,
|
||||
&beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode,
|
||||
dev_id, stream);
|
||||
it = op_cache_.insert({tune_key, op}).first;
|
||||
}
|
||||
run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta,
|
||||
scalar_mode, it->second, stream);
|
||||
}
|
||||
|
||||
void init_cutlass(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("cutlass");
|
||||
subm.def("matmul", &cutlass_matmul, "matrix multiplication");
|
||||
}
|
@@ -3,10 +3,14 @@
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_torch_utils(pybind11::module &m);
|
||||
void init_triton(pybind11::module &m);
|
||||
void init_cutlass(pybind11::module &m);
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
init_triton(m);
|
||||
init_torch_utils(m);
|
||||
init_superblocking(m);
|
||||
#ifdef WITH_CUTLASS_BINDINGS
|
||||
init_cutlass(m);
|
||||
#endif
|
||||
}
|
||||
|
@@ -1,6 +1,11 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
except ImportError:
|
||||
_cutlass = None
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
@@ -9,6 +14,15 @@ def sparsify_tensor(x, mask, block):
|
||||
return ret
|
||||
|
||||
|
||||
def cutlass_matmul(a, b):
|
||||
if _cutlass is None:
|
||||
raise RuntimeError("Cannot find cutlass library")
|
||||
M, N = a.shape[0], b.shape[1]
|
||||
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||
_cutlass.matmul(a, b, c)
|
||||
return c
|
||||
|
||||
|
||||
def mask_tensor(x, mask, block, value=0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
|
Reference in New Issue
Block a user