[PYTHON] CUTLASS wrapper for fair benchmarks (#75)

Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
Philippe Tillet
2021-03-09 16:32:44 -05:00
committed by Philippe Tillet
parent d6f18742b1
commit eacbb73968
6 changed files with 257 additions and 41 deletions

View File

@@ -14,6 +14,7 @@ from setuptools.command.test import test as TestCommand
import distutils.spawn
import torch
def find_llvm():
versions = ["-10", "-10.0", ""]
supported = ["llvm-config{v}".format(v=v) for v in versions]
@@ -28,19 +29,23 @@ def find_llvm():
version = os.popen("{config} --version".format(config=config)).read()
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
raise RuntimeError(
"CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions)
)
if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
@@ -92,6 +97,7 @@ class CMakeBuild(build_ext):
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
setup(
name="triton",
version="1.0.0",
@@ -101,7 +107,10 @@ setup(
long_description="",
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
install_requires=["numpy", "torch"],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"]
},
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},