Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

View File

@@ -1,5 +1,4 @@
import distutils
import distutils.spawn
import os
import platform
import re
@@ -25,42 +24,54 @@ def get_build_type():
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
return "TritonRelBuildWithAsserts"
else:
return "Release"
# TODO: change to release when stable enough
return "TritonRelBuildWithAsserts"
def use_system_llvm():
if platform.system() == "Windows":
return True
versions = ['-11.0', '-11', '-11-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
return any(p is not None for p in paths)
# --- third party packages -----
class Package(NamedTuple):
package: str
name: str
url: str
test_file: str
include_flag: str
lib_flag: str
syspath_var_name: str
def get_pybind11_package_info():
name = "pybind11-2.10.0"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
def get_llvm_package_info():
# download if nothing is installed
system = platform.system()
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
if use_assert_enabled_llvm:
name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix)
url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name)
else:
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
def get_thirdparty_packages(triton_cache_path):
class Package(NamedTuple):
package: str
name: str
url: str
test_file: str
include_flag: str
lib_flag: str
packages = [
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
]
if not use_system_llvm():
# download LLVM if no suitable system LLVM is installed
packages.append(
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
)
packages = [get_pybind11_package_info(), get_llvm_package_info()]
thirdparty_cmake_args = []
for p in packages:
package_root_dir = os.path.join(triton_cache_path, p.package)
package_dir = os.path.join(package_root_dir, p.name)
test_file_path = os.path.join(package_dir, p.test_file)
if p.syspath_var_name in os.environ:
package_dir = os.environ[p.syspath_var_name]
if not os.path.exists(test_file_path):
try:
shutil.rmtree(package_root_dir)
@@ -77,6 +88,8 @@ def get_thirdparty_packages(triton_cache_path):
thirdparty_cmake_args.append("-D{}={}/lib".format(p.lib_flag, package_dir))
return thirdparty_cmake_args
# ---- cmake extension ----
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=""):
@@ -113,22 +126,27 @@ class CMakeBuild(build_ext):
self.build_extension(ext)
def build_extension(self, ext):
lit_dir = shutil.which('lit')
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
# lit is used by the test suite
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
# python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()]
python_include_dir = distutils.sysconfig.get_python_inc()
cmake_args = [
"-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON",
"-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DLLVM_EXTERNAL_LIT=" + lit_dir
] + thirdparty_cmake_args
# configuration
cfg = get_build_type()
build_args = ["--config", cfg]
@@ -155,16 +173,17 @@ setup(
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/runtime", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
install_requires=[
"cmake",
"filelock",
"torch",
"lit",
],
package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"],
"triton/language": ["*.bc"],
"triton/language": ["*.bc"]
},
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
@@ -180,6 +199,7 @@ setup(
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6",
],
test_suite="tests",
extras_require={
"tests": [
"autopep8",