From 2a93ea8782ce965fa2214e65e22efa2cd4fac8e1 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Tue, 31 Jul 2018 11:13:31 -0700 Subject: [PATCH] serialize variables as a dict, not as a list --- baselines/common/tf_util.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 26c33fa..c85a28e 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -330,8 +330,9 @@ def save_variables(save_path, variables, sess=None): sess = get_session() ps = sess.run(variables) + save_dict = {v.name: value for v, value in zip(variables, ps)} os.makedirs(os.path.dirname(save_path), exist_ok=True) - joblib.dump(ps, save_path) + joblib.dump(save_dict, save_path) def load_variables(load_path, variables, sess=None): if sess is None: @@ -339,8 +340,8 @@ def load_variables(load_path, variables, sess=None): loaded_params = joblib.load(load_path) restores = [] - for p, loaded_p in zip(variables, loaded_params): - restores.append(p.assign(loaded_p)) + for v in variables: + restores.append(v.assign(loaded_params[v.name])) sess.run(restores)