[pytorch] clean-up of dynamic framework load

This commit is contained in:
Philippe Tillet
2019-09-05 02:16:27 -04:00
parent 65133cdf33
commit 44896ee777
5 changed files with 53 additions and 70 deletions

View File

@@ -34,5 +34,6 @@ def run_torch():
b = th.randn(K, N).cuda()
th_c = th.matmul(a, b)
tr_c = triton.ops.dot(a, b)
print(tr_c)
run_torch()

View File

@@ -2,21 +2,10 @@ import sys
import os
import libtriton
torch_id = 'torch'
tensorflow_id = 'tensorflow'
torch = None
tensorflow = None
tf_extra_ops = None
def to_str(framework):
if framework == tensorflow_id:
return 'tensorflow'
elif framework == torch_id:
return 'torch'
else:
assert False
def _import_torch():
global torch
if torch is None:
@@ -35,19 +24,8 @@ def _import_tf_extra_ops():
_import_tensorflow()
tf_extra_ops = tensorflow.load_op_library(path)
def has_tensorflow():
return 'tensorflow' in sys.modules
def _find_framework(default = None):
is_tf_imported = 'tensorflow' in sys.modules
is_torch_imported = 'torch' in sys.modules
if default:
if default not in [tensorflow_id, torch_id]:
raise ValueError('unsupported framework')
else:
return default
elif is_tf_imported and not is_torch_imported:
return tensorflow_id
elif is_torch_imported and not is_tf_imported:
return torch_id
else:
raise ValueError('cannot determine imported framework, '
'please provide framework argument')
def has_torch():
return 'torch' in sys.modules

View File

@@ -14,7 +14,6 @@ class function_meta(type):
def __init__(cls, name, bases, attrs):
cls.contexts = dict()
cls.registered = False
cls.framework = None
return super(function_meta, cls).__init__(name, bases, attrs)
class function(metaclass = function_meta):
@@ -59,8 +58,9 @@ class function(metaclass = function_meta):
@classmethod
def apply(cls, *args, **kwargs):
cls.framework = fw._find_framework(cls.framework)
if cls.framework == fw.tensorflow_id:
if fw.has_tensorflow():
return cls.apply_tensorflow(*args, **kwargs)
else:
elif fw.has_torch():
return cls.apply_torch(*args, **kwargs)
else:
assert False

View File

