Pydocstyle utils vector docstring (#2788)

* Added pydocstyle to pre-commit

* Added docstrings for tests and updated the tests for autoreset

* Add pydocstyle exclude folder to allow slowly adding new docstrings

* Add docstrings for setup.py and gym/__init__.py, core.py, error.py and logger.py

* Check that all unwrapped environment are of a particular wrapper type

* Reverted back to import gym.spaces.Space to gym.spaces

* Fixed the __init__.py docstring

* Fixed autoreset autoreset test

* Updated gym __init__.py top docstring

* Fix examples in docstrings

* Add docstrings and type hints where known to all functions and classes in gym/utils and gym/vector

* Remove unnecessary import

* Removed "unused error" and make APIerror deprecated at gym 1.0

* Add pydocstyle description to CONTRIBUTING.md

* Added docstrings section to CONTRIBUTING.md

* Added :meth: and :attr: keywords to docstrings

* Added :meth: and :attr: keywords to docstrings

* Imported annotations from __future__ to fix python 3.7

* Add __future__ import annotations for python 3.7

* isort

* Remove utils and vectors for this PR and spaces for previous PR

* Update gym/envs/classic_control/acrobot.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/envs/classic_control/acrobot.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/envs/classic_control/acrobot.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/spaces/dict.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/ezpickle.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/ezpickle.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/play.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Pre-commit

* Updated docstrings with :meth:

* Updated docstrings with :meth:

* Update gym/utils/play.py

* Update gym/utils/play.py

* Update gym/utils/play.py

* Apply suggestions from code review

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* pre-commit

* Update gym/utils/play.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Updated fps and zoom parameter docstring

* Update play docstring

* Apply suggestions from code review

Added suggested corrections from @markus28

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Pre-commit magic

* Update the `gym.make` docstring with a warning for `env_checker`

* Updated and fixed vector docstrings

* Update test names for reflect the project filename style

Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
Mark Towers
2022-05-20 14:49:30 +01:00
committed by GitHub
parent 1b09191535
commit e2266025e6
24 changed files with 845 additions and 757 deletions

View File

@@ -30,7 +30,7 @@ repos:
rev: 6.1.1 # pick a git hash / tag to point to
hooks:
- id: pydocstyle
exclude: ^(gym/version.py)|(gym/(envs|utils|vector)/)|(tests/)
exclude: ^(gym/version.py)|(gym/envs/)|(tests/)
args:
- --source
- --explain

View File

@@ -398,27 +398,26 @@ def rk4(derivs, y0, t):
yourself stranded on a system w/o scipy. Otherwise use
:func:`scipy.integrate`.
Example:
>>> ### 2D system
>>> def derivs(x):
... d1 = x[0] + 2*x[1]
... d2 = -3*x[0] + 4*x[1]
... return (d1, d2)
>>> dt = 0.0005
>>> t = arange(0.0, 2.0, dt)
>>> y0 = (1,2)
>>> yout = rk4(derivs, y0, t)
If you have access to scipy, you should probably be using the
:func:`scipy.integrate` tools rather than this function.
This would then require re-adding the time variable to the signature of derivs.
Args:
derivs: the derivative of the system and has the signature ``dy = derivs(yi)``
y0: initial state vector
t: sample times
args: additional arguments passed to the derivative function
kwargs: additional keyword arguments passed to the derivative function
Example 1 ::
### 2D system
def derivs(x):
d1 = x[0] + 2*x[1]
d2 = -3*x[0] + 4*x[1]
return (d1, d2)
dt = 0.0005
t = arange(0.0, 2.0, dt)
y0 = (1,2)
yout = rk4(derivs6, y0, t)
If you have access to scipy, you should probably be using the
scipy.integrate tools rather than this function.
This would then require re-adding the time variable to the signature of derivs.
Returns:
yout: Runge-Kutta approximation of the ODE

View File

@@ -499,10 +499,16 @@ def make(
"""
Create an environment according to the given ID.
Warnings:
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid
if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`.
Args:
id: Name of the environment.
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
disable_env_checker: If to disable the environment checker
kwargs: Additional arguments to pass to the environment constructor.
Returns:
An instance of the environment.

View File

@@ -17,25 +17,27 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
Example usage::
Example usage:
>>> observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
>>> from gym.spaces import Dict, Discrete
>>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
>>> observation_space.sample()
OrderedDict([('position', 1), ('velocity', 2)])
Example usage [nested]::
>>> spaces.Dict(
>>> from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete
>>> Dict(
... {
... "ext_controller": spaces.MultiDiscrete((5, 2, 2)),
... "inner_state": spaces.Dict(
... "ext_controller": MultiDiscrete([5, 2, 2]),
... "inner_state": Dict(
... {
... "charge": spaces.Discrete(100),
... "system_checks": spaces.MultiBinary(10),
... "job_status": spaces.Dict(
... "charge": Discrete(100),
... "system_checks": MultiBinary(10),
... "job_status": Dict(
... {
... "task": spaces.Discrete(5),
... "progress": spaces.Box(low=0, high=100, shape=()),
... "task": Discrete(5),
... "progress": Box(low=0, high=100, shape=()),
... }
... ),
... }
@@ -63,9 +65,10 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
Example::
>>> spaces.Dict({"position": spaces.Box(-1, 1, shape=(2,)), "color": spaces.Discrete(3)})
>>> from gym.spaces import Box, Discrete
>>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)})
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
>>> spaces.Dict(position=spaces.Box(-1, 1, shape=(2,)), color=spaces.Discrete(3))
>>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3))
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
Args:

View File

@@ -16,11 +16,11 @@ class MultiBinary(Space[np.ndarray]):
Example Usage::
>>> self.observation_space = spaces.MultiBinary(5)
>>> self.observation_space.sample()
>>> observation_space = MultiBinary(5)
>>> observation_space.sample()
array([0, 1, 0, 1, 0], dtype=int8)
>>> self.observation_space = spaces.MultiBinary([3, 2])
>>> self.observation_space.sample()
>>> observation_space = MultiBinary([3, 2])
>>> observation_space.sample()
array([[0, 0],
[0, 1],
[1, 1]], dtype=int8)

View File

@@ -16,8 +16,9 @@ class Tuple(Space[tuple], Sequence):
Example usage::
>> observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Box(-1, 1, shape=(2,))))
>> observation_space.sample()
>>> from gym.spaces import Box, Discrete
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))))
>>> observation_space.sample()
(0, array([0.03633198, 0.42370757], dtype=float32))
"""

View File

@@ -25,8 +25,9 @@ def flatdim(space: Space) -> int:
Example usage::
>>> s = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
>>> spaces.flatdim(s)
>>> from gym.spaces import Discrete
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
>>> flatdim(space)
5
"""
raise NotImplementedError(f"Unknown space: `{space}`")
@@ -195,8 +196,7 @@ def flatten_space(space: Space) -> Box:
Example that recursively flattens a dict::
>>> space = Dict({"position": Discrete(2),
... "velocity": Box(0, 1, shape=(2, 2))})
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
>>> flatten_space(space)
Box(6,)
>>> flatten(space, space.sample()) in flatten_space(space)

View File

@@ -1,5 +1,6 @@
"""A set of common utilities used within the environments. These are
not intended as API functions, and will not remain stable over time.
"""A set of common utilities used within the environments.
These are not intended as API functions, and will not remain stable over time.
"""
color2num = dict(
@@ -15,12 +16,20 @@ color2num = dict(
)
def colorize(string, color, bold=False, highlight=False):
"""Return string surrounded by appropriate terminal color codes to
print colorized text. Valid colors: gray, red, green, yellow,
blue, magenta, cyan, white, crimson
"""
def colorize(
string: str, color: str, bold: bool = False, highlight: bool = False
) -> str:
"""Returns string surrounded by appropriate terminal colour codes to print colourised text.
Args:
string: The message to colourise
color: Literal values are gray, red, green, yellow, blue, magenta, cyan, white, crimson
bold: If to bold the string
highlight: If to highlight the string
Returns:
Colourised string
"""
attr = []
num = color2num[color]
if highlight:

View File

@@ -1,4 +1,5 @@
"""
"""A set of functions for checking an environment details.
This file is originally from the Stable Baselines3 repository hosted on GitHub
(https://github.com/DLR-RM/stable-baselines3/)
Original Author: Antonin Raffin
@@ -16,21 +17,33 @@ from typing import Optional, Union
import numpy as np
import gym
from gym import logger, spaces
from gym import logger
from gym.spaces import Box, Dict, Discrete, Space, Tuple
def _is_numpy_array_space(space: spaces.Space) -> bool:
def _is_numpy_array_space(space: Space) -> bool:
"""Checks if a space can be represented as a single numpy array (e.g. Dict and Tuple spaces return False).
Args:
space: The space to check
Returns:
Returns False if the provided space is not representable as a single numpy array
"""
Returns False if provided space is not representable as a single numpy array
(e.g. Dict and Tuple spaces return False)
"""
return not isinstance(space, (spaces.Dict, spaces.Tuple))
return not isinstance(space, (Dict, Tuple))
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input adheres to general standards
when the observation is apparently an image.
def _check_image_input(observation_space: Box, key: str = ""):
"""Check whether an observation space of type :class:`Box` adheres to general standards for spaces that represent images.
It will check that:
- The datatype is ``np.uint8``
- The lower bound is 0 across all dimensions
- The upper bound is 255 across all dimensions
Args:
observation_space: The observation space to check
key: The observation shape key for warning
"""
if observation_space.dtype != np.uint8:
logger.warn(
@@ -49,8 +62,13 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
)
def _check_nan(env: gym.Env, check_inf: bool = True) -> None:
"""Check for NaN and Inf."""
def _check_nan(env: gym.Env, check_inf: bool = True):
"""Check if the environment observation, reward are NaN and Inf.
Args:
env: The environment to check
check_inf: Checks if the observation is infinity
"""
for _ in range(10):
action = env.action_space.sample()
observation, reward, done, _ = env.step(action)
@@ -70,19 +88,22 @@ def _check_nan(env: gym.Env, check_inf: bool = True) -> None:
def _check_obs(
obs: Union[tuple, dict, np.ndarray, int],
observation_space: spaces.Space,
observation_space: Space,
method_name: str,
) -> None:
):
"""Check that the observation returned by the environment correspond to the declared one.
Args:
obs: The observation to check
observation_space: The observation space of the observation
method_name: The method name that generated the observation
"""
Check that the observation returned by the environment
correspond to the declared one.
"""
if not isinstance(observation_space, spaces.Tuple):
if not isinstance(observation_space, Tuple):
assert not isinstance(
obs, tuple
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
if isinstance(observation_space, spaces.Discrete):
if isinstance(observation_space, Discrete):
assert isinstance(
obs, int
), f"The observation returned by `{method_name}()` method must be an int"
@@ -96,12 +117,16 @@ def _check_obs(
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the observation space is correctly formatted
when dealing with a ``Box()`` space. In particular, it checks:
def _check_box_obs(observation_space: Box, key: str = ""):
"""Check that the observation space is correctly formatted when dealing with a :class:`Box` space.
In particular, it checks:
- that the dimensions are big enough when it is an image, and that the type matches
- that the observation has an expected shape (warn the user if not)
Args:
observation_space: Checks if the Box observation space
key: The observation key
"""
# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
@@ -137,14 +162,19 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
), "Agent's observation_space.high and observation_space have different shapes"
def _check_box_action(action_space: spaces.Box):
def _check_box_action(action_space: Box):
"""Checks that a :class:`Box` action space is defined in a sensible way.
Args:
action_space: A box action space
"""
if np.any(np.equal(action_space.low, -np.inf)):
logger.warn(
"Agent's minimum action space value is -infinity. This is probably too low."
)
if np.any(np.equal(action_space.high, np.inf)):
logger.warn(
"Agent's maxmimum action space value is infinity. This is probably too high"
"Agent's maximum action space value is infinity. This is probably too high"
)
if np.any(np.equal(action_space.low, action_space.high)):
logger.warn("Agent's maximum and minimum action space values are equal")
@@ -156,7 +186,12 @@ def _check_box_action(action_space: spaces.Box):
assert False, "Agent's action_space.high and action_space have different shapes"
def _check_normalized_action(action_space: spaces.Box):
def _check_normalized_action(action_space: Box):
"""Checks that a box action space is normalized.
Args:
action_space: A box action space
"""
if (
np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(np.abs(action_space.low) > 1)
@@ -168,16 +203,18 @@ def _check_normalized_action(action_space: spaces.Box):
)
def _check_returned_values(
env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space
) -> None:
"""
Check the returned values by the env when calling `.reset()` or `.step()` methods.
def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space):
"""Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods.
Args:
env: The environment
observation_space: The environment's observation space
action_space: The environment's action space
"""
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
obs = env.reset()
if isinstance(observation_space, spaces.Dict):
if isinstance(observation_space, Dict):
assert isinstance(
obs, dict
), "The observation returned by `reset()` must be a dictionary"
@@ -200,7 +237,7 @@ def _check_returned_values(
# Unpack
obs, reward, done, info = data
if isinstance(observation_space, spaces.Dict):
if isinstance(observation_space, Dict):
assert isinstance(
obs, dict
), "The observation returned by `step()` must be a dictionary"
@@ -223,10 +260,11 @@ def _check_returned_values(
), "The `info` returned by `step()` must be a python dictionary"
def _check_spaces(env: gym.Env) -> None:
"""
Check that the observation and action spaces are defined
and inherit from gym.spaces.Space.
def _check_spaces(env: gym.Env):
"""Check that the observation and action spaces are defined and inherit from :class:`gym.spaces.Space`.
Args:
env: The environment's observation and action space to check
"""
# Helper to link to the code, because gym has no proper documentation
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
@@ -238,25 +276,22 @@ def _check_spaces(env: gym.Env) -> None:
"You must specify an action space (cf gym.spaces)" + gym_spaces
)
assert isinstance(env.observation_space, spaces.Space), (
assert isinstance(env.observation_space, Space), (
"The observation space must inherit from gym.spaces" + gym_spaces
)
assert isinstance(env.action_space, spaces.Space), (
assert isinstance(env.action_space, Space), (
"The action space must inherit from gym.spaces" + gym_spaces
)
# Check render cannot be covered by CI
def _check_render(
env: gym.Env, warn: bool = True, headless: bool = False
) -> None: # pragma: no cover
"""
Check the declared render modes/fps and the `render()`/`close()`
method of the environment.
:param env: The environment to check
:param warn: Whether to output additional warnings
:param headless: Whether to disable render modes
that require a graphical interface. False by default.
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False):
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
Args:
env: The environment to check
warn: Whether to output additional warnings
headless: Whether to disable render modes that require a graphical interface. False by default.
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
@@ -288,9 +323,12 @@ def _check_render(
env.close()
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
"""
Check that the environment can be reset with a random seed.
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
"""Check that the environment can be reset with a seed.
Args:
env: The environment to check
seed: The optional seed to use
"""
signature = inspect.signature(env.reset)
assert (
@@ -303,7 +341,7 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
raise AssertionError(
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. "
"The error was: " + str(e)
f"The error was: {e}"
)
if env.unwrapped.np_random is None:
@@ -322,7 +360,12 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
)
def _check_reset_info(env: gym.Env) -> None:
def _check_reset_info(env: gym.Env):
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
Args:
env: The environment to check
"""
signature = inspect.signature(env.reset)
assert (
"return_info" in signature.parameters or "kwargs" in signature.parameters
@@ -334,7 +377,7 @@ def _check_reset_info(env: gym.Env) -> None:
raise AssertionError(
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. "
"The error was: " + str(e)
f"The error was: {e}"
)
assert (
len(result) == 2
@@ -346,9 +389,11 @@ def _check_reset_info(env: gym.Env) -> None:
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
def _check_reset_options(env: gym.Env) -> None:
"""
Check that the environment can be reset with options.
def _check_reset_options(env: gym.Env):
"""Check that the environment can be reset with options.
Args:
env: The environment to check
"""
signature = inspect.signature(env.reset)
assert (
@@ -361,22 +406,22 @@ def _check_reset_options(env: gym.Env) -> None:
raise AssertionError(
"The environment cannot be reset with options, even though `options` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. "
"The error was: " + str(e)
f"The error was: {e}"
)
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
"""
Check that an environment follows Gym API.
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True):
"""Check that an environment follows Gym API.
This is particularly useful when using a custom environment.
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
for more information about the API.
It also optionally check that the environment is compatible with Stable-Baselines.
:param env: The Gym environment that will be checked
:param warn: Whether to output additional warnings
mainly related to the interaction with Stable Baselines
:param skip_render_check: Whether to skip the checks for the render method.
True by default (useful for the CI)
Args:
env: The Gym environment that will be checked
warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
"""
assert isinstance(
env, gym.Env
@@ -393,15 +438,15 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
if warn:
obs_spaces = (
observation_space.spaces
if isinstance(observation_space, spaces.Dict)
if isinstance(observation_space, Dict)
else {"": observation_space}
)
for key, space in obs_spaces.items():
if isinstance(space, spaces.Box):
if isinstance(space, Box):
_check_box_obs(space, key)
# Check for the action space, it may lead to hard-to-debug issues
if isinstance(action_space, spaces.Box):
if isinstance(action_space, Box):
_check_box_action(action_space)
_check_normalized_action(action_space)

View File

@@ -1,33 +1,35 @@
"""Class for pickling and unpickling objects via their constructor arguments."""
class EzPickle:
"""Objects that are pickled and unpickled via their constructor
arguments.
"""Objects that are pickled and unpickled via their constructor arguments.
Example usage:
Example::
class Dog(Animal, EzPickle):
def __init__(self, furcolor, tailkind="bushy"):
Animal.__init__()
EzPickle.__init__(furcolor, tailkind)
...
>>> class Dog(Animal, EzPickle):
... def __init__(self, furcolor, tailkind="bushy"):
... Animal.__init__()
... EzPickle.__init__(furcolor, tailkind)
When this object is unpickled, a new Dog will be constructed by passing the provided
furcolor and tailkind into the constructor. However, philosophers are still not sure
whether it is still the same dog.
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
However, philosophers are still not sure whether it is still the same dog.
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
and Atari.
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
"""
def __init__(self, *args, **kwargs):
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
self._ezpickle_args = args
self._ezpickle_kwargs = kwargs
def __getstate__(self):
"""Returns the object pickle state with args and kwargs."""
return {
"_ezpickle_args": self._ezpickle_args,
"_ezpickle_kwargs": self._ezpickle_kwargs,
}
def __setstate__(self, d):
"""Sets the object pickle state using d."""
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
self.__dict__.update(out.__dict__)

View File

@@ -1,11 +1,18 @@
"""Utilities of visualising an environment."""
from __future__ import annotations
from collections import deque
from typing import Callable, Dict, Optional, Tuple, Union
import numpy as np
import pygame
from numpy.typing import NDArray
from pygame import Surface
from pygame.event import Event
from pygame.locals import VIDEORESIZE
from gym import Env, logger
from gym.core import ActType, ObsType
from gym.error import DependencyNotInstalled
from gym.logger import deprecation
try:
@@ -13,30 +20,31 @@ try:
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
except ImportError as e:
logger.warn(f"failed to set matplotlib backend, plotting will not work: {str(e)}")
plt = None
from collections import deque
from pygame.locals import VIDEORESIZE
from gym.core import ActType
except ImportError:
logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
matplotlib, plt = None, None
class MissingKeysToAction(Exception):
"""Raised when the environment does not have
a default keys_to_action mapping
"""
"""Raised when the environment does not have a default ``keys_to_action`` mapping."""
class PlayableGame:
"""Wraps an environment allowing keyboard inputs to interact with the environment."""
def __init__(
self,
env: Env,
keys_to_action: Optional[Dict[Tuple[int], int]] = None,
keys_to_action: Optional[dict[tuple[int], int]] = None,
zoom: Optional[float] = None,
):
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
Args:
env: The environment to play
keys_to_action: The dictionary of keyboard tuples and action value
zoom: If to zoom in on the environment render
"""
self.env = env
self.relevant_keys = self._get_relevant_keys(keys_to_action)
self.video_size = self._get_video_size(zoom)
@@ -45,7 +53,7 @@ class PlayableGame:
self.running = True
def _get_relevant_keys(
self, keys_to_action: Optional[Dict[Tuple[int], int]] = None
self, keys_to_action: Optional[dict[tuple[int], int]] = None
) -> set:
if keys_to_action is None:
if hasattr(self.env, "get_keys_to_action"):
@@ -60,7 +68,7 @@ class PlayableGame:
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
return relevant_keys
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
def _get_video_size(self, zoom: Optional[float] = None) -> tuple[int, int]:
# TODO: this needs to be updated when the render API change goes through
rendered = self.env.render(mode="rgb_array")
video_size = [rendered.shape[1], rendered.shape[0]]
@@ -70,7 +78,14 @@ class PlayableGame:
return video_size
def process_event(self, event: Event) -> None:
def process_event(self, event: Event):
"""Processes a PyGame event.
In particular, this function is used to keep track of which buttons are currently pressed and to exit the :func:`play` function when the PyGame window is closed.
Args:
event: The event to process
"""
if event.type == pygame.KEYDOWN:
if event.key in self.relevant_keys:
self.pressed_keys.append(event.key)
@@ -87,9 +102,17 @@ class PlayableGame:
def display_arr(
screen: Surface, arr: NDArray, video_size: Tuple[int, int], transpose: bool
screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
):
arr_min, arr_max = arr.min(), arr.max()
"""Displays a numpy array on screen.
Args:
screen: The screen to show the array on
arr: The array to show
video_size: The video size of the screen
transpose: If to transpose the array on the screen
"""
arr_min, arr_max = np.min(arr), np.max(arr)
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
pyg_img = pygame.transform.scale(pyg_img, video_size)
@@ -108,60 +131,74 @@ def play(
):
"""Allows one to play the game using keyboard.
To simply play the game use:
Example::
play(gym.make("Pong-v4"))
>>> import gym
>>> from gym.utils.play import play
>>> play(gym.make("CarRacing-v1"), keys_to_action={"w": np.array([0, 0.7, 0]),
... "a": np.array([-1, 0, 0]),
... "s": np.array([0, 0, 1]),
... "d": np.array([1, 0, 0]),
... "wa": np.array([-1, 0.7, 0]),
... "dw": np.array([1, 0.7, 0]),
... "ds": np.array([1, 0, 1]),
... "as": np.array([-1, 0, 1]),
... }, noop=np.array([0,0,0]))
Above code works also if env is wrapped, so it's particularly useful in
Above code works also if the environment is wrapped, so it's particularly useful in
verifying that the frame-level preprocessing does not render the game
unplayable.
If you wish to plot real time statistics as you play, you can use
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
for last 5 second of gameplay.
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
for last 150 steps.
def callback(obs_t, obs_tp1, action, rew, done, info):
return [rew,]
plotter = PlayPlot(callback, 30 * 5, ["reward"])
env = gym.make("Pong-v4")
play(env, callback=plotter.callback)
>>> def callback(obs_t, obs_tp1, action, rew, done, info):
... return [rew,]
>>> plotter = PlayPlot(callback, 150, ["reward"])
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
Arguments
---------
env: gym.Env
Environment to use for playing.
transpose: bool
If True the output of observation is transposed.
Defaults to true.
fps: int
Maximum number of steps of the environment to execute every second.
Defaults to 30.
zoom: float
Make screen edge this many times bigger
callback: lambda or None
Callback if a callback is provided it will be executed after
every step. It takes the following input:
obs_t: observation before performing action
obs_tp1: observation after performing action
action: action that was executed
rew: reward that was received
done: whether the environment is done or not
info: debug info
keys_to_action: dict: tuple(int) -> int or None
Mapping from keys pressed to action performed.
For example if pressed 'w' and space at the same time is supposed
to trigger action number 2 then key_to_action dict would look like this:
{
# ...
sorted(ord('w'), ord(' ')) -> 2
# ...
}
If None, default key_to_action mapping for that env is used, if provided.
seed: bool or None
Random seed used when resetting the environment. If None, no seed is used.
Args:
env: Environment to use for playing.
transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
fps: Maximum number of steps of the environment executed every second. If ``None`` (the default),
``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
zoom: Zoom the observation in, ``zoom`` amount, should be positive float
callback: If a callback is provided, it will be executed after every step. It takes the following input:
obs_t: observation before performing action
obs_tp1: observation after performing action
action: action that was executed
rew: reward that was received
done: whether the environment is done or not
info: debug info
keys_to_action: Mapping from keys pressed to action performed.
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
points of the keys, as a tuple of characters, or as a string where each character of the string represents
one key.
For example if pressing 'w' and space at the same time is supposed
to trigger action number 2 then ``key_to_action`` dict could look like this:
>>> {
... # ...
... (ord('w'), ord(' ')): 2
... # ...
... }
or like this:
>>> {
... # ...
... ("w", " "): 2
... # ...
... }
or like this:
>>> {
... # ...
... "w ": 2
... # ...
... }
If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
seed: Random seed used when resetting the environment. If None, no seed is used.
noop: The action used when no key input has been entered, or the entered key combination is unknown.
"""
env.reset(seed=seed)
@@ -208,7 +245,44 @@ def play(
class PlayPlot:
def __init__(self, callback, horizon_timesteps, plot_names):
"""Provides a callback to create live plots of arbitrary metrics when using :func:`play`.
This class is instantiated with a function that accepts information about a single environment transition:
- obs_t: observation before performing action
- obs_tp1: observation after performing action
- action: action that was executed
- rew: reward that was received
- done: whether the environment is done or not
- info: debug info
It should return a list of metrics that are computed from this data.
For instance, the function may look like this::
def compute_metrics(obs_t, obs_tp, action, reward, done, info):
return [reward, info["cumulative_reward"], np.linalg.norm(action)]
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
and uses the returned values to update live plots of the metrics.
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
>>> play(your_env, callback=plotter.callback)
"""
def __init__(
self, callback: callable, horizon_timesteps: int, plot_names: list[str]
):
"""Constructor of :class:`PlayPlot`.
The function ``callback`` that is passed to this constructor should return
a list of metrics that is of length ``len(plot_names)``.
Args:
callback: Function that computes metrics from environment transitions
horizon_timesteps: The time horizon used for the live plots
plot_names: List of plot titles
"""
deprecation(
"`PlayPlot` is marked as deprecated and will be removed in the near future."
)
@@ -216,7 +290,10 @@ class PlayPlot:
self.horizon_timesteps = horizon_timesteps
self.plot_names = plot_names
assert plt is not None, "matplotlib backend failed, plotting will not work"
if plt is None:
raise DependencyNotInstalled(
"matplotlib is not installed, run `pip install gym[other]`"
)
num_plots = len(self.plot_names)
self.fig, self.ax = plt.subplots(num_plots)
@@ -228,7 +305,25 @@ class PlayPlot:
self.cur_plot = [None for _ in range(num_plots)]
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
def callback(self, obs_t, obs_tp1, action, rew, done, info):
def callback(
self,
obs_t: ObsType,
obs_tp1: ObsType,
action: ActType,
rew: float,
done: bool,
info: dict,
):
"""The callback that calls the provided data callback and adds the data to the plots.
Args:
obs_t: The observation at time step t
obs_tp1: The observation at time step t+1
action: The action
rew: The reward
done: If the environment is done
info: The information from the environment
"""
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
for point, data_series in zip(points, self.data):
data_series.append(point)

View File

@@ -1,7 +1,10 @@
"""Set of random number generator functions: seeding, generator, hashing seeds."""
from __future__ import annotations
import hashlib
import os
import struct
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import numpy as np
@@ -9,7 +12,15 @@ from gym import error
from gym.logger import deprecation
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]:
"""Generates a random number generator from the seed and returns the Generator and seed.
Args:
seed: The seed used to create the generator
Returns:
The generator and resulting seed
"""
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
@@ -22,7 +33,10 @@ def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
# RandomNumberGenerator = np.random.Generator
class RandomNumberGenerator(np.random.Generator):
"""Random number generator class that inherits from numpy's random Generator class."""
def rand(self, *size):
"""Deprecated rand function using random."""
deprecation(
"Function `rng.rand(*size)` is marked as deprecated "
"and will be removed in the future. "
@@ -34,6 +48,7 @@ class RandomNumberGenerator(np.random.Generator):
random_sample = rand
def randn(self, *size):
"""Deprecated random standard normal function use standard_normal."""
deprecation(
"Function `rng.randn(*size)` is marked as deprecated "
"and will be removed in the future. "
@@ -43,6 +58,7 @@ class RandomNumberGenerator(np.random.Generator):
return self.standard_normal(size)
def randint(self, low, high=None, size=None, dtype=int):
"""Deprecated random integer function use integers."""
deprecation(
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
"and will be removed in the future. "
@@ -54,6 +70,7 @@ class RandomNumberGenerator(np.random.Generator):
random_integers = randint
def get_state(self):
"""Deprecated get rng state use bit_generator.state."""
deprecation(
"Function `rng.get_state()` is marked as deprecated "
"and will be removed in the future. "
@@ -63,6 +80,7 @@ class RandomNumberGenerator(np.random.Generator):
return self.bit_generator.state
def set_state(self, state):
"""Deprecated set rng state function use bit_generator.state = state."""
deprecation(
"Function `rng.set_state(state)` is marked as deprecated "
"and will be removed in the future. "
@@ -72,6 +90,7 @@ class RandomNumberGenerator(np.random.Generator):
self.bit_generator.state = state
def seed(self, seed=None):
"""Deprecated seed function use gym.utils.seeding.np_random(seed)."""
deprecation(
"Function `rng.seed(seed)` is marked as deprecated "
"and will be removed in the future. "
@@ -88,6 +107,7 @@ class RandomNumberGenerator(np.random.Generator):
seed.__doc__ = np.random.seed.__doc__
def __reduce__(self):
"""Reduces the Random Number Generator to a RandomNumberGenerator, init_args and additional args."""
# np.random.Generator defines __reduce__, but it's hard-coded to
# return a Generator instead of its subclass RandomNumberGenerator.
# We need to override it here, otherwise sampling from a Space will
@@ -119,20 +139,21 @@ RNG = RandomNumberGenerator
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
"""Any given evaluation is likely to have many PRNG's active at
once. (Most commonly, because the environment is running in
multiple processes.) There's literature indicating that having
linear correlations between seeds of multiple PRNG's can correlate
the outputs:
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928
Thus, for sanity we hash the seeds before using them. (This scheme
is likely not crypto-strength, but it should be good enough to get
rid of simple correlations.)
"""Any given evaluation is likely to have many PRNG's active at once.
(Most commonly, because the environment is running in multiple processes.)
There's literature indicating that having linear correlations between seeds of multiple PRNG's can correlate the outputs:
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928
Thus, for sanity we hash the seeds before using them. (This scheme is likely not crypto-strength, but it should be good enough to get rid of simple correlations.)
Args:
seed: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed.
Returns:
The hashed seed
"""
deprecation(
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
@@ -144,12 +165,16 @@ def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
"""Create a strong random seed. Otherwise, Python 2 would seed using
the system time, which might be non-robust especially in the
presence of concurrency.
"""Create a strong random seed.
Otherwise, Python 2 would seed using the system time, which might be non-robust especially in the presence of concurrency.
Args:
a: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed.
Returns:
A seed
"""
deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
@@ -185,7 +210,7 @@ def _bigint_from_bytes(bt: bytes) -> int:
return accum
def _int_list_from_bigint(bigint: int) -> List[int]:
def _int_list_from_bigint(bigint: int) -> list[int]:
deprecation(
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
)
@@ -195,7 +220,7 @@ def _int_list_from_bigint(bigint: int) -> List[int]:
elif bigint == 0:
return [0]
ints: List[int] = []
ints: list[int] = []
while bigint > 0:
bigint, mod = divmod(bigint, 2**32)
ints.append(mod)

View File

@@ -1,7 +1,7 @@
try:
from collections.abc import Iterable
except ImportError:
Iterable = (tuple, list)
"""Module for vector environments."""
from __future__ import annotations
from typing import Iterable, Optional, Union
from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv
@@ -10,40 +10,34 @@ from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
"""Create a vectorized environment from multiple copies of an environment,
from its id.
def make(
id: str,
num_envs: int = 1,
asynchronous: bool = True,
wrappers: Optional[Union[callable, list[callable]]] = None,
**kwargs,
) -> VectorEnv:
"""Create a vectorized environment from multiple copies of an environment, from its id.
Parameters
----------
id : str
The environment ID. This must be a valid ID from the registry.
Example::
num_envs : int
Number of copies of the environment.
>>> import gym
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
>>> env.reset()
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
dtype=float32)
asynchronous : bool
If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses
`multiprocessing`_ to run the environments in parallel). If ``False``,
wraps the environments in a :class:`SyncVectorEnv`.
Args:
id: The environment ID. This must be a valid ID from the registry.
num_envs: Number of copies of the environment.
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
**kwargs: Keywords arguments applied during gym.make
wrappers : callable, or iterable of callables, optional
If not ``None``, then apply the wrappers to each internal
environment during creation.
Returns
-------
:class:`gym.vector.VectorEnv`
Returns:
The vectorized environment.
Example
-------
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
>>> env.reset()
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
dtype=float32)
"""
from gym.envs import make as make_

View File

@@ -1,13 +1,18 @@
"""An async vector environment."""
from __future__ import annotations
import multiprocessing as mp
import sys
import time
from copy import deepcopy
from enum import Enum
from typing import List, Optional, Union
from typing import Optional, Sequence, Union
import numpy as np
import gym
from gym import logger
from gym.core import ObsType
from gym.error import (
AlreadyPendingCallError,
ClosedEnvironmentError,
@@ -37,69 +42,13 @@ class AsyncState(Enum):
class AsyncVectorEnv(VectorEnv):
"""Vectorized environment that runs multiple environments in parallel. It
uses `multiprocessing`_ processes, and pipes for communication.
"""Vectorized environment that runs multiple environments in parallel.
Parameters
----------
env_fns : iterable of callable
Functions that create the environments.
It uses ``multiprocessing`` processes, and pipes for communication.
observation_space : :class:`gym.spaces.Space`, optional
Observation space of a single environment. If ``None``, then the
observation space of the first environment is taken.
action_space : :class:`gym.spaces.Space`, optional
Action space of a single environment. If ``None``, then the action space
of the first environment is taken.
shared_memory : bool
If ``True``, then the observations from the worker processes are
communicated back through shared variables. This can improve the
efficiency if the observations are large (e.g. images).
copy : bool
If ``True``, then the :meth:`~AsyncVectorEnv.reset` and
:meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
context : str, optional
Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon : bool
If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they
will quit if the head process quits. However, ``daemon=True`` prevents
subprocesses to spawn children, so for some environments you may want
to have it set to ``False``.
worker : callable, optional
If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance,
how resets on done are handled.
Warning
-------
:attr:`worker` is an advanced mode option. It provides a high degree of
flexibility and a high chance to shoot yourself in the foot; thus,
if you are writing your own worker, it is recommended to start from the code
for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises
------
RuntimeError
If the observation space of some sub-environment does not match
:obj:`observation_space` (or, by default, the observation space of
the first sub-environment).
ValueError
If :obj:`observation_space` is a custom space (i.e. not a default
space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`,
or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``.
Example
-------
.. code-block::
Example::
>>> import gym
>>> env = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v0", g=9.81),
... lambda: gym.make("Pendulum-v0", g=1.62)
@@ -111,15 +60,33 @@ class AsyncVectorEnv(VectorEnv):
def __init__(
self,
env_fns,
observation_space=None,
action_space=None,
shared_memory=True,
copy=True,
context=None,
daemon=True,
worker=None,
env_fns: Sequence[callable],
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
shared_memory: bool = True,
copy: bool = True,
context: Optional[str] = None,
daemon: bool = True,
worker: Optional[callable] = None,
):
"""Vectorized environment that runs multiple environments in parallel.
Args:
env_fns: Functions that create the environments.
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images).
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True.
"""
ctx = mp.get_context(context)
self.env_fns = env_fns
self.shared_memory = shared_memory
@@ -192,6 +159,11 @@ class AsyncVectorEnv(VectorEnv):
self._check_spaces()
def seed(self, seed=None):
"""Seeds the vector environments.
Args:
seed: The seeds use with the environments
"""
super().seed(seed=seed)
self._assert_is_running()
if seed is None:
@@ -213,22 +185,24 @@ class AsyncVectorEnv(VectorEnv):
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Send the calls to :obj:`reset` to each sub-environment.
"""Send calls to the :obj:`reset` methods of the sub-environments.
Raises
------
ClosedEnvironmentError
If the environment was closed (if :meth:`close` was previously called).
To get the results of these calls, you may invoke :meth:`reset_wait`.
AlreadyPendingCallError
If the environment is already waiting for a pending call to another
method (e.g. :meth:`step_async`). This can be caused by two consecutive
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in
between.
Args:
seed: List of seeds for each environment
return_info: If to return information
options: The reset option
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`step_async`). This can be caused by two consecutive
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
"""
self._assert_is_running()
@@ -258,37 +232,26 @@ class AsyncVectorEnv(VectorEnv):
def reset_wait(
self,
timeout=None,
timeout: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to `reset_wait` times out. If
`None`, the call to `reset_wait` never times out.
seed: ignored
options: ignored
) -> Union[ObsType, tuple[ObsType, list[dict]]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Returns
-------
element of :attr:`~VectorEnv.observation_space`
A batch of observations from the vectorized environment.
infos : list of dicts containing metadata
Args:
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
seed: ignored
return_info: If to return information
options: ignored
Raises
------
ClosedEnvironmentError
If the environment was closed (if :meth:`close` was previously called).
Returns:
A tuple of batched observations and list of dictionaries
NoAsyncCallError
If :meth:`reset_wait` was called without any prior call to
:meth:`reset_async`.
TimeoutError
If :meth:`reset_wait` timed out.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
TimeoutError: If :meth:`reset_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_RESET:
@@ -327,24 +290,18 @@ class AsyncVectorEnv(VectorEnv):
return deepcopy(self.observations) if self.copy else self.observations
def step_async(self, actions):
def step_async(self, actions: np.ndarray):
"""Send the calls to :obj:`step` to each sub-environment.
Parameters
----------
actions : element of :attr:`~VectorEnv.action_space`
Batch of actions.
Args:
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
Raises
------
ClosedEnvironmentError
If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError
If the environment is already waiting for a pending call to another
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
between.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
between.
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
@@ -358,40 +315,21 @@ class AsyncVectorEnv(VectorEnv):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
def step_wait(self, timeout=None):
def step_wait(
self, timeout: Optional[Union[int, float]] = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[dict]]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to :meth:`step_wait` times out. If
``None``, the call to :meth:`step_wait` never times out.
Args:
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
Returns
-------
observations : element of :attr:`~VectorEnv.observation_space`
A batch of observations from the vectorized environment.
Returns:
The batched environment step information, obs, reward, done and info
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
A vector of rewards from the vectorized environment.
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
A vector whose entries indicate whether the episode has ended.
infos : list of dict
A list of auxiliary diagnostic information dicts from sub-environments.
Raises
------
ClosedEnvironmentError
If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError
If :meth:`step_wait` was called without any prior call to
:meth:`step_async`.
TimeoutError
If :meth:`step_wait` timed out.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
TimeoutError: If :meth:`step_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_STEP:
@@ -425,18 +363,13 @@ class AsyncVectorEnv(VectorEnv):
infos,
)
def call_async(self, name, *args, **kwargs):
"""
Parameters
----------
name : string
Name of the method or property to call.
def call_async(self, name: str, *args, **kwargs):
"""Calls the method with name asynchronously and apply args and kwargs to the method.
*args
Arguments to apply to the method call.
**kwargs
Keywoard arguments to apply to the method call.
Args:
name: Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
@@ -450,19 +383,14 @@ class AsyncVectorEnv(VectorEnv):
pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout=None):
"""
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to `step_wait` times out. If
`None` (default), the call to `step_wait` never times out.
def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list:
"""Calls all parent pipes and waits for the results.
Returns
-------
results : list
List of the results of the individual calls to the method or
property for each environment.
Args:
timeout: Number of seconds before the call to `step_wait` times out. If `None` (default), the call to `step_wait` never times out.
Returns:
List of the results of the individual calls to the method or property for each environment.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_CALL:
@@ -483,17 +411,14 @@ class AsyncVectorEnv(VectorEnv):
return results
def set_attr(self, name, values):
"""
Parameters
----------
name : string
Name of the property to be set in each individual environment.
def set_attr(self, name: str, values: Union[list, tuple, object]):
"""Sets an attribute of the sub-environments.
values : list, tuple, or object
Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
Args:
name: Name of the property to be set in each individual environment.
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
"""
self._assert_is_running()
if not isinstance(values, (list, tuple)):
@@ -517,25 +442,19 @@ class AsyncVectorEnv(VectorEnv):
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def close_extras(self, timeout=None, terminate=False):
"""Close the environments & clean up the extra resources
(processes and pipes).
def close_extras(
self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
):
"""Close the environments & clean up the extra resources (processes and pipes).
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to :meth:`close` times out. If ``None``,
the call to :meth:`close` never times out. If the call to :meth:`close`
times out, then all processes are terminated.
Args:
timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
the call to :meth:`close` never times out. If the call to :meth:`close`
times out, then all processes are terminated.
terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
terminate : bool
If ``True``, then the :meth:`close` operation is forced and all processes
are terminated.
Raises
------
TimeoutError
If :meth:`close` timed out.
Raises:
TimeoutError: If :meth:`close` timed out.
"""
timeout = 0 if terminate else timeout
try:
@@ -626,6 +545,7 @@ class AsyncVectorEnv(VectorEnv):
raise exctype(value)
def __del__(self):
"""On deleting the object, checks that the vector environment is closed."""
if not getattr(self, "closed", True) and hasattr(self, "_state"):
self.close(terminate=True)

View File

@@ -1,8 +1,12 @@
"""A synchronous vector environment."""
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional, Union
from typing import Any, Iterator, Optional, Sequence, Union
import numpy as np
from gym.spaces import Space
from gym.vector.utils import concatenate, create_empty_array, iterate
from gym.vector.vector_env import VectorEnv
@@ -12,35 +16,9 @@ __all__ = ["SyncVectorEnv"]
class SyncVectorEnv(VectorEnv):
"""Vectorized environment that serially runs multiple environments.
Parameters
----------
env_fns : iterable of callable
Functions that create the environments.
observation_space : :class:`gym.spaces.Space`, optional
Observation space of a single environment. If ``None``, then the
observation space of the first environment is taken.
action_space : :class:`gym.spaces.Space`, optional
Action space of a single environment. If ``None``, then the action space
of the first environment is taken.
copy : bool
If ``True``, then the :meth:`reset` and :meth:`step` methods return a
copy of the observations.
Raises
------
RuntimeError
If the observation space of some sub-environment does not match
:obj:`observation_space` (or, by default, the observation space of
the first sub-environment).
Example
-------
.. code-block::
Example::
>>> import gym
>>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v0", g=9.81),
... lambda: gym.make("Pendulum-v0", g=1.62)
@@ -50,7 +28,24 @@ class SyncVectorEnv(VectorEnv):
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
"""
def __init__(self, env_fns, observation_space=None, action_space=None, copy=True):
def __init__(
self,
env_fns: Iterator[callable],
observation_space: Space = None,
action_space: Space = None,
copy: bool = True,
):
"""Vectorized environment that serially runs multiple environments.
Args:
env_fns: iterable of callable functions that create the environments.
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
"""
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.copy = copy
@@ -60,7 +55,7 @@ class SyncVectorEnv(VectorEnv):
observation_space = observation_space or self.envs[0].observation_space
action_space = action_space or self.envs[0].action_space
super().__init__(
num_envs=len(env_fns),
num_envs=len(self.envs),
observation_space=observation_space,
action_space=action_space,
)
@@ -73,7 +68,12 @@ class SyncVectorEnv(VectorEnv):
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
"""Sets the seed in all sub-environments.
Args:
seed: The seed
"""
super().seed(seed=seed)
if seed is None:
seed = [None for _ in range(self.num_envs)]
@@ -86,10 +86,20 @@ class SyncVectorEnv(VectorEnv):
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
seed: The reset environment seed
return_info: If to return information
options: Option information for the environment reset
Returns:
The reset observation of the environment and reset information
"""
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
@@ -128,9 +138,15 @@ class SyncVectorEnv(VectorEnv):
), data_list
def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
self._actions = iterate(self.action_space, actions)
def step_wait(self):
"""Steps through each of the environments returning the batched results.
Returns:
The batched environment step results
"""
observations, infos = [], []
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
@@ -150,7 +166,17 @@ class SyncVectorEnv(VectorEnv):
infos,
)
def call(self, name, *args, **kwargs):
def call(self, name, *args, **kwargs) -> tuple:
"""Calls the method with name and applies args and kwargs.
Args:
name: The method name
*args: The method args
**kwargs: The method kwargs
Returns:
Tuple of results
"""
results = []
for env in self.envs:
function = getattr(env, name)
@@ -161,7 +187,15 @@ class SyncVectorEnv(VectorEnv):
return tuple(results)
def set_attr(self, name, values):
def set_attr(self, name: str, values: Union[list, tuple, Any]):
"""Sets an attribute of the sub-environments.
Args:
name: The property name to change
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise, a single value is set for all environments.
"""
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
@@ -178,7 +212,7 @@ class SyncVectorEnv(VectorEnv):
"""Close the environments."""
[env.close() for env in self.envs]
def _check_spaces(self):
def _check_spaces(self) -> bool:
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
raise RuntimeError(
@@ -194,5 +228,4 @@ class SyncVectorEnv(VectorEnv):
"action spaces from all environments must be equal."
)
else:
return True
return True

View File

@@ -1,3 +1,4 @@
"""Module for gym vector utils."""
from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
from gym.vector.utils.shared_memory import (

View File

@@ -1,3 +1,4 @@
"""Miscellaneous utilities."""
import contextlib
import os
@@ -5,28 +6,35 @@ __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
class CloudpickleWrapper:
def __init__(self, fn):
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
def __init__(self, fn: callable):
"""Cloudpickle wrapper for a function."""
self.fn = fn
def __getstate__(self):
"""Get the state using `cloudpickle.dumps(self.fn)`."""
import cloudpickle
return cloudpickle.dumps(self.fn)
def __setstate__(self, ob):
"""Sets the state with obs."""
import pickle
self.fn = pickle.loads(ob)
def __call__(self):
"""Calls the function `self.fn` with no arguments."""
return self.fn()
@contextlib.contextmanager
def clear_mpi_env_vars():
"""
`from mpi4py import MPI` will call `MPI_Init` by default. If the child
process has MPI environment variables, MPI will think that the child process
"""Clears the MPI of environment variables.
`from mpi4py import MPI` will call `MPI_Init` by default.
If the child process has MPI environment variables, MPI will think that the child process
is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables

View File

@@ -1,5 +1,7 @@
"""Numpy utility functions: concatenate space samples and create empty array."""
from collections import OrderedDict
from functools import singledispatch
from typing import Iterable, Union
import numpy as np
@@ -9,36 +11,29 @@ __all__ = ["concatenate", "create_empty_array"]
@singledispatch
def concatenate(space, items, out):
def concatenate(
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
) -> Union[tuple, dict, np.ndarray]:
"""Concatenate multiple samples from space into a single object.
Parameters
----------
items : iterable of samples of `space`
Samples to be concatenated.
Example::
out : tuple, dict, or `np.ndarray`
>>> from gym.spaces import Box
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(space, items, out)
array([[0.6348213 , 0.28607962, 0.60760117],
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array.
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
Returns
-------
out : tuple, dict, or `np.ndarray`
The output object. This object is a (possibly nested) numpy array.
Example
-------
>>> from gym.spaces import Box
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(items, out, space)
array([[0.6348213 , 0.28607962, 0.60760117],
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
"""
assert isinstance(items, (list, tuple))
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
)
@@ -76,38 +71,30 @@ def _concatenate_custom(space, items, out):
@singledispatch
def create_empty_array(space, n=1, fn=np.zeros):
def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
) -> Union[tuple, dict, np.ndarray]:
"""Create an empty (possibly nested) numpy array.
Parameters
----------
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
Example::
n : int
Number of environments in the vectorized environment. If `None`, creates
an empty sample from `space`.
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)),
('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
fn : callable
Function to apply when creating the empty numpy array. Examples of such
functions are `np.empty` or `np.zeros`.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
Returns
-------
out : tuple, dict, or `np.ndarray`
Returns:
The output object. This object is a (possibly nested) numpy array.
Example
-------
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)),
('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."

View File

@@ -1,44 +1,40 @@
"""Utility functions for vector environments to share memory between processes."""
import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
from typing import Union
import numpy as np
from gym.error import CustomSpaceError
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@singledispatch
def create_shared_memory(space, n=1, ctx=mp):
"""Create a shared memory object, to be shared across processes. This
eventually contains the observations from the vectorized environment.
def create_shared_memory(
space: Space, n: int = 1, ctx=mp
) -> Union[dict, tuple, mp.Array]:
"""Create a shared memory object, to be shared across processes.
Parameters
----------
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
This eventually contains the observations from the vectorized environment.
n : int
Number of environments in the vectorized environment (i.e. the number
of processes).
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment (i.e. the number of processes).
ctx: The multiprocess module
ctx : `multiprocessing` context
Context for multiprocessing.
Returns
-------
shared_memory : dict, tuple, or `multiprocessing.Array` instance
Shared object across processes.
Returns:
shared_memory for the shared object across processes.
"""
raise CustomSpaceError(
"Cannot create a shared memory for space with "
"type `{}`. Shared memory only supports "
f"type `{type(space)}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
"Gym spaces."
)
@@ -46,7 +42,7 @@ def create_shared_memory(space, n=1, ctx=mp):
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n=1, ctx=mp):
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
@@ -54,7 +50,7 @@ def _create_base_shared_memory(space, n=1, ctx=mp):
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n=1, ctx=mp):
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@@ -71,39 +67,32 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
@singledispatch
def read_from_shared_memory(space, shared_memory, n=1):
def read_from_shared_memory(
space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1
) -> Union[dict, tuple, np.ndarray]:
"""Read the batch of observations from shared memory as a numpy array.
Parameters
----------
shared_memory : dict, tuple, or `multiprocessing.Array` instance
Shared object across processes. This contains the observations from the
vectorized environment. This object is created with `create_shared_memory`.
..notes::
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
Args:
space: Observation space of a single environment in the vectorized environment.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
n: Number of environments in the vectorized environment (i.e. the number of processes).
n : int
Number of environments in the vectorized environment (i.e. the number
of processes).
Returns
-------
observations : dict, tuple or `np.ndarray` instance
Returns:
Batch of observations as a (possibly nested) numpy array.
Notes
-----
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
"""
raise CustomSpaceError(
"Cannot read from a shared memory for space with "
"type `{}`. Shared memory only supports "
f"type `{type(space)}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
"Gym spaces."
)
@@ -111,14 +100,14 @@ def read_from_shared_memory(space, shared_memory, n=1):
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n=1):
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n=1):
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
@@ -126,7 +115,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n=1):
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n=1):
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
return OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
@@ -136,34 +125,26 @@ def _read_dict_from_shared_memory(space, shared_memory, n=1):
@singledispatch
def write_to_shared_memory(space, index, value, shared_memory):
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: Union[dict, tuple, mp.Array],
):
"""Write the observation of a single environment into shared memory.
Parameters
----------
index : int
Index of the environment (must be in `[0, num_envs)`).
value : sample from `space`
Observation of the single environment to write to shared memory.
shared_memory : dict, tuple, or `multiprocessing.Array` instance
Shared object across processes. This contains the observations from the
vectorized environment. This object is created with `create_shared_memory`.
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
Returns
-------
`None`
Args:
space: Observation space of a single environment in the vectorized environment.
index: Index of the environment (must be in `[0, num_envs)`).
value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment. This object is created with `create_shared_memory`.
"""
raise CustomSpaceError(
"Cannot write to a shared memory for space with "
"type `{}`. Shared memory only supports "
f"type `{type(space)}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
"Gym spaces."
)

