[PYTHON] CUTLASS wrapper for fair benchmarks (#75)
Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
committed by
Philippe Tillet
parent
d6f18742b1
commit
eacbb73968
@@ -14,6 +14,7 @@ from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import torch
|
||||
|
||||
|
||||
def find_llvm():
|
||||
versions = ["-10", "-10.0", ""]
|
||||
supported = ["llvm-config{v}".format(v=v) for v in versions]
|
||||
@@ -28,19 +29,23 @@ def find_llvm():
|
||||
version = os.popen("{config} --version".format(config=config)).read()
|
||||
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name, path, sourcedir=""):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
self.path = path
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
def run(self):
|
||||
try:
|
||||
out = subprocess.check_output(["cmake", "--version"])
|
||||
except OSError:
|
||||
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
||||
", ".join(e.name for e in self.extensions))
|
||||
raise RuntimeError(
|
||||
"CMake must be installed to build the following extensions: " +
|
||||
", ".join(e.name for e in self.extensions)
|
||||
)
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
|
||||
@@ -92,6 +97,7 @@ class CMakeBuild(build_ext):
|
||||
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
version="1.0.0",
|
||||
@@ -101,7 +107,10 @@ setup(
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
||||
install_requires=["numpy", "torch"],
|
||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||
package_data={
|
||||
"triton/ops": ["*.c"],
|
||||
"triton/ops/blocksparse": ["*.c"]
|
||||
},
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||
cmdclass={"build_ext": CMakeBuild},
|
||||
|
Reference in New Issue
Block a user