[pytorch] clean-up of dynamic framework load

This commit is contained in:
Philippe Tillet
2019-09-05 02:16:27 -04:00
parent 65133cdf33
commit 44896ee777
5 changed files with 53 additions and 70 deletions

View File

@@ -34,5 +34,6 @@ def run_torch():
b = th.randn(K, N).cuda() b = th.randn(K, N).cuda()
th_c = th.matmul(a, b) th_c = th.matmul(a, b)
tr_c = triton.ops.dot(a, b) tr_c = triton.ops.dot(a, b)
print(tr_c)
run_torch() run_torch()

View File

@@ -2,21 +2,10 @@ import sys
import os import os
import libtriton import libtriton
torch_id = 'torch'
tensorflow_id = 'tensorflow'
torch = None torch = None
tensorflow = None tensorflow = None
tf_extra_ops = None tf_extra_ops = None
def to_str(framework):
if framework == tensorflow_id:
return 'tensorflow'
elif framework == torch_id:
return 'torch'
else:
assert False
def _import_torch(): def _import_torch():
global torch global torch
if torch is None: if torch is None:
@@ -35,19 +24,8 @@ def _import_tf_extra_ops():
_import_tensorflow() _import_tensorflow()
tf_extra_ops = tensorflow.load_op_library(path) tf_extra_ops = tensorflow.load_op_library(path)
def has_tensorflow():
return 'tensorflow' in sys.modules
def _find_framework(default = None): def has_torch():
is_tf_imported = 'tensorflow' in sys.modules return 'torch' in sys.modules
is_torch_imported = 'torch' in sys.modules
if default:
if default not in [tensorflow_id, torch_id]:
raise ValueError('unsupported framework')
else:
return default
elif is_tf_imported and not is_torch_imported:
return tensorflow_id
elif is_torch_imported and not is_tf_imported:
return torch_id
else:
raise ValueError('cannot determine imported framework, '
'please provide framework argument')

View File

@@ -14,7 +14,6 @@ class function_meta(type):
def __init__(cls, name, bases, attrs): def __init__(cls, name, bases, attrs):
cls.contexts = dict() cls.contexts = dict()
cls.registered = False cls.registered = False
cls.framework = None
return super(function_meta, cls).__init__(name, bases, attrs) return super(function_meta, cls).__init__(name, bases, attrs)
class function(metaclass = function_meta): class function(metaclass = function_meta):
@@ -59,8 +58,9 @@ class function(metaclass = function_meta):
@classmethod @classmethod
def apply(cls, *args, **kwargs): def apply(cls, *args, **kwargs):
cls.framework = fw._find_framework(cls.framework) if fw.has_tensorflow():
if cls.framework == fw.tensorflow_id:
return cls.apply_tensorflow(*args, **kwargs) return cls.apply_tensorflow(*args, **kwargs)
else: elif fw.has_torch():
return cls.apply_torch(*args, **kwargs) return cls.apply_torch(*args, **kwargs)
else:
assert False

View File

