diff --git a/python/examples/dot.py b/python/examples/dot.py index 56788d422..f60397bb7 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -34,5 +34,6 @@ def run_torch(): b = th.randn(K, N).cuda() th_c = th.matmul(a, b) tr_c = triton.ops.dot(a, b) + print(tr_c) run_torch() \ No newline at end of file diff --git a/python/triton/frameworks.py b/python/triton/frameworks.py index 4d10697ad..e3524c7ac 100644 --- a/python/triton/frameworks.py +++ b/python/triton/frameworks.py @@ -2,21 +2,10 @@ import sys import os import libtriton -torch_id = 'torch' -tensorflow_id = 'tensorflow' - torch = None tensorflow = None tf_extra_ops = None -def to_str(framework): - if framework == tensorflow_id: - return 'tensorflow' - elif framework == torch_id: - return 'torch' - else: - assert False - def _import_torch(): global torch if torch is None: @@ -35,19 +24,8 @@ def _import_tf_extra_ops(): _import_tensorflow() tf_extra_ops = tensorflow.load_op_library(path) +def has_tensorflow(): + return 'tensorflow' in sys.modules -def _find_framework(default = None): - is_tf_imported = 'tensorflow' in sys.modules - is_torch_imported = 'torch' in sys.modules - if default: - if default not in [tensorflow_id, torch_id]: - raise ValueError('unsupported framework') - else: - return default - elif is_tf_imported and not is_torch_imported: - return tensorflow_id - elif is_torch_imported and not is_tf_imported: - return torch_id - else: - raise ValueError('cannot determine imported framework, ' - 'please provide framework argument') \ No newline at end of file +def has_torch(): + return 'torch' in sys.modules \ No newline at end of file diff --git a/python/triton/function.py b/python/triton/function.py index c51061652..53fc5dfb3 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -14,7 +14,6 @@ class function_meta(type): def __init__(cls, name, bases, attrs): cls.contexts = dict() cls.registered = False - cls.framework = None return super(function_meta, cls).__init__(name, bases, attrs) class function(metaclass = function_meta): @@ -59,8 +58,9 @@ class function(metaclass = function_meta): @classmethod def apply(cls, *args, **kwargs): - cls.framework = fw._find_framework(cls.framework) - if cls.framework == fw.tensorflow_id: + if fw.has_tensorflow(): return cls.apply_tensorflow(*args, **kwargs) - else: + elif fw.has_torch(): return cls.apply_torch(*args, **kwargs) + else: + assert False diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 2a7f2c929..554f0db1d 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -14,10 +14,10 @@ import triton.frameworks as fw import triton.utils import libtriton -def _make_framework_src(src, out, grid, framework): - if framework == fw.tensorflow_id: +def _make_framework_src(src, out, grid): + if fw.has_tensorflow(): return libtriton.make_tensorflow_src(src, out, grid) - elif framework == fw.torch_id: + elif fw.has_torch: return libtriton.make_torch_src(src, out, grid) else: assert False @@ -32,10 +32,16 @@ def _make_cache_path(src): os.makedirs(cachepath) return cachepath -def _write_bindings(src, root, framework): - cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework)) +def _write_bindings(src, root): + if fw.has_tensorflow(): + name = 'tensorflow' + elif fw.has_torch(): + name = 'torch' + else: + assert False + cpp = os.path.join(root, '{name}.cpp'.format(name=name)) suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix)) + so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix)) recompile = False # recompile if .so does not exist if not os.path.exists(cpp) or not os.path.exists(so): @@ -50,7 +56,7 @@ def _write_bindings(src, root, framework): # return path of cpp file return (cpp, so) -def _build(src, path, framework): +def _build(src, path): # include directories triton_include_dirs = ['/home/philippe/development/triton/include'] include_dirs = triton_include_dirs @@ -61,14 +67,15 @@ def _build(src, path, framework): libraries = ['triton'] # add framework extra_compile_args = [] - if framework == fw.tensorflow_id: + if fw.has_tensorflow(): library_dirs += [fw.tensorflow.sysconfig.get_lib()] include_dirs += [fw.tensorflow.sysconfig.get_include()] include_dirs += ['/usr/local/cuda/include/'] libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')] abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0 extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] - elif framework == fw.torch_id: + name = 'tensorflow' + elif fw.has_torch(): prefix = os.path.dirname(fw.torch.__file__) library_dirs += [os.path.join(prefix, 'lib')] include_dirs += ['/usr/local/cuda/include/', @@ -79,6 +86,7 @@ def _build(src, path, framework): libraries += ['torch'] abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] + name = 'torch' else: assert False # extra arguments @@ -87,7 +95,7 @@ def _build(src, path, framework): depends = [os.path.realpath(libtriton.__file__)] # create extension module ext = setuptools.Extension( - name = fw.to_str(framework), + name = name, language = 'c++', sources = [src], include_dirs = include_dirs, @@ -104,19 +112,19 @@ def _build(src, path, framework): args.append('--build-lib=' + path) args.append('-q') args = dict( - name = 'tensorflow', + name = name, ext_modules = [ext], script_args = args, ) setuptools.setup(**args) shutil.rmtree(tmp) -def _cvt_to_def_str(obj, framework): +def _cvt_to_def_str(obj): # bool if isinstance(obj, bool): return str(int(obj)) # tensorflow type - if framework == fw.tensorflow_id: + if fw.has_tensorflow(): if isinstance(obj, fw.tensorflow.DType): return {fw.tensorflow.int8: 'char', fw.tensorflow.int16: 'short', @@ -126,7 +134,7 @@ def _cvt_to_def_str(obj, framework): fw.tensorflow.float32: 'float', fw.tensorflow.float64: 'double'}[obj] # torch type - elif framework == fw.torch_id: + elif fw.has_torch(): if isinstance(obj, fw.torch.dtype): return {fw.torch.int8: 'char', fw.torch.int16: 'short', @@ -141,14 +149,14 @@ def _cvt_to_def_str(obj, framework): return str(obj) -def _make_framework_op(src, outputs, options, framework): - src, name = _make_framework_src(src, outputs, options, framework) +def _make_framework_op(src, outputs, options): + src, name = _make_framework_src(src, outputs, options) cache_path = _make_cache_path(src) - cpp, so = _write_bindings(src, cache_path, framework) - _build(cpp, cache_path, framework) - if framework == fw.tensorflow_id: + cpp, so = _write_bindings(src, cache_path) + _build(cpp, cache_path) + if fw.has_tensorflow(): return fw.tensorflow.load_op_library(so).__dict__[name] - elif framework == fw.torch_id: + elif fw.has_torch(): fw.torch.ops.load_library(so) return getattr(fw.torch.ops.triton, name) else: @@ -168,22 +176,18 @@ def _make_grid(args) : class kernel: - def __init__(self, src, outputs, framework = None): + def __init__(self, src, outputs): self.fw_id = dict() self.fw_grids = dict() self.fw_op = None self.src = src self.outputs = outputs - self.framework = framework def _init_framework(self): - if self.framework is not None: - return - self.framework = fw._find_framework(self.framework) - if self.framework == fw.tensorflow_id: + if fw.has_tensorflow(): fw._import_tensorflow() fw._import_tf_extra_ops() - elif self.framework == fw.torch_id: + elif fw.has_torch(): fw._import_torch() else: assert False @@ -196,7 +200,7 @@ class kernel: # code generation options defines = [] for k, v in kwargs.items(): - cvt = lambda x: _cvt_to_def_str(x, self.framework) + cvt = lambda x: _cvt_to_def_str(x) if(isinstance(v, list)): values = list(map(cvt, v)) else: @@ -211,7 +215,7 @@ class kernel: # register function libtriton.register_fn(op_id, self.src, opt) if self.fw_op is None: - self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework) + self.fw_op = _make_framework_op(self.src, self.outputs, opt) # retrieve framework op op_id = self.fw_id[key] @@ -220,9 +224,9 @@ class kernel: # create operands op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]] # call framework function - if self.framework == fw.tensorflow_id: + if fw.has_tensorflow(): return self.fw_op(*op_args, id=op_id) - elif self.framework == fw.torch_id: + elif fw.has_torch(): return self.fw_op(op_id, *op_args) else: assert False \ No newline at end of file diff --git a/python/triton/utils.py b/python/triton/utils.py index 422f1117b..3ef8be7b9 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -4,14 +4,13 @@ import libtriton def cdiv(a, b): return -(-a // b) -def empty(shapes, dtype, framework = None): - framework = fw._find_framework(framework) - if framework == fw.tensorflow_id: +def empty(shapes, dtype): + if fw.has_tensorflow(): fw._import_tensorflow() args = [x.handle if isinstance(x, scalar) else x for x in shapes] args = fw.tensorflow.stack(args) return fw.tf_extra_ops.alloc_empty(args, T = dtype) - elif framework == fw.torch_id: + elif fw.has_torch(): fw._import_torch() return fw.torch.empty(*shapes).cuda() @@ -23,18 +22,19 @@ class lazy_shape: def __getitem__(self, key): return scalar(self.shape[key]) -def shape(A, framework = None) : - framework = fw._find_framework(framework) - if framework == fw.tensorflow_id: +def shape(A) : + if fw.has_tensorflow(): fw._import_tensorflow() return lazy_shape(fw.tensorflow.shape(A)) - else: + elif fw.has_torch(): return A.shape + else: + assert False class scalar: - def __init__(self, x, framework = None): + def __init__(self, x): self.id = libtriton.make_scalar_id() fw._import_tf_extra_ops() self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)