fixes for DummyVecEnv
Fixes various problems running MuJoCo tasks.
This commit is contained in:
@@ -20,20 +20,19 @@ class NotSteppingError(Exception):
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
class VecEnv(ABC):
|
||||
|
||||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
"""
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
"""
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all the environments and return an array of
|
||||
observations.
|
||||
observations, or a tuple of observation arrays.
|
||||
|
||||
If step_async is still doing work, that work will
|
||||
be cancelled and step_wait() should not be called
|
||||
@@ -59,10 +58,11 @@ class VecEnv(ABC):
|
||||
Wait for the step taken with step_async().
|
||||
|
||||
Returns (obs, rews, dones, infos):
|
||||
- obs: an array of observations
|
||||
- obs: an array of observations, or a tuple of
|
||||
arrays of observations.
|
||||
- rews: an array of rewards
|
||||
- dones: an array of "episode done" booleans
|
||||
- infos: an array of info objects
|
||||
- infos: a sequence of info objects
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@@ -21,14 +21,16 @@ class DummyVecEnv(VecEnv):
|
||||
def step_wait(self):
|
||||
for i in range(self.num_envs):
|
||||
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i])
|
||||
if self.buf_dones[i]:
|
||||
obs_tuple = self.envs[i].reset()
|
||||
if isinstance(obs_tuple, (tuple, list)):
|
||||
for t,x in enumerate(obs_tuple):
|
||||
self.buf_obs[t][i] = x
|
||||
else:
|
||||
self.buf_obs[0][i] = obs_tuple
|
||||
return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos
|
||||
return self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos
|
||||
|
||||
def reset(self):
|
||||
def reset(self):
|
||||
for i in range(self.num_envs):
|
||||
obs_tuple = self.envs[i].reset()
|
||||
if isinstance(obs_tuple, (tuple, list)):
|
||||
@@ -36,7 +38,13 @@ class DummyVecEnv(VecEnv):
|
||||
self.buf_obs[t][i] = x
|
||||
else:
|
||||
self.buf_obs[0][i] = obs_tuple
|
||||
return self.buf_obs
|
||||
return self._obs_from_buf()
|
||||
|
||||
def close(self):
|
||||
return
|
||||
return
|
||||
|
||||
def _obs_from_buf(self):
|
||||
if len(self.buf_obs) == 1:
|
||||
return self.buf_obs[0]
|
||||
else:
|
||||
return tuple(self.buf_obs)
|
||||
|
Reference in New Issue
Block a user