Files
triton/python/triton/frameworks.py
2019-09-04 21:55:47 -04:00

46 lines
1.1 KiB
Python

import sys
import os
import libtriton
torch_id = 'torch'
tensorflow_id = 'tensorflow'
torch = None
tensorflow = None
tf_extra_ops = None
def _import_torch():
global torch
if torch is None:
import torch
def _import_tensorflow():
global tensorflow
if tensorflow is None:
import tensorflow
def _import_tf_extra_ops():
global tf_extra_ops
if tf_extra_ops is None:
path = os.path.dirname(libtriton.__file__)
path = os.path.join(path, 'libextra_tf_ops.so')
_import_tensorflow()
tf_extra_ops = tensorflow.load_op_library(path)
def _find_framework(default = None):
is_tf_imported = 'tensorflow' in sys.modules
is_torch_imported = 'torch' in sys.modules
if default:
if default not in [tensorflow_id, torch_id]:
raise ValueError('unsupported framework')
else:
return default
elif is_tf_imported and not is_torch_imported:
return tensorflow_id
elif is_torch_imported and not is_tf_imported:
return torch_id
else:
raise ValueError('cannot determine imported framework, '
'please provide framework argument')