[pytorch] clean-up of dynamic framework load
This commit is contained in:
@@ -34,5 +34,6 @@ def run_torch():
|
|||||||
b = th.randn(K, N).cuda()
|
b = th.randn(K, N).cuda()
|
||||||
th_c = th.matmul(a, b)
|
th_c = th.matmul(a, b)
|
||||||
tr_c = triton.ops.dot(a, b)
|
tr_c = triton.ops.dot(a, b)
|
||||||
|
print(tr_c)
|
||||||
|
|
||||||
run_torch()
|
run_torch()
|
@@ -2,21 +2,10 @@ import sys
|
|||||||
import os
|
import os
|
||||||
import libtriton
|
import libtriton
|
||||||
|
|
||||||
torch_id = 'torch'
|
|
||||||
tensorflow_id = 'tensorflow'
|
|
||||||
|
|
||||||
torch = None
|
torch = None
|
||||||
tensorflow = None
|
tensorflow = None
|
||||||
tf_extra_ops = None
|
tf_extra_ops = None
|
||||||
|
|
||||||
def to_str(framework):
|
|
||||||
if framework == tensorflow_id:
|
|
||||||
return 'tensorflow'
|
|
||||||
elif framework == torch_id:
|
|
||||||
return 'torch'
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
def _import_torch():
|
def _import_torch():
|
||||||
global torch
|
global torch
|
||||||
if torch is None:
|
if torch is None:
|
||||||
@@ -35,19 +24,8 @@ def _import_tf_extra_ops():
|
|||||||
_import_tensorflow()
|
_import_tensorflow()
|
||||||
tf_extra_ops = tensorflow.load_op_library(path)
|
tf_extra_ops = tensorflow.load_op_library(path)
|
||||||
|
|
||||||
|
def has_tensorflow():
|
||||||
|
return 'tensorflow' in sys.modules
|
||||||
|
|
||||||
def _find_framework(default = None):
|
def has_torch():
|
||||||
is_tf_imported = 'tensorflow' in sys.modules
|
return 'torch' in sys.modules
|
||||||
is_torch_imported = 'torch' in sys.modules
|
|
||||||
if default:
|
|
||||||
if default not in [tensorflow_id, torch_id]:
|
|
||||||
raise ValueError('unsupported framework')
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
elif is_tf_imported and not is_torch_imported:
|
|
||||||
return tensorflow_id
|
|
||||||
elif is_torch_imported and not is_tf_imported:
|
|
||||||
return torch_id
|
|
||||||
else:
|
|
||||||
raise ValueError('cannot determine imported framework, '
|
|
||||||
'please provide framework argument')
|
|
@@ -14,7 +14,6 @@ class function_meta(type):
|
|||||||
def __init__(cls, name, bases, attrs):
|
def __init__(cls, name, bases, attrs):
|
||||||
cls.contexts = dict()
|
cls.contexts = dict()
|
||||||
cls.registered = False
|
cls.registered = False
|
||||||
cls.framework = None
|
|
||||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||||
|
|
||||||
class function(metaclass = function_meta):
|
class function(metaclass = function_meta):
|
||||||
@@ -59,8 +58,9 @@ class function(metaclass = function_meta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, *args, **kwargs):
|
def apply(cls, *args, **kwargs):
|
||||||
cls.framework = fw._find_framework(cls.framework)
|
if fw.has_tensorflow():
|
||||||
if cls.framework == fw.tensorflow_id:
|
|
||||||
return cls.apply_tensorflow(*args, **kwargs)
|
return cls.apply_tensorflow(*args, **kwargs)
|
||||||
else:
|
elif fw.has_torch():
|
||||||
return cls.apply_torch(*args, **kwargs)
|
return cls.apply_torch(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
@@ -14,10 +14,10 @@ import triton.frameworks as fw
|
|||||||
import triton.utils
|
import triton.utils
|
||||||
import libtriton
|
import libtriton
|
||||||
|
|
||||||
def _make_framework_src(src, out, grid, framework):
|
def _make_framework_src(src, out, grid):
|
||||||
if framework == fw.tensorflow_id:
|
if fw.has_tensorflow():
|
||||||
return libtriton.make_tensorflow_src(src, out, grid)
|
return libtriton.make_tensorflow_src(src, out, grid)
|
||||||
elif framework == fw.torch_id:
|
elif fw.has_torch:
|
||||||
return libtriton.make_torch_src(src, out, grid)
|
return libtriton.make_torch_src(src, out, grid)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
@@ -32,10 +32,16 @@ def _make_cache_path(src):
|
|||||||
os.makedirs(cachepath)
|
os.makedirs(cachepath)
|
||||||
return cachepath
|
return cachepath
|
||||||
|
|
||||||
def _write_bindings(src, root, framework):
|
def _write_bindings(src, root):
|
||||||
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
|
if fw.has_tensorflow():
|
||||||
|
name = 'tensorflow'
|
||||||
|
elif fw.has_torch():
|
||||||
|
name = 'torch'
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
cpp = os.path.join(root, '{name}.cpp'.format(name=name))
|
||||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||||
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
|
so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||||
recompile = False
|
recompile = False
|
||||||
# recompile if .so does not exist
|
# recompile if .so does not exist
|
||||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||||
@@ -50,7 +56,7 @@ def _write_bindings(src, root, framework):
|
|||||||
# return path of cpp file
|
# return path of cpp file
|
||||||
return (cpp, so)
|
return (cpp, so)
|
||||||
|
|
||||||
def _build(src, path, framework):
|
def _build(src, path):
|
||||||
# include directories
|
# include directories
|
||||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||||
include_dirs = triton_include_dirs
|
include_dirs = triton_include_dirs
|
||||||
@@ -61,14 +67,15 @@ def _build(src, path, framework):
|
|||||||
libraries = ['triton']
|
libraries = ['triton']
|
||||||
# add framework
|
# add framework
|
||||||
extra_compile_args = []
|
extra_compile_args = []
|
||||||
if framework == fw.tensorflow_id:
|
if fw.has_tensorflow():
|
||||||
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
|
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
|
||||||
include_dirs += [fw.tensorflow.sysconfig.get_include()]
|
include_dirs += [fw.tensorflow.sysconfig.get_include()]
|
||||||
include_dirs += ['/usr/local/cuda/include/']
|
include_dirs += ['/usr/local/cuda/include/']
|
||||||
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
|
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
|
||||||
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
|
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
|
||||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||||
elif framework == fw.torch_id:
|
name = 'tensorflow'
|
||||||
|
elif fw.has_torch():
|
||||||
prefix = os.path.dirname(fw.torch.__file__)
|
prefix = os.path.dirname(fw.torch.__file__)
|
||||||
library_dirs += [os.path.join(prefix, 'lib')]
|
library_dirs += [os.path.join(prefix, 'lib')]
|
||||||
include_dirs += ['/usr/local/cuda/include/',
|
include_dirs += ['/usr/local/cuda/include/',
|
||||||
@@ -79,6 +86,7 @@ def _build(src, path, framework):
|
|||||||
libraries += ['torch']
|
libraries += ['torch']
|
||||||
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
||||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||||
|
name = 'torch'
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
# extra arguments
|
# extra arguments
|
||||||
@@ -87,7 +95,7 @@ def _build(src, path, framework):
|
|||||||
depends = [os.path.realpath(libtriton.__file__)]
|
depends = [os.path.realpath(libtriton.__file__)]
|
||||||
# create extension module
|
# create extension module
|
||||||
ext = setuptools.Extension(
|
ext = setuptools.Extension(
|
||||||
name = fw.to_str(framework),
|
name = name,
|
||||||
language = 'c++',
|
language = 'c++',
|
||||||
sources = [src],
|
sources = [src],
|
||||||
include_dirs = include_dirs,
|
include_dirs = include_dirs,
|
||||||
@@ -104,19 +112,19 @@ def _build(src, path, framework):
|
|||||||
args.append('--build-lib=' + path)
|
args.append('--build-lib=' + path)
|
||||||
args.append('-q')
|
args.append('-q')
|
||||||
args = dict(
|
args = dict(
|
||||||
name = 'tensorflow',
|
name = name,
|
||||||
ext_modules = [ext],
|
ext_modules = [ext],
|
||||||
script_args = args,
|
script_args = args,
|
||||||
)
|
)
|
||||||
setuptools.setup(**args)
|
setuptools.setup(**args)
|
||||||
shutil.rmtree(tmp)
|
shutil.rmtree(tmp)
|
||||||
|
|
||||||
def _cvt_to_def_str(obj, framework):
|
def _cvt_to_def_str(obj):
|
||||||
# bool
|
# bool
|
||||||
if isinstance(obj, bool):
|
if isinstance(obj, bool):
|
||||||
return str(int(obj))
|
return str(int(obj))
|
||||||
# tensorflow type
|
# tensorflow type
|
||||||
if framework == fw.tensorflow_id:
|
if fw.has_tensorflow():
|
||||||
if isinstance(obj, fw.tensorflow.DType):
|
if isinstance(obj, fw.tensorflow.DType):
|
||||||
return {fw.tensorflow.int8: 'char',
|
return {fw.tensorflow.int8: 'char',
|
||||||
fw.tensorflow.int16: 'short',
|
fw.tensorflow.int16: 'short',
|
||||||
@@ -126,7 +134,7 @@ def _cvt_to_def_str(obj, framework):
|
|||||||
fw.tensorflow.float32: 'float',
|
fw.tensorflow.float32: 'float',
|
||||||
fw.tensorflow.float64: 'double'}[obj]
|
fw.tensorflow.float64: 'double'}[obj]
|
||||||
# torch type
|
# torch type
|
||||||
elif framework == fw.torch_id:
|
elif fw.has_torch():
|
||||||
if isinstance(obj, fw.torch.dtype):
|
if isinstance(obj, fw.torch.dtype):
|
||||||
return {fw.torch.int8: 'char',
|
return {fw.torch.int8: 'char',
|
||||||
fw.torch.int16: 'short',
|
fw.torch.int16: 'short',
|
||||||
@@ -141,14 +149,14 @@ def _cvt_to_def_str(obj, framework):
|
|||||||
return str(obj)
|
return str(obj)
|
||||||
|
|
||||||
|
|
||||||
def _make_framework_op(src, outputs, options, framework):
|
def _make_framework_op(src, outputs, options):
|
||||||
src, name = _make_framework_src(src, outputs, options, framework)
|
src, name = _make_framework_src(src, outputs, options)
|
||||||
cache_path = _make_cache_path(src)
|
cache_path = _make_cache_path(src)
|
||||||
cpp, so = _write_bindings(src, cache_path, framework)
|
cpp, so = _write_bindings(src, cache_path)
|
||||||
_build(cpp, cache_path, framework)
|
_build(cpp, cache_path)
|
||||||
if framework == fw.tensorflow_id:
|
if fw.has_tensorflow():
|
||||||
return fw.tensorflow.load_op_library(so).__dict__[name]
|
return fw.tensorflow.load_op_library(so).__dict__[name]
|
||||||
elif framework == fw.torch_id:
|
elif fw.has_torch():
|
||||||
fw.torch.ops.load_library(so)
|
fw.torch.ops.load_library(so)
|
||||||
return getattr(fw.torch.ops.triton, name)
|
return getattr(fw.torch.ops.triton, name)
|
||||||
else:
|
else:
|
||||||
@@ -168,22 +176,18 @@ def _make_grid(args) :
|
|||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
|
||||||
def __init__(self, src, outputs, framework = None):
|
def __init__(self, src, outputs):
|
||||||
self.fw_id = dict()
|
self.fw_id = dict()
|
||||||
self.fw_grids = dict()
|
self.fw_grids = dict()
|
||||||
self.fw_op = None
|
self.fw_op = None
|
||||||
self.src = src
|
self.src = src
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
self.framework = framework
|
|
||||||
|
|
||||||
def _init_framework(self):
|
def _init_framework(self):
|
||||||
if self.framework is not None:
|
if fw.has_tensorflow():
|
||||||
return
|
|
||||||
self.framework = fw._find_framework(self.framework)
|
|
||||||
if self.framework == fw.tensorflow_id:
|
|
||||||
fw._import_tensorflow()
|
fw._import_tensorflow()
|
||||||
fw._import_tf_extra_ops()
|
fw._import_tf_extra_ops()
|
||||||
elif self.framework == fw.torch_id:
|
elif fw.has_torch():
|
||||||
fw._import_torch()
|
fw._import_torch()
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
@@ -196,7 +200,7 @@ class kernel:
|
|||||||
# code generation options
|
# code generation options
|
||||||
defines = []
|
defines = []
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
cvt = lambda x: _cvt_to_def_str(x, self.framework)
|
cvt = lambda x: _cvt_to_def_str(x)
|
||||||
if(isinstance(v, list)):
|
if(isinstance(v, list)):
|
||||||
values = list(map(cvt, v))
|
values = list(map(cvt, v))
|
||||||
else:
|
else:
|
||||||
@@ -211,7 +215,7 @@ class kernel:
|
|||||||
# register function
|
# register function
|
||||||
libtriton.register_fn(op_id, self.src, opt)
|
libtriton.register_fn(op_id, self.src, opt)
|
||||||
if self.fw_op is None:
|
if self.fw_op is None:
|
||||||
self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework)
|
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
|
||||||
|
|
||||||
# retrieve framework op
|
# retrieve framework op
|
||||||
op_id = self.fw_id[key]
|
op_id = self.fw_id[key]
|
||||||
@@ -220,9 +224,9 @@ class kernel:
|
|||||||
# create operands
|
# create operands
|
||||||
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
||||||
# call framework function
|
# call framework function
|
||||||
if self.framework == fw.tensorflow_id:
|
if fw.has_tensorflow():
|
||||||
return self.fw_op(*op_args, id=op_id)
|
return self.fw_op(*op_args, id=op_id)
|
||||||
elif self.framework == fw.torch_id:
|
elif fw.has_torch():
|
||||||
return self.fw_op(op_id, *op_args)
|
return self.fw_op(op_id, *op_args)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
@@ -4,14 +4,13 @@ import libtriton
|
|||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return -(-a // b)
|
return -(-a // b)
|
||||||
|
|
||||||
def empty(shapes, dtype, framework = None):
|
def empty(shapes, dtype):
|
||||||
framework = fw._find_framework(framework)
|
if fw.has_tensorflow():
|
||||||
if framework == fw.tensorflow_id:
|
|
||||||
fw._import_tensorflow()
|
fw._import_tensorflow()
|
||||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||||
args = fw.tensorflow.stack(args)
|
args = fw.tensorflow.stack(args)
|
||||||
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||||
elif framework == fw.torch_id:
|
elif fw.has_torch():
|
||||||
fw._import_torch()
|
fw._import_torch()
|
||||||
return fw.torch.empty(*shapes).cuda()
|
return fw.torch.empty(*shapes).cuda()
|
||||||
|
|
||||||
@@ -23,18 +22,19 @@ class lazy_shape:
|
|||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return scalar(self.shape[key])
|
return scalar(self.shape[key])
|
||||||
|
|
||||||
def shape(A, framework = None) :
|
def shape(A) :
|
||||||
framework = fw._find_framework(framework)
|
if fw.has_tensorflow():
|
||||||
if framework == fw.tensorflow_id:
|
|
||||||
fw._import_tensorflow()
|
fw._import_tensorflow()
|
||||||
return lazy_shape(fw.tensorflow.shape(A))
|
return lazy_shape(fw.tensorflow.shape(A))
|
||||||
else:
|
elif fw.has_torch():
|
||||||
return A.shape
|
return A.shape
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
class scalar:
|
class scalar:
|
||||||
|
|
||||||
def __init__(self, x, framework = None):
|
def __init__(self, x):
|
||||||
self.id = libtriton.make_scalar_id()
|
self.id = libtriton.make_scalar_id()
|
||||||
fw._import_tf_extra_ops()
|
fw._import_tf_extra_ops()
|
||||||
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)
|
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)
|
||||||
|
Reference in New Issue
Block a user