From 97be70d6c86846c04a8c0af74ebdd3b0c6337c08 Mon Sep 17 00:00:00 2001 From: Alex Nichol Date: Tue, 27 Feb 2018 18:55:10 -0800 Subject: [PATCH] fixes for DummyVecEnv Fixes various problems running MuJoCo tasks. --- baselines/common/vec_env/__init__.py | 14 +++++++------- baselines/common/vec_env/dummy_vec_env.py | 16 ++++++++++++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index 6211937..2af2d37 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -20,20 +20,19 @@ class NotSteppingError(Exception): Exception.__init__(self, msg) class VecEnv(ABC): - + """ + An abstract asynchronous, vectorized environment. + """ def __init__(self, num_envs, observation_space, action_space): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space - """ - An abstract asynchronous, vectorized environment. - """ @abstractmethod def reset(self): """ Reset all the environments and return an array of - observations. + observations, or a tuple of observation arrays. If step_async is still doing work, that work will be cancelled and step_wait() should not be called @@ -59,10 +58,11 @@ class VecEnv(ABC): Wait for the step taken with step_async(). Returns (obs, rews, dones, infos): - - obs: an array of observations + - obs: an array of observations, or a tuple of + arrays of observations. - rews: an array of rewards - dones: an array of "episode done" booleans - - infos: an array of info objects + - infos: a sequence of info objects """ pass diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index a09e375..076dab5 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -21,14 +21,16 @@ class DummyVecEnv(VecEnv): def step_wait(self): for i in range(self.num_envs): obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i]) + if self.buf_dones[i]: + obs_tuple = self.envs[i].reset() if isinstance(obs_tuple, (tuple, list)): for t,x in enumerate(obs_tuple): self.buf_obs[t][i] = x else: self.buf_obs[0][i] = obs_tuple - return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos + return self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos - def reset(self): + def reset(self): for i in range(self.num_envs): obs_tuple = self.envs[i].reset() if isinstance(obs_tuple, (tuple, list)): @@ -36,7 +38,13 @@ class DummyVecEnv(VecEnv): self.buf_obs[t][i] = x else: self.buf_obs[0][i] = obs_tuple - return self.buf_obs + return self._obs_from_buf() def close(self): - return \ No newline at end of file + return + + def _obs_from_buf(self): + if len(self.buf_obs) == 1: + return self.buf_obs[0] + else: + return tuple(self.buf_obs)