[python] fixed various issues in pytorch supoport

This commit is contained in:
Philippe Tillet
2019-09-05 00:19:42 -04:00
parent 945b5d0de9
commit ed0f706005
8 changed files with 182 additions and 92 deletions

View File

@@ -41,18 +41,22 @@ class CMakeBuild(build_ext):
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
# tensorflow directories
import tensorflow as tf
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
tf_include_dirs = tf.sysconfig.get_include()
tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '')
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TESTS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
'-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(),
'-DTF_LIBS=' + tf_libs,
'-DTF_ABI=' + str(tf_abi)]
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs]
# tensorflow compatibility
try:
import tensorflow as tf
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
tf_include_dirs = tf.sysconfig.get_include()
tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '')
cmake_args += ['-DTF_INCLUDE_DIRS=' + tf_include_dirs,
'-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(),
'-DTF_LIBS=' + tf_libs,
'-DTF_ABI=' + str(tf_abi)]
except ModuleNotFoundError:
pass
cfg = 'Debug' if self.debug else 'Release'
build_args = ['--config', cfg]