Merge pull request #39 from mirceamironenco/master
Fix TF graph variables deprecation
This commit is contained in:
@@ -526,7 +526,7 @@ class Module(object):
|
|||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
assert self.scope is not None, "need to call module once before getting variables"
|
assert self.scope is not None, "need to call module once before getting variables"
|
||||||
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
|
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
|
||||||
|
|
||||||
|
|
||||||
def module(name):
|
def module(name):
|
||||||
@@ -681,7 +681,7 @@ def scope_vars(scope, trainable_only=False):
|
|||||||
list of variables in `scope`.
|
list of variables in `scope`.
|
||||||
"""
|
"""
|
||||||
return tf.get_collection(
|
return tf.get_collection(
|
||||||
tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.VARIABLES,
|
tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.GLOBAL_VARIABLES,
|
||||||
scope=scope if isinstance(scope, str) else scope.name
|
scope=scope if isinstance(scope, str) else scope.name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user