From 51cefc933b2054e531e017446c127007c0e9361b Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Thu, 30 Aug 2018 15:32:55 -0700 Subject: [PATCH] make load_variables compatible with old list format (#71) * make load_variables compatible with old list format * cosmetic fixes --- baselines/common/tf_util.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 3da9441..92dde9a 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -312,11 +312,15 @@ def get_available_gpus(): # ================================================================ def load_state(fname, sess=None): + from baselines import logger + logger.warn('load_state method is deprecated, please use load_variables instead') sess = sess or get_session() saver = tf.train.Saver() saver.restore(tf.get_default_session(), fname) def save_state(fname, sess=None): + from baselines import logger + logger.warn('save_state method is deprecated, please use save_variables instead') sess = sess or get_session() os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() @@ -339,11 +343,16 @@ def load_variables(load_path, variables=None, sess=None): variables = variables or tf.trainable_variables() loaded_params = joblib.load(os.path.expanduser(load_path)) - restores = [] - for v in variables: - restores.append(v.assign(loaded_params[v.name])) - sess.run(restores) + restores = [] + if isinstance(loaded_params, list): + assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)' + for d, v in zip(loaded_params, variables): + restores.append(v.assign(d)) + else: + for v in variables: + restores.append(v.assign(loaded_params[v.name])) + sess.run(restores) # ================================================================ # Shape adjustment for feeding into tf placeholders