fixes for DummyVecEnv
Fixes various problems running MuJoCo tasks.
This commit is contained in:
@@ -20,20 +20,19 @@ class NotSteppingError(Exception):
|
|||||||
Exception.__init__(self, msg)
|
Exception.__init__(self, msg)
|
||||||
|
|
||||||
class VecEnv(ABC):
|
class VecEnv(ABC):
|
||||||
|
"""
|
||||||
|
An abstract asynchronous, vectorized environment.
|
||||||
|
"""
|
||||||
def __init__(self, num_envs, observation_space, action_space):
|
def __init__(self, num_envs, observation_space, action_space):
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
|
|
||||||
"""
|
|
||||||
An abstract asynchronous, vectorized environment.
|
|
||||||
"""
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Reset all the environments and return an array of
|
Reset all the environments and return an array of
|
||||||
observations.
|
observations, or a tuple of observation arrays.
|
||||||
|
|
||||||
If step_async is still doing work, that work will
|
If step_async is still doing work, that work will
|
||||||
be cancelled and step_wait() should not be called
|
be cancelled and step_wait() should not be called
|
||||||
@@ -59,10 +58,11 @@ class VecEnv(ABC):
|
|||||||
Wait for the step taken with step_async().
|
Wait for the step taken with step_async().
|
||||||
|
|
||||||
Returns (obs, rews, dones, infos):
|
Returns (obs, rews, dones, infos):
|
||||||
- obs: an array of observations
|
- obs: an array of observations, or a tuple of
|
||||||
|
arrays of observations.
|
||||||
- rews: an array of rewards
|
- rews: an array of rewards
|
||||||
- dones: an array of "episode done" booleans
|
- dones: an array of "episode done" booleans
|
||||||
- infos: an array of info objects
|
- infos: a sequence of info objects
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@@ -21,14 +21,16 @@ class DummyVecEnv(VecEnv):
|
|||||||
def step_wait(self):
|
def step_wait(self):
|
||||||
for i in range(self.num_envs):
|
for i in range(self.num_envs):
|
||||||
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i])
|
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i])
|
||||||
|
if self.buf_dones[i]:
|
||||||
|
obs_tuple = self.envs[i].reset()
|
||||||
if isinstance(obs_tuple, (tuple, list)):
|
if isinstance(obs_tuple, (tuple, list)):
|
||||||
for t,x in enumerate(obs_tuple):
|
for t,x in enumerate(obs_tuple):
|
||||||
self.buf_obs[t][i] = x
|
self.buf_obs[t][i] = x
|
||||||
else:
|
else:
|
||||||
self.buf_obs[0][i] = obs_tuple
|
self.buf_obs[0][i] = obs_tuple
|
||||||
return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos
|
return self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for i in range(self.num_envs):
|
for i in range(self.num_envs):
|
||||||
obs_tuple = self.envs[i].reset()
|
obs_tuple = self.envs[i].reset()
|
||||||
if isinstance(obs_tuple, (tuple, list)):
|
if isinstance(obs_tuple, (tuple, list)):
|
||||||
@@ -36,7 +38,13 @@ class DummyVecEnv(VecEnv):
|
|||||||
self.buf_obs[t][i] = x
|
self.buf_obs[t][i] = x
|
||||||
else:
|
else:
|
||||||
self.buf_obs[0][i] = obs_tuple
|
self.buf_obs[0][i] = obs_tuple
|
||||||
return self.buf_obs
|
return self._obs_from_buf()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _obs_from_buf(self):
|
||||||
|
if len(self.buf_obs) == 1:
|
||||||
|
return self.buf_obs[0]
|
||||||
|
else:
|
||||||
|
return tuple(self.buf_obs)
|
||||||
|
Reference in New Issue
Block a user