[python] fixed various issues in pytorch supoport

This commit is contained in:
Philippe Tillet
2019-09-05 00:19:42 -04:00
parent 945b5d0de9
commit ed0f706005
8 changed files with 182 additions and 92 deletions

View File

@@ -66,16 +66,19 @@ def _build(src, path, framework):
include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
ABI = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
elif framework == fw.torch_id:
prefix = os.path.dirname(torch.__file__)
prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += [os.path.join(prefix, 'lib', 'include'),
include_dirs += ['/usr/local/cuda/include/',
os.path.join(prefix, 'lib', 'include'),
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
os.path.join(prefix, 'include'),
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
libraries += ['torch']
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
else:
assert False
# extra arguments
@@ -84,7 +87,7 @@ def _build(src, path, framework):
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = 'tensorflow',
name = fw.to_str(framework),
language = 'c++',
sources = [src],
include_dirs = include_dirs,
@@ -124,14 +127,14 @@ def _cvt_to_def_str(obj, framework):
fw.tensorflow.float64: 'double'}[obj]
# torch type
elif framework == fw.torch_id:
if isinstance(obj, torch.dtype):
return {torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double'}[obj]
if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char',
fw.torch.int16: 'short',
fw.torch.int32: 'int',
fw.torch.int64: 'long',
fw.torch.float16: 'half',
fw.torch.float32: 'float',
fw.torch.float64: 'double'}[obj]
else:
assert False
# default
@@ -146,8 +149,8 @@ def _make_framework_op(src, outputs, options, framework):
if framework == fw.tensorflow_id:
return fw.tensorflow.load_op_library(so).__dict__[name]
elif framework == fw.torch_id:
torch.ops.load_library(so)
return torch.ops.triton.__dict__[name]
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
else:
assert False
@@ -171,7 +174,12 @@ class kernel:
self.fw_op = None
self.src = src
self.outputs = outputs
self.framework = fw._find_framework(framework)
self.framework = framework
def _init_framework(self):
if self.framework is not None:
return
self.framework = fw._find_framework(self.framework)
if self.framework == fw.tensorflow_id:
fw._import_tensorflow()
fw._import_tf_extra_ops()
@@ -180,8 +188,8 @@ class kernel:
else:
assert False
def __call__(self, *args, **kwargs):
self._init_framework()
# create a new framework op when defines are different
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
if key not in self.fw_id.keys():
@@ -212,4 +220,9 @@ class kernel:
# create operands
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function
return self.fw_op(*op_args, id=op_id)
if self.framework == fw.tensorflow_id:
return self.fw_op(*op_args, id=op_id)
elif self.framework == fw.torch_id:
return self.fw_op(op_id, *op_args)
else:
assert False