[python] more robust way to add triton includes to python package

This commit is contained in:
Philippe Tillet
2019-09-05 16:01:56 -04:00
parent 945593e847
commit 0a6329ea7d
3 changed files with 8 additions and 7 deletions

View File

@@ -84,9 +84,7 @@ directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))]
data = [] data = []
for d in directories: for d in directories:
files = glob.glob(os.path.join(d, '*.h'), recursive=False) files = glob.glob(os.path.join(d, '*.h'), recursive=False)
dest = os.path.relpath(d, os.path.pardir) data += [os.path.relpath(f, os.path.pardir) for f in files]
dest = os.path.join('triton', '_C', dest)
data += [(dest, files)]
setup( setup(
name='triton', name='triton',
@@ -95,8 +93,8 @@ setup(
author_email='ptillet@g.harvard.edu', author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations', description='A language and compiler for custom Deep Learning operations',
long_description='', long_description='',
packages=['triton', 'triton/ops'], packages=['triton', 'triton/_C', 'triton/ops'],
data_files=data, package_data={'': data},
ext_modules=[CMakeExtension('triton', 'triton/_C/')], ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass=dict(build_ext=CMakeBuild), cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False, zip_safe=False,

1
python/triton/_C/include Symbolic link
View File

@@ -0,0 +1 @@
../../../include/

View File

@@ -57,11 +57,13 @@ def _write_bindings(src, root):
return (cpp, so) return (cpp, so)
def _build(src, path): def _build(src, path):
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir)
# include directories # include directories
triton_include_dirs = [os.path.realpath(os.path.join(libtriton.__file__, 'include'))] triton_include_dirs = [os.path.join(ccdir, 'include')]
include_dirs = triton_include_dirs include_dirs = triton_include_dirs
# library directories # library directories
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))] triton_library_dirs = [ccdir]
library_dirs = triton_library_dirs library_dirs = triton_library_dirs
# libraries # libraries
libraries = ['triton'] libraries = ['triton']