2022-01-06 14:34:17 -08:00
|
|
|
import distutils
|
2021-07-27 12:38:38 -07:00
|
|
|
import os
|
|
|
|
import platform
|
2022-01-06 14:34:17 -08:00
|
|
|
import re
|
|
|
|
import shutil
|
2021-07-27 12:38:38 -07:00
|
|
|
import subprocess
|
2022-01-06 14:34:17 -08:00
|
|
|
import sys
|
|
|
|
import tarfile
|
|
|
|
import urllib.request
|
2021-07-27 12:38:38 -07:00
|
|
|
from distutils.version import LooseVersion
|
2022-09-23 15:54:07 -07:00
|
|
|
from typing import NamedTuple
|
2022-01-06 14:34:17 -08:00
|
|
|
|
|
|
|
from setuptools import Extension, setup
|
2021-07-27 12:38:38 -07:00
|
|
|
from setuptools.command.build_ext import build_ext
|
2022-01-06 14:34:17 -08:00
|
|
|
|
2021-03-09 16:32:44 -05:00
|
|
|
|
2022-07-01 12:17:22 -07:00
|
|
|
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
|
|
|
def check_env_flag(name: str, default: str = "") -> bool:
|
|
|
|
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
|
|
|
|
|
|
|
|
|
|
|
|
def get_build_type():
|
|
|
|
if check_env_flag("DEBUG"):
|
|
|
|
return "Debug"
|
|
|
|
elif check_env_flag("REL_WITH_DEB_INFO"):
|
|
|
|
return "RelWithDebInfo"
|
2022-12-21 01:30:50 -08:00
|
|
|
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
|
|
|
|
return "TritonRelBuildWithAsserts"
|
2022-07-01 12:17:22 -07:00
|
|
|
else:
|
2022-12-21 01:30:50 -08:00
|
|
|
# TODO: change to release when stable enough
|
|
|
|
return "TritonRelBuildWithAsserts"
|
2022-07-01 12:17:22 -07:00
|
|
|
|
|
|
|
|
2022-12-21 01:30:50 -08:00
|
|
|
# --- third party packages -----
|
2022-09-23 15:54:07 -07:00
|
|
|
|
2022-12-21 01:30:50 -08:00
|
|
|
class Package(NamedTuple):
|
|
|
|
package: str
|
|
|
|
name: str
|
|
|
|
url: str
|
|
|
|
test_file: str
|
|
|
|
include_flag: str
|
|
|
|
lib_flag: str
|
|
|
|
syspath_var_name: str
|
2022-09-23 15:54:07 -07:00
|
|
|
|
|
|
|
|
2022-12-21 01:30:50 -08:00
|
|
|
def get_pybind11_package_info():
|
|
|
|
name = "pybind11-2.10.0"
|
|
|
|
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
|
|
|
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
|
|
|
|
|
|
|
|
|
|
|
def get_llvm_package_info():
|
|
|
|
# download if nothing is installed
|
|
|
|
system = platform.system()
|
|
|
|
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
|
|
|
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
|
|
|
if use_assert_enabled_llvm:
|
|
|
|
name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix)
|
|
|
|
url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name)
|
|
|
|
else:
|
|
|
|
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
|
|
|
|
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
|
|
|
|
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
|
|
|
|
|
|
|
|
|
|
|
def get_thirdparty_packages(triton_cache_path):
|
|
|
|
packages = [get_pybind11_package_info(), get_llvm_package_info()]
|
2022-09-23 15:54:07 -07:00
|
|
|
thirdparty_cmake_args = []
|
|
|
|
for p in packages:
|
|
|
|
package_root_dir = os.path.join(triton_cache_path, p.package)
|
|
|
|
package_dir = os.path.join(package_root_dir, p.name)
|
|
|
|
test_file_path = os.path.join(package_dir, p.test_file)
|
2022-12-21 01:30:50 -08:00
|
|
|
if p.syspath_var_name in os.environ:
|
|
|
|
package_dir = os.environ[p.syspath_var_name]
|
2022-09-23 15:54:07 -07:00
|
|
|
if not os.path.exists(test_file_path):
|
|
|
|
try:
|
|
|
|
shutil.rmtree(package_root_dir)
|
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
os.makedirs(package_root_dir, exist_ok=True)
|
|
|
|
print('downloading and extracting {} ...'.format(p.url))
|
|
|
|
ftpstream = urllib.request.urlopen(p.url)
|
|
|
|
file = tarfile.open(fileobj=ftpstream, mode="r|*")
|
|
|
|
file.extractall(path=package_root_dir)
|
|
|
|
if p.include_flag:
|
|
|
|
thirdparty_cmake_args.append("-D{}={}/include".format(p.include_flag, package_dir))
|
|
|
|
if p.lib_flag:
|
|
|
|
thirdparty_cmake_args.append("-D{}={}/lib".format(p.lib_flag, package_dir))
|
|
|
|
return thirdparty_cmake_args
|
2021-07-28 01:02:31 -07:00
|
|
|
|
2022-12-21 01:30:50 -08:00
|
|
|
# ---- cmake extension ----
|
|
|
|
|
2021-07-28 01:02:31 -07:00
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
class CMakeExtension(Extension):
|
2021-02-21 15:19:39 -08:00
|
|
|
def __init__(self, name, path, sourcedir=""):
|
2021-07-27 12:38:38 -07:00
|
|
|
Extension.__init__(self, name, sources=[])
|
|
|
|
self.sourcedir = os.path.abspath(sourcedir)
|
|
|
|
self.path = path
|
|
|
|
|
2021-03-09 16:32:44 -05:00
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
class CMakeBuild(build_ext):
|
2021-03-22 20:03:37 -04:00
|
|
|
|
|
|
|
user_options = build_ext.user_options + [('base-dir=', None, 'base directory of Triton')]
|
|
|
|
|
|
|
|
def initialize_options(self):
|
|
|
|
build_ext.initialize_options(self)
|
|
|
|
self.base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
|
|
|
|
|
|
|
def finalize_options(self):
|
|
|
|
build_ext.finalize_options(self)
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
def run(self):
|
|
|
|
try:
|
2021-02-21 15:19:39 -08:00
|
|
|
out = subprocess.check_output(["cmake", "--version"])
|
2021-07-27 12:38:38 -07:00
|
|
|
except OSError:
|
2021-03-09 16:32:44 -05:00
|
|
|
raise RuntimeError(
|
2021-03-14 18:49:59 -04:00
|
|
|
"CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions)
|
2021-03-09 16:32:44 -05:00
|
|
|
)
|
2021-07-27 12:38:38 -07:00
|
|
|
|
|
|
|
if platform.system() == "Windows":
|
2021-02-21 15:19:39 -08:00
|
|
|
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
|
|
|
|
if cmake_version < "3.1.0":
|
2021-07-27 12:38:38 -07:00
|
|
|
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
|
|
|
|
|
|
|
|
for ext in self.extensions:
|
|
|
|
self.build_extension(ext)
|
|
|
|
|
|
|
|
def build_extension(self, ext):
|
2022-12-21 01:30:50 -08:00
|
|
|
lit_dir = shutil.which('lit')
|
2022-09-23 15:54:07 -07:00
|
|
|
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
|
2022-12-21 01:30:50 -08:00
|
|
|
# lit is used by the test suite
|
2022-09-23 15:54:07 -07:00
|
|
|
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
|
2021-07-27 12:38:38 -07:00
|
|
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
2021-03-22 20:03:37 -04:00
|
|
|
# create build directories
|
|
|
|
if not os.path.exists(self.build_temp):
|
|
|
|
os.makedirs(self.build_temp)
|
2021-07-27 12:38:38 -07:00
|
|
|
# python directories
|
2022-12-21 01:30:50 -08:00
|
|
|
python_include_dir = distutils.sysconfig.get_python_inc()
|
2021-02-08 12:16:41 -08:00
|
|
|
cmake_args = [
|
2022-12-21 01:30:50 -08:00
|
|
|
"-DLLVM_ENABLE_WERROR=ON",
|
2021-02-21 15:19:39 -08:00
|
|
|
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
2022-12-21 01:30:50 -08:00
|
|
|
"-DTRITON_BUILD_TUTORIALS=OFF",
|
|
|
|
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
2022-12-29 23:10:34 +00:00
|
|
|
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
|
|
|
|
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
|
2022-12-21 01:30:50 -08:00
|
|
|
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
2022-12-29 23:10:34 +00:00
|
|
|
"-DLLVM_EXTERNAL_LIT=" + lit_dir,
|
2022-09-23 15:54:07 -07:00
|
|
|
] + thirdparty_cmake_args
|
2022-12-21 01:30:50 -08:00
|
|
|
|
2020-02-24 17:46:20 -05:00
|
|
|
# configuration
|
2022-07-01 12:17:22 -07:00
|
|
|
cfg = get_build_type()
|
2021-02-21 15:19:39 -08:00
|
|
|
build_args = ["--config", cfg]
|
2021-07-27 12:38:38 -07:00
|
|
|
|
|
|
|
if platform.system() == "Windows":
|
2022-01-30 20:21:20 -08:00
|
|
|
cmake_args += ["-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
|
2021-07-27 12:38:38 -07:00
|
|
|
if sys.maxsize > 2**32:
|
2021-02-21 15:19:39 -08:00
|
|
|
cmake_args += ["-A", "x64"]
|
|
|
|
build_args += ["--", "/m"]
|
2021-07-27 12:38:38 -07:00
|
|
|
else:
|
2021-05-12 19:24:11 -04:00
|
|
|
import multiprocessing
|
2021-02-21 15:19:39 -08:00
|
|
|
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
|
2021-05-15 00:44:43 -04:00
|
|
|
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
|
2021-07-27 12:38:38 -07:00
|
|
|
|
|
|
|
env = os.environ.copy()
|
2021-03-22 20:03:37 -04:00
|
|
|
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
|
2021-02-21 15:19:39 -08:00
|
|
|
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
2020-03-13 18:03:25 +00:00
|
|
|
|
2021-03-09 16:32:44 -05:00
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
setup(
|
2021-02-21 15:19:39 -08:00
|
|
|
name="triton",
|
2021-10-29 01:28:17 -07:00
|
|
|
version="2.0.0",
|
2021-02-21 15:19:39 -08:00
|
|
|
author="Philippe Tillet",
|
|
|
|
author_email="phil@openai.com",
|
|
|
|
description="A language and compiler for custom Deep Learning operations",
|
|
|
|
long_description="",
|
2023-01-09 11:03:45 -08:00
|
|
|
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
|
2022-01-07 12:34:38 -08:00
|
|
|
install_requires=[
|
|
|
|
"cmake",
|
|
|
|
"filelock",
|
|
|
|
"torch",
|
2022-12-21 01:30:50 -08:00
|
|
|
"lit",
|
2022-01-07 12:34:38 -08:00
|
|
|
],
|
2022-07-13 23:45:27 -07:00
|
|
|
package_data={
|
|
|
|
"triton/ops": ["*.c"],
|
|
|
|
"triton/ops/blocksparse": ["*.c"],
|
2022-12-21 01:30:50 -08:00
|
|
|
"triton/language": ["*.bc"]
|
2022-07-13 23:45:27 -07:00
|
|
|
},
|
2021-01-29 17:27:16 -05:00
|
|
|
include_package_data=True,
|
2021-02-21 15:19:39 -08:00
|
|
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
|
|
|
cmdclass={"build_ext": CMakeBuild},
|
2021-07-27 12:38:38 -07:00
|
|
|
zip_safe=False,
|
2020-05-04 08:58:58 -04:00
|
|
|
# for PyPI
|
2021-02-21 15:19:39 -08:00
|
|
|
keywords=["Compiler", "Deep Learning"],
|
2021-07-29 13:39:50 -05:00
|
|
|
url="https://github.com/openai/triton/",
|
2020-05-04 08:58:58 -04:00
|
|
|
classifiers=[
|
2021-03-22 20:03:37 -04:00
|
|
|
"Development Status :: 4 - Beta",
|
|
|
|
"Intended Audience :: Developers",
|
2021-02-21 15:19:39 -08:00
|
|
|
"Topic :: Software Development :: Build Tools",
|
2021-03-22 20:03:37 -04:00
|
|
|
"License :: OSI Approved :: MIT License",
|
2021-02-21 15:19:39 -08:00
|
|
|
"Programming Language :: Python :: 3.6",
|
2021-02-08 12:16:41 -08:00
|
|
|
],
|
2022-12-21 01:30:50 -08:00
|
|
|
test_suite="tests",
|
2022-01-07 12:34:38 -08:00
|
|
|
extras_require={
|
|
|
|
"tests": [
|
2022-01-07 13:11:34 -08:00
|
|
|
"autopep8",
|
2022-01-07 15:28:36 -08:00
|
|
|
"flake8",
|
2022-01-07 13:11:34 -08:00
|
|
|
"isort",
|
2022-01-07 12:34:38 -08:00
|
|
|
"numpy",
|
|
|
|
"pytest",
|
|
|
|
"scipy>=1.7.1",
|
|
|
|
],
|
|
|
|
"tutorials": [
|
|
|
|
"matplotlib",
|
|
|
|
"pandas",
|
|
|
|
"tabulate",
|
|
|
|
],
|
|
|
|
},
|
2021-07-27 12:38:38 -07:00
|
|
|
)
|