diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 26c33fa..c85a28e 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -330,8 +330,9 @@ def save_variables(save_path, variables, sess=None): sess = get_session() ps = sess.run(variables) + save_dict = {v.name: value for v, value in zip(variables, ps)} os.makedirs(os.path.dirname(save_path), exist_ok=True) - joblib.dump(ps, save_path) + joblib.dump(save_dict, save_path) def load_variables(load_path, variables, sess=None): if sess is None: @@ -339,8 +340,8 @@ def load_variables(load_path, variables, sess=None): loaded_params = joblib.load(load_path) restores = [] - for p, loaded_p in zip(variables, loaded_params): - restores.append(p.assign(loaded_p)) + for v in variables: + restores.append(v.assign(loaded_params[v.name])) sess.run(restores)