diff --git a/baselines/common/vec_env/vec_normalize.py b/baselines/common/vec_env/vec_normalize.py index cd80e20..f3255e9 100644 --- a/baselines/common/vec_env/vec_normalize.py +++ b/baselines/common/vec_env/vec_normalize.py @@ -26,6 +26,7 @@ class VecNormalize(VecEnvWrapper): if self.ret_rms: self.ret_rms.update(self.ret) rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) + self.ret[news] = 0. return obs, rews, news, infos def _obfilt(self, obs): @@ -37,5 +38,6 @@ class VecNormalize(VecEnvWrapper): return obs def reset(self): + self.ret = np.zeros(self.num_envs) obs = self.venv.reset() return self._obfilt(obs) diff --git a/baselines/run.py b/baselines/run.py index 0133850..cf65099 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -53,7 +53,7 @@ _game_envs['retro'] = { def train(args, extra_args): env_type, env_id = get_env_type(args.env) - print(f'env_type: {env_type}') + print('env_type: {}'.format(env_type)) total_timesteps = int(args.num_timesteps) seed = args.seed