[python] now packaging include and libtriton in triton._C submodule

This commit is contained in:
Philippe Tillet
2019-09-05 15:37:00 -04:00
parent 9ab2880fba
commit 7bfbb89612
5 changed files with 20 additions and 10 deletions

View File

@@ -5,16 +5,17 @@ import sysconfig
import platform
import subprocess
import distutils
import glob
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
class CMakeExtension(Extension):
def __init__(self, name, sourcedir=''):
def __init__(self, name, path, sourcedir=''):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
class CMakeBuild(build_ext):
@@ -36,7 +37,7 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# python directories
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
@@ -78,6 +79,15 @@ class CMakeBuild(build_ext):
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))]
data = []
for d in directories:
files = glob.glob(os.path.join(d, '*.h'), recursive=False)
dest = os.path.relpath(d, os.path.pardir)
dest = os.path.join('triton', '_C', dest)
data += [(dest, files)]
setup(
name='triton',
version='0.1',
@@ -85,9 +95,9 @@ setup(
author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton',
'triton/ops'],
ext_modules=[CMakeExtension('triton')],
packages=['triton', 'triton/ops'],
data_files=data,
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False,
)

View File

@@ -6,7 +6,7 @@ import triton.ops
# clean-up libtriton resources
import atexit
import libtriton
import triton._C.libtriton as libtriton
@atexit.register
def cleanup():
libtriton.cleanup()

View File

@@ -1,6 +1,6 @@
import sys
import os
import libtriton
import triton._C.libtriton as libtriton
torch = None
tensorflow = None

View File

@@ -12,7 +12,7 @@ import setuptools
# triton
import triton.frameworks as fw
import triton.utils
import libtriton
import triton._C.libtriton as libtriton
def _make_framework_src(src, out, grid):
if fw.has_tensorflow():

View File

@@ -1,5 +1,5 @@
import triton.frameworks as fw
import libtriton
import triton._C.libtriton as libtriton
def cdiv(a, b):
return -(-a // b)