View File

@@ -1,6 +1,8 @@
"""Utility functions for gym spaces: batch space and iterator."""
from collections import OrderedDict
from copy import deepcopy
from functools import singledispatch
from typing import Iterator
import numpy as np
@@ -12,32 +14,25 @@ __all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
@singledispatch
def batch_space(space, n=1):
def batch_space(space: Space, n: int = 1) -> Space:
"""Create a (batched) space, containing multiple copies of a single space.
Parameters
----------
space : `gym.spaces.Space` instance
Space (e.g. the observation space) for a single environment in the
vectorized environment.
Example::
n : int
Number of environments in the vectorized environment.
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
... })
>>> batch_space(space, n=5)
Dict(position:Box(5, 3), velocity:Box(5, 2))
Returns
-------
batched_space : `gym.spaces.Space` instance
Space (e.g. the observation space) for a batch of environments in the
vectorized environment.
Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
n: Number of environments in the vectorized environment.
Example
-------
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> batch_space(space, n=5)
Dict(position:Box(5, 3), velocity:Box(5, 2))
Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
"""
raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
@@ -126,41 +121,35 @@ def _batch_space_custom(space, n=1):
@singledispatch
def iterate(space, items):
def iterate(space: Space, items) -> Iterator:
"""Iterate over the elements of a (batched) space.
Parameters
----------
space : `gym.spaces.Space` instance
Space to which `items` belong to.
Example::
items : samples of `space`
Items to be iterated over.
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
>>> next(it)
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
>>> next(it)
StopIteration
Returns
-------
iterator : `Iterable` instance
Args:
space: Space to which `items` belong to.
items: Items to be iterated over.
Returns:
Iterator over the elements in `items`.
Example
-------
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
>>> next(it)
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
>>> next(it)
StopIteration
"""
raise ValueError(
"Space of type `{}` is not a valid `gym.Space` " "instance.".format(type(space))
f"Space of type `{type(space)}` is not a valid `gym.Space` " "instance."
)

