[python] more cleaning of frameworks logic
This commit is contained in:
@@ -25,7 +25,14 @@ def _import_tf_extra_ops():
|
||||
tf_extra_ops = tensorflow.load_op_library(path)
|
||||
|
||||
def has_tensorflow():
|
||||
return 'tensorflow' in sys.modules
|
||||
result = 'tensorflow' in sys.modules
|
||||
if result:
|
||||
_import_tensorflow()
|
||||
_import_tf_extra_ops()
|
||||
return result
|
||||
|
||||
def has_torch():
|
||||
return 'torch' in sys.modules
|
||||
result = 'torch' in sys.modules
|
||||
if result:
|
||||
_import_torch()
|
||||
return result
|
@@ -28,7 +28,6 @@ class function(metaclass = function_meta):
|
||||
|
||||
@classmethod
|
||||
def apply_torch(cls, *args, **kwargs):
|
||||
fw._import_torch()
|
||||
class TorchFunction(fw.torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *targs, **tkwargs):
|
||||
@@ -40,7 +39,6 @@ class function(metaclass = function_meta):
|
||||
|
||||
@classmethod
|
||||
def apply_tensorflow(cls, *args, **kwargs):
|
||||
fw._import_tensorflow()
|
||||
ctx = OpContext()
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
id = result.op.get_attr('id')
|
||||
|
@@ -183,17 +183,7 @@ class kernel:
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
|
||||
def _init_framework(self):
|
||||
if fw.has_tensorflow():
|
||||
fw._import_tensorflow()
|
||||
fw._import_tf_extra_ops()
|
||||
elif fw.has_torch():
|
||||
fw._import_torch()
|
||||
else:
|
||||
assert False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._init_framework()
|
||||
# create a new framework op when defines are different
|
||||
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
|
||||
if key not in self.fw_id.keys():
|
||||
|
@@ -6,12 +6,10 @@ def cdiv(a, b):
|
||||
|
||||
def empty(shapes, dtype):
|
||||
if fw.has_tensorflow():
|
||||
fw._import_tensorflow()
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
args = fw.tensorflow.stack(args)
|
||||
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif fw.has_torch():
|
||||
fw._import_torch()
|
||||
return fw.torch.empty(*shapes).cuda()
|
||||
|
||||
class lazy_shape:
|
||||
@@ -24,7 +22,6 @@ class lazy_shape:
|
||||
|
||||
def shape(A) :
|
||||
if fw.has_tensorflow():
|
||||
fw._import_tensorflow()
|
||||
return lazy_shape(fw.tensorflow.shape(A))
|
||||
elif fw.has_torch():
|
||||
return A.shape
|
||||
@@ -36,7 +33,6 @@ class scalar:
|
||||
|
||||
def __init__(self, x):
|
||||
self.id = libtriton.make_scalar_id()
|
||||
fw._import_tf_extra_ops()
|
||||
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)
|
||||
self.assume_initialized = False
|
||||
|
||||
|
Reference in New Issue
Block a user