[python] fixed various issues in pytorch supoport
This commit is contained in:
@@ -66,16 +66,19 @@ def _build(src, path, framework):
|
||||
include_dirs += [fw.tensorflow.sysconfig.get_include()]
|
||||
include_dirs += ['/usr/local/cuda/include/']
|
||||
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
|
||||
ABI = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
|
||||
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||
elif framework == fw.torch_id:
|
||||
prefix = os.path.dirname(torch.__file__)
|
||||
prefix = os.path.dirname(fw.torch.__file__)
|
||||
library_dirs += [os.path.join(prefix, 'lib')]
|
||||
include_dirs += [os.path.join(prefix, 'lib', 'include'),
|
||||
include_dirs += ['/usr/local/cuda/include/',
|
||||
os.path.join(prefix, 'lib', 'include'),
|
||||
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
|
||||
os.path.join(prefix, 'include'),
|
||||
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
|
||||
libraries += ['torch']
|
||||
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||
else:
|
||||
assert False
|
||||
# extra arguments
|
||||
@@ -84,7 +87,7 @@ def _build(src, path, framework):
|
||||
depends = [os.path.realpath(libtriton.__file__)]
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name = 'tensorflow',
|
||||
name = fw.to_str(framework),
|
||||
language = 'c++',
|
||||
sources = [src],
|
||||
include_dirs = include_dirs,
|
||||
@@ -124,14 +127,14 @@ def _cvt_to_def_str(obj, framework):
|
||||
fw.tensorflow.float64: 'double'}[obj]
|
||||
# torch type
|
||||
elif framework == fw.torch_id:
|
||||
if isinstance(obj, torch.dtype):
|
||||
return {torch.int8: 'char',
|
||||
torch.int16: 'short',
|
||||
torch.int32: 'int',
|
||||
torch.int64: 'long',
|
||||
torch.float16: 'half',
|
||||
torch.float32: 'float',
|
||||
torch.float64: 'double'}[obj]
|
||||
if isinstance(obj, fw.torch.dtype):
|
||||
return {fw.torch.int8: 'char',
|
||||
fw.torch.int16: 'short',
|
||||
fw.torch.int32: 'int',
|
||||
fw.torch.int64: 'long',
|
||||
fw.torch.float16: 'half',
|
||||
fw.torch.float32: 'float',
|
||||
fw.torch.float64: 'double'}[obj]
|
||||
else:
|
||||
assert False
|
||||
# default
|
||||
@@ -146,8 +149,8 @@ def _make_framework_op(src, outputs, options, framework):
|
||||
if framework == fw.tensorflow_id:
|
||||
return fw.tensorflow.load_op_library(so).__dict__[name]
|
||||
elif framework == fw.torch_id:
|
||||
torch.ops.load_library(so)
|
||||
return torch.ops.triton.__dict__[name]
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -171,7 +174,12 @@ class kernel:
|
||||
self.fw_op = None
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
self.framework = fw._find_framework(framework)
|
||||
self.framework = framework
|
||||
|
||||
def _init_framework(self):
|
||||
if self.framework is not None:
|
||||
return
|
||||
self.framework = fw._find_framework(self.framework)
|
||||
if self.framework == fw.tensorflow_id:
|
||||
fw._import_tensorflow()
|
||||
fw._import_tf_extra_ops()
|
||||
@@ -180,8 +188,8 @@ class kernel:
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._init_framework()
|
||||
# create a new framework op when defines are different
|
||||
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
|
||||
if key not in self.fw_id.keys():
|
||||
@@ -212,4 +220,9 @@ class kernel:
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
||||
# call framework function
|
||||
return self.fw_op(*op_args, id=op_id)
|
||||
if self.framework == fw.tensorflow_id:
|
||||
return self.fw_op(*op_args, id=op_id)
|
||||
elif self.framework == fw.torch_id:
|
||||
return self.fw_op(op_id, *op_args)
|
||||
else:
|
||||
assert False
|
Reference in New Issue
Block a user