@@ -14,10 +14,10 @@ import triton.frameworks as fw
import triton.utils import triton.utils
import libtriton import libtriton
def _make_framework_src(src, out, grid, framework): def _make_framework_src(src, out, grid):
if framework == fw.tensorflow_id: if fw.has_tensorflow():
return libtriton.make_tensorflow_src(src, out, grid) return libtriton.make_tensorflow_src(src, out, grid)
elif framework == fw.torch_id: elif fw.has_torch:
return libtriton.make_torch_src(src, out, grid) return libtriton.make_torch_src(src, out, grid)
else: else:
assert False assert False
@@ -32,10 +32,16 @@ def _make_cache_path(src):
os.makedirs(cachepath) os.makedirs(cachepath)
return cachepath return cachepath
def _write_bindings(src, root, framework): def _write_bindings(src, root):
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework)) if fw.has_tensorflow():
name = 'tensorflow'
elif fw.has_torch():
name = 'torch'
else:
assert False
cpp = os.path.join(root, '{name}.cpp'.format(name=name))
suffix = sysconfig.get_config_var('EXT_SUFFIX') suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix)) so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
recompile = False recompile = False
# recompile if .so does not exist # recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so): if not os.path.exists(cpp) or not os.path.exists(so):
@@ -50,7 +56,7 @@ def _write_bindings(src, root, framework):
# return path of cpp file # return path of cpp file
return (cpp, so) return (cpp, so)
def _build(src, path, framework): def _build(src, path):
# include directories # include directories
triton_include_dirs = ['/home/philippe/development/triton/include'] triton_include_dirs = ['/home/philippe/development/triton/include']
include_dirs = triton_include_dirs include_dirs = triton_include_dirs
@@ -61,14 +67,15 @@ def _build(src, path, framework):
libraries = ['triton'] libraries = ['triton']
# add framework # add framework
extra_compile_args = [] extra_compile_args = []
if framework == fw.tensorflow_id: if fw.has_tensorflow():
library_dirs += [fw.tensorflow.sysconfig.get_lib()] library_dirs += [fw.tensorflow.sysconfig.get_lib()]
include_dirs += [fw.tensorflow.sysconfig.get_include()] include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/'] include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')] libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0 abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
elif framework == fw.torch_id: name = 'tensorflow'
elif fw.has_torch():
prefix = os.path.dirname(fw.torch.__file__) prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')] library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += ['/usr/local/cuda/include/', include_dirs += ['/usr/local/cuda/include/',
@@ -79,6 +86,7 @@ def _build(src, path, framework):
libraries += ['torch'] libraries += ['torch']
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
name = 'torch'
else: else:
assert False assert False
# extra arguments # extra arguments
@@ -87,7 +95,7 @@ def _build(src, path, framework):
depends = [os.path.realpath(libtriton.__file__)] depends = [os.path.realpath(libtriton.__file__)]
# create extension module # create extension module
ext = setuptools.Extension( ext = setuptools.Extension(
name = fw.to_str(framework), name = name,
language = 'c++', language = 'c++',
sources = [src], sources = [src],
include_dirs = include_dirs, include_dirs = include_dirs,
@@ -104,19 +112,19 @@ def _build(src, path, framework):
args.append('--build-lib=' + path) args.append('--build-lib=' + path)
args.append('-q') args.append('-q')
args = dict( args = dict(
name = 'tensorflow', name = name,
ext_modules = [ext], ext_modules = [ext],
script_args = args, script_args = args,
) )
setuptools.setup(**args) setuptools.setup(**args)
shutil.rmtree(tmp) shutil.rmtree(tmp)
def _cvt_to_def_str(obj, framework): def _cvt_to_def_str(obj):
# bool # bool
if isinstance(obj, bool): if isinstance(obj, bool):
return str(int(obj)) return str(int(obj))
# tensorflow type # tensorflow type
if framework == fw.tensorflow_id: if fw.has_tensorflow():
if isinstance(obj, fw.tensorflow.DType): if isinstance(obj, fw.tensorflow.DType):
return {fw.tensorflow.int8: 'char', return {fw.tensorflow.int8: 'char',
fw.tensorflow.int16: 'short', fw.tensorflow.int16: 'short',
@@ -126,7 +134,7 @@ def _cvt_to_def_str(obj, framework):
fw.tensorflow.float32: 'float', fw.tensorflow.float32: 'float',
fw.tensorflow.float64: 'double'}[obj] fw.tensorflow.float64: 'double'}[obj]
# torch type # torch type
elif framework == fw.torch_id: elif fw.has_torch():
if isinstance(obj, fw.torch.dtype): if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char', return {fw.torch.int8: 'char',
fw.torch.int16: 'short', fw.torch.int16: 'short',
@@ -141,14 +149,14 @@ def _cvt_to_def_str(obj, framework):
return str(obj) return str(obj)
def _make_framework_op(src, outputs, options, framework): def _make_framework_op(src, outputs, options):
src, name = _make_framework_src(src, outputs, options, framework) src, name = _make_framework_src(src, outputs, options)
cache_path = _make_cache_path(src) cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path, framework) cpp, so = _write_bindings(src, cache_path)
_build(cpp, cache_path, framework) _build(cpp, cache_path)
if framework == fw.tensorflow_id: if fw.has_tensorflow():
return fw.tensorflow.load_op_library(so).__dict__[name] return fw.tensorflow.load_op_library(so).__dict__[name]
elif framework == fw.torch_id: elif fw.has_torch():
fw.torch.ops.load_library(so) fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name) return getattr(fw.torch.ops.triton, name)
else: else:
@@ -168,22 +176,18 @@ def _make_grid(args) :
class kernel: class kernel:
def __init__(self, src, outputs, framework = None): def __init__(self, src, outputs):
self.fw_id = dict() self.fw_id = dict()
self.fw_grids = dict() self.fw_grids = dict()
self.fw_op = None self.fw_op = None
self.src = src self.src = src
self.outputs = outputs self.outputs = outputs
self.framework = framework
def _init_framework(self): def _init_framework(self):
if self.framework is not None: if fw.has_tensorflow():
return
self.framework = fw._find_framework(self.framework)
if self.framework == fw.tensorflow_id:
fw._import_tensorflow() fw._import_tensorflow()
fw._import_tf_extra_ops() fw._import_tf_extra_ops()
elif self.framework == fw.torch_id: elif fw.has_torch():
fw._import_torch() fw._import_torch()
else: else:
assert False assert False
@@ -196,7 +200,7 @@ class kernel:
# code generation options # code generation options
defines = [] defines = []
for k, v in kwargs.items(): for k, v in kwargs.items():
cvt = lambda x: _cvt_to_def_str(x, self.framework) cvt = lambda x: _cvt_to_def_str(x)
if(isinstance(v, list)): if(isinstance(v, list)):
values = list(map(cvt, v)) values = list(map(cvt, v))
else: else:
@@ -211,7 +215,7 @@ class kernel:
# register function # register function
libtriton.register_fn(op_id, self.src, opt) libtriton.register_fn(op_id, self.src, opt)
if self.fw_op is None: if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework) self.fw_op = _make_framework_op(self.src, self.outputs, opt)
# retrieve framework op # retrieve framework op
op_id = self.fw_id[key] op_id = self.fw_id[key]
@@ -220,9 +224,9 @@ class kernel:
# create operands # create operands
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]] op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function # call framework function
if self.framework == fw.tensorflow_id: if fw.has_tensorflow():
return self.fw_op(*op_args, id=op_id) return self.fw_op(*op_args, id=op_id)
elif self.framework == fw.torch_id: elif fw.has_torch():
return self.fw_op(op_id, *op_args) return self.fw_op(op_id, *op_args)
else: else:
assert False assert False

