fixes for DummyVecEnv

Fixes various problems running MuJoCo tasks.
This commit is contained in:
Alex Nichol
2018-02-27 18:55:10 -08:00
parent b71152eea0
commit 97be70d6c8
2 changed files with 19 additions and 11 deletions

View File

@@ -20,20 +20,19 @@ class NotSteppingError(Exception):
Exception.__init__(self, msg)
class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
"""
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
"""
An abstract asynchronous, vectorized environment.
"""
@abstractmethod
def reset(self):
"""
Reset all the environments and return an array of
observations.
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
@@ -59,10 +58,11 @@ class VecEnv(ABC):
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations
- obs: an array of observations, or a tuple of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: an array of info objects
- infos: a sequence of info objects
"""
pass

View File

@@ -21,14 +21,16 @@ class DummyVecEnv(VecEnv):
def step_wait(self):
for i in range(self.num_envs):
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i])
if self.buf_dones[i]:
obs_tuple = self.envs[i].reset()
if isinstance(obs_tuple, (tuple, list)):
for t,x in enumerate(obs_tuple):
self.buf_obs[t][i] = x
else:
self.buf_obs[0][i] = obs_tuple
return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos
return self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos
def reset(self):
def reset(self):
for i in range(self.num_envs):
obs_tuple = self.envs[i].reset()
if isinstance(obs_tuple, (tuple, list)):
@@ -36,7 +38,13 @@ class DummyVecEnv(VecEnv):
self.buf_obs[t][i] = x
else:
self.buf_obs[0][i] = obs_tuple
return self.buf_obs
return self._obs_from_buf()
def close(self):
return
return
def _obs_from_buf(self):
if len(self.buf_obs) == 1:
return self.buf_obs[0]
else:
return tuple(self.buf_obs)