From b2629da1fea213496f3bfe01ec3641121c911d01 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 5 Sep 2019 02:21:07 -0400 Subject: [PATCH] [python] more cleaning of frameworks logic --- python/triton/frameworks.py | 11 +++++++++-- python/triton/function.py | 2 -- python/triton/kernel.py | 10 ---------- python/triton/utils.py | 4 ---- 4 files changed, 9 insertions(+), 18 deletions(-) diff --git a/python/triton/frameworks.py b/python/triton/frameworks.py index e3524c7ac..fcab5dcbf 100644 --- a/python/triton/frameworks.py +++ b/python/triton/frameworks.py @@ -25,7 +25,14 @@ def _import_tf_extra_ops(): tf_extra_ops = tensorflow.load_op_library(path) def has_tensorflow(): - return 'tensorflow' in sys.modules + result = 'tensorflow' in sys.modules + if result: + _import_tensorflow() + _import_tf_extra_ops() + return result def has_torch(): - return 'torch' in sys.modules \ No newline at end of file + result = 'torch' in sys.modules + if result: + _import_torch() + return result \ No newline at end of file diff --git a/python/triton/function.py b/python/triton/function.py index 53fc5dfb3..125cad668 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -28,7 +28,6 @@ class function(metaclass = function_meta): @classmethod def apply_torch(cls, *args, **kwargs): - fw._import_torch() class TorchFunction(fw.torch.autograd.Function): @staticmethod def forward(ctx, *targs, **tkwargs): @@ -40,7 +39,6 @@ class function(metaclass = function_meta): @classmethod def apply_tensorflow(cls, *args, **kwargs): - fw._import_tensorflow() ctx = OpContext() result = cls.forward(ctx, *args, **kwargs) id = result.op.get_attr('id') diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 554f0db1d..355bc3675 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -183,17 +183,7 @@ class kernel: self.src = src self.outputs = outputs - def _init_framework(self): - if fw.has_tensorflow(): - fw._import_tensorflow() - fw._import_tf_extra_ops() - elif fw.has_torch(): - fw._import_torch() - else: - assert False - def __call__(self, *args, **kwargs): - self._init_framework() # create a new framework op when defines are different key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()]) if key not in self.fw_id.keys(): diff --git a/python/triton/utils.py b/python/triton/utils.py index 3ef8be7b9..6c5df7b09 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -6,12 +6,10 @@ def cdiv(a, b): def empty(shapes, dtype): if fw.has_tensorflow(): - fw._import_tensorflow() args = [x.handle if isinstance(x, scalar) else x for x in shapes] args = fw.tensorflow.stack(args) return fw.tf_extra_ops.alloc_empty(args, T = dtype) elif fw.has_torch(): - fw._import_torch() return fw.torch.empty(*shapes).cuda() class lazy_shape: @@ -24,7 +22,6 @@ class lazy_shape: def shape(A) : if fw.has_tensorflow(): - fw._import_tensorflow() return lazy_shape(fw.tensorflow.shape(A)) elif fw.has_torch(): return A.shape @@ -36,7 +33,6 @@ class scalar: def __init__(self, x): self.id = libtriton.make_scalar_id() - fw._import_tf_extra_ops() self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id) self.assume_initialized = False