Files
Gymnasium/gym/vector/sync_vector_env.py
Mark Towers 2ede09074f Full type hinting (#2942)
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9117.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Allow Box automatic scalar shape

* Add test box and change default from () to (1,)

* update Box shape inference with more strict checking

* Update the box shape and add check on the custom Box shape

* Removed incorrect shape type and assert shape code

* Update the Box and associated tests

* Remove all folders and files from pyright exclude

* Revert issues

* Push RedTachyon code review

* Add Python Platform

* Remove play from pyright check

* Fixed CI issues

* remove mujoco env type hinting

* Fixed pixel observation test

* Added some new type hints

* Fixed CI errors

* Fixed CI errors

* Remove play.py from exlucde pyright

* Fixed pyright issues
2022-07-04 13:19:25 -04:00

237 lines
8.6 KiB
Python

"""A synchronous vector environment."""
from copy import deepcopy
from typing import Any, Callable, Iterator, List, Optional, Sequence, Union
import numpy as np
from gym import Env
from gym.spaces import Space
from gym.vector.utils import concatenate, create_empty_array, iterate
from gym.vector.vector_env import VectorEnv
__all__ = ["SyncVectorEnv"]
class SyncVectorEnv(VectorEnv):
"""Vectorized environment that serially runs multiple environments.
Example::
>>> import gym
>>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v0", g=9.81),
... lambda: gym.make("Pendulum-v0", g=1.62)
... ])
>>> env.reset()
array([[-0.8286432 , 0.5597771 , 0.90249056],
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
"""
def __init__(
self,
env_fns: Iterator[Callable[[], Env]],
observation_space: Space = None,
action_space: Space = None,
copy: bool = True,
):
"""Vectorized environment that serially runs multiple environments.
Args:
env_fns: iterable of callable functions that create the environments.
observation_space: Observation space of a single environment. If ``None``,
then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``,
then the action space of the first environment is taken.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
"""
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.copy = copy
self.metadata = self.envs[0].metadata
if (observation_space is None) or (action_space is None):
observation_space = observation_space or self.envs[0].observation_space
action_space = action_space or self.envs[0].action_space
super().__init__(
num_envs=len(self.envs),
observation_space=observation_space,
action_space=action_space,
)
self._check_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
"""Sets the seed in all sub-environments.
Args:
seed: The seed
"""
super().seed(seed=seed)
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
for env, single_seed in zip(self.envs, seed):
env.seed(single_seed)
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
seed: The reset environment seed
return_info: If to return information
options: Option information for the environment reset
Returns:
The reset observation of the environment and reset information
"""
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
self._dones[:] = False
observations = []
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {}
if single_seed is not None:
kwargs["seed"] = single_seed
if options is not None:
kwargs["options"] = options
if return_info is True:
kwargs["return_info"] = return_info
if not return_info:
observation = env.reset(**kwargs)
observations.append(observation)
else:
observation, info = env.reset(**kwargs)
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
if not return_info:
return deepcopy(self.observations) if self.copy else self.observations
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), infos
def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
self._actions = iterate(self.action_space, actions)
def step_wait(self):
"""Steps through each of the environments returning the batched results.
Returns:
The batched environment step results
"""
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]:
info["terminal_observation"] = observation
observation = env.reset()
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._dones),
infos,
)
def call(self, name, *args, **kwargs) -> tuple:
"""Calls the method with name and applies args and kwargs.
Args:
name: The method name
*args: The method args
**kwargs: The method kwargs
Returns:
Tuple of results
"""
results = []
for env in self.envs:
function = getattr(env, name)
if callable(function):
results.append(function(*args, **kwargs))
else:
results.append(function)
return tuple(results)
def set_attr(self, name: str, values: Union[list, tuple, Any]):
"""Sets an attribute of the sub-environments.
Args:
name: The property name to change
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise, a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
"""
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
for env, value in zip(self.envs, values):
setattr(env, name, value)
def close_extras(self, **kwargs):
"""Close the environments."""
[env.close() for env in self.envs]
def _check_spaces(self) -> bool:
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not (env.action_space == self.single_action_space):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)
return True