Pydocstyle utils vector docstring (#2788)

* Added pydocstyle to pre-commit

* Added docstrings for tests and updated the tests for autoreset

* Add pydocstyle exclude folder to allow slowly adding new docstrings

* Add docstrings for setup.py and gym/__init__.py, core.py, error.py and logger.py

* Check that all unwrapped environment are of a particular wrapper type

* Reverted back to import gym.spaces.Space to gym.spaces

* Fixed the __init__.py docstring

* Fixed autoreset autoreset test

* Updated gym __init__.py top docstring

* Fix examples in docstrings

* Add docstrings and type hints where known to all functions and classes in gym/utils and gym/vector

* Remove unnecessary import

* Removed "unused error" and make APIerror deprecated at gym 1.0

* Add pydocstyle description to CONTRIBUTING.md

* Added docstrings section to CONTRIBUTING.md

* Added :meth: and :attr: keywords to docstrings

* Added :meth: and :attr: keywords to docstrings

* Imported annotations from __future__ to fix python 3.7

* Add __future__ import annotations for python 3.7

* isort

* Remove utils and vectors for this PR and spaces for previous PR

* Update gym/envs/classic_control/acrobot.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/envs/classic_control/acrobot.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/envs/classic_control/acrobot.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/spaces/dict.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/env_checker.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/ezpickle.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/ezpickle.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Update gym/utils/play.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Pre-commit

* Updated docstrings with :meth:

* Updated docstrings with :meth:

* Update gym/utils/play.py

* Update gym/utils/play.py

* Update gym/utils/play.py

* Apply suggestions from code review

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* pre-commit

* Update gym/utils/play.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Updated fps and zoom parameter docstring

* Update play docstring

* Apply suggestions from code review

Added suggested corrections from @markus28

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Pre-commit magic

* Update the `gym.make` docstring with a warning for `env_checker`

* Updated and fixed vector docstrings

* Update test names for reflect the project filename style

Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
Mark Towers
2022-05-20 14:49:30 +01:00
committed by GitHub
parent 1b09191535
commit e2266025e6
24 changed files with 845 additions and 757 deletions

View File

@@ -30,7 +30,7 @@ repos:
rev: 6.1.1 # pick a git hash / tag to point to rev: 6.1.1 # pick a git hash / tag to point to
hooks: hooks:
- id: pydocstyle - id: pydocstyle
exclude: ^(gym/version.py)|(gym/(envs|utils|vector)/)|(tests/) exclude: ^(gym/version.py)|(gym/envs/)|(tests/)
args: args:
- --source - --source
- --explain - --explain

View File

@@ -398,27 +398,26 @@ def rk4(derivs, y0, t):
yourself stranded on a system w/o scipy. Otherwise use yourself stranded on a system w/o scipy. Otherwise use
:func:`scipy.integrate`. :func:`scipy.integrate`.
Example:
>>> ### 2D system
>>> def derivs(x):
... d1 = x[0] + 2*x[1]
... d2 = -3*x[0] + 4*x[1]
... return (d1, d2)
>>> dt = 0.0005
>>> t = arange(0.0, 2.0, dt)
>>> y0 = (1,2)
>>> yout = rk4(derivs, y0, t)
If you have access to scipy, you should probably be using the
:func:`scipy.integrate` tools rather than this function.
This would then require re-adding the time variable to the signature of derivs.
Args: Args:
derivs: the derivative of the system and has the signature ``dy = derivs(yi)`` derivs: the derivative of the system and has the signature ``dy = derivs(yi)``
y0: initial state vector y0: initial state vector
t: sample times t: sample times
args: additional arguments passed to the derivative function
kwargs: additional keyword arguments passed to the derivative function
Example 1 ::
### 2D system
def derivs(x):
d1 = x[0] + 2*x[1]
d2 = -3*x[0] + 4*x[1]
return (d1, d2)
dt = 0.0005
t = arange(0.0, 2.0, dt)
y0 = (1,2)
yout = rk4(derivs6, y0, t)
If you have access to scipy, you should probably be using the
scipy.integrate tools rather than this function.
This would then require re-adding the time variable to the signature of derivs.
Returns: Returns:
yout: Runge-Kutta approximation of the ODE yout: Runge-Kutta approximation of the ODE

View File

@@ -499,10 +499,16 @@ def make(
""" """
Create an environment according to the given ID. Create an environment according to the given ID.
Warnings:
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid
if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`.
Args: Args:
id: Name of the environment. id: Name of the environment.
max_episode_steps: Maximum length of an episode (TimeLimit wrapper). max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
disable_env_checker: If to disable the environment checker
kwargs: Additional arguments to pass to the environment constructor. kwargs: Additional arguments to pass to the environment constructor.
Returns: Returns:
An instance of the environment. An instance of the environment.

View File

@@ -17,25 +17,27 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
Elements of this space are (ordered) dictionaries of elements from the constituent spaces. Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
Example usage:: Example usage:
>>> observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)}) >>> from gym.spaces import Dict, Discrete
>>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
>>> observation_space.sample() >>> observation_space.sample()
OrderedDict([('position', 1), ('velocity', 2)]) OrderedDict([('position', 1), ('velocity', 2)])
Example usage [nested]:: Example usage [nested]::
>>> spaces.Dict( >>> from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete
>>> Dict(
... { ... {
... "ext_controller": spaces.MultiDiscrete((5, 2, 2)), ... "ext_controller": MultiDiscrete([5, 2, 2]),
... "inner_state": spaces.Dict( ... "inner_state": Dict(
... { ... {
... "charge": spaces.Discrete(100), ... "charge": Discrete(100),
... "system_checks": spaces.MultiBinary(10), ... "system_checks": MultiBinary(10),
... "job_status": spaces.Dict( ... "job_status": Dict(
... { ... {
... "task": spaces.Discrete(5), ... "task": Discrete(5),
... "progress": spaces.Box(low=0, high=100, shape=()), ... "progress": Box(low=0, high=100, shape=()),
... } ... }
... ), ... ),
... } ... }
@@ -63,9 +65,10 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
Example:: Example::
>>> spaces.Dict({"position": spaces.Box(-1, 1, shape=(2,)), "color": spaces.Discrete(3)}) >>> from gym.spaces import Box, Discrete
>>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)})
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32)) Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
>>> spaces.Dict(position=spaces.Box(-1, 1, shape=(2,)), color=spaces.Discrete(3)) >>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3))
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32)) Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
Args: Args:

View File

@@ -16,11 +16,11 @@ class MultiBinary(Space[np.ndarray]):
Example Usage:: Example Usage::
>>> self.observation_space = spaces.MultiBinary(5) >>> observation_space = MultiBinary(5)
>>> self.observation_space.sample() >>> observation_space.sample()
array([0, 1, 0, 1, 0], dtype=int8) array([0, 1, 0, 1, 0], dtype=int8)
>>> self.observation_space = spaces.MultiBinary([3, 2]) >>> observation_space = MultiBinary([3, 2])
>>> self.observation_space.sample() >>> observation_space.sample()
array([[0, 0], array([[0, 0],
[0, 1], [0, 1],
[1, 1]], dtype=int8) [1, 1]], dtype=int8)

View File

@@ -16,8 +16,9 @@ class Tuple(Space[tuple], Sequence):
Example usage:: Example usage::
>> observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Box(-1, 1, shape=(2,)))) >>> from gym.spaces import Box, Discrete
>> observation_space.sample() >>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))))
>>> observation_space.sample()
(0, array([0.03633198, 0.42370757], dtype=float32)) (0, array([0.03633198, 0.42370757], dtype=float32))
""" """

View File

@@ -25,8 +25,9 @@ def flatdim(space: Space) -> int:
Example usage:: Example usage::
>>> s = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)}) >>> from gym.spaces import Discrete
>>> spaces.flatdim(s) >>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
>>> flatdim(space)
5 5
""" """
raise NotImplementedError(f"Unknown space: `{space}`") raise NotImplementedError(f"Unknown space: `{space}`")
@@ -195,8 +196,7 @@ def flatten_space(space: Space) -> Box:
Example that recursively flattens a dict:: Example that recursively flattens a dict::
>>> space = Dict({"position": Discrete(2), >>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
... "velocity": Box(0, 1, shape=(2, 2))})
>>> flatten_space(space) >>> flatten_space(space)
Box(6,) Box(6,)
>>> flatten(space, space.sample()) in flatten_space(space) >>> flatten(space, space.sample()) in flatten_space(space)

View File

@@ -1,5 +1,6 @@
"""A set of common utilities used within the environments. These are """A set of common utilities used within the environments.
not intended as API functions, and will not remain stable over time.
These are not intended as API functions, and will not remain stable over time.
""" """
color2num = dict( color2num = dict(
@@ -15,12 +16,20 @@ color2num = dict(
) )
def colorize(string, color, bold=False, highlight=False): def colorize(
"""Return string surrounded by appropriate terminal color codes to string: str, color: str, bold: bool = False, highlight: bool = False
print colorized text. Valid colors: gray, red, green, yellow, ) -> str:
blue, magenta, cyan, white, crimson """Returns string surrounded by appropriate terminal colour codes to print colourised text.
"""
Args:
string: The message to colourise
color: Literal values are gray, red, green, yellow, blue, magenta, cyan, white, crimson
bold: If to bold the string
highlight: If to highlight the string
Returns:
Colourised string
"""
attr = [] attr = []
num = color2num[color] num = color2num[color]
if highlight: if highlight:

View File

@@ -1,4 +1,5 @@
""" """A set of functions for checking an environment details.
This file is originally from the Stable Baselines3 repository hosted on GitHub This file is originally from the Stable Baselines3 repository hosted on GitHub
(https://github.com/DLR-RM/stable-baselines3/) (https://github.com/DLR-RM/stable-baselines3/)
Original Author: Antonin Raffin Original Author: Antonin Raffin
@@ -16,21 +17,33 @@ from typing import Optional, Union
import numpy as np import numpy as np
import gym import gym
from gym import logger, spaces from gym import logger
from gym.spaces import Box, Dict, Discrete, Space, Tuple
def _is_numpy_array_space(space: spaces.Space) -> bool: def _is_numpy_array_space(space: Space) -> bool:
"""Checks if a space can be represented as a single numpy array (e.g. Dict and Tuple spaces return False).
Args:
space: The space to check
Returns:
Returns False if the provided space is not representable as a single numpy array
""" """
Returns False if provided space is not representable as a single numpy array return not isinstance(space, (Dict, Tuple))
(e.g. Dict and Tuple spaces return False)
"""
return not isinstance(space, (spaces.Dict, spaces.Tuple))
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: def _check_image_input(observation_space: Box, key: str = ""):
""" """Check whether an observation space of type :class:`Box` adheres to general standards for spaces that represent images.
Check that the input adheres to general standards
when the observation is apparently an image. It will check that:
- The datatype is ``np.uint8``
- The lower bound is 0 across all dimensions
- The upper bound is 255 across all dimensions
Args:
observation_space: The observation space to check
key: The observation shape key for warning
""" """
if observation_space.dtype != np.uint8: if observation_space.dtype != np.uint8:
logger.warn( logger.warn(
@@ -49,8 +62,13 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
) )
def _check_nan(env: gym.Env, check_inf: bool = True) -> None: def _check_nan(env: gym.Env, check_inf: bool = True):
"""Check for NaN and Inf.""" """Check if the environment observation, reward are NaN and Inf.
Args:
env: The environment to check
check_inf: Checks if the observation is infinity
"""
for _ in range(10): for _ in range(10):
action = env.action_space.sample() action = env.action_space.sample()
observation, reward, done, _ = env.step(action) observation, reward, done, _ = env.step(action)
@@ -70,19 +88,22 @@ def _check_nan(env: gym.Env, check_inf: bool = True) -> None:
def _check_obs( def _check_obs(
obs: Union[tuple, dict, np.ndarray, int], obs: Union[tuple, dict, np.ndarray, int],
observation_space: spaces.Space, observation_space: Space,
method_name: str, method_name: str,
) -> None: ):
"""Check that the observation returned by the environment correspond to the declared one.
Args:
obs: The observation to check
observation_space: The observation space of the observation
method_name: The method name that generated the observation
""" """
Check that the observation returned by the environment if not isinstance(observation_space, Tuple):
correspond to the declared one.
"""
if not isinstance(observation_space, spaces.Tuple):
assert not isinstance( assert not isinstance(
obs, tuple obs, tuple
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple" ), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
if isinstance(observation_space, spaces.Discrete): if isinstance(observation_space, Discrete):
assert isinstance( assert isinstance(
obs, int obs, int
), f"The observation returned by `{method_name}()` method must be an int" ), f"The observation returned by `{method_name}()` method must be an int"
@@ -96,12 +117,16 @@ def _check_obs(
), f"The observation returned by the `{method_name}()` method does not match the given observation space" ), f"The observation returned by the `{method_name}()` method does not match the given observation space"
def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: def _check_box_obs(observation_space: Box, key: str = ""):
""" """Check that the observation space is correctly formatted when dealing with a :class:`Box` space.
Check that the observation space is correctly formatted
when dealing with a ``Box()`` space. In particular, it checks: In particular, it checks:
- that the dimensions are big enough when it is an image, and that the type matches - that the dimensions are big enough when it is an image, and that the type matches
- that the observation has an expected shape (warn the user if not) - that the observation has an expected shape (warn the user if not)
Args:
observation_space: Checks if the Box observation space
key: The observation key
""" """
# If image, check the low and high values, the type and the number of channels # If image, check the low and high values, the type and the number of channels
# and the shape (minimal value) # and the shape (minimal value)
@@ -137,14 +162,19 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
), "Agent's observation_space.high and observation_space have different shapes" ), "Agent's observation_space.high and observation_space have different shapes"
def _check_box_action(action_space: spaces.Box): def _check_box_action(action_space: Box):
"""Checks that a :class:`Box` action space is defined in a sensible way.
Args:
action_space: A box action space
"""
if np.any(np.equal(action_space.low, -np.inf)): if np.any(np.equal(action_space.low, -np.inf)):
logger.warn( logger.warn(
"Agent's minimum action space value is -infinity. This is probably too low." "Agent's minimum action space value is -infinity. This is probably too low."
) )
if np.any(np.equal(action_space.high, np.inf)): if np.any(np.equal(action_space.high, np.inf)):
logger.warn( logger.warn(
"Agent's maxmimum action space value is infinity. This is probably too high" "Agent's maximum action space value is infinity. This is probably too high"
) )
if np.any(np.equal(action_space.low, action_space.high)): if np.any(np.equal(action_space.low, action_space.high)):
logger.warn("Agent's maximum and minimum action space values are equal") logger.warn("Agent's maximum and minimum action space values are equal")
@@ -156,7 +186,12 @@ def _check_box_action(action_space: spaces.Box):
assert False, "Agent's action_space.high and action_space have different shapes" assert False, "Agent's action_space.high and action_space have different shapes"
def _check_normalized_action(action_space: spaces.Box): def _check_normalized_action(action_space: Box):
"""Checks that a box action space is normalized.
Args:
action_space: A box action space
"""
if ( if (
np.any(np.abs(action_space.low) != np.abs(action_space.high)) np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(np.abs(action_space.low) > 1) or np.any(np.abs(action_space.low) > 1)
@@ -168,16 +203,18 @@ def _check_normalized_action(action_space: spaces.Box):
) )
def _check_returned_values( def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space):
env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space """Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods.
) -> None:
""" Args:
Check the returned values by the env when calling `.reset()` or `.step()` methods. env: The environment
observation_space: The environment's observation space
action_space: The environment's action space
""" """
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
obs = env.reset() obs = env.reset()
if isinstance(observation_space, spaces.Dict): if isinstance(observation_space, Dict):
assert isinstance( assert isinstance(
obs, dict obs, dict
), "The observation returned by `reset()` must be a dictionary" ), "The observation returned by `reset()` must be a dictionary"
@@ -200,7 +237,7 @@ def _check_returned_values(
# Unpack # Unpack
obs, reward, done, info = data obs, reward, done, info = data
if isinstance(observation_space, spaces.Dict): if isinstance(observation_space, Dict):
assert isinstance( assert isinstance(
obs, dict obs, dict
), "The observation returned by `step()` must be a dictionary" ), "The observation returned by `step()` must be a dictionary"
@@ -223,10 +260,11 @@ def _check_returned_values(
), "The `info` returned by `step()` must be a python dictionary" ), "The `info` returned by `step()` must be a python dictionary"
def _check_spaces(env: gym.Env) -> None: def _check_spaces(env: gym.Env):
""" """Check that the observation and action spaces are defined and inherit from :class:`gym.spaces.Space`.
Check that the observation and action spaces are defined
and inherit from gym.spaces.Space. Args:
env: The environment's observation and action space to check
""" """
# Helper to link to the code, because gym has no proper documentation # Helper to link to the code, because gym has no proper documentation
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/" gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
@@ -238,25 +276,22 @@ def _check_spaces(env: gym.Env) -> None:
"You must specify an action space (cf gym.spaces)" + gym_spaces "You must specify an action space (cf gym.spaces)" + gym_spaces
) )
assert isinstance(env.observation_space, spaces.Space), ( assert isinstance(env.observation_space, Space), (
"The observation space must inherit from gym.spaces" + gym_spaces "The observation space must inherit from gym.spaces" + gym_spaces
) )
assert isinstance(env.action_space, spaces.Space), ( assert isinstance(env.action_space, Space), (
"The action space must inherit from gym.spaces" + gym_spaces "The action space must inherit from gym.spaces" + gym_spaces
) )
# Check render cannot be covered by CI # Check render cannot be covered by CI
def _check_render( def _check_render(env: gym.Env, warn: bool = True, headless: bool = False):
env: gym.Env, warn: bool = True, headless: bool = False """Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
) -> None: # pragma: no cover
""" Args:
Check the declared render modes/fps and the `render()`/`close()` env: The environment to check
method of the environment. warn: Whether to output additional warnings
:param env: The environment to check headless: Whether to disable render modes that require a graphical interface. False by default.
:param warn: Whether to output additional warnings
:param headless: Whether to disable render modes
that require a graphical interface. False by default.
""" """
render_modes = env.metadata.get("render_modes") render_modes = env.metadata.get("render_modes")
if render_modes is None: if render_modes is None:
@@ -288,9 +323,12 @@ def _check_render(
env.close() env.close()
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None: def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
""" """Check that the environment can be reset with a seed.
Check that the environment can be reset with a random seed.
Args:
env: The environment to check
seed: The optional seed to use
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
assert ( assert (
@@ -303,7 +341,7 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
raise AssertionError( raise AssertionError(
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` " "The environment cannot be reset with a random seed, even though `seed` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. " "appear in the signature. This should never happen, please report this issue. "
"The error was: " + str(e) f"The error was: {e}"
) )
if env.unwrapped.np_random is None: if env.unwrapped.np_random is None:
@@ -322,7 +360,12 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
) )
def _check_reset_info(env: gym.Env) -> None: def _check_reset_info(env: gym.Env):
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
Args:
env: The environment to check
"""
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
assert ( assert (
"return_info" in signature.parameters or "kwargs" in signature.parameters "return_info" in signature.parameters or "kwargs" in signature.parameters
@@ -334,7 +377,7 @@ def _check_reset_info(env: gym.Env) -> None:
raise AssertionError( raise AssertionError(
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` " "The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. " "appear in the signature. This should never happen, please report this issue. "
"The error was: " + str(e) f"The error was: {e}"
) )
assert ( assert (
len(result) == 2 len(result) == 2
@@ -346,9 +389,11 @@ def _check_reset_info(env: gym.Env) -> None:
), "The second element returned by `env.reset(return_info=True)` was not a dictionary" ), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
def _check_reset_options(env: gym.Env) -> None: def _check_reset_options(env: gym.Env):
""" """Check that the environment can be reset with options.
Check that the environment can be reset with options.
Args:
env: The environment to check
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
assert ( assert (
@@ -361,22 +406,22 @@ def _check_reset_options(env: gym.Env) -> None:
raise AssertionError( raise AssertionError(
"The environment cannot be reset with options, even though `options` or `kwargs` " "The environment cannot be reset with options, even though `options` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. " "appear in the signature. This should never happen, please report this issue. "
"The error was: " + str(e) f"The error was: {e}"
) )
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True):
""" """Check that an environment follows Gym API.
Check that an environment follows Gym API.
This is particularly useful when using a custom environment. This is particularly useful when using a custom environment.
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
for more information about the API. for more information about the API.
It also optionally check that the environment is compatible with Stable-Baselines. It also optionally check that the environment is compatible with Stable-Baselines.
:param env: The Gym environment that will be checked
:param warn: Whether to output additional warnings Args:
mainly related to the interaction with Stable Baselines env: The Gym environment that will be checked
:param skip_render_check: Whether to skip the checks for the render method. warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines
True by default (useful for the CI) skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
""" """
assert isinstance( assert isinstance(
env, gym.Env env, gym.Env
@@ -393,15 +438,15 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
if warn: if warn:
obs_spaces = ( obs_spaces = (
observation_space.spaces observation_space.spaces
if isinstance(observation_space, spaces.Dict) if isinstance(observation_space, Dict)
else {"": observation_space} else {"": observation_space}
) )
for key, space in obs_spaces.items(): for key, space in obs_spaces.items():
if isinstance(space, spaces.Box): if isinstance(space, Box):
_check_box_obs(space, key) _check_box_obs(space, key)
# Check for the action space, it may lead to hard-to-debug issues # Check for the action space, it may lead to hard-to-debug issues
if isinstance(action_space, spaces.Box): if isinstance(action_space, Box):
_check_box_action(action_space) _check_box_action(action_space)
_check_normalized_action(action_space) _check_normalized_action(action_space)

View File

@@ -1,33 +1,35 @@
"""Class for pickling and unpickling objects via their constructor arguments."""
class EzPickle: class EzPickle:
"""Objects that are pickled and unpickled via their constructor """Objects that are pickled and unpickled via their constructor arguments.
arguments.
Example usage: Example::
class Dog(Animal, EzPickle): >>> class Dog(Animal, EzPickle):
def __init__(self, furcolor, tailkind="bushy"): ... def __init__(self, furcolor, tailkind="bushy"):
Animal.__init__() ... Animal.__init__()
EzPickle.__init__(furcolor, tailkind) ... EzPickle.__init__(furcolor, tailkind)
...
When this object is unpickled, a new Dog will be constructed by passing the provided When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
furcolor and tailkind into the constructor. However, philosophers are still not sure However, philosophers are still not sure whether it is still the same dog.
whether it is still the same dog.
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
and Atari.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
self._ezpickle_args = args self._ezpickle_args = args
self._ezpickle_kwargs = kwargs self._ezpickle_kwargs = kwargs
def __getstate__(self): def __getstate__(self):
"""Returns the object pickle state with args and kwargs."""
return { return {
"_ezpickle_args": self._ezpickle_args, "_ezpickle_args": self._ezpickle_args,
"_ezpickle_kwargs": self._ezpickle_kwargs, "_ezpickle_kwargs": self._ezpickle_kwargs,
} }
def __setstate__(self, d): def __setstate__(self, d):
"""Sets the object pickle state using d."""
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"]) out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
self.__dict__.update(out.__dict__) self.__dict__.update(out.__dict__)

View File

@@ -1,11 +1,18 @@
"""Utilities of visualising an environment."""
from __future__ import annotations
from collections import deque
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union
import numpy as np
import pygame import pygame
from numpy.typing import NDArray
from pygame import Surface from pygame import Surface
from pygame.event import Event from pygame.event import Event
from pygame.locals import VIDEORESIZE
from gym import Env, logger from gym import Env, logger
from gym.core import ActType, ObsType
from gym.error import DependencyNotInstalled
from gym.logger import deprecation from gym.logger import deprecation
try: try:
@@ -13,30 +20,31 @@ try:
matplotlib.use("TkAgg") matplotlib.use("TkAgg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError as e: except ImportError:
logger.warn(f"failed to set matplotlib backend, plotting will not work: {str(e)}") logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
plt = None matplotlib, plt = None, None
from collections import deque
from pygame.locals import VIDEORESIZE
from gym.core import ActType
class MissingKeysToAction(Exception): class MissingKeysToAction(Exception):
"""Raised when the environment does not have """Raised when the environment does not have a default ``keys_to_action`` mapping."""
a default keys_to_action mapping
"""
class PlayableGame: class PlayableGame:
"""Wraps an environment allowing keyboard inputs to interact with the environment."""
def __init__( def __init__(
self, self,
env: Env, env: Env,
keys_to_action: Optional[Dict[Tuple[int], int]] = None, keys_to_action: Optional[dict[tuple[int], int]] = None,
zoom: Optional[float] = None, zoom: Optional[float] = None,
): ):
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
Args:
env: The environment to play
keys_to_action: The dictionary of keyboard tuples and action value
zoom: If to zoom in on the environment render
"""
self.env = env self.env = env
self.relevant_keys = self._get_relevant_keys(keys_to_action) self.relevant_keys = self._get_relevant_keys(keys_to_action)
self.video_size = self._get_video_size(zoom) self.video_size = self._get_video_size(zoom)
@@ -45,7 +53,7 @@ class PlayableGame:
self.running = True self.running = True
def _get_relevant_keys( def _get_relevant_keys(
self, keys_to_action: Optional[Dict[Tuple[int], int]] = None self, keys_to_action: Optional[dict[tuple[int], int]] = None
) -> set: ) -> set:
if keys_to_action is None: if keys_to_action is None:
if hasattr(self.env, "get_keys_to_action"): if hasattr(self.env, "get_keys_to_action"):
@@ -60,7 +68,7 @@ class PlayableGame:
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
return relevant_keys return relevant_keys
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]: def _get_video_size(self, zoom: Optional[float] = None) -> tuple[int, int]:
# TODO: this needs to be updated when the render API change goes through # TODO: this needs to be updated when the render API change goes through
rendered = self.env.render(mode="rgb_array") rendered = self.env.render(mode="rgb_array")
video_size = [rendered.shape[1], rendered.shape[0]] video_size = [rendered.shape[1], rendered.shape[0]]
@@ -70,7 +78,14 @@ class PlayableGame:
return video_size return video_size
def process_event(self, event: Event) -> None: def process_event(self, event: Event):
"""Processes a PyGame event.
In particular, this function is used to keep track of which buttons are currently pressed and to exit the :func:`play` function when the PyGame window is closed.
Args:
event: The event to process
"""
if event.type == pygame.KEYDOWN: if event.type == pygame.KEYDOWN:
if event.key in self.relevant_keys: if event.key in self.relevant_keys:
self.pressed_keys.append(event.key) self.pressed_keys.append(event.key)
@@ -87,9 +102,17 @@ class PlayableGame:
def display_arr( def display_arr(
screen: Surface, arr: NDArray, video_size: Tuple[int, int], transpose: bool screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
): ):
arr_min, arr_max = arr.min(), arr.max() """Displays a numpy array on screen.
Args:
screen: The screen to show the array on
arr: The array to show
video_size: The video size of the screen
transpose: If to transpose the array on the screen
"""
arr_min, arr_max = np.min(arr), np.max(arr)
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min) arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr) pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
pyg_img = pygame.transform.scale(pyg_img, video_size) pyg_img = pygame.transform.scale(pyg_img, video_size)
@@ -108,60 +131,74 @@ def play(
): ):
"""Allows one to play the game using keyboard. """Allows one to play the game using keyboard.
To simply play the game use: Example::
play(gym.make("Pong-v4")) >>> import gym
>>> from gym.utils.play import play
>>> play(gym.make("CarRacing-v1"), keys_to_action={"w": np.array([0, 0.7, 0]),
... "a": np.array([-1, 0, 0]),
... "s": np.array([0, 0, 1]),
... "d": np.array([1, 0, 0]),
... "wa": np.array([-1, 0.7, 0]),
... "dw": np.array([1, 0.7, 0]),
... "ds": np.array([1, 0, 1]),
... "as": np.array([-1, 0, 1]),
... }, noop=np.array([0,0,0]))
Above code works also if env is wrapped, so it's particularly useful in
Above code works also if the environment is wrapped, so it's particularly useful in
verifying that the frame-level preprocessing does not render the game verifying that the frame-level preprocessing does not render the game
unplayable. unplayable.
If you wish to plot real time statistics as you play, you can use If you wish to plot real time statistics as you play, you can use
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward :class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
for last 5 second of gameplay. for last 150 steps.
def callback(obs_t, obs_tp1, action, rew, done, info): >>> def callback(obs_t, obs_tp1, action, rew, done, info):
return [rew,] ... return [rew,]
plotter = PlayPlot(callback, 30 * 5, ["reward"]) >>> plotter = PlayPlot(callback, 150, ["reward"])
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
env = gym.make("Pong-v4")
play(env, callback=plotter.callback)
Arguments Args:
--------- env: Environment to use for playing.
env: gym.Env transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
Environment to use for playing. fps: Maximum number of steps of the environment executed every second. If ``None`` (the default),
transpose: bool ``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
If True the output of observation is transposed. zoom: Zoom the observation in, ``zoom`` amount, should be positive float
Defaults to true. callback: If a callback is provided, it will be executed after every step. It takes the following input:
fps: int obs_t: observation before performing action
Maximum number of steps of the environment to execute every second. obs_tp1: observation after performing action
Defaults to 30. action: action that was executed
zoom: float rew: reward that was received
Make screen edge this many times bigger done: whether the environment is done or not
callback: lambda or None info: debug info
Callback if a callback is provided it will be executed after keys_to_action: Mapping from keys pressed to action performed.
every step. It takes the following input: Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
obs_t: observation before performing action points of the keys, as a tuple of characters, or as a string where each character of the string represents
obs_tp1: observation after performing action one key.
action: action that was executed For example if pressing 'w' and space at the same time is supposed
rew: reward that was received to trigger action number 2 then ``key_to_action`` dict could look like this:
done: whether the environment is done or not >>> {
info: debug info ... # ...
keys_to_action: dict: tuple(int) -> int or None ... (ord('w'), ord(' ')): 2
Mapping from keys pressed to action performed. ... # ...
For example if pressed 'w' and space at the same time is supposed ... }
to trigger action number 2 then key_to_action dict would look like this: or like this:
>>> {
{ ... # ...
# ... ... ("w", " "): 2
sorted(ord('w'), ord(' ')) -> 2 ... # ...
# ... ... }
} or like this:
If None, default key_to_action mapping for that env is used, if provided. >>> {
seed: bool or None ... # ...
Random seed used when resetting the environment. If None, no seed is used. ... "w ": 2
... # ...
... }
If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
seed: Random seed used when resetting the environment. If None, no seed is used.
noop: The action used when no key input has been entered, or the entered key combination is unknown.
""" """
env.reset(seed=seed) env.reset(seed=seed)
@@ -208,7 +245,44 @@ def play(
class PlayPlot: class PlayPlot:
def __init__(self, callback, horizon_timesteps, plot_names): """Provides a callback to create live plots of arbitrary metrics when using :func:`play`.
This class is instantiated with a function that accepts information about a single environment transition:
- obs_t: observation before performing action
- obs_tp1: observation after performing action
- action: action that was executed
- rew: reward that was received
- done: whether the environment is done or not
- info: debug info
It should return a list of metrics that are computed from this data.
For instance, the function may look like this::
def compute_metrics(obs_t, obs_tp, action, reward, done, info):
return [reward, info["cumulative_reward"], np.linalg.norm(action)]
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
and uses the returned values to update live plots of the metrics.
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
>>> play(your_env, callback=plotter.callback)
"""
def __init__(
self, callback: callable, horizon_timesteps: int, plot_names: list[str]
):
"""Constructor of :class:`PlayPlot`.
The function ``callback`` that is passed to this constructor should return
a list of metrics that is of length ``len(plot_names)``.
Args:
callback: Function that computes metrics from environment transitions
horizon_timesteps: The time horizon used for the live plots
plot_names: List of plot titles
"""
deprecation( deprecation(
"`PlayPlot` is marked as deprecated and will be removed in the near future." "`PlayPlot` is marked as deprecated and will be removed in the near future."
) )
@@ -216,7 +290,10 @@ class PlayPlot:
self.horizon_timesteps = horizon_timesteps self.horizon_timesteps = horizon_timesteps
self.plot_names = plot_names self.plot_names = plot_names
assert plt is not None, "matplotlib backend failed, plotting will not work" if plt is None:
raise DependencyNotInstalled(
"matplotlib is not installed, run `pip install gym[other]`"
)
num_plots = len(self.plot_names) num_plots = len(self.plot_names)
self.fig, self.ax = plt.subplots(num_plots) self.fig, self.ax = plt.subplots(num_plots)
@@ -228,7 +305,25 @@ class PlayPlot:
self.cur_plot = [None for _ in range(num_plots)] self.cur_plot = [None for _ in range(num_plots)]
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
def callback(self, obs_t, obs_tp1, action, rew, done, info): def callback(
self,
obs_t: ObsType,
obs_tp1: ObsType,
action: ActType,
rew: float,
done: bool,
info: dict,
):
"""The callback that calls the provided data callback and adds the data to the plots.
Args:
obs_t: The observation at time step t
obs_tp1: The observation at time step t+1
action: The action
rew: The reward
done: If the environment is done
info: The information from the environment
"""
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
for point, data_series in zip(points, self.data): for point, data_series in zip(points, self.data):
data_series.append(point) data_series.append(point)

View File

@@ -1,7 +1,10 @@
"""Set of random number generator functions: seeding, generator, hashing seeds."""
from __future__ import annotations
import hashlib import hashlib
import os import os
import struct import struct
from typing import Any, List, Optional, Tuple, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
@@ -9,7 +12,15 @@ from gym import error
from gym.logger import deprecation from gym.logger import deprecation
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]: def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]:
"""Generates a random number generator from the seed and returns the Generator and seed.
Args:
seed: The seed used to create the generator
Returns:
The generator and resulting seed
"""
if seed is not None and not (isinstance(seed, int) and 0 <= seed): if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}") raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
@@ -22,7 +33,10 @@ def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]
# TODO: Remove this class and make it alias to `Generator` in a future Gym release # TODO: Remove this class and make it alias to `Generator` in a future Gym release
# RandomNumberGenerator = np.random.Generator # RandomNumberGenerator = np.random.Generator
class RandomNumberGenerator(np.random.Generator): class RandomNumberGenerator(np.random.Generator):
"""Random number generator class that inherits from numpy's random Generator class."""
def rand(self, *size): def rand(self, *size):
"""Deprecated rand function using random."""
deprecation( deprecation(
"Function `rng.rand(*size)` is marked as deprecated " "Function `rng.rand(*size)` is marked as deprecated "
"and will be removed in the future. " "and will be removed in the future. "
@@ -34,6 +48,7 @@ class RandomNumberGenerator(np.random.Generator):
random_sample = rand random_sample = rand
def randn(self, *size): def randn(self, *size):
"""Deprecated random standard normal function use standard_normal."""
deprecation( deprecation(
"Function `rng.randn(*size)` is marked as deprecated " "Function `rng.randn(*size)` is marked as deprecated "
"and will be removed in the future. " "and will be removed in the future. "
@@ -43,6 +58,7 @@ class RandomNumberGenerator(np.random.Generator):
return self.standard_normal(size) return self.standard_normal(size)
def randint(self, low, high=None, size=None, dtype=int): def randint(self, low, high=None, size=None, dtype=int):
"""Deprecated random integer function use integers."""
deprecation( deprecation(
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated " "Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
"and will be removed in the future. " "and will be removed in the future. "
@@ -54,6 +70,7 @@ class RandomNumberGenerator(np.random.Generator):
random_integers = randint random_integers = randint
def get_state(self): def get_state(self):
"""Deprecated get rng state use bit_generator.state."""
deprecation( deprecation(
"Function `rng.get_state()` is marked as deprecated " "Function `rng.get_state()` is marked as deprecated "
"and will be removed in the future. " "and will be removed in the future. "
@@ -63,6 +80,7 @@ class RandomNumberGenerator(np.random.Generator):
return self.bit_generator.state return self.bit_generator.state
def set_state(self, state): def set_state(self, state):
"""Deprecated set rng state function use bit_generator.state = state."""
deprecation( deprecation(
"Function `rng.set_state(state)` is marked as deprecated " "Function `rng.set_state(state)` is marked as deprecated "
"and will be removed in the future. " "and will be removed in the future. "
@@ -72,6 +90,7 @@ class RandomNumberGenerator(np.random.Generator):
self.bit_generator.state = state self.bit_generator.state = state
def seed(self, seed=None): def seed(self, seed=None):
"""Deprecated seed function use gym.utils.seeding.np_random(seed)."""
deprecation( deprecation(
"Function `rng.seed(seed)` is marked as deprecated " "Function `rng.seed(seed)` is marked as deprecated "
"and will be removed in the future. " "and will be removed in the future. "
@@ -88,6 +107,7 @@ class RandomNumberGenerator(np.random.Generator):
seed.__doc__ = np.random.seed.__doc__ seed.__doc__ = np.random.seed.__doc__
def __reduce__(self): def __reduce__(self):
"""Reduces the Random Number Generator to a RandomNumberGenerator, init_args and additional args."""
# np.random.Generator defines __reduce__, but it's hard-coded to # np.random.Generator defines __reduce__, but it's hard-coded to
# return a Generator instead of its subclass RandomNumberGenerator. # return a Generator instead of its subclass RandomNumberGenerator.
# We need to override it here, otherwise sampling from a Space will # We need to override it here, otherwise sampling from a Space will
@@ -119,20 +139,21 @@ RNG = RandomNumberGenerator
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int: def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
"""Any given evaluation is likely to have many PRNG's active at """Any given evaluation is likely to have many PRNG's active at once.
once. (Most commonly, because the environment is running in
multiple processes.) There's literature indicating that having (Most commonly, because the environment is running in multiple processes.)
linear correlations between seeds of multiple PRNG's can correlate There's literature indicating that having linear correlations between seeds of multiple PRNG's can correlate the outputs:
the outputs: http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/ http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be http://dl.acm.org/citation.cfm?id=1276928
http://dl.acm.org/citation.cfm?id=1276928 Thus, for sanity we hash the seeds before using them. (This scheme is likely not crypto-strength, but it should be good enough to get rid of simple correlations.)
Thus, for sanity we hash the seeds before using them. (This scheme
is likely not crypto-strength, but it should be good enough to get
rid of simple correlations.)
Args: Args:
seed: None seeds from an operating system specific randomness source. seed: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed. max_bytes: Maximum number of bytes to use in the hashed seed.
Returns:
The hashed seed
""" """
deprecation( deprecation(
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. " "Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
@@ -144,12 +165,16 @@ def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int: def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
"""Create a strong random seed. Otherwise, Python 2 would seed using """Create a strong random seed.
the system time, which might be non-robust especially in the
presence of concurrency. Otherwise, Python 2 would seed using the system time, which might be non-robust especially in the presence of concurrency.
Args: Args:
a: None seeds from an operating system specific randomness source. a: None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed. max_bytes: Maximum number of bytes to use in the seed.
Returns:
A seed
""" """
deprecation( deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. " "Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
@@ -185,7 +210,7 @@ def _bigint_from_bytes(bt: bytes) -> int:
return accum return accum
def _int_list_from_bigint(bigint: int) -> List[int]: def _int_list_from_bigint(bigint: int) -> list[int]:
deprecation( deprecation(
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. " "Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
) )
@@ -195,7 +220,7 @@ def _int_list_from_bigint(bigint: int) -> List[int]:
elif bigint == 0: elif bigint == 0:
return [0] return [0]
ints: List[int] = [] ints: list[int] = []
while bigint > 0: while bigint > 0:
bigint, mod = divmod(bigint, 2**32) bigint, mod = divmod(bigint, 2**32)
ints.append(mod) ints.append(mod)

View File

@@ -1,7 +1,7 @@
try: """Module for vector environments."""
from collections.abc import Iterable from __future__ import annotations
except ImportError:
Iterable = (tuple, list) from typing import Iterable, Optional, Union
from gym.vector.async_vector_env import AsyncVectorEnv from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv from gym.vector.sync_vector_env import SyncVectorEnv
@@ -10,40 +10,34 @@ from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] __all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs): def make(
"""Create a vectorized environment from multiple copies of an environment, id: str,
from its id. num_envs: int = 1,
asynchronous: bool = True,
wrappers: Optional[Union[callable, list[callable]]] = None,
**kwargs,
) -> VectorEnv:
"""Create a vectorized environment from multiple copies of an environment, from its id.
Parameters Example::
----------
id : str
The environment ID. This must be a valid ID from the registry.
num_envs : int >>> import gym
Number of copies of the environment. >>> env = gym.vector.make('CartPole-v1', num_envs=3)
>>> env.reset()
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
dtype=float32)
asynchronous : bool Args:
If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses id: The environment ID. This must be a valid ID from the registry.
`multiprocessing`_ to run the environments in parallel). If ``False``, num_envs: Number of copies of the environment.
wraps the environments in a :class:`SyncVectorEnv`. asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
**kwargs: Keywords arguments applied during gym.make
wrappers : callable, or iterable of callables, optional Returns:
If not ``None``, then apply the wrappers to each internal
environment during creation.
Returns
-------
:class:`gym.vector.VectorEnv`
The vectorized environment. The vectorized environment.
Example
-------
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
>>> env.reset()
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
dtype=float32)
""" """
from gym.envs import make as make_ from gym.envs import make as make_

View File

@@ -1,13 +1,18 @@
"""An async vector environment."""
from __future__ import annotations
import multiprocessing as mp import multiprocessing as mp
import sys import sys
import time import time
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import Optional, Sequence, Union
import numpy as np import numpy as np
import gym
from gym import logger from gym import logger
from gym.core import ObsType
from gym.error import ( from gym.error import (
AlreadyPendingCallError, AlreadyPendingCallError,
ClosedEnvironmentError, ClosedEnvironmentError,
@@ -37,69 +42,13 @@ class AsyncState(Enum):
class AsyncVectorEnv(VectorEnv): class AsyncVectorEnv(VectorEnv):
"""Vectorized environment that runs multiple environments in parallel. It """Vectorized environment that runs multiple environments in parallel.
uses `multiprocessing`_ processes, and pipes for communication.
Parameters It uses ``multiprocessing`` processes, and pipes for communication.
----------
env_fns : iterable of callable
Functions that create the environments.
observation_space : :class:`gym.spaces.Space`, optional Example::
Observation space of a single environment. If ``None``, then the
observation space of the first environment is taken.
action_space : :class:`gym.spaces.Space`, optional
Action space of a single environment. If ``None``, then the action space
of the first environment is taken.
shared_memory : bool
If ``True``, then the observations from the worker processes are
communicated back through shared variables. This can improve the
efficiency if the observations are large (e.g. images).
copy : bool
If ``True``, then the :meth:`~AsyncVectorEnv.reset` and
:meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
context : str, optional
Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon : bool
If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they
will quit if the head process quits. However, ``daemon=True`` prevents
subprocesses to spawn children, so for some environments you may want
to have it set to ``False``.
worker : callable, optional
If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance,
how resets on done are handled.
Warning
-------
:attr:`worker` is an advanced mode option. It provides a high degree of
flexibility and a high chance to shoot yourself in the foot; thus,
if you are writing your own worker, it is recommended to start from the code
for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises
------
RuntimeError
If the observation space of some sub-environment does not match
:obj:`observation_space` (or, by default, the observation space of
the first sub-environment).
ValueError
If :obj:`observation_space` is a custom space (i.e. not a default
space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`,
or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``.
Example
-------
.. code-block::
>>> import gym
>>> env = gym.vector.AsyncVectorEnv([ >>> env = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v0", g=9.81), ... lambda: gym.make("Pendulum-v0", g=9.81),
... lambda: gym.make("Pendulum-v0", g=1.62) ... lambda: gym.make("Pendulum-v0", g=1.62)
@@ -111,15 +60,33 @@ class AsyncVectorEnv(VectorEnv):
def __init__( def __init__(
self, self,
env_fns, env_fns: Sequence[callable],
observation_space=None, observation_space: Optional[gym.Space] = None,
action_space=None, action_space: Optional[gym.Space] = None,
shared_memory=True, shared_memory: bool = True,
copy=True, copy: bool = True,
context=None, context: Optional[str] = None,
daemon=True, daemon: bool = True,
worker=None, worker: Optional[callable] = None,
): ):
"""Vectorized environment that runs multiple environments in parallel.
Args:
env_fns: Functions that create the environments.
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images).
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True.
"""
ctx = mp.get_context(context) ctx = mp.get_context(context)
self.env_fns = env_fns self.env_fns = env_fns
self.shared_memory = shared_memory self.shared_memory = shared_memory
@@ -192,6 +159,11 @@ class AsyncVectorEnv(VectorEnv):
self._check_spaces() self._check_spaces()
def seed(self, seed=None): def seed(self, seed=None):
"""Seeds the vector environments.
Args:
seed: The seeds use with the environments
"""
super().seed(seed=seed) super().seed(seed=seed)
self._assert_is_running() self._assert_is_running()
if seed is None: if seed is None:
@@ -213,22 +185,24 @@ class AsyncVectorEnv(VectorEnv):
def reset_async( def reset_async(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False, return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Send the calls to :obj:`reset` to each sub-environment. """Send calls to the :obj:`reset` methods of the sub-environments.
Raises To get the results of these calls, you may invoke :meth:`reset_wait`.
------
ClosedEnvironmentError
If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError Args:
If the environment is already waiting for a pending call to another seed: List of seeds for each environment
method (e.g. :meth:`step_async`). This can be caused by two consecutive return_info: If to return information
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in options: The reset option
between.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`step_async`). This can be caused by two consecutive
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
""" """
self._assert_is_running() self._assert_is_running()
@@ -258,37 +232,26 @@ class AsyncVectorEnv(VectorEnv):
def reset_wait( def reset_wait(
self, self,
timeout=None, timeout: Optional[Union[int, float]] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False, return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ) -> Union[ObsType, tuple[ObsType, list[dict]]]:
""" """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to `reset_wait` times out. If
`None`, the call to `reset_wait` never times out.
seed: ignored
options: ignored
Returns Args:
------- timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
element of :attr:`~VectorEnv.observation_space` seed: ignored
A batch of observations from the vectorized environment. return_info: If to return information
infos : list of dicts containing metadata options: ignored
Raises Returns:
------ A tuple of batched observations and list of dictionaries
ClosedEnvironmentError
If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError Raises:
If :meth:`reset_wait` was called without any prior call to ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
:meth:`reset_async`. NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
TimeoutError: If :meth:`reset_wait` timed out.
TimeoutError
If :meth:`reset_wait` timed out.
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.WAITING_RESET: if self._state != AsyncState.WAITING_RESET:
@@ -327,24 +290,18 @@ class AsyncVectorEnv(VectorEnv):
return deepcopy(self.observations) if self.copy else self.observations return deepcopy(self.observations) if self.copy else self.observations
def step_async(self, actions): def step_async(self, actions: np.ndarray):
"""Send the calls to :obj:`step` to each sub-environment. """Send the calls to :obj:`step` to each sub-environment.
Parameters Args:
---------- actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
actions : element of :attr:`~VectorEnv.action_space`
Batch of actions.
Raises Raises:
------ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
ClosedEnvironmentError AlreadyPendingCallError: If the environment is already waiting for a pending call to another
If the environment was closed (if :meth:`close` was previously called). method (e.g. :meth:`reset_async`). This can be caused by two consecutive
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
AlreadyPendingCallError between.
If the environment is already waiting for a pending call to another
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
between.
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
@@ -358,40 +315,21 @@ class AsyncVectorEnv(VectorEnv):
pipe.send(("step", action)) pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP self._state = AsyncState.WAITING_STEP
def step_wait(self, timeout=None): def step_wait(
self, timeout: Optional[Union[int, float]] = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[dict]]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish. """Wait for the calls to :obj:`step` in each sub-environment to finish.
Parameters Args:
---------- timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
timeout : int or float, optional
Number of seconds before the call to :meth:`step_wait` times out. If
``None``, the call to :meth:`step_wait` never times out.
Returns Returns:
------- The batched environment step information, obs, reward, done and info
observations : element of :attr:`~VectorEnv.observation_space`
A batch of observations from the vectorized environment.
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_` Raises:
A vector of rewards from the vectorized environment. ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_` TimeoutError: If :meth:`step_wait` timed out.
A vector whose entries indicate whether the episode has ended.
infos : list of dict
A list of auxiliary diagnostic information dicts from sub-environments.
Raises
------
ClosedEnvironmentError
If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError
If :meth:`step_wait` was called without any prior call to
:meth:`step_async`.
TimeoutError
If :meth:`step_wait` timed out.
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.WAITING_STEP: if self._state != AsyncState.WAITING_STEP:
@@ -425,18 +363,13 @@ class AsyncVectorEnv(VectorEnv):
infos, infos,
) )
def call_async(self, name, *args, **kwargs): def call_async(self, name: str, *args, **kwargs):
""" """Calls the method with name asynchronously and apply args and kwargs to the method.
Parameters
----------
name : string
Name of the method or property to call.
*args Args:
Arguments to apply to the method call. name: Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs **kwargs: Keyword arguments to apply to the method call.
Keywoard arguments to apply to the method call.
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
@@ -450,19 +383,14 @@ class AsyncVectorEnv(VectorEnv):
pipe.send(("_call", (name, args, kwargs))) pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout=None): def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list:
""" """Calls all parent pipes and waits for the results.
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to `step_wait` times out. If
`None` (default), the call to `step_wait` never times out.
Returns Args:
------- timeout: Number of seconds before the call to `step_wait` times out. If `None` (default), the call to `step_wait` never times out.
results : list
List of the results of the individual calls to the method or Returns:
property for each environment. List of the results of the individual calls to the method or property for each environment.
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.WAITING_CALL: if self._state != AsyncState.WAITING_CALL:
@@ -483,17 +411,14 @@ class AsyncVectorEnv(VectorEnv):
return results return results
def set_attr(self, name, values): def set_attr(self, name: str, values: Union[list, tuple, object]):
""" """Sets an attribute of the sub-environments.
Parameters
----------
name : string
Name of the property to be set in each individual environment.
values : list, tuple, or object Args:
Values of the property to be set to. If `values` is a list or name: Name of the property to be set in each individual environment.
tuple, then it corresponds to the values for each individual values: Values of the property to be set to. If ``values`` is a list or
environment, otherwise a single value is set for all environments. tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
""" """
self._assert_is_running() self._assert_is_running()
if not isinstance(values, (list, tuple)): if not isinstance(values, (list, tuple)):
@@ -517,25 +442,19 @@ class AsyncVectorEnv(VectorEnv):
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes) self._raise_if_errors(successes)
def close_extras(self, timeout=None, terminate=False): def close_extras(
"""Close the environments & clean up the extra resources self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
(processes and pipes). ):
"""Close the environments & clean up the extra resources (processes and pipes).
Parameters Args:
---------- timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
timeout : int or float, optional the call to :meth:`close` never times out. If the call to :meth:`close`
Number of seconds before the call to :meth:`close` times out. If ``None``, times out, then all processes are terminated.
the call to :meth:`close` never times out. If the call to :meth:`close` terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
times out, then all processes are terminated.
terminate : bool Raises:
If ``True``, then the :meth:`close` operation is forced and all processes TimeoutError: If :meth:`close` timed out.
are terminated.
Raises
------
TimeoutError
If :meth:`close` timed out.
""" """
timeout = 0 if terminate else timeout timeout = 0 if terminate else timeout
try: try:
@@ -626,6 +545,7 @@ class AsyncVectorEnv(VectorEnv):
raise exctype(value) raise exctype(value)
def __del__(self): def __del__(self):
"""On deleting the object, checks that the vector environment is closed."""
if not getattr(self, "closed", True) and hasattr(self, "_state"): if not getattr(self, "closed", True) and hasattr(self, "_state"):
self.close(terminate=True) self.close(terminate=True)

View File

@@ -1,8 +1,12 @@
"""A synchronous vector environment."""
from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import List, Optional, Union from typing import Any, Iterator, Optional, Sequence, Union
import numpy as np import numpy as np
from gym.spaces import Space
from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.utils import concatenate, create_empty_array, iterate
from gym.vector.vector_env import VectorEnv from gym.vector.vector_env import VectorEnv
@@ -12,35 +16,9 @@ __all__ = ["SyncVectorEnv"]
class SyncVectorEnv(VectorEnv): class SyncVectorEnv(VectorEnv):
"""Vectorized environment that serially runs multiple environments. """Vectorized environment that serially runs multiple environments.
Parameters Example::
----------
env_fns : iterable of callable
Functions that create the environments.
observation_space : :class:`gym.spaces.Space`, optional
Observation space of a single environment. If ``None``, then the
observation space of the first environment is taken.
action_space : :class:`gym.spaces.Space`, optional
Action space of a single environment. If ``None``, then the action space
of the first environment is taken.
copy : bool
If ``True``, then the :meth:`reset` and :meth:`step` methods return a
copy of the observations.
Raises
------
RuntimeError
If the observation space of some sub-environment does not match
:obj:`observation_space` (or, by default, the observation space of
the first sub-environment).
Example
-------
.. code-block::
>>> import gym
>>> env = gym.vector.SyncVectorEnv([ >>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v0", g=9.81), ... lambda: gym.make("Pendulum-v0", g=9.81),
... lambda: gym.make("Pendulum-v0", g=1.62) ... lambda: gym.make("Pendulum-v0", g=1.62)
@@ -50,7 +28,24 @@ class SyncVectorEnv(VectorEnv):
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32) [-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
""" """
def __init__(self, env_fns, observation_space=None, action_space=None, copy=True): def __init__(
self,
env_fns: Iterator[callable],
observation_space: Space = None,
action_space: Space = None,
copy: bool = True,
):
"""Vectorized environment that serially runs multiple environments.
Args:
env_fns: iterable of callable functions that create the environments.
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
"""
self.env_fns = env_fns self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns] self.envs = [env_fn() for env_fn in env_fns]
self.copy = copy self.copy = copy
@@ -60,7 +55,7 @@ class SyncVectorEnv(VectorEnv):
observation_space = observation_space or self.envs[0].observation_space observation_space = observation_space or self.envs[0].observation_space
action_space = action_space or self.envs[0].action_space action_space = action_space or self.envs[0].action_space
super().__init__( super().__init__(
num_envs=len(env_fns), num_envs=len(self.envs),
observation_space=observation_space, observation_space=observation_space,
action_space=action_space, action_space=action_space,
) )
@@ -73,7 +68,12 @@ class SyncVectorEnv(VectorEnv):
self._dones = np.zeros((self.num_envs,), dtype=np.bool_) self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None self._actions = None
def seed(self, seed=None): def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
"""Sets the seed in all sub-environments.
Args:
seed: The seed
"""
super().seed(seed=seed) super().seed(seed=seed)
if seed is None: if seed is None:
seed = [None for _ in range(self.num_envs)] seed = [None for _ in range(self.num_envs)]
@@ -86,10 +86,20 @@ class SyncVectorEnv(VectorEnv):
def reset_wait( def reset_wait(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False, return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
seed: The reset environment seed
return_info: If to return information
options: Option information for the environment reset
Returns:
The reset observation of the environment and reset information
"""
if seed is None: if seed is None:
seed = [None for _ in range(self.num_envs)] seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int): if isinstance(seed, int):
@@ -128,9 +138,15 @@ class SyncVectorEnv(VectorEnv):
), data_list ), data_list
def step_async(self, actions): def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
self._actions = iterate(self.action_space, actions) self._actions = iterate(self.action_space, actions)
def step_wait(self): def step_wait(self):
"""Steps through each of the environments returning the batched results.
Returns:
The batched environment step results
"""
observations, infos = [], [] observations, infos = [], []
for i, (env, action) in enumerate(zip(self.envs, self._actions)): for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action) observation, self._rewards[i], self._dones[i], info = env.step(action)
@@ -150,7 +166,17 @@ class SyncVectorEnv(VectorEnv):
infos, infos,
) )
def call(self, name, *args, **kwargs): def call(self, name, *args, **kwargs) -> tuple:
"""Calls the method with name and applies args and kwargs.
Args:
name: The method name
*args: The method args
**kwargs: The method kwargs
Returns:
Tuple of results
"""
results = [] results = []
for env in self.envs: for env in self.envs:
function = getattr(env, name) function = getattr(env, name)
@@ -161,7 +187,15 @@ class SyncVectorEnv(VectorEnv):
return tuple(results) return tuple(results)
def set_attr(self, name, values): def set_attr(self, name: str, values: Union[list, tuple, Any]):
"""Sets an attribute of the sub-environments.
Args:
name: The property name to change
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise, a single value is set for all environments.
"""
if not isinstance(values, (list, tuple)): if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)] values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs: if len(values) != self.num_envs:
@@ -178,7 +212,7 @@ class SyncVectorEnv(VectorEnv):
"""Close the environments.""" """Close the environments."""
[env.close() for env in self.envs] [env.close() for env in self.envs]
def _check_spaces(self): def _check_spaces(self) -> bool:
for env in self.envs: for env in self.envs:
if not (env.observation_space == self.single_observation_space): if not (env.observation_space == self.single_observation_space):
raise RuntimeError( raise RuntimeError(
@@ -194,5 +228,4 @@ class SyncVectorEnv(VectorEnv):
"action spaces from all environments must be equal." "action spaces from all environments must be equal."
) )
else: return True
return True

View File

@@ -1,3 +1,4 @@
"""Module for gym vector utils."""
from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
from gym.vector.utils.numpy_utils import concatenate, create_empty_array from gym.vector.utils.numpy_utils import concatenate, create_empty_array
from gym.vector.utils.shared_memory import ( from gym.vector.utils.shared_memory import (

View File

@@ -1,3 +1,4 @@
"""Miscellaneous utilities."""
import contextlib import contextlib
import os import os
@@ -5,28 +6,35 @@ __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
class CloudpickleWrapper: class CloudpickleWrapper:
def __init__(self, fn): """Wrapper that uses cloudpickle to pickle and unpickle the result."""
def __init__(self, fn: callable):
"""Cloudpickle wrapper for a function."""
self.fn = fn self.fn = fn
def __getstate__(self): def __getstate__(self):
"""Get the state using `cloudpickle.dumps(self.fn)`."""
import cloudpickle import cloudpickle
return cloudpickle.dumps(self.fn) return cloudpickle.dumps(self.fn)
def __setstate__(self, ob): def __setstate__(self, ob):
"""Sets the state with obs."""
import pickle import pickle
self.fn = pickle.loads(ob) self.fn = pickle.loads(ob)
def __call__(self): def __call__(self):
"""Calls the function `self.fn` with no arguments."""
return self.fn() return self.fn()
@contextlib.contextmanager @contextlib.contextmanager
def clear_mpi_env_vars(): def clear_mpi_env_vars():
""" """Clears the MPI of environment variables.
`from mpi4py import MPI` will call `MPI_Init` by default. If the child
process has MPI environment variables, MPI will think that the child process `from mpi4py import MPI` will call `MPI_Init` by default.
If the child process has MPI environment variables, MPI will think that the child process
is an MPI process just like the parent and do bad things such as hang. is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables This context manager is a hacky way to clear those environment variables

View File

@@ -1,5 +1,7 @@
"""Numpy utility functions: concatenate space samples and create empty array."""
from collections import OrderedDict from collections import OrderedDict
from functools import singledispatch from functools import singledispatch
from typing import Iterable, Union
import numpy as np import numpy as np
@@ -9,36 +11,29 @@ __all__ = ["concatenate", "create_empty_array"]
@singledispatch @singledispatch
def concatenate(space, items, out): def concatenate(
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
) -> Union[tuple, dict, np.ndarray]:
"""Concatenate multiple samples from space into a single object. """Concatenate multiple samples from space into a single object.
Parameters Example::
----------
items : iterable of samples of `space`
Samples to be concatenated.
out : tuple, dict, or `np.ndarray` >>> from gym.spaces import Box
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(space, items, out)
array([[0.6348213 , 0.28607962, 0.60760117],
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array. The output object. This object is a (possibly nested) numpy array.
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
Returns
-------
out : tuple, dict, or `np.ndarray`
The output object. This object is a (possibly nested) numpy array.
Example
-------
>>> from gym.spaces import Box
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(items, out, space)
array([[0.6348213 , 0.28607962, 0.60760117],
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
""" """
assert isinstance(items, (list, tuple))
raise ValueError( raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance." f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
) )
@@ -76,38 +71,30 @@ def _concatenate_custom(space, items, out):
@singledispatch @singledispatch
def create_empty_array(space, n=1, fn=np.zeros): def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
) -> Union[tuple, dict, np.ndarray]:
"""Create an empty (possibly nested) numpy array. """Create an empty (possibly nested) numpy array.
Parameters Example::
----------
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
n : int >>> from gym.spaces import Box, Dict
Number of environments in the vectorized environment. If `None`, creates >>> space = Dict({
an empty sample from `space`. ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)),
('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
fn : callable Args:
Function to apply when creating the empty numpy array. Examples of such space: Observation space of a single environment in the vectorized environment.
functions are `np.empty` or `np.zeros`. n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
Returns Returns:
-------
out : tuple, dict, or `np.ndarray`
The output object. This object is a (possibly nested) numpy array. The output object. This object is a (possibly nested) numpy array.
Example
-------
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)),
('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
""" """
raise ValueError( raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance." f"Space of type `{type(space)}` is not a valid `gym.Space` instance."

View File

@@ -1,44 +1,40 @@
"""Utility functions for vector environments to share memory between processes."""
import multiprocessing as mp import multiprocessing as mp
from collections import OrderedDict from collections import OrderedDict
from ctypes import c_bool from ctypes import c_bool
from functools import singledispatch from functools import singledispatch
from typing import Union
import numpy as np import numpy as np
from gym.error import CustomSpaceError from gym.error import CustomSpaceError
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"] __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@singledispatch @singledispatch
def create_shared_memory(space, n=1, ctx=mp): def create_shared_memory(
"""Create a shared memory object, to be shared across processes. This space: Space, n: int = 1, ctx=mp
eventually contains the observations from the vectorized environment. ) -> Union[dict, tuple, mp.Array]:
"""Create a shared memory object, to be shared across processes.
Parameters This eventually contains the observations from the vectorized environment.
----------
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
n : int Args:
Number of environments in the vectorized environment (i.e. the number space: Observation space of a single environment in the vectorized environment.
of processes). n: Number of environments in the vectorized environment (i.e. the number of processes).
ctx: The multiprocess module
ctx : `multiprocessing` context Returns:
Context for multiprocessing. shared_memory for the shared object across processes.
Returns
-------
shared_memory : dict, tuple, or `multiprocessing.Array` instance
Shared object across processes.
""" """
raise CustomSpaceError( raise CustomSpaceError(
"Cannot create a shared memory for space with " "Cannot create a shared memory for space with "
"type `{}`. Shared memory only supports " f"type `{type(space)}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, " "default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom " "`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space)) "Gym spaces."
) )
@@ -46,7 +42,7 @@ def create_shared_memory(space, n=1, ctx=mp):
@create_shared_memory.register(Discrete) @create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete) @create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary) @create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n=1, ctx=mp): def _create_base_shared_memory(space, n: int = 1, ctx=mp):
dtype = space.dtype.char dtype = space.dtype.char
if dtype in "?": if dtype in "?":
dtype = c_bool dtype = c_bool
@@ -54,7 +50,7 @@ def _create_base_shared_memory(space, n=1, ctx=mp):
@create_shared_memory.register(Tuple) @create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n=1, ctx=mp): def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
return tuple( return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
) )
@@ -71,39 +67,32 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
@singledispatch @singledispatch
def read_from_shared_memory(space, shared_memory, n=1): def read_from_shared_memory(
space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1
) -> Union[dict, tuple, np.ndarray]:
"""Read the batch of observations from shared memory as a numpy array. """Read the batch of observations from shared memory as a numpy array.
Parameters ..notes::
---------- The numpy array objects returned by `read_from_shared_memory` shares the
shared_memory : dict, tuple, or `multiprocessing.Array` instance memory of `shared_memory`. Any changes to `shared_memory` are forwarded
Shared object across processes. This contains the observations from the to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
vectorized environment. This object is created with `create_shared_memory`.
space : `gym.spaces.Space` instance Args:
Observation space of a single environment in the vectorized environment. space: Observation space of a single environment in the vectorized environment.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
n: Number of environments in the vectorized environment (i.e. the number of processes).
n : int Returns:
Number of environments in the vectorized environment (i.e. the number
of processes).
Returns
-------
observations : dict, tuple or `np.ndarray` instance
Batch of observations as a (possibly nested) numpy array. Batch of observations as a (possibly nested) numpy array.
Notes
-----
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
""" """
raise CustomSpaceError( raise CustomSpaceError(
"Cannot read from a shared memory for space with " "Cannot read from a shared memory for space with "
"type `{}`. Shared memory only supports " f"type `{type(space)}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, " "default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom " "`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space)) "Gym spaces."
) )
@@ -111,14 +100,14 @@ def read_from_shared_memory(space, shared_memory, n=1):
@read_from_shared_memory.register(Discrete) @read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete) @read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary) @read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n=1): def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape( return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape (n,) + space.shape
) )
@read_from_shared_memory.register(Tuple) @read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n=1): def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
return tuple( return tuple(
read_from_shared_memory(subspace, memory, n=n) read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces) for (memory, subspace) in zip(shared_memory, space.spaces)
@@ -126,7 +115,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n=1):
@read_from_shared_memory.register(Dict) @read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n=1): def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
return OrderedDict( return OrderedDict(
[ [
(key, read_from_shared_memory(subspace, shared_memory[key], n=n)) (key, read_from_shared_memory(subspace, shared_memory[key], n=n))
@@ -136,34 +125,26 @@ def _read_dict_from_shared_memory(space, shared_memory, n=1):
@singledispatch @singledispatch
def write_to_shared_memory(space, index, value, shared_memory): def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: Union[dict, tuple, mp.Array],
):
"""Write the observation of a single environment into shared memory. """Write the observation of a single environment into shared memory.
Parameters Args:
---------- space: Observation space of a single environment in the vectorized environment.
index : int index: Index of the environment (must be in `[0, num_envs)`).
Index of the environment (must be in `[0, num_envs)`). value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment. This object is created with `create_shared_memory`.
value : sample from `space`
Observation of the single environment to write to shared memory.
shared_memory : dict, tuple, or `multiprocessing.Array` instance
Shared object across processes. This contains the observations from the
vectorized environment. This object is created with `create_shared_memory`.
space : `gym.spaces.Space` instance
Observation space of a single environment in the vectorized environment.
Returns
-------
`None`
""" """
raise CustomSpaceError( raise CustomSpaceError(
"Cannot write to a shared memory for space with " "Cannot write to a shared memory for space with "
"type `{}`. Shared memory only supports " f"type `{type(space)}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, " "default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom " "`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space)) "Gym spaces."
) )

View File

@@ -1,6 +1,8 @@
"""Utility functions for gym spaces: batch space and iterator."""
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from functools import singledispatch from functools import singledispatch
from typing import Iterator
import numpy as np import numpy as np
@@ -12,32 +14,25 @@ __all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
@singledispatch @singledispatch
def batch_space(space, n=1): def batch_space(space: Space, n: int = 1) -> Space:
"""Create a (batched) space, containing multiple copies of a single space. """Create a (batched) space, containing multiple copies of a single space.
Parameters Example::
----------
space : `gym.spaces.Space` instance
Space (e.g. the observation space) for a single environment in the
vectorized environment.
n : int >>> from gym.spaces import Box, Dict
Number of environments in the vectorized environment. >>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
... })
>>> batch_space(space, n=5)
Dict(position:Box(5, 3), velocity:Box(5, 2))
Returns Args:
------- space: Space (e.g. the observation space) for a single environment in the vectorized environment.
batched_space : `gym.spaces.Space` instance n: Number of environments in the vectorized environment.
Space (e.g. the observation space) for a batch of environments in the
vectorized environment.
Example Returns:
------- Space (e.g. the observation space) for a batch of environments in the vectorized environment.
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> batch_space(space, n=5)
Dict(position:Box(5, 3), velocity:Box(5, 2))
""" """
raise ValueError( raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance." f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
@@ -126,41 +121,35 @@ def _batch_space_custom(space, n=1):
@singledispatch @singledispatch
def iterate(space, items): def iterate(space: Space, items) -> Iterator:
"""Iterate over the elements of a (batched) space. """Iterate over the elements of a (batched) space.
Parameters Example::
----------
space : `gym.spaces.Space` instance
Space to which `items` belong to.
items : samples of `space` >>> from gym.spaces import Box, Dict
Items to be iterated over. >>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
>>> next(it)
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
>>> next(it)
StopIteration
Returns Args:
------- space: Space to which `items` belong to.
iterator : `Iterable` instance items: Items to be iterated over.
Returns:
Iterator over the elements in `items`. Iterator over the elements in `items`.
Example
-------
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
>>> next(it)
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
>>> next(it)
StopIteration
""" """
raise ValueError( raise ValueError(
"Space of type `{}` is not a valid `gym.Space` " "instance.".format(type(space)) f"Space of type `{type(space)}` is not a valid `gym.Space` " "instance."
) )

View File

@@ -1,4 +1,7 @@
from typing import List, Optional, Union """Base class for vectorized environments."""
from __future__ import annotations
from typing import Any, Optional, Union
import gym import gym
from gym.logger import deprecation from gym.logger import deprecation
@@ -8,32 +11,28 @@ __all__ = ["VectorEnv"]
class VectorEnv(gym.Env): class VectorEnv(gym.Env):
r"""Base class for vectorized environments. Runs multiple independent copies of the """Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel.
same environment in parallel. This is not the same as 1 environment that has multiple
sub components, but it is many copies of the same base env.
Each observation returned from vectorized environment is a batch of observations This is not the same as 1 environment that has multiple subcomponents, but it is many copies of the same base env.
for each parallel environment. And :meth:`step` is also expected to receive a batch of
actions for each parallel environment.
.. note:: Each observation returned from vectorized environment is a batch of observations for each parallel environment.
And :meth:`step` is also expected to receive a batch of actions for each parallel environment.
Notes:
All parallel environments should share the identical observation and action spaces. All parallel environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported. In other words, a vector of multiple different environments is not supported.
Parameters
----------
num_envs : int
Number of environments in the vectorized environment.
observation_space : :class:`gym.spaces.Space`
Observation space of a single environment.
action_space : :class:`gym.spaces.Space`
Action space of a single environment.
""" """
def __init__(self, num_envs, observation_space, action_space): def __init__(
self, num_envs: int, observation_space: gym.Space, action_space: gym.Space
):
"""Base class for vectorized environments.
Args:
num_envs: Number of environments in the vectorized environment.
observation_space: Observation space of a single environment.
action_space: Action space of a single environment.
"""
self.num_envs = num_envs self.num_envs = num_envs
self.is_vector_env = True self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs) self.observation_space = batch_space(observation_space, n=num_envs)
@@ -49,141 +48,134 @@ class VectorEnv(gym.Env):
def reset_async( def reset_async(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False, return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Reset the sub-environments asynchronously.
This method will return ``None``. A call to :meth:`reset_async` should be followed by a call to :meth:`reset_wait` to retrieve the results.
"""
pass pass
def reset_wait( def reset_wait(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False, return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Retrieves the results of a :meth:`reset_async` call.
A call to this method must always be preceded by a call to :meth:`reset_async`.
"""
raise NotImplementedError() raise NotImplementedError()
def reset( def reset(
self, self,
*, *,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, list[int]]] = None,
return_info: bool = False, return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
r"""Reset all parallel environments and return a batch of initial observations. """Reset all parallel environments and return a batch of initial observations.
Returns Args:
------- seed: The environment reset seeds
observations : element of :attr:`observation_space` return_info: If to return the info
options: If to return the options
Returns:
A batch of observations from the vectorized environment. A batch of observations from the vectorized environment.
""" """
self.reset_async(seed=seed, return_info=return_info, options=options) self.reset_async(seed=seed, return_info=return_info, options=options)
return self.reset_wait(seed=seed, return_info=return_info, options=options) return self.reset_wait(seed=seed, return_info=return_info, options=options)
def step_async(self, actions): def step_async(self, actions):
"""Asynchronously performs steps in the sub-environments.
The results can be retrieved via a call to :meth:`step_wait`.
"""
pass pass
def step_wait(self, **kwargs): def step_wait(self, **kwargs):
"""Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`.
"""
raise NotImplementedError() raise NotImplementedError()
def step(self, actions): def step(self, actions):
r"""Take an action for each parallel environment. """Take an action for each parallel environment.
Parameters Args:
---------- actions: element of :attr:`action_space` Batch of actions.
actions : element of :attr:`action_space`
Batch of actions.
Returns Returns:
------- Batch of observations, rewards, done and infos
observations : element of :attr:`observation_space`
A batch of observations from the vectorized environment.
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
A vector of rewards from the vectorized environment.
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
A vector whose entries indicate whether the episode has ended.
infos : list of dict
A list of auxiliary diagnostic information dicts from each parallel environment.
""" """
self.step_async(actions) self.step_async(actions)
return self.step_wait() return self.step_wait()
def call_async(self, name, *args, **kwargs): def call_async(self, name, *args, **kwargs):
"""Calls a method name for each parallel environment asynchronously."""
pass pass
def call_wait(self, **kwargs): def call_wait(self, **kwargs):
"""After calling a method in :meth:`call_async`, this function collects the results."""
raise NotImplementedError() raise NotImplementedError()
def call(self, name, *args, **kwargs): def call(self, name: str, *args, **kwargs) -> list[Any]:
"""Call a method, or get a property, from each parallel environment. """Call a method, or get a property, from each parallel environment.
Parameters Args:
---------- name (str): Name of the method or property to call.
name : string *args: Arguments to apply to the method call.
Name of the method or property to call. **kwargs: Keyword arguments to apply to the method call.
*args Returns:
Arguments to apply to the method call. List of the results of the individual calls to the method or property for each environment.
**kwargs
Keywoard arguments to apply to the method call.
Returns
-------
results : list
List of the results of the individual calls to the method or
property for each environment.
""" """
self.call_async(name, *args, **kwargs) self.call_async(name, *args, **kwargs)
return self.call_wait() return self.call_wait()
def get_attr(self, name): def get_attr(self, name: str):
"""Get a property from each parallel environment. """Get a property from each parallel environment.
Parameters Args:
---------- name (str): Name of the property to be get from each individual environment.
name : string
Name of the property to be get from each individual environment. Returns:
The property with name
""" """
return self.call(name) return self.call(name)
def set_attr(self, name, values): def set_attr(self, name: str, values: Union[list, tuple, object]):
"""Set a property in each parallel environment. """Set a property in each sub-environment.
Parameters Args:
---------- name (str): Name of the property to be set in each individual environment.
name : string values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
Name of the property to be set in each individual environment. tuple, then it corresponds to the values for each individual environment, otherwise a single value
is set for all environments.
values : list, tuple, or object
Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
""" """
raise NotImplementedError() raise NotImplementedError()
def close_extras(self, **kwargs): def close_extras(self, **kwargs):
r"""Clean up the extra resources e.g. beyond what's in this base class.""" """Clean up the extra resources e.g. beyond what's in this base class."""
pass pass
def close(self, **kwargs): def close(self, **kwargs):
r"""Close all parallel environments and release resources. """Close all parallel environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set It also closes all the existing image viewers, then calls :meth:`close_extras` and set
:attr:`closed` as ``True``. :attr:`closed` as ``True``.
.. warning:: Warnings:
This function itself does not close the environments, it should be handled This function itself does not close the environments, it should be handled
in :meth:`close_extras`. This is generic for both synchronous and asynchronous in :meth:`close_extras`. This is generic for both synchronous and asynchronous
vectorized environments. vectorized environments.
.. note:: Notes:
This will be automatically called when garbage collected or program exited. This will be automatically called when garbage collected or program exited.
""" """
@@ -197,14 +189,12 @@ class VectorEnv(gym.Env):
def seed(self, seed=None): def seed(self, seed=None):
"""Set the random seed in all parallel environments. """Set the random seed in all parallel environments.
Parameters Args:
---------- seed: Random seed for each parallel environment. If ``seed`` is a list of
seed : list of int, or int, optional length ``num_envs``, then the items of the list are chosen as random
Random seed for each parallel environment. If ``seed`` is a list of seeds. If ``seed`` is an int, then each parallel environment uses the random
length ``num_envs``, then the items of the list are chosen as random seed ``seed + n``, where ``n`` is the index of the parallel environment
seeds. If ``seed`` is an int, then each parallel environment uses the random (between ``0`` and ``num_envs - 1``).
seed ``seed + n``, where ``n`` is the index of the parallel environment
(between ``0`` and ``num_envs - 1``).
""" """
deprecation( deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. " "Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
@@ -212,10 +202,12 @@ class VectorEnv(gym.Env):
) )
def __del__(self): def __del__(self):
"""Closes the vector environment."""
if not getattr(self, "closed", True): if not getattr(self, "closed", True):
self.close() self.close()
def __repr__(self): def __repr__(self):
"""Returns a string representation of the vector environment using the class name, number of environments and environment spec id."""
if self.spec is None: if self.spec is None:
return f"{self.__class__.__name__}({self.num_envs})" return f"{self.__class__.__name__}({self.num_envs})"
else: else:
@@ -223,19 +215,17 @@ class VectorEnv(gym.Env):
class VectorEnvWrapper(VectorEnv): class VectorEnvWrapper(VectorEnv):
r"""Wraps the vectorized environment to allow a modular transformation. """Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass This class is the base class for all wrappers for vectorized environments. The subclass
could override some methods to change the behavior of the original vectorized environment could override some methods to change the behavior of the original vectorized environment
without touching the original code. without touching the original code.
.. note:: Notes:
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
""" """
def __init__(self, env): def __init__(self, env: VectorEnv):
assert isinstance(env, VectorEnv) assert isinstance(env, VectorEnv)
self.env = env self.env = env

View File

@@ -17,7 +17,7 @@ extras = {
"classic_control": ["pygame==2.1.0"], "classic_control": ["pygame==2.1.0"],
"mujoco": ["mujoco_py>=1.50, <2.0"], "mujoco": ["mujoco_py>=1.50, <2.0"],
"toy_text": ["pygame==2.1.0", "scipy>=1.4.1"], "toy_text": ["pygame==2.1.0", "scipy>=1.4.1"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0"], "other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0"],
} }
# Meta dependency groups. # Meta dependency groups.