[pytorch] clean-up of dynamic framework load
This commit is contained in:
@@ -34,5 +34,6 @@ def run_torch():
|
||||
b = th.randn(K, N).cuda()
|
||||
th_c = th.matmul(a, b)
|
||||
tr_c = triton.ops.dot(a, b)
|
||||
print(tr_c)
|
||||
|
||||
run_torch()
|
@@ -2,21 +2,10 @@ import sys
|
||||
import os
|
||||
import libtriton
|
||||
|
||||
torch_id = 'torch'
|
||||
tensorflow_id = 'tensorflow'
|
||||
|
||||
torch = None
|
||||
tensorflow = None
|
||||
tf_extra_ops = None
|
||||
|
||||
def to_str(framework):
|
||||
if framework == tensorflow_id:
|
||||
return 'tensorflow'
|
||||
elif framework == torch_id:
|
||||
return 'torch'
|
||||
else:
|
||||
assert False
|
||||
|
||||
def _import_torch():
|
||||
global torch
|
||||
if torch is None:
|
||||
@@ -35,19 +24,8 @@ def _import_tf_extra_ops():
|
||||
_import_tensorflow()
|
||||
tf_extra_ops = tensorflow.load_op_library(path)
|
||||
|
||||
def has_tensorflow():
|
||||
return 'tensorflow' in sys.modules
|
||||
|
||||
def _find_framework(default = None):
|
||||
is_tf_imported = 'tensorflow' in sys.modules
|
||||
is_torch_imported = 'torch' in sys.modules
|
||||
if default:
|
||||
if default not in [tensorflow_id, torch_id]:
|
||||
raise ValueError('unsupported framework')
|
||||
else:
|
||||
return default
|
||||
elif is_tf_imported and not is_torch_imported:
|
||||
return tensorflow_id
|
||||
elif is_torch_imported and not is_tf_imported:
|
||||
return torch_id
|
||||
else:
|
||||
raise ValueError('cannot determine imported framework, '
|
||||
'please provide framework argument')
|
||||
def has_torch():
|
||||
return 'torch' in sys.modules
|
@@ -14,7 +14,6 @@ class function_meta(type):
|
||||
def __init__(cls, name, bases, attrs):
|
||||
cls.contexts = dict()
|
||||
cls.registered = False
|
||||
cls.framework = None
|
||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||
|
||||
class function(metaclass = function_meta):
|
||||
@@ -59,8 +58,9 @@ class function(metaclass = function_meta):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kwargs):
|
||||
cls.framework = fw._find_framework(cls.framework)
|
||||
if cls.framework == fw.tensorflow_id:
|
||||
if fw.has_tensorflow():
|
||||
return cls.apply_tensorflow(*args, **kwargs)
|
||||
else:
|
||||
elif fw.has_torch():
|
||||
return cls.apply_torch(*args, **kwargs)
|
||||
else:
|
||||
assert False
|
||||
|
@@ -14,10 +14,10 @@ import triton.frameworks as fw
|
||||
import triton.utils
|
||||
import libtriton
|
||||
|
||||
def _make_framework_src(src, out, grid, framework):
|
||||
if framework == fw.tensorflow_id:
|
||||
def _make_framework_src(src, out, grid):
|
||||
if fw.has_tensorflow():
|
||||
return libtriton.make_tensorflow_src(src, out, grid)
|
||||
elif framework == fw.torch_id:
|
||||
elif fw.has_torch:
|
||||
return libtriton.make_torch_src(src, out, grid)
|
||||
else:
|
||||
assert False
|
||||
@@ -32,10 +32,16 @@ def _make_cache_path(src):
|
||||
os.makedirs(cachepath)
|
||||
return cachepath
|
||||
|
||||
def _write_bindings(src, root, framework):
|
||||
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
|
||||
def _write_bindings(src, root):
|
||||
if fw.has_tensorflow():
|
||||
name = 'tensorflow'
|
||||
elif fw.has_torch():
|
||||
name = 'torch'
|
||||
else:
|
||||
assert False
|
||||
cpp = os.path.join(root, '{name}.cpp'.format(name=name))
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
|
||||
so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
recompile = False
|
||||
# recompile if .so does not exist
|
||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||
@@ -50,7 +56,7 @@ def _write_bindings(src, root, framework):
|
||||
# return path of cpp file
|
||||
return (cpp, so)
|
||||
|
||||
def _build(src, path, framework):
|
||||
def _build(src, path):
|
||||
# include directories
|
||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||
include_dirs = triton_include_dirs
|
||||
@@ -61,14 +67,15 @@ def _build(src, path, framework):
|
||||
libraries = ['triton']
|
||||
# add framework
|
||||
extra_compile_args = []
|
||||
if framework == fw.tensorflow_id:
|
||||
if fw.has_tensorflow():
|
||||
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
|
||||
include_dirs += [fw.tensorflow.sysconfig.get_include()]
|
||||
include_dirs += ['/usr/local/cuda/include/']
|
||||
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
|
||||
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||
elif framework == fw.torch_id:
|
||||
name = 'tensorflow'
|
||||
elif fw.has_torch():
|
||||
prefix = os.path.dirname(fw.torch.__file__)
|
||||
library_dirs += [os.path.join(prefix, 'lib')]
|
||||
include_dirs += ['/usr/local/cuda/include/',
|
||||
@@ -79,6 +86,7 @@ def _build(src, path, framework):
|
||||
libraries += ['torch']
|
||||
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||
name = 'torch'
|
||||
else:
|
||||
assert False
|
||||
# extra arguments
|
||||
@@ -87,7 +95,7 @@ def _build(src, path, framework):
|
||||
depends = [os.path.realpath(libtriton.__file__)]
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name = fw.to_str(framework),
|
||||
name = name,
|
||||
language = 'c++',
|
||||
sources = [src],
|
||||
include_dirs = include_dirs,
|
||||
@@ -104,19 +112,19 @@ def _build(src, path, framework):
|
||||
args.append('--build-lib=' + path)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name = 'tensorflow',
|
||||
name = name,
|
||||
ext_modules = [ext],
|
||||
script_args = args,
|
||||
)
|
||||
setuptools.setup(**args)
|
||||
shutil.rmtree(tmp)
|
||||
|
||||
def _cvt_to_def_str(obj, framework):
|
||||
def _cvt_to_def_str(obj):
|
||||
# bool
|
||||
if isinstance(obj, bool):
|
||||
return str(int(obj))
|
||||
# tensorflow type
|
||||
if framework == fw.tensorflow_id:
|
||||
if fw.has_tensorflow():
|
||||
if isinstance(obj, fw.tensorflow.DType):
|
||||
return {fw.tensorflow.int8: 'char',
|
||||
fw.tensorflow.int16: 'short',
|
||||
@@ -126,7 +134,7 @@ def _cvt_to_def_str(obj, framework):
|
||||
fw.tensorflow.float32: 'float',
|
||||
fw.tensorflow.float64: 'double'}[obj]
|
||||
# torch type
|
||||
elif framework == fw.torch_id:
|
||||
elif fw.has_torch():
|
||||
if isinstance(obj, fw.torch.dtype):
|
||||
return {fw.torch.int8: 'char',
|
||||
fw.torch.int16: 'short',
|
||||
@@ -141,14 +149,14 @@ def _cvt_to_def_str(obj, framework):
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _make_framework_op(src, outputs, options, framework):
|
||||
src, name = _make_framework_src(src, outputs, options, framework)
|
||||
def _make_framework_op(src, outputs, options):
|
||||
src, name = _make_framework_src(src, outputs, options)
|
||||
cache_path = _make_cache_path(src)
|
||||
cpp, so = _write_bindings(src, cache_path, framework)
|
||||
_build(cpp, cache_path, framework)
|
||||
if framework == fw.tensorflow_id:
|
||||
cpp, so = _write_bindings(src, cache_path)
|
||||
_build(cpp, cache_path)
|
||||
if fw.has_tensorflow():
|
||||
return fw.tensorflow.load_op_library(so).__dict__[name]
|
||||
elif framework == fw.torch_id:
|
||||
elif fw.has_torch():
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
else:
|
||||
@@ -168,22 +176,18 @@ def _make_grid(args) :
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, outputs, framework = None):
|
||||
def __init__(self, src, outputs):
|
||||
self.fw_id = dict()
|
||||
self.fw_grids = dict()
|
||||
self.fw_op = None
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
self.framework = framework
|
||||
|
||||
def _init_framework(self):
|
||||
if self.framework is not None:
|
||||
return
|
||||
self.framework = fw._find_framework(self.framework)
|
||||
if self.framework == fw.tensorflow_id:
|
||||
if fw.has_tensorflow():
|
||||
fw._import_tensorflow()
|
||||
fw._import_tf_extra_ops()
|
||||
elif self.framework == fw.torch_id:
|
||||
elif fw.has_torch():
|
||||
fw._import_torch()
|
||||
else:
|
||||
assert False
|
||||
@@ -196,7 +200,7 @@ class kernel:
|
||||
# code generation options
|
||||
defines = []
|
||||
for k, v in kwargs.items():
|
||||
cvt = lambda x: _cvt_to_def_str(x, self.framework)
|
||||
cvt = lambda x: _cvt_to_def_str(x)
|
||||
if(isinstance(v, list)):
|
||||
values = list(map(cvt, v))
|
||||
else:
|
||||
@@ -211,7 +215,7 @@ class kernel:
|
||||
# register function
|
||||
libtriton.register_fn(op_id, self.src, opt)
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework)
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
|
||||
|
||||
# retrieve framework op
|
||||
op_id = self.fw_id[key]
|
||||
@@ -220,9 +224,9 @@ class kernel:
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
||||
# call framework function
|
||||
if self.framework == fw.tensorflow_id:
|
||||
if fw.has_tensorflow():
|
||||
return self.fw_op(*op_args, id=op_id)
|
||||
elif self.framework == fw.torch_id:
|
||||
elif fw.has_torch():
|
||||
return self.fw_op(op_id, *op_args)
|
||||
else:
|
||||
assert False
|
@@ -4,14 +4,13 @@ import libtriton
|
||||
def cdiv(a, b):
|
||||
return -(-a // b)
|
||||
|
||||
def empty(shapes, dtype, framework = None):
|
||||
framework = fw._find_framework(framework)
|
||||
if framework == fw.tensorflow_id:
|
||||
def empty(shapes, dtype):
|
||||
if fw.has_tensorflow():
|
||||
fw._import_tensorflow()
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
args = fw.tensorflow.stack(args)
|
||||
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif framework == fw.torch_id:
|
||||
elif fw.has_torch():
|
||||
fw._import_torch()
|
||||
return fw.torch.empty(*shapes).cuda()
|
||||
|
||||
@@ -23,18 +22,19 @@ class lazy_shape:
|
||||
def __getitem__(self, key):
|
||||
return scalar(self.shape[key])
|
||||
|
||||
def shape(A, framework = None) :
|
||||
framework = fw._find_framework(framework)
|
||||
if framework == fw.tensorflow_id:
|
||||
def shape(A) :
|
||||
if fw.has_tensorflow():
|
||||
fw._import_tensorflow()
|
||||
return lazy_shape(fw.tensorflow.shape(A))
|
||||
else:
|
||||
elif fw.has_torch():
|
||||
return A.shape
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class scalar:
|
||||
|
||||
def __init__(self, x, framework = None):
|
||||
def __init__(self, x):
|
||||
self.id = libtriton.make_scalar_id()
|
||||
fw._import_tf_extra_ops()
|
||||
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)
|
||||
|
Reference in New Issue
Block a user