From 8b3a6c20519a45a1387c7d23b98b55d5961f3403 Mon Sep 17 00:00:00 2001 From: Alex Nichol Date: Fri, 2 Mar 2018 17:18:07 -0800 Subject: [PATCH] fix DummyVecEnv reusing buffers --- baselines/common/vec_env/dummy_vec_env.py | 7 ++++--- baselines/ddpg/training.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index 076dab5..edabf25 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -28,7 +28,8 @@ class DummyVecEnv(VecEnv): self.buf_obs[t][i] = x else: self.buf_obs[0][i] = obs_tuple - return self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos + return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), + self.buf_infos.copy()) def reset(self): for i in range(self.num_envs): @@ -45,6 +46,6 @@ class DummyVecEnv(VecEnv): def _obs_from_buf(self): if len(self.buf_obs) == 1: - return self.buf_obs[0] + return np.copy(self.buf_obs[0]) else: - return tuple(self.buf_obs) + return tuple(np.copy(x) for x in self.buf_obs) diff --git a/baselines/ddpg/training.py b/baselines/ddpg/training.py index 35388a2..74a9b8f 100644 --- a/baselines/ddpg/training.py +++ b/baselines/ddpg/training.py @@ -189,4 +189,3 @@ def train(env, nb_epochs, nb_epoch_cycles, render_eval, reward_scale, render, pa if eval_env and hasattr(eval_env, 'get_state'): with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f: pickle.dump(eval_env.get_state(), f) -