Merge branch 'master' into rcom52_fixes
This commit is contained in:
166
python/setup.py
166
python/setup.py
@@ -1,48 +1,81 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import sysconfig
|
||||
import platform
|
||||
import subprocess
|
||||
import distutils
|
||||
import glob
|
||||
import tempfile
|
||||
import shutil
|
||||
from distutils.version import LooseVersion
|
||||
from setuptools import setup, Extension, find_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import urllib.request
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import torch
|
||||
import urllib.request
|
||||
from distutils.version import LooseVersion
|
||||
from typing import NamedTuple
|
||||
|
||||
def get_llvm():
|
||||
# tries to find system LLVM
|
||||
versions = ['-13.0', '-13', '-13-64']
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
|
||||
def check_env_flag(name: str, default: str = "") -> bool:
|
||||
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
|
||||
|
||||
|
||||
def get_build_type():
|
||||
if check_env_flag("DEBUG"):
|
||||
return "Debug"
|
||||
elif check_env_flag("REL_WITH_DEB_INFO"):
|
||||
return "RelWithDebInfo"
|
||||
else:
|
||||
return "Release"
|
||||
|
||||
|
||||
def use_system_llvm():
|
||||
if platform.system() == "Windows":
|
||||
return True
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
if paths:
|
||||
return '', ''
|
||||
if platform.system() == "Windows":
|
||||
return '', ''
|
||||
# download if nothing is installed
|
||||
name = 'clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04'
|
||||
dir = '/tmp'
|
||||
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
||||
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
||||
if not os.path.exists(llvm_library_dir):
|
||||
try:
|
||||
shutil.rmtree(os.path.join(dir, name))
|
||||
except:
|
||||
pass
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/{name}.tar.xz".format(name=name)
|
||||
print('downloading and extracting ' + url + '...')
|
||||
ftpstream = urllib.request.urlopen(url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
||||
file.extractall(path=dir)
|
||||
return llvm_include_dir, llvm_library_dir
|
||||
return any(p is not None for p in paths)
|
||||
|
||||
|
||||
def get_thirdparty_packages(triton_cache_path):
|
||||
class Package(NamedTuple):
|
||||
package: str
|
||||
name: str
|
||||
url: str
|
||||
test_file: str
|
||||
include_flag: str
|
||||
lib_flag: str
|
||||
|
||||
packages = [
|
||||
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
|
||||
]
|
||||
if not use_system_llvm():
|
||||
# donwload LLVM if no suitable system LLVM is installed
|
||||
packages.append(
|
||||
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
||||
)
|
||||
|
||||
thirdparty_cmake_args = []
|
||||
for p in packages:
|
||||
package_root_dir = os.path.join(triton_cache_path, p.package)
|
||||
package_dir = os.path.join(package_root_dir, p.name)
|
||||
test_file_path = os.path.join(package_dir, p.test_file)
|
||||
if not os.path.exists(test_file_path):
|
||||
try:
|
||||
shutil.rmtree(package_root_dir)
|
||||
except Exception:
|
||||
pass
|
||||
os.makedirs(package_root_dir, exist_ok=True)
|
||||
print('downloading and extracting {} ...'.format(p.url))
|
||||
ftpstream = urllib.request.urlopen(p.url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|*")
|
||||
file.extractall(path=package_root_dir)
|
||||
if p.include_flag:
|
||||
thirdparty_cmake_args.append("-D{}={}/include".format(p.include_flag, package_dir))
|
||||
if p.lib_flag:
|
||||
thirdparty_cmake_args.append("-D{}={}/lib".format(p.lib_flag, package_dir))
|
||||
return thirdparty_cmake_args
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
@@ -80,34 +113,24 @@ class CMakeBuild(build_ext):
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||
self.debug = True
|
||||
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
|
||||
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
if not os.path.exists(llvm_build_dir):
|
||||
os.makedirs(llvm_build_dir)
|
||||
# python directories
|
||||
if torch.version.hip is not None:
|
||||
python_include_dirs= [distutils.sysconfig.get_python_inc()] +['/opt/rocm/include']
|
||||
else:
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()]
|
||||
cmake_args = [
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
"-DBUILD_TUTORIALS=OFF",
|
||||
"-DBUILD_PYTHON_MODULE=ON",
|
||||
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
|
||||
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
|
||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
|
||||
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
|
||||
]
|
||||
] + thirdparty_cmake_args
|
||||
# configuration
|
||||
cfg = "Debug" if self.debug else "Release"
|
||||
cfg = get_build_type()
|
||||
build_args = ["--config", cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
@@ -130,14 +153,22 @@ class CMakeBuild(build_ext):
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
version="1.1.2",
|
||||
version="2.0.0",
|
||||
author="Philippe Tillet",
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||
install_requires=["torch", "filelock"],
|
||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||
packages=["triton", "triton/_C", "triton/language", "triton/runtime", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||
install_requires=[
|
||||
"cmake",
|
||||
"filelock",
|
||||
"torch",
|
||||
],
|
||||
package_data={
|
||||
"triton/ops": ["*.c"],
|
||||
"triton/ops/blocksparse": ["*.c"],
|
||||
"triton/language": ["*.bc"],
|
||||
},
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||
cmdclass={"build_ext": CMakeBuild},
|
||||
@@ -152,4 +183,19 @@ setup(
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
],
|
||||
extras_require={
|
||||
"tests": [
|
||||
"autopep8",
|
||||
"flake8",
|
||||
"isort",
|
||||
"numpy",
|
||||
"pytest",
|
||||
"scipy>=1.7.1",
|
||||
],
|
||||
"tutorials": [
|
||||
"matplotlib",
|
||||
"pandas",
|
||||
"tabulate",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
Reference in New Issue
Block a user