fix DummyVecEnv reusing buffers
This commit is contained in:
@@ -28,7 +28,8 @@ class DummyVecEnv(VecEnv):
|
|||||||
self.buf_obs[t][i] = x
|
self.buf_obs[t][i] = x
|
||||||
else:
|
else:
|
||||||
self.buf_obs[0][i] = obs_tuple
|
self.buf_obs[0][i] = obs_tuple
|
||||||
return self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos
|
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
|
||||||
|
self.buf_infos.copy())
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for i in range(self.num_envs):
|
for i in range(self.num_envs):
|
||||||
@@ -45,6 +46,6 @@ class DummyVecEnv(VecEnv):
|
|||||||
|
|
||||||
def _obs_from_buf(self):
|
def _obs_from_buf(self):
|
||||||
if len(self.buf_obs) == 1:
|
if len(self.buf_obs) == 1:
|
||||||
return self.buf_obs[0]
|
return np.copy(self.buf_obs[0])
|
||||||
else:
|
else:
|
||||||
return tuple(self.buf_obs)
|
return tuple(np.copy(x) for x in self.buf_obs)
|
||||||
|
@@ -189,4 +189,3 @@ def train(env, nb_epochs, nb_epoch_cycles, render_eval, reward_scale, render, pa
|
|||||||
if eval_env and hasattr(eval_env, 'get_state'):
|
if eval_env and hasattr(eval_env, 'get_state'):
|
||||||
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
|
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
|
||||||
pickle.dump(eval_env.get_state(), f)
|
pickle.dump(eval_env.get_state(), f)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user