2021-12-08 22:14:15 +01:00
|
|
|
from typing import List, Union, Optional
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
import numpy as np
|
2019-06-22 11:30:04 -04:00
|
|
|
from copy import deepcopy
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
from gym import logger
|
2021-12-08 22:14:15 +01:00
|
|
|
from gym.logger import warn
|
2019-06-21 17:29:44 -04:00
|
|
|
from gym.vector.vector_env import VectorEnv
|
2021-12-08 21:31:41 -05:00
|
|
|
from gym.vector.utils import concatenate, iterate, create_empty_array
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
__all__ = ["SyncVectorEnv"]
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
|
|
|
|
class SyncVectorEnv(VectorEnv):
|
|
|
|
"""Vectorized environment that serially runs multiple environments.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
env_fns : iterable of callable
|
|
|
|
Functions that create the environments.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
observation_space : :class:`gym.spaces.Space`, optional
|
|
|
|
Observation space of a single environment. If ``None``, then the
|
2019-06-21 17:29:44 -04:00
|
|
|
observation space of the first environment is taken.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
action_space : :class:`gym.spaces.Space`, optional
|
|
|
|
Action space of a single environment. If ``None``, then the action space
|
2019-06-21 17:29:44 -04:00
|
|
|
of the first environment is taken.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
copy : bool
|
|
|
|
If ``True``, then the :meth:`reset` and :meth:`step` methods return a
|
|
|
|
copy of the observations.
|
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
RuntimeError
|
|
|
|
If the observation space of some sub-environment does not match
|
|
|
|
:obj:`observation_space` (or, by default, the observation space of
|
|
|
|
the first sub-environment).
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
>>> env = gym.vector.SyncVectorEnv([
|
|
|
|
... lambda: gym.make("Pendulum-v0", g=9.81),
|
|
|
|
... lambda: gym.make("Pendulum-v0", g=1.62)
|
|
|
|
... ])
|
|
|
|
>>> env.reset()
|
|
|
|
array([[-0.8286432 , 0.5597771 , 0.90249056],
|
|
|
|
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
|
|
|
def __init__(self, env_fns, observation_space=None, action_space=None, copy=True):
|
2019-06-21 17:29:44 -04:00
|
|
|
self.env_fns = env_fns
|
|
|
|
self.envs = [env_fn() for env_fn in env_fns]
|
|
|
|
self.copy = copy
|
2021-08-18 16:36:40 -04:00
|
|
|
self.metadata = self.envs[0].metadata
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
if (observation_space is None) or (action_space is None):
|
|
|
|
observation_space = observation_space or self.envs[0].observation_space
|
|
|
|
action_space = action_space or self.envs[0].action_space
|
2021-11-14 14:51:32 +01:00
|
|
|
super().__init__(
|
2021-07-29 02:26:34 +02:00
|
|
|
num_envs=len(env_fns),
|
|
|
|
observation_space=observation_space,
|
|
|
|
action_space=action_space,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 21:31:41 -05:00
|
|
|
self._check_spaces()
|
2021-07-29 15:39:42 -04:00
|
|
|
self.observations = create_empty_array(
|
|
|
|
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
|
|
|
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
2019-10-09 15:08:10 -07:00
|
|
|
self._actions = None
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
def seed(self, seed=None):
|
|
|
|
super().seed(seed=seed)
|
|
|
|
if seed is None:
|
|
|
|
seed = [None for _ in range(self.num_envs)]
|
|
|
|
if isinstance(seed, int):
|
|
|
|
seed = [seed + i for i in range(self.num_envs)]
|
|
|
|
assert len(seed) == self.num_envs
|
|
|
|
|
|
|
|
for env, single_seed in zip(self.envs, seed):
|
|
|
|
env.seed(single_seed)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-01-19 23:28:59 +01:00
|
|
|
def reset_wait(
|
|
|
|
self,
|
|
|
|
seed: Optional[Union[int, List[int]]] = None,
|
|
|
|
options: Optional[dict] = None,
|
|
|
|
):
|
2021-12-08 22:14:15 +01:00
|
|
|
if seed is None:
|
|
|
|
seed = [None for _ in range(self.num_envs)]
|
|
|
|
if isinstance(seed, int):
|
|
|
|
seed = [seed + i for i in range(self.num_envs)]
|
|
|
|
assert len(seed) == self.num_envs
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
self._dones[:] = False
|
|
|
|
observations = []
|
2021-12-08 22:14:15 +01:00
|
|
|
for env, single_seed in zip(self.envs, seed):
|
2022-01-19 23:28:59 +01:00
|
|
|
single_kwargs = {}
|
|
|
|
if single_seed is not None:
|
|
|
|
single_kwargs["seed"] = single_seed
|
|
|
|
if options is not None:
|
|
|
|
single_kwargs["options"] = options
|
|
|
|
observation = env.reset(**single_kwargs)
|
2019-06-21 17:29:44 -04:00
|
|
|
observations.append(observation)
|
2021-07-29 15:39:42 -04:00
|
|
|
self.observations = concatenate(
|
2022-01-21 11:28:34 -05:00
|
|
|
self.single_observation_space, observations, self.observations
|
2021-07-29 15:39:42 -04:00
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2020-09-21 22:38:51 +02:00
|
|
|
return deepcopy(self.observations) if self.copy else self.observations
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-10-09 15:08:10 -07:00
|
|
|
def step_async(self, actions):
|
2021-12-08 21:31:41 -05:00
|
|
|
self._actions = iterate(self.action_space, actions)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-10-09 15:08:10 -07:00
|
|
|
def step_wait(self):
|
2019-06-21 17:29:44 -04:00
|
|
|
observations, infos = [], []
|
2019-10-09 15:08:10 -07:00
|
|
|
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
2019-06-21 17:29:44 -04:00
|
|
|
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
|
|
|
if self._dones[i]:
|
2021-11-14 08:57:44 -05:00
|
|
|
info["terminal_observation"] = observation
|
2019-06-21 17:29:44 -04:00
|
|
|
observation = env.reset()
|
|
|
|
observations.append(observation)
|
|
|
|
infos.append(info)
|
2021-07-29 15:39:42 -04:00
|
|
|
self.observations = concatenate(
|
2022-01-21 11:28:34 -05:00
|
|
|
self.single_observation_space, observations, self.observations
|
2021-07-29 15:39:42 -04:00
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
return (
|
|
|
|
deepcopy(self.observations) if self.copy else self.observations,
|
|
|
|
np.copy(self._rewards),
|
|
|
|
np.copy(self._dones),
|
|
|
|
infos,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-10-26 00:18:54 +02:00
|
|
|
def close_extras(self, **kwargs):
|
2021-11-14 08:59:04 -05:00
|
|
|
"""Close the environments."""
|
2019-10-26 00:18:54 +02:00
|
|
|
[env.close() for env in self.envs]
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 21:31:41 -05:00
|
|
|
def _check_spaces(self):
|
2019-06-21 17:29:44 -04:00
|
|
|
for env in self.envs:
|
|
|
|
if not (env.observation_space == self.single_observation_space):
|
2021-12-08 21:31:41 -05:00
|
|
|
raise RuntimeError(
|
|
|
|
"Some environments have an observation space different from "
|
|
|
|
f"`{self.single_observation_space}`. In order to batch observations, "
|
|
|
|
"the observation spaces from all environments must be equal."
|
|
|
|
)
|
|
|
|
|
|
|
|
if not (env.action_space == self.single_action_space):
|
|
|
|
raise RuntimeError(
|
|
|
|
"Some environments have an action space different from "
|
|
|
|
f"`{self.single_action_space}`. In order to batch actions, the "
|
|
|
|
"action spaces from all environments must be equal."
|
|
|
|
)
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
else:
|
|
|
|
return True
|