From ad219e205dc7358d93352452fd7a6e190823fd56 Mon Sep 17 00:00:00 2001 From: Isaac Lascasas Date: Thu, 6 Sep 2018 19:21:50 +0200 Subject: [PATCH 1/2] VecNormalize: set env. returns to zero on resets. (#556) * VecNormalize: set env. returns to zero on resets. * VecNormalize: returns reset in step_wait after ret_rms.update. --- baselines/common/vec_env/vec_normalize.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/baselines/common/vec_env/vec_normalize.py b/baselines/common/vec_env/vec_normalize.py index cd80e20..f3255e9 100644 --- a/baselines/common/vec_env/vec_normalize.py +++ b/baselines/common/vec_env/vec_normalize.py @@ -26,6 +26,7 @@ class VecNormalize(VecEnvWrapper): if self.ret_rms: self.ret_rms.update(self.ret) rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) + self.ret[news] = 0. return obs, rews, news, infos def _obfilt(self, obs): @@ -37,5 +38,6 @@ class VecNormalize(VecEnvWrapper): return obs def reset(self): + self.ret = np.zeros(self.num_envs) obs = self.venv.reset() return self._obfilt(obs) From 35e95ee85a394e824868fd362cec3c888a8d6cfc Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Thu, 6 Sep 2018 12:00:19 -0700 Subject: [PATCH 2/2] fix python 3.5 string format compatibility --- baselines/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baselines/run.py b/baselines/run.py index a4bdde2..ca3a1a5 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -53,7 +53,7 @@ _game_envs['retro'] = { def train(args, extra_args): env_type, env_id = get_env_type(args.env) - print(f'env_type: {env_type}') + print('env_type: {}'.format(env_type)) total_timesteps = int(args.num_timesteps) seed = args.seed