From f703776c91ccec5c6cb4fb442756cbf309f67ba0 Mon Sep 17 00:00:00 2001 From: Haiyang Chen <38243078+DylanHaiyangChen@users.noreply.github.com> Date: Fri, 27 Sep 2019 23:39:41 +0100 Subject: [PATCH] fix a bug in acer saving and loading model (#990) --- baselines/acer/acer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/baselines/acer/acer.py b/baselines/acer/acer.py index df4e0bf..62764db 100644 --- a/baselines/acer/acer.py +++ b/baselines/acer/acer.py @@ -6,7 +6,7 @@ from baselines import logger from baselines.common import set_global_seeds from baselines.common.policies import build_policy -from baselines.common.tf_util import get_session, save_variables +from baselines.common.tf_util import get_session, save_variables, load_variables from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.a2c.utils import batch_to_seq, seq_to_batch @@ -216,7 +216,8 @@ class Model(object): self.train = train - self.save = functools.partial(save_variables, sess=sess, variables=params) + self.save = functools.partial(save_variables, sess=sess) + self.load = functools.partial(load_variables, sess=sess) self.train_model = train_model self.step_model = step_model self._step = _step @@ -358,6 +359,9 @@ def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef= total_timesteps=total_timesteps, lrschedule=lrschedule, c=c, trust_region=trust_region, alpha=alpha, delta=delta) + if load_path is not None: + model.load(load_path) + runner = Runner(env=env, model=model, nsteps=nsteps) if replay_ratio > 0: buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)