View File

@@ -4,14 +4,13 @@ import libtriton
def cdiv(a, b): def cdiv(a, b):
return -(-a // b) return -(-a // b)
def empty(shapes, dtype, framework = None): def empty(shapes, dtype):
framework = fw._find_framework(framework) if fw.has_tensorflow():
if framework == fw.tensorflow_id:
fw._import_tensorflow() fw._import_tensorflow()
args = [x.handle if isinstance(x, scalar) else x for x in shapes] args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = fw.tensorflow.stack(args) args = fw.tensorflow.stack(args)
return fw.tf_extra_ops.alloc_empty(args, T = dtype) return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif framework == fw.torch_id: elif fw.has_torch():
fw._import_torch() fw._import_torch()
return fw.torch.empty(*shapes).cuda() return fw.torch.empty(*shapes).cuda()
@@ -23,18 +22,19 @@ class lazy_shape:
def __getitem__(self, key): def __getitem__(self, key):
return scalar(self.shape[key]) return scalar(self.shape[key])
def shape(A, framework = None) : def shape(A) :
framework = fw._find_framework(framework) if fw.has_tensorflow():
if framework == fw.tensorflow_id:
fw._import_tensorflow() fw._import_tensorflow()
return lazy_shape(fw.tensorflow.shape(A)) return lazy_shape(fw.tensorflow.shape(A))
else: elif fw.has_torch():
return A.shape return A.shape
else:
assert False
class scalar: class scalar:
def __init__(self, x, framework = None): def __init__(self, x):
self.id = libtriton.make_scalar_id() self.id = libtriton.make_scalar_id()
fw._import_tf_extra_ops() fw._import_tf_extra_ops()
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id) self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)