[python] more cleaning of frameworks logic

This commit is contained in:
Philippe Tillet
2019-09-05 02:21:07 -04:00
parent 44896ee777
commit b2629da1fe
4 changed files with 9 additions and 18 deletions

View File

@@ -25,7 +25,14 @@ def _import_tf_extra_ops():
tf_extra_ops = tensorflow.load_op_library(path)
def has_tensorflow():
return 'tensorflow' in sys.modules
result = 'tensorflow' in sys.modules
if result:
_import_tensorflow()
_import_tf_extra_ops()
return result
def has_torch():
return 'torch' in sys.modules
result = 'torch' in sys.modules
if result:
_import_torch()
return result

View File

@@ -28,7 +28,6 @@ class function(metaclass = function_meta):
@classmethod
def apply_torch(cls, *args, **kwargs):
fw._import_torch()
class TorchFunction(fw.torch.autograd.Function):
@staticmethod
def forward(ctx, *targs, **tkwargs):
@@ -40,7 +39,6 @@ class function(metaclass = function_meta):
@classmethod
def apply_tensorflow(cls, *args, **kwargs):
fw._import_tensorflow()
ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs)
id = result.op.get_attr('id')

View File

@@ -183,17 +183,7 @@ class kernel:
self.src = src
self.outputs = outputs
def _init_framework(self):
if fw.has_tensorflow():
fw._import_tensorflow()
fw._import_tf_extra_ops()
elif fw.has_torch():
fw._import_torch()
else:
assert False
def __call__(self, *args, **kwargs):
self._init_framework()
# create a new framework op when defines are different
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
if key not in self.fw_id.keys():

View File

@@ -6,12 +6,10 @@ def cdiv(a, b):
def empty(shapes, dtype):
if fw.has_tensorflow():
fw._import_tensorflow()
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = fw.tensorflow.stack(args)
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif fw.has_torch():
fw._import_torch()
return fw.torch.empty(*shapes).cuda()
class lazy_shape:
@@ -24,7 +22,6 @@ class lazy_shape:
def shape(A) :
if fw.has_tensorflow():
fw._import_tensorflow()
return lazy_shape(fw.tensorflow.shape(A))
elif fw.has_torch():
return A.shape
@@ -36,7 +33,6 @@ class scalar:
def __init__(self, x):
self.id = libtriton.make_scalar_id()
fw._import_tf_extra_ops()
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)
self.assume_initialized = False