View File

@@ -1,4 +1,7 @@
from typing import List, Optional, Union
"""Base class for vectorized environments."""
from __future__ import annotations
from typing import Any, Optional, Union
import gym
from gym.logger import deprecation
@@ -8,32 +11,28 @@ __all__ = ["VectorEnv"]
class VectorEnv(gym.Env):
r"""Base class for vectorized environments. Runs multiple independent copies of the
same environment in parallel. This is not the same as 1 environment that has multiple
sub components, but it is many copies of the same base env.
"""Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel.
Each observation returned from vectorized environment is a batch of observations
for each parallel environment. And :meth:`step` is also expected to receive a batch of
actions for each parallel environment.
This is not the same as 1 environment that has multiple subcomponents, but it is many copies of the same base env.
.. note::
Each observation returned from vectorized environment is a batch of observations for each parallel environment.
And :meth:`step` is also expected to receive a batch of actions for each parallel environment.
Notes:
All parallel environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported.
Parameters
----------
num_envs : int
Number of environments in the vectorized environment.
observation_space : :class:`gym.spaces.Space`
Observation space of a single environment.
action_space : :class:`gym.spaces.Space`
Action space of a single environment.
"""
def __init__(self, num_envs, observation_space, action_space):
def __init__(
self, num_envs: int, observation_space: gym.Space, action_space: gym.Space
):
"""Base class for vectorized environments.
Args:
num_envs: Number of environments in the vectorized environment.
observation_space: Observation space of a single environment.
action_space: Action space of a single environment.
"""
self.num_envs = num_envs
self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs)
@@ -49,141 +48,134 @@ class VectorEnv(gym.Env):
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Reset the sub-environments asynchronously.
This method will return ``None``. A call to :meth:`reset_async` should be followed by a call to :meth:`reset_wait` to retrieve the results.
"""
pass
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Retrieves the results of a :meth:`reset_async` call.
A call to this method must always be preceded by a call to :meth:`reset_async`.
"""
raise NotImplementedError()
def reset(
self,
*,
seed: Optional[Union[int, List[int]]] = None,
seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
r"""Reset all parallel environments and return a batch of initial observations.
"""Reset all parallel environments and return a batch of initial observations.
Returns
-------
observations : element of :attr:`observation_space`
Args:
seed: The environment reset seeds
return_info: If to return the info
options: If to return the options
Returns:
A batch of observations from the vectorized environment.
"""
self.reset_async(seed=seed, return_info=return_info, options=options)
return self.reset_wait(seed=seed, return_info=return_info, options=options)
def step_async(self, actions):
"""Asynchronously performs steps in the sub-environments.
The results can be retrieved via a call to :meth:`step_wait`.
"""
pass
def step_wait(self, **kwargs):
"""Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`.
"""
raise NotImplementedError()
def step(self, actions):
r"""Take an action for each parallel environment.
"""Take an action for each parallel environment.
Parameters
----------
actions : element of :attr:`action_space`
Batch of actions.
Args:
actions: element of :attr:`action_space` Batch of actions.
Returns
-------
observations : element of :attr:`observation_space`
A batch of observations from the vectorized environment.
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
A vector of rewards from the vectorized environment.
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
A vector whose entries indicate whether the episode has ended.
infos : list of dict
A list of auxiliary diagnostic information dicts from each parallel environment.
Returns:
Batch of observations, rewards, done and infos
"""
self.step_async(actions)
return self.step_wait()
def call_async(self, name, *args, **kwargs):
"""Calls a method name for each parallel environment asynchronously."""
pass
def call_wait(self, **kwargs):
"""After calling a method in :meth:`call_async`, this function collects the results."""
raise NotImplementedError()
def call(self, name, *args, **kwargs):
def call(self, name: str, *args, **kwargs) -> list[Any]:
"""Call a method, or get a property, from each parallel environment.
Parameters
----------
name : string
Name of the method or property to call.
Args:
name (str): Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
*args
Arguments to apply to the method call.
**kwargs
Keywoard arguments to apply to the method call.
Returns
-------
results : list
List of the results of the individual calls to the method or
property for each environment.
Returns:
List of the results of the individual calls to the method or property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()
def get_attr(self, name):
def get_attr(self, name: str):
"""Get a property from each parallel environment.
Parameters
----------
name : string
Name of the property to be get from each individual environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name, values):
"""Set a property in each parallel environment.
def set_attr(self, name: str, values: Union[list, tuple, object]):
"""Set a property in each sub-environment.
Parameters
----------
name : string
Name of the property to be set in each individual environment.
values : list, tuple, or object
Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
Args:
name (str): Name of the property to be set in each individual environment.
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual environment, otherwise a single value
is set for all environments.
"""
raise NotImplementedError()
def close_extras(self, **kwargs):
r"""Clean up the extra resources e.g. beyond what's in this base class."""
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
def close(self, **kwargs):
r"""Close all parallel environments and release resources.
"""Close all parallel environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
:attr:`closed` as ``True``.
.. warning::
Warnings:
This function itself does not close the environments, it should be handled
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
vectorized environments.
.. note::
Notes:
This will be automatically called when garbage collected or program exited.
"""
@@ -197,14 +189,12 @@ class VectorEnv(gym.Env):
def seed(self, seed=None):
"""Set the random seed in all parallel environments.
Parameters
----------
seed : list of int, or int, optional
Random seed for each parallel environment. If ``seed`` is a list of
length ``num_envs``, then the items of the list are chosen as random
seeds. If ``seed`` is an int, then each parallel environment uses the random
seed ``seed + n``, where ``n`` is the index of the parallel environment
(between ``0`` and ``num_envs - 1``).
Args:
seed: Random seed for each parallel environment. If ``seed`` is a list of
length ``num_envs``, then the items of the list are chosen as random
seeds. If ``seed`` is an int, then each parallel environment uses the random
seed ``seed + n``, where ``n`` is the index of the parallel environment
(between ``0`` and ``num_envs - 1``).
"""
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
@@ -212,10 +202,12 @@ class VectorEnv(gym.Env):
)
def __del__(self):
"""Closes the vector environment."""
if not getattr(self, "closed", True):
self.close()
def __repr__(self):
"""Returns a string representation of the vector environment using the class name, number of environments and environment spec id."""
if self.spec is None:
return f"{self.__class__.__name__}({self.num_envs})"
else:
@@ -223,19 +215,17 @@ class VectorEnv(gym.Env):
class VectorEnvWrapper(VectorEnv):
r"""Wraps the vectorized environment to allow a modular transformation.
"""Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass
could override some methods to change the behavior of the original vectorized environment
without touching the original code.
.. note::
Notes:
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
def __init__(self, env):
def __init__(self, env: VectorEnv):
assert isinstance(env, VectorEnv)
self.env = env

View File

@@ -17,7 +17,7 @@ extras = {
"classic_control": ["pygame==2.1.0"],
"mujoco": ["mujoco_py>=1.50, <2.0"],
"toy_text": ["pygame==2.1.0", "scipy>=1.4.1"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0"],
}
# Meta dependency groups.