[python] now packaging include and libtriton in triton._C submodule
This commit is contained in:
@@ -5,16 +5,17 @@ import sysconfig
|
||||
import platform
|
||||
import subprocess
|
||||
import distutils
|
||||
|
||||
import glob
|
||||
from distutils.version import LooseVersion
|
||||
from setuptools import setup, Extension, find_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.test import test as TestCommand
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name, sourcedir=''):
|
||||
def __init__(self, name, path, sourcedir=''):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
self.path = path
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
@@ -36,7 +37,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_extension(self, ext):
|
||||
self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# python directories
|
||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
|
||||
@@ -78,6 +79,15 @@ class CMakeBuild(build_ext):
|
||||
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
||||
directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))]
|
||||
data = []
|
||||
for d in directories:
|
||||
files = glob.glob(os.path.join(d, '*.h'), recursive=False)
|
||||
dest = os.path.relpath(d, os.path.pardir)
|
||||
dest = os.path.join('triton', '_C', dest)
|
||||
data += [(dest, files)]
|
||||
|
||||
setup(
|
||||
name='triton',
|
||||
version='0.1',
|
||||
@@ -85,9 +95,9 @@ setup(
|
||||
author_email='ptillet@g.harvard.edu',
|
||||
description='A language and compiler for custom Deep Learning operations',
|
||||
long_description='',
|
||||
packages=['triton',
|
||||
'triton/ops'],
|
||||
ext_modules=[CMakeExtension('triton')],
|
||||
packages=['triton', 'triton/ops'],
|
||||
data_files=data,
|
||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||
cmdclass=dict(build_ext=CMakeBuild),
|
||||
zip_safe=False,
|
||||
)
|
||||
|
@@ -6,7 +6,7 @@ import triton.ops
|
||||
|
||||
# clean-up libtriton resources
|
||||
import atexit
|
||||
import libtriton
|
||||
import triton._C.libtriton as libtriton
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
libtriton.cleanup()
|
@@ -1,6 +1,6 @@
|
||||
import sys
|
||||
import os
|
||||
import libtriton
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
torch = None
|
||||
tensorflow = None
|
||||
|
@@ -12,7 +12,7 @@ import setuptools
|
||||
# triton
|
||||
import triton.frameworks as fw
|
||||
import triton.utils
|
||||
import libtriton
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
def _make_framework_src(src, out, grid):
|
||||
if fw.has_tensorflow():
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import triton.frameworks as fw
|
||||
import libtriton
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
def cdiv(a, b):
|
||||
return -(-a // b)
|
||||
|
Reference in New Issue
Block a user