[pytorch] clean-up of dynamic framework load

This commit is contained in:
Philippe Tillet
2019-09-05 02:16:27 -04:00
parent 65133cdf33
commit 44896ee777
5 changed files with 53 additions and 70 deletions

View File

@@ -14,10 +14,10 @@ import triton.frameworks as fw
import triton.utils
import libtriton
def _make_framework_src(src, out, grid, framework):
if framework == fw.tensorflow_id:
def _make_framework_src(src, out, grid):
if fw.has_tensorflow():
return libtriton.make_tensorflow_src(src, out, grid)
elif framework == fw.torch_id:
elif fw.has_torch:
return libtriton.make_torch_src(src, out, grid)
else:
assert False
@@ -32,10 +32,16 @@ def _make_cache_path(src):
os.makedirs(cachepath)
return cachepath
def _write_bindings(src, root, framework):
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
def _write_bindings(src, root):
if fw.has_tensorflow():
name = 'tensorflow'
elif fw.has_torch():
name = 'torch'
else:
assert False
cpp = os.path.join(root, '{name}.cpp'.format(name=name))
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
@@ -50,7 +56,7 @@ def _write_bindings(src, root, framework):
# return path of cpp file
return (cpp, so)
def _build(src, path, framework):
def _build(src, path):
# include directories
triton_include_dirs = ['/home/philippe/development/triton/include']
include_dirs = triton_include_dirs
@@ -61,14 +67,15 @@ def _build(src, path, framework):
libraries = ['triton']
# add framework
extra_compile_args = []
if framework == fw.tensorflow_id:
if fw.has_tensorflow():
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
elif framework == fw.torch_id:
name = 'tensorflow'
elif fw.has_torch():
prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += ['/usr/local/cuda/include/',
@@ -79,6 +86,7 @@ def _build(src, path, framework):
libraries += ['torch']
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
name = 'torch'
else:
assert False
# extra arguments
@@ -87,7 +95,7 @@ def _build(src, path, framework):
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = fw.to_str(framework),
name = name,
language = 'c++',
sources = [src],
include_dirs = include_dirs,
@@ -104,19 +112,19 @@ def _build(src, path, framework):
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = 'tensorflow',
name = name,
ext_modules = [ext],
script_args = args,
)
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj, framework):
def _cvt_to_def_str(obj):
# bool
if isinstance(obj, bool):
return str(int(obj))
# tensorflow type
if framework == fw.tensorflow_id:
if fw.has_tensorflow():
if isinstance(obj, fw.tensorflow.DType):
return {fw.tensorflow.int8: 'char',
fw.tensorflow.int16: 'short',
@@ -126,7 +134,7 @@ def _cvt_to_def_str(obj, framework):
fw.tensorflow.float32: 'float',
fw.tensorflow.float64: 'double'}[obj]
# torch type
elif framework == fw.torch_id:
elif fw.has_torch():
if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char',
fw.torch.int16: 'short',
@@ -141,14 +149,14 @@ def _cvt_to_def_str(obj, framework):
return str(obj)
def _make_framework_op(src, outputs, options, framework):
src, name = _make_framework_src(src, outputs, options, framework)
def _make_framework_op(src, outputs, options):
src, name = _make_framework_src(src, outputs, options)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path, framework)
_build(cpp, cache_path, framework)
if framework == fw.tensorflow_id:
cpp, so = _write_bindings(src, cache_path)
_build(cpp, cache_path)
if fw.has_tensorflow():
return fw.tensorflow.load_op_library(so).__dict__[name]
elif framework == fw.torch_id:
elif fw.has_torch():
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
else:
@@ -168,22 +176,18 @@ def _make_grid(args) :
class kernel:
def __init__(self, src, outputs, framework = None):
def __init__(self, src, outputs):
self.fw_id = dict()
self.fw_grids = dict()
self.fw_op = None
self.src = src
self.outputs = outputs
self.framework = framework
def _init_framework(self):
if self.framework is not None:
return
self.framework = fw._find_framework(self.framework)
if self.framework == fw.tensorflow_id:
if fw.has_tensorflow():
fw._import_tensorflow()
fw._import_tf_extra_ops()
elif self.framework == fw.torch_id:
elif fw.has_torch():
fw._import_torch()
else:
assert False
@@ -196,7 +200,7 @@ class kernel:
# code generation options
defines = []
for k, v in kwargs.items():
cvt = lambda x: _cvt_to_def_str(x, self.framework)
cvt = lambda x: _cvt_to_def_str(x)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
@@ -211,7 +215,7 @@ class kernel:
# register function
libtriton.register_fn(op_id, self.src, opt)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework)
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
# retrieve framework op
op_id = self.fw_id[key]
@@ -220,9 +224,9 @@ class kernel:
# create operands
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function
if self.framework == fw.tensorflow_id:
if fw.has_tensorflow():
return self.fw_op(*op_args, id=op_id)
elif self.framework == fw.torch_id:
elif fw.has_torch():
return self.fw_op(op_id, *op_args)
else:
assert False