@@ -14,10 +14,10 @@ import triton.frameworks as fw
import triton.utils
import libtriton
def _make_framework_src(src, out, grid, framework):
if framework == fw.tensorflow_id:
def _make_framework_src(src, out, grid):
if fw.has_tensorflow():
return libtriton.make_tensorflow_src(src, out, grid)
elif framework == fw.torch_id:
elif fw.has_torch:
return libtriton.make_torch_src(src, out, grid)
else:
assert False
@@ -32,10 +32,16 @@ def _make_cache_path(src):
os.makedirs(cachepath)
return cachepath
def _write_bindings(src, root, framework):
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
def _write_bindings(src, root):
if fw.has_tensorflow():
name = 'tensorflow'
elif fw.has_torch():
name = 'torch'
else:
assert False
cpp = os.path.join(root, '{name}.cpp'.format(name=name))
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
@@ -50,7 +56,7 @@ def _write_bindings(src, root, framework):
# return path of cpp file
return (cpp, so)
def _build(src, path, framework):
def _build(src, path):
# include directories
triton_include_dirs = ['/home/philippe/development/triton/include']
include_dirs = triton_include_dirs
@@ -61,14 +67,15 @@ def _build(src, path, framework):
libraries = ['triton']
# add framework
extra_compile_args = []
if framework == fw.tensorflow_id:
if fw.has_tensorflow():
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
elif framework == fw.torch_id:
name = 'tensorflow'
elif fw.has_torch():
prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += ['/usr/local/cuda/include/',
@@ -79,6 +86,7 @@ def _build(src, path, framework):
libraries += ['torch']
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
name = 'torch'
else:
assert False
# extra arguments
@@ -87,7 +95,7 @@ def _build(src, path, framework):
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = fw.to_str(framework),
name = name,
language = 'c++',
sources = [src],
include_dirs = include_dirs,
@@ -104,19 +112,19 @@ def _build(src, path, framework):
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = 'tensorflow',
name = name,
ext_modules = [ext],
script_args = args,
)
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj, framework):
def _cvt_to_def_str(obj):
# bool
if isinstance(obj, bool):
return str(int(obj))
# tensorflow type
if framework == fw.tensorflow_id:
if fw.has_tensorflow():
if isinstance(obj, fw.tensorflow.DType):
return {fw.tensorflow.int8: 'char',
fw.tensorflow.int16: 'short',
@@ -126,7 +134,7 @@ def _cvt_to_def_str(obj, framework):
fw.tensorflow.float32: 'float',
fw.tensorflow.float64: 'double'}[obj]
# torch type
elif framework == fw.torch_id:
elif fw.has_torch():
if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char',
fw.torch.int16: 'short',
@@ -141,14 +149,14 @@ def _cvt_to_def_str(obj, framework):
return str(obj)
def _make_framework_op(src, outputs, options, framework):
src, name = _make_framework_src(src, outputs, options, framework)
def _make_framework_op(src, outputs, options):
src, name = _make_framework_src(src, outputs, options)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path, framework)
_build(cpp, cache_path, framework)
if framework == fw.tensorflow_id:
cpp, so = _write_bindings(src, cache_path)
_build(cpp, cache_path)
if fw.has_tensorflow():
return fw.tensorflow.load_op_library(so).__dict__[name]
elif framework == fw.torch_id:
elif fw.has_torch():
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
else:
@@ -168,22 +176,18 @@ def _make_grid(args) :
class kernel:
def __init__(self, src, outputs, framework = None):
def __init__(self, src, outputs):
self.fw_id = dict()
self.fw_grids = dict()
self.fw_op = None
self.src = src
self.outputs = outputs
self.framework = framework
def _init_framework(self):
if self.framework is not None:
return
self.framework = fw._find_framework(self.framework)
if self.framework == fw.tensorflow_id:
if fw.has_tensorflow():
fw._import_tensorflow()
fw._import_tf_extra_ops()
elif self.framework == fw.torch_id:
elif fw.has_torch():
fw._import_torch()
else:
assert False
@@ -196,7 +200,7 @@ class kernel:
# code generation options
defines = []
for k, v in kwargs.items():
cvt = lambda x: _cvt_to_def_str(x, self.framework)
cvt = lambda x: _cvt_to_def_str(x)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
@@ -211,7 +215,7 @@ class kernel:
# register function
libtriton.register_fn(op_id, self.src, opt)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework)
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
# retrieve framework op
op_id = self.fw_id[key]
@@ -220,9 +224,9 @@ class kernel:
# create operands
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function
if self.framework == fw.tensorflow_id:
if fw.has_tensorflow():
return self.fw_op(*op_args, id=op_id)
elif self.framework == fw.torch_id:
elif fw.has_torch():
return self.fw_op(op_id, *op_args)
else:
assert False

View File

@@ -4,14 +4,13 @@ import libtriton
def cdiv(a, b):
return -(-a // b)
def empty(shapes, dtype, framework = None):
framework = fw._find_framework(framework)
if framework == fw.tensorflow_id:
def empty(shapes, dtype):
if fw.has_tensorflow():
fw._import_tensorflow()
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = fw.tensorflow.stack(args)
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif framework == fw.torch_id:
elif fw.has_torch():
fw._import_torch()
return fw.torch.empty(*shapes).cuda()
@@ -23,18 +22,19 @@ class lazy_shape:
def __getitem__(self, key):
return scalar(self.shape[key])
def shape(A, framework = None) :
framework = fw._find_framework(framework)
if framework == fw.tensorflow_id:
def shape(A) :
if fw.has_tensorflow():
fw._import_tensorflow()
return lazy_shape(fw.tensorflow.shape(A))
else:
elif fw.has_torch():
return A.shape
else:
assert False
class scalar:
def __init__(self, x, framework = None):
def __init__(self, x):
self.id = libtriton.make_scalar_id()
fw._import_tf_extra_ops()
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)