mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Pydocstyle utils vector docstring (#2788)
* Added pydocstyle to pre-commit * Added docstrings for tests and updated the tests for autoreset * Add pydocstyle exclude folder to allow slowly adding new docstrings * Add docstrings for setup.py and gym/__init__.py, core.py, error.py and logger.py * Check that all unwrapped environment are of a particular wrapper type * Reverted back to import gym.spaces.Space to gym.spaces * Fixed the __init__.py docstring * Fixed autoreset autoreset test * Updated gym __init__.py top docstring * Fix examples in docstrings * Add docstrings and type hints where known to all functions and classes in gym/utils and gym/vector * Remove unnecessary import * Removed "unused error" and make APIerror deprecated at gym 1.0 * Add pydocstyle description to CONTRIBUTING.md * Added docstrings section to CONTRIBUTING.md * Added :meth: and :attr: keywords to docstrings * Added :meth: and :attr: keywords to docstrings * Imported annotations from __future__ to fix python 3.7 * Add __future__ import annotations for python 3.7 * isort * Remove utils and vectors for this PR and spaces for previous PR * Update gym/envs/classic_control/acrobot.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/envs/classic_control/acrobot.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/envs/classic_control/acrobot.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/spaces/dict.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/ezpickle.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/ezpickle.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/play.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Pre-commit * Updated docstrings with :meth: * Updated docstrings with :meth: * Update gym/utils/play.py * Update gym/utils/play.py * Update gym/utils/play.py * Apply suggestions from code review Co-authored-by: Markus Krimmel <montcyril@gmail.com> * pre-commit * Update gym/utils/play.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Updated fps and zoom parameter docstring * Update play docstring * Apply suggestions from code review Added suggested corrections from @markus28 Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Pre-commit magic * Update the `gym.make` docstring with a warning for `env_checker` * Updated and fixed vector docstrings * Update test names for reflect the project filename style Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
@@ -30,7 +30,7 @@ repos:
|
||||
rev: 6.1.1 # pick a git hash / tag to point to
|
||||
hooks:
|
||||
- id: pydocstyle
|
||||
exclude: ^(gym/version.py)|(gym/(envs|utils|vector)/)|(tests/)
|
||||
exclude: ^(gym/version.py)|(gym/envs/)|(tests/)
|
||||
args:
|
||||
- --source
|
||||
- --explain
|
||||
|
@@ -398,27 +398,26 @@ def rk4(derivs, y0, t):
|
||||
yourself stranded on a system w/o scipy. Otherwise use
|
||||
:func:`scipy.integrate`.
|
||||
|
||||
Example:
|
||||
|
||||
>>> ### 2D system
|
||||
>>> def derivs(x):
|
||||
... d1 = x[0] + 2*x[1]
|
||||
... d2 = -3*x[0] + 4*x[1]
|
||||
... return (d1, d2)
|
||||
>>> dt = 0.0005
|
||||
>>> t = arange(0.0, 2.0, dt)
|
||||
>>> y0 = (1,2)
|
||||
>>> yout = rk4(derivs, y0, t)
|
||||
|
||||
If you have access to scipy, you should probably be using the
|
||||
:func:`scipy.integrate` tools rather than this function.
|
||||
This would then require re-adding the time variable to the signature of derivs.
|
||||
|
||||
Args:
|
||||
derivs: the derivative of the system and has the signature ``dy = derivs(yi)``
|
||||
y0: initial state vector
|
||||
t: sample times
|
||||
args: additional arguments passed to the derivative function
|
||||
kwargs: additional keyword arguments passed to the derivative function
|
||||
|
||||
Example 1 ::
|
||||
### 2D system
|
||||
def derivs(x):
|
||||
d1 = x[0] + 2*x[1]
|
||||
d2 = -3*x[0] + 4*x[1]
|
||||
return (d1, d2)
|
||||
dt = 0.0005
|
||||
t = arange(0.0, 2.0, dt)
|
||||
y0 = (1,2)
|
||||
yout = rk4(derivs6, y0, t)
|
||||
|
||||
If you have access to scipy, you should probably be using the
|
||||
scipy.integrate tools rather than this function.
|
||||
This would then require re-adding the time variable to the signature of derivs.
|
||||
|
||||
Returns:
|
||||
yout: Runge-Kutta approximation of the ODE
|
||||
|
@@ -499,10 +499,16 @@ def make(
|
||||
"""
|
||||
Create an environment according to the given ID.
|
||||
|
||||
Warnings:
|
||||
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
|
||||
This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid
|
||||
if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`.
|
||||
|
||||
Args:
|
||||
id: Name of the environment.
|
||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||
disable_env_checker: If to disable the environment checker
|
||||
kwargs: Additional arguments to pass to the environment constructor.
|
||||
Returns:
|
||||
An instance of the environment.
|
||||
|
@@ -17,25 +17,27 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
|
||||
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
||||
|
||||
Example usage::
|
||||
Example usage:
|
||||
|
||||
>>> observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
|
||||
>>> from gym.spaces import Dict, Discrete
|
||||
>>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||
>>> observation_space.sample()
|
||||
OrderedDict([('position', 1), ('velocity', 2)])
|
||||
|
||||
Example usage [nested]::
|
||||
|
||||
>>> spaces.Dict(
|
||||
>>> from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete
|
||||
>>> Dict(
|
||||
... {
|
||||
... "ext_controller": spaces.MultiDiscrete((5, 2, 2)),
|
||||
... "inner_state": spaces.Dict(
|
||||
... "ext_controller": MultiDiscrete([5, 2, 2]),
|
||||
... "inner_state": Dict(
|
||||
... {
|
||||
... "charge": spaces.Discrete(100),
|
||||
... "system_checks": spaces.MultiBinary(10),
|
||||
... "job_status": spaces.Dict(
|
||||
... "charge": Discrete(100),
|
||||
... "system_checks": MultiBinary(10),
|
||||
... "job_status": Dict(
|
||||
... {
|
||||
... "task": spaces.Discrete(5),
|
||||
... "progress": spaces.Box(low=0, high=100, shape=()),
|
||||
... "task": Discrete(5),
|
||||
... "progress": Box(low=0, high=100, shape=()),
|
||||
... }
|
||||
... ),
|
||||
... }
|
||||
@@ -63,9 +65,10 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
|
||||
Example::
|
||||
|
||||
>>> spaces.Dict({"position": spaces.Box(-1, 1, shape=(2,)), "color": spaces.Discrete(3)})
|
||||
>>> from gym.spaces import Box, Discrete
|
||||
>>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)})
|
||||
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
||||
>>> spaces.Dict(position=spaces.Box(-1, 1, shape=(2,)), color=spaces.Discrete(3))
|
||||
>>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3))
|
||||
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
||||
|
||||
Args:
|
||||
|
@@ -16,11 +16,11 @@ class MultiBinary(Space[np.ndarray]):
|
||||
|
||||
Example Usage::
|
||||
|
||||
>>> self.observation_space = spaces.MultiBinary(5)
|
||||
>>> self.observation_space.sample()
|
||||
>>> observation_space = MultiBinary(5)
|
||||
>>> observation_space.sample()
|
||||
array([0, 1, 0, 1, 0], dtype=int8)
|
||||
>>> self.observation_space = spaces.MultiBinary([3, 2])
|
||||
>>> self.observation_space.sample()
|
||||
>>> observation_space = MultiBinary([3, 2])
|
||||
>>> observation_space.sample()
|
||||
array([[0, 0],
|
||||
[0, 1],
|
||||
[1, 1]], dtype=int8)
|
||||
|
@@ -16,8 +16,9 @@ class Tuple(Space[tuple], Sequence):
|
||||
|
||||
Example usage::
|
||||
|
||||
>> observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Box(-1, 1, shape=(2,))))
|
||||
>> observation_space.sample()
|
||||
>>> from gym.spaces import Box, Discrete
|
||||
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))))
|
||||
>>> observation_space.sample()
|
||||
(0, array([0.03633198, 0.42370757], dtype=float32))
|
||||
"""
|
||||
|
||||
|
@@ -25,8 +25,9 @@ def flatdim(space: Space) -> int:
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> s = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
|
||||
>>> spaces.flatdim(s)
|
||||
>>> from gym.spaces import Discrete
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||
>>> flatdim(space)
|
||||
5
|
||||
"""
|
||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||
@@ -195,8 +196,7 @@ def flatten_space(space: Space) -> Box:
|
||||
|
||||
Example that recursively flattens a dict::
|
||||
|
||||
>>> space = Dict({"position": Discrete(2),
|
||||
... "velocity": Box(0, 1, shape=(2, 2))})
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
|
||||
>>> flatten_space(space)
|
||||
Box(6,)
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""A set of common utilities used within the environments. These are
|
||||
not intended as API functions, and will not remain stable over time.
|
||||
"""A set of common utilities used within the environments.
|
||||
|
||||
These are not intended as API functions, and will not remain stable over time.
|
||||
"""
|
||||
|
||||
color2num = dict(
|
||||
@@ -15,12 +16,20 @@ color2num = dict(
|
||||
)
|
||||
|
||||
|
||||
def colorize(string, color, bold=False, highlight=False):
|
||||
"""Return string surrounded by appropriate terminal color codes to
|
||||
print colorized text. Valid colors: gray, red, green, yellow,
|
||||
blue, magenta, cyan, white, crimson
|
||||
"""
|
||||
def colorize(
|
||||
string: str, color: str, bold: bool = False, highlight: bool = False
|
||||
) -> str:
|
||||
"""Returns string surrounded by appropriate terminal colour codes to print colourised text.
|
||||
|
||||
Args:
|
||||
string: The message to colourise
|
||||
color: Literal values are gray, red, green, yellow, blue, magenta, cyan, white, crimson
|
||||
bold: If to bold the string
|
||||
highlight: If to highlight the string
|
||||
|
||||
Returns:
|
||||
Colourised string
|
||||
"""
|
||||
attr = []
|
||||
num = color2num[color]
|
||||
if highlight:
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""
|
||||
"""A set of functions for checking an environment details.
|
||||
|
||||
This file is originally from the Stable Baselines3 repository hosted on GitHub
|
||||
(https://github.com/DLR-RM/stable-baselines3/)
|
||||
Original Author: Antonin Raffin
|
||||
@@ -16,21 +17,33 @@ from typing import Optional, Union
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym import logger, spaces
|
||||
from gym import logger
|
||||
from gym.spaces import Box, Dict, Discrete, Space, Tuple
|
||||
|
||||
|
||||
def _is_numpy_array_space(space: spaces.Space) -> bool:
|
||||
def _is_numpy_array_space(space: Space) -> bool:
|
||||
"""Checks if a space can be represented as a single numpy array (e.g. Dict and Tuple spaces return False).
|
||||
|
||||
Args:
|
||||
space: The space to check
|
||||
|
||||
Returns:
|
||||
Returns False if the provided space is not representable as a single numpy array
|
||||
"""
|
||||
Returns False if provided space is not representable as a single numpy array
|
||||
(e.g. Dict and Tuple spaces return False)
|
||||
"""
|
||||
return not isinstance(space, (spaces.Dict, spaces.Tuple))
|
||||
return not isinstance(space, (Dict, Tuple))
|
||||
|
||||
|
||||
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
||||
"""
|
||||
Check that the input adheres to general standards
|
||||
when the observation is apparently an image.
|
||||
def _check_image_input(observation_space: Box, key: str = ""):
|
||||
"""Check whether an observation space of type :class:`Box` adheres to general standards for spaces that represent images.
|
||||
|
||||
It will check that:
|
||||
- The datatype is ``np.uint8``
|
||||
- The lower bound is 0 across all dimensions
|
||||
- The upper bound is 255 across all dimensions
|
||||
|
||||
Args:
|
||||
observation_space: The observation space to check
|
||||
key: The observation shape key for warning
|
||||
"""
|
||||
if observation_space.dtype != np.uint8:
|
||||
logger.warn(
|
||||
@@ -49,8 +62,13 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_nan(env: gym.Env, check_inf: bool = True) -> None:
|
||||
"""Check for NaN and Inf."""
|
||||
def _check_nan(env: gym.Env, check_inf: bool = True):
|
||||
"""Check if the environment observation, reward are NaN and Inf.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
check_inf: Checks if the observation is infinity
|
||||
"""
|
||||
for _ in range(10):
|
||||
action = env.action_space.sample()
|
||||
observation, reward, done, _ = env.step(action)
|
||||
@@ -70,19 +88,22 @@ def _check_nan(env: gym.Env, check_inf: bool = True) -> None:
|
||||
|
||||
def _check_obs(
|
||||
obs: Union[tuple, dict, np.ndarray, int],
|
||||
observation_space: spaces.Space,
|
||||
observation_space: Space,
|
||||
method_name: str,
|
||||
) -> None:
|
||||
):
|
||||
"""Check that the observation returned by the environment correspond to the declared one.
|
||||
|
||||
Args:
|
||||
obs: The observation to check
|
||||
observation_space: The observation space of the observation
|
||||
method_name: The method name that generated the observation
|
||||
"""
|
||||
Check that the observation returned by the environment
|
||||
correspond to the declared one.
|
||||
"""
|
||||
if not isinstance(observation_space, spaces.Tuple):
|
||||
if not isinstance(observation_space, Tuple):
|
||||
assert not isinstance(
|
||||
obs, tuple
|
||||
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
|
||||
|
||||
if isinstance(observation_space, spaces.Discrete):
|
||||
if isinstance(observation_space, Discrete):
|
||||
assert isinstance(
|
||||
obs, int
|
||||
), f"The observation returned by `{method_name}()` method must be an int"
|
||||
@@ -96,12 +117,16 @@ def _check_obs(
|
||||
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
|
||||
|
||||
|
||||
def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
|
||||
"""
|
||||
Check that the observation space is correctly formatted
|
||||
when dealing with a ``Box()`` space. In particular, it checks:
|
||||
def _check_box_obs(observation_space: Box, key: str = ""):
|
||||
"""Check that the observation space is correctly formatted when dealing with a :class:`Box` space.
|
||||
|
||||
In particular, it checks:
|
||||
- that the dimensions are big enough when it is an image, and that the type matches
|
||||
- that the observation has an expected shape (warn the user if not)
|
||||
|
||||
Args:
|
||||
observation_space: Checks if the Box observation space
|
||||
key: The observation key
|
||||
"""
|
||||
# If image, check the low and high values, the type and the number of channels
|
||||
# and the shape (minimal value)
|
||||
@@ -137,14 +162,19 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
|
||||
), "Agent's observation_space.high and observation_space have different shapes"
|
||||
|
||||
|
||||
def _check_box_action(action_space: spaces.Box):
|
||||
def _check_box_action(action_space: Box):
|
||||
"""Checks that a :class:`Box` action space is defined in a sensible way.
|
||||
|
||||
Args:
|
||||
action_space: A box action space
|
||||
"""
|
||||
if np.any(np.equal(action_space.low, -np.inf)):
|
||||
logger.warn(
|
||||
"Agent's minimum action space value is -infinity. This is probably too low."
|
||||
)
|
||||
if np.any(np.equal(action_space.high, np.inf)):
|
||||
logger.warn(
|
||||
"Agent's maxmimum action space value is infinity. This is probably too high"
|
||||
"Agent's maximum action space value is infinity. This is probably too high"
|
||||
)
|
||||
if np.any(np.equal(action_space.low, action_space.high)):
|
||||
logger.warn("Agent's maximum and minimum action space values are equal")
|
||||
@@ -156,7 +186,12 @@ def _check_box_action(action_space: spaces.Box):
|
||||
assert False, "Agent's action_space.high and action_space have different shapes"
|
||||
|
||||
|
||||
def _check_normalized_action(action_space: spaces.Box):
|
||||
def _check_normalized_action(action_space: Box):
|
||||
"""Checks that a box action space is normalized.
|
||||
|
||||
Args:
|
||||
action_space: A box action space
|
||||
"""
|
||||
if (
|
||||
np.any(np.abs(action_space.low) != np.abs(action_space.high))
|
||||
or np.any(np.abs(action_space.low) > 1)
|
||||
@@ -168,16 +203,18 @@ def _check_normalized_action(action_space: spaces.Box):
|
||||
)
|
||||
|
||||
|
||||
def _check_returned_values(
|
||||
env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space
|
||||
) -> None:
|
||||
"""
|
||||
Check the returned values by the env when calling `.reset()` or `.step()` methods.
|
||||
def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space):
|
||||
"""Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods.
|
||||
|
||||
Args:
|
||||
env: The environment
|
||||
observation_space: The environment's observation space
|
||||
action_space: The environment's action space
|
||||
"""
|
||||
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
|
||||
obs = env.reset()
|
||||
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
if isinstance(observation_space, Dict):
|
||||
assert isinstance(
|
||||
obs, dict
|
||||
), "The observation returned by `reset()` must be a dictionary"
|
||||
@@ -200,7 +237,7 @@ def _check_returned_values(
|
||||
# Unpack
|
||||
obs, reward, done, info = data
|
||||
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
if isinstance(observation_space, Dict):
|
||||
assert isinstance(
|
||||
obs, dict
|
||||
), "The observation returned by `step()` must be a dictionary"
|
||||
@@ -223,10 +260,11 @@ def _check_returned_values(
|
||||
), "The `info` returned by `step()` must be a python dictionary"
|
||||
|
||||
|
||||
def _check_spaces(env: gym.Env) -> None:
|
||||
"""
|
||||
Check that the observation and action spaces are defined
|
||||
and inherit from gym.spaces.Space.
|
||||
def _check_spaces(env: gym.Env):
|
||||
"""Check that the observation and action spaces are defined and inherit from :class:`gym.spaces.Space`.
|
||||
|
||||
Args:
|
||||
env: The environment's observation and action space to check
|
||||
"""
|
||||
# Helper to link to the code, because gym has no proper documentation
|
||||
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
|
||||
@@ -238,25 +276,22 @@ def _check_spaces(env: gym.Env) -> None:
|
||||
"You must specify an action space (cf gym.spaces)" + gym_spaces
|
||||
)
|
||||
|
||||
assert isinstance(env.observation_space, spaces.Space), (
|
||||
assert isinstance(env.observation_space, Space), (
|
||||
"The observation space must inherit from gym.spaces" + gym_spaces
|
||||
)
|
||||
assert isinstance(env.action_space, spaces.Space), (
|
||||
assert isinstance(env.action_space, Space), (
|
||||
"The action space must inherit from gym.spaces" + gym_spaces
|
||||
)
|
||||
|
||||
|
||||
# Check render cannot be covered by CI
|
||||
def _check_render(
|
||||
env: gym.Env, warn: bool = True, headless: bool = False
|
||||
) -> None: # pragma: no cover
|
||||
"""
|
||||
Check the declared render modes/fps and the `render()`/`close()`
|
||||
method of the environment.
|
||||
:param env: The environment to check
|
||||
:param warn: Whether to output additional warnings
|
||||
:param headless: Whether to disable render modes
|
||||
that require a graphical interface. False by default.
|
||||
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False):
|
||||
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
warn: Whether to output additional warnings
|
||||
headless: Whether to disable render modes that require a graphical interface. False by default.
|
||||
"""
|
||||
render_modes = env.metadata.get("render_modes")
|
||||
if render_modes is None:
|
||||
@@ -288,9 +323,12 @@ def _check_render(
|
||||
env.close()
|
||||
|
||||
|
||||
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
|
||||
"""
|
||||
Check that the environment can be reset with a random seed.
|
||||
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
|
||||
"""Check that the environment can be reset with a seed.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
seed: The optional seed to use
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
assert (
|
||||
@@ -303,7 +341,7 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` "
|
||||
"appear in the signature. This should never happen, please report this issue. "
|
||||
"The error was: " + str(e)
|
||||
f"The error was: {e}"
|
||||
)
|
||||
|
||||
if env.unwrapped.np_random is None:
|
||||
@@ -322,7 +360,12 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_reset_info(env: gym.Env) -> None:
|
||||
def _check_reset_info(env: gym.Env):
|
||||
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
assert (
|
||||
"return_info" in signature.parameters or "kwargs" in signature.parameters
|
||||
@@ -334,7 +377,7 @@ def _check_reset_info(env: gym.Env) -> None:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
|
||||
"appear in the signature. This should never happen, please report this issue. "
|
||||
"The error was: " + str(e)
|
||||
f"The error was: {e}"
|
||||
)
|
||||
assert (
|
||||
len(result) == 2
|
||||
@@ -346,9 +389,11 @@ def _check_reset_info(env: gym.Env) -> None:
|
||||
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
|
||||
|
||||
|
||||
def _check_reset_options(env: gym.Env) -> None:
|
||||
"""
|
||||
Check that the environment can be reset with options.
|
||||
def _check_reset_options(env: gym.Env):
|
||||
"""Check that the environment can be reset with options.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
assert (
|
||||
@@ -361,22 +406,22 @@ def _check_reset_options(env: gym.Env) -> None:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with options, even though `options` or `kwargs` "
|
||||
"appear in the signature. This should never happen, please report this issue. "
|
||||
"The error was: " + str(e)
|
||||
f"The error was: {e}"
|
||||
)
|
||||
|
||||
|
||||
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
|
||||
"""
|
||||
Check that an environment follows Gym API.
|
||||
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True):
|
||||
"""Check that an environment follows Gym API.
|
||||
|
||||
This is particularly useful when using a custom environment.
|
||||
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
|
||||
for more information about the API.
|
||||
It also optionally check that the environment is compatible with Stable-Baselines.
|
||||
:param env: The Gym environment that will be checked
|
||||
:param warn: Whether to output additional warnings
|
||||
mainly related to the interaction with Stable Baselines
|
||||
:param skip_render_check: Whether to skip the checks for the render method.
|
||||
True by default (useful for the CI)
|
||||
|
||||
Args:
|
||||
env: The Gym environment that will be checked
|
||||
warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines
|
||||
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
|
||||
"""
|
||||
assert isinstance(
|
||||
env, gym.Env
|
||||
@@ -393,15 +438,15 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
|
||||
if warn:
|
||||
obs_spaces = (
|
||||
observation_space.spaces
|
||||
if isinstance(observation_space, spaces.Dict)
|
||||
if isinstance(observation_space, Dict)
|
||||
else {"": observation_space}
|
||||
)
|
||||
for key, space in obs_spaces.items():
|
||||
if isinstance(space, spaces.Box):
|
||||
if isinstance(space, Box):
|
||||
_check_box_obs(space, key)
|
||||
|
||||
# Check for the action space, it may lead to hard-to-debug issues
|
||||
if isinstance(action_space, spaces.Box):
|
||||
if isinstance(action_space, Box):
|
||||
_check_box_action(action_space)
|
||||
_check_normalized_action(action_space)
|
||||
|
||||
|
@@ -1,33 +1,35 @@
|
||||
"""Class for pickling and unpickling objects via their constructor arguments."""
|
||||
|
||||
|
||||
class EzPickle:
|
||||
"""Objects that are pickled and unpickled via their constructor
|
||||
arguments.
|
||||
"""Objects that are pickled and unpickled via their constructor arguments.
|
||||
|
||||
Example usage:
|
||||
Example::
|
||||
|
||||
class Dog(Animal, EzPickle):
|
||||
def __init__(self, furcolor, tailkind="bushy"):
|
||||
Animal.__init__()
|
||||
EzPickle.__init__(furcolor, tailkind)
|
||||
...
|
||||
>>> class Dog(Animal, EzPickle):
|
||||
... def __init__(self, furcolor, tailkind="bushy"):
|
||||
... Animal.__init__()
|
||||
... EzPickle.__init__(furcolor, tailkind)
|
||||
|
||||
When this object is unpickled, a new Dog will be constructed by passing the provided
|
||||
furcolor and tailkind into the constructor. However, philosophers are still not sure
|
||||
whether it is still the same dog.
|
||||
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
|
||||
However, philosophers are still not sure whether it is still the same dog.
|
||||
|
||||
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
|
||||
and Atari.
|
||||
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
|
||||
self._ezpickle_args = args
|
||||
self._ezpickle_kwargs = kwargs
|
||||
|
||||
def __getstate__(self):
|
||||
"""Returns the object pickle state with args and kwargs."""
|
||||
return {
|
||||
"_ezpickle_args": self._ezpickle_args,
|
||||
"_ezpickle_kwargs": self._ezpickle_kwargs,
|
||||
}
|
||||
|
||||
def __setstate__(self, d):
|
||||
"""Sets the object pickle state using d."""
|
||||
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
|
||||
self.__dict__.update(out.__dict__)
|
||||
|
@@ -1,11 +1,18 @@
|
||||
"""Utilities of visualising an environment."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pygame
|
||||
from numpy.typing import NDArray
|
||||
from pygame import Surface
|
||||
from pygame.event import Event
|
||||
from pygame.locals import VIDEORESIZE
|
||||
|
||||
from gym import Env, logger
|
||||
from gym.core import ActType, ObsType
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.logger import deprecation
|
||||
|
||||
try:
|
||||
@@ -13,30 +20,31 @@ try:
|
||||
|
||||
matplotlib.use("TkAgg")
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError as e:
|
||||
logger.warn(f"failed to set matplotlib backend, plotting will not work: {str(e)}")
|
||||
plt = None
|
||||
|
||||
from collections import deque
|
||||
|
||||
from pygame.locals import VIDEORESIZE
|
||||
|
||||
from gym.core import ActType
|
||||
except ImportError:
|
||||
logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
|
||||
matplotlib, plt = None, None
|
||||
|
||||
|
||||
class MissingKeysToAction(Exception):
|
||||
"""Raised when the environment does not have
|
||||
a default keys_to_action mapping
|
||||
"""
|
||||
"""Raised when the environment does not have a default ``keys_to_action`` mapping."""
|
||||
|
||||
|
||||
class PlayableGame:
|
||||
"""Wraps an environment allowing keyboard inputs to interact with the environment."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: Env,
|
||||
keys_to_action: Optional[Dict[Tuple[int], int]] = None,
|
||||
keys_to_action: Optional[dict[tuple[int], int]] = None,
|
||||
zoom: Optional[float] = None,
|
||||
):
|
||||
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
||||
|
||||
Args:
|
||||
env: The environment to play
|
||||
keys_to_action: The dictionary of keyboard tuples and action value
|
||||
zoom: If to zoom in on the environment render
|
||||
"""
|
||||
self.env = env
|
||||
self.relevant_keys = self._get_relevant_keys(keys_to_action)
|
||||
self.video_size = self._get_video_size(zoom)
|
||||
@@ -45,7 +53,7 @@ class PlayableGame:
|
||||
self.running = True
|
||||
|
||||
def _get_relevant_keys(
|
||||
self, keys_to_action: Optional[Dict[Tuple[int], int]] = None
|
||||
self, keys_to_action: Optional[dict[tuple[int], int]] = None
|
||||
) -> set:
|
||||
if keys_to_action is None:
|
||||
if hasattr(self.env, "get_keys_to_action"):
|
||||
@@ -60,7 +68,7 @@ class PlayableGame:
|
||||
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
||||
return relevant_keys
|
||||
|
||||
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
||||
def _get_video_size(self, zoom: Optional[float] = None) -> tuple[int, int]:
|
||||
# TODO: this needs to be updated when the render API change goes through
|
||||
rendered = self.env.render(mode="rgb_array")
|
||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||
@@ -70,7 +78,14 @@ class PlayableGame:
|
||||
|
||||
return video_size
|
||||
|
||||
def process_event(self, event: Event) -> None:
|
||||
def process_event(self, event: Event):
|
||||
"""Processes a PyGame event.
|
||||
|
||||
In particular, this function is used to keep track of which buttons are currently pressed and to exit the :func:`play` function when the PyGame window is closed.
|
||||
|
||||
Args:
|
||||
event: The event to process
|
||||
"""
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key in self.relevant_keys:
|
||||
self.pressed_keys.append(event.key)
|
||||
@@ -87,9 +102,17 @@ class PlayableGame:
|
||||
|
||||
|
||||
def display_arr(
|
||||
screen: Surface, arr: NDArray, video_size: Tuple[int, int], transpose: bool
|
||||
screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
|
||||
):
|
||||
arr_min, arr_max = arr.min(), arr.max()
|
||||
"""Displays a numpy array on screen.
|
||||
|
||||
Args:
|
||||
screen: The screen to show the array on
|
||||
arr: The array to show
|
||||
video_size: The video size of the screen
|
||||
transpose: If to transpose the array on the screen
|
||||
"""
|
||||
arr_min, arr_max = np.min(arr), np.max(arr)
|
||||
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
|
||||
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
|
||||
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
||||
@@ -108,60 +131,74 @@ def play(
|
||||
):
|
||||
"""Allows one to play the game using keyboard.
|
||||
|
||||
To simply play the game use:
|
||||
Example::
|
||||
|
||||
play(gym.make("Pong-v4"))
|
||||
>>> import gym
|
||||
>>> from gym.utils.play import play
|
||||
>>> play(gym.make("CarRacing-v1"), keys_to_action={"w": np.array([0, 0.7, 0]),
|
||||
... "a": np.array([-1, 0, 0]),
|
||||
... "s": np.array([0, 0, 1]),
|
||||
... "d": np.array([1, 0, 0]),
|
||||
... "wa": np.array([-1, 0.7, 0]),
|
||||
... "dw": np.array([1, 0.7, 0]),
|
||||
... "ds": np.array([1, 0, 1]),
|
||||
... "as": np.array([-1, 0, 1]),
|
||||
... }, noop=np.array([0,0,0]))
|
||||
|
||||
Above code works also if env is wrapped, so it's particularly useful in
|
||||
|
||||
Above code works also if the environment is wrapped, so it's particularly useful in
|
||||
verifying that the frame-level preprocessing does not render the game
|
||||
unplayable.
|
||||
|
||||
If you wish to plot real time statistics as you play, you can use
|
||||
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
|
||||
for last 5 second of gameplay.
|
||||
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
|
||||
for last 150 steps.
|
||||
|
||||
def callback(obs_t, obs_tp1, action, rew, done, info):
|
||||
return [rew,]
|
||||
plotter = PlayPlot(callback, 30 * 5, ["reward"])
|
||||
|
||||
env = gym.make("Pong-v4")
|
||||
play(env, callback=plotter.callback)
|
||||
>>> def callback(obs_t, obs_tp1, action, rew, done, info):
|
||||
... return [rew,]
|
||||
>>> plotter = PlayPlot(callback, 150, ["reward"])
|
||||
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
env: gym.Env
|
||||
Environment to use for playing.
|
||||
transpose: bool
|
||||
If True the output of observation is transposed.
|
||||
Defaults to true.
|
||||
fps: int
|
||||
Maximum number of steps of the environment to execute every second.
|
||||
Defaults to 30.
|
||||
zoom: float
|
||||
Make screen edge this many times bigger
|
||||
callback: lambda or None
|
||||
Callback if a callback is provided it will be executed after
|
||||
every step. It takes the following input:
|
||||
obs_t: observation before performing action
|
||||
obs_tp1: observation after performing action
|
||||
action: action that was executed
|
||||
rew: reward that was received
|
||||
done: whether the environment is done or not
|
||||
info: debug info
|
||||
keys_to_action: dict: tuple(int) -> int or None
|
||||
Mapping from keys pressed to action performed.
|
||||
For example if pressed 'w' and space at the same time is supposed
|
||||
to trigger action number 2 then key_to_action dict would look like this:
|
||||
|
||||
{
|
||||
# ...
|
||||
sorted(ord('w'), ord(' ')) -> 2
|
||||
# ...
|
||||
}
|
||||
If None, default key_to_action mapping for that env is used, if provided.
|
||||
seed: bool or None
|
||||
Random seed used when resetting the environment. If None, no seed is used.
|
||||
Args:
|
||||
env: Environment to use for playing.
|
||||
transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
|
||||
fps: Maximum number of steps of the environment executed every second. If ``None`` (the default),
|
||||
``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
|
||||
zoom: Zoom the observation in, ``zoom`` amount, should be positive float
|
||||
callback: If a callback is provided, it will be executed after every step. It takes the following input:
|
||||
obs_t: observation before performing action
|
||||
obs_tp1: observation after performing action
|
||||
action: action that was executed
|
||||
rew: reward that was received
|
||||
done: whether the environment is done or not
|
||||
info: debug info
|
||||
keys_to_action: Mapping from keys pressed to action performed.
|
||||
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
|
||||
points of the keys, as a tuple of characters, or as a string where each character of the string represents
|
||||
one key.
|
||||
For example if pressing 'w' and space at the same time is supposed
|
||||
to trigger action number 2 then ``key_to_action`` dict could look like this:
|
||||
>>> {
|
||||
... # ...
|
||||
... (ord('w'), ord(' ')): 2
|
||||
... # ...
|
||||
... }
|
||||
or like this:
|
||||
>>> {
|
||||
... # ...
|
||||
... ("w", " "): 2
|
||||
... # ...
|
||||
... }
|
||||
or like this:
|
||||
>>> {
|
||||
... # ...
|
||||
... "w ": 2
|
||||
... # ...
|
||||
... }
|
||||
If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
|
||||
seed: Random seed used when resetting the environment. If None, no seed is used.
|
||||
noop: The action used when no key input has been entered, or the entered key combination is unknown.
|
||||
"""
|
||||
env.reset(seed=seed)
|
||||
|
||||
@@ -208,7 +245,44 @@ def play(
|
||||
|
||||
|
||||
class PlayPlot:
|
||||
def __init__(self, callback, horizon_timesteps, plot_names):
|
||||
"""Provides a callback to create live plots of arbitrary metrics when using :func:`play`.
|
||||
|
||||
This class is instantiated with a function that accepts information about a single environment transition:
|
||||
- obs_t: observation before performing action
|
||||
- obs_tp1: observation after performing action
|
||||
- action: action that was executed
|
||||
- rew: reward that was received
|
||||
- done: whether the environment is done or not
|
||||
- info: debug info
|
||||
|
||||
It should return a list of metrics that are computed from this data.
|
||||
For instance, the function may look like this::
|
||||
|
||||
def compute_metrics(obs_t, obs_tp, action, reward, done, info):
|
||||
return [reward, info["cumulative_reward"], np.linalg.norm(action)]
|
||||
|
||||
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
|
||||
and uses the returned values to update live plots of the metrics.
|
||||
|
||||
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
|
||||
|
||||
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
|
||||
>>> play(your_env, callback=plotter.callback)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, callback: callable, horizon_timesteps: int, plot_names: list[str]
|
||||
):
|
||||
"""Constructor of :class:`PlayPlot`.
|
||||
|
||||
The function ``callback`` that is passed to this constructor should return
|
||||
a list of metrics that is of length ``len(plot_names)``.
|
||||
|
||||
Args:
|
||||
callback: Function that computes metrics from environment transitions
|
||||
horizon_timesteps: The time horizon used for the live plots
|
||||
plot_names: List of plot titles
|
||||
"""
|
||||
deprecation(
|
||||
"`PlayPlot` is marked as deprecated and will be removed in the near future."
|
||||
)
|
||||
@@ -216,7 +290,10 @@ class PlayPlot:
|
||||
self.horizon_timesteps = horizon_timesteps
|
||||
self.plot_names = plot_names
|
||||
|
||||
assert plt is not None, "matplotlib backend failed, plotting will not work"
|
||||
if plt is None:
|
||||
raise DependencyNotInstalled(
|
||||
"matplotlib is not installed, run `pip install gym[other]`"
|
||||
)
|
||||
|
||||
num_plots = len(self.plot_names)
|
||||
self.fig, self.ax = plt.subplots(num_plots)
|
||||
@@ -228,7 +305,25 @@ class PlayPlot:
|
||||
self.cur_plot = [None for _ in range(num_plots)]
|
||||
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
||||
|
||||
def callback(self, obs_t, obs_tp1, action, rew, done, info):
|
||||
def callback(
|
||||
self,
|
||||
obs_t: ObsType,
|
||||
obs_tp1: ObsType,
|
||||
action: ActType,
|
||||
rew: float,
|
||||
done: bool,
|
||||
info: dict,
|
||||
):
|
||||
"""The callback that calls the provided data callback and adds the data to the plots.
|
||||
|
||||
Args:
|
||||
obs_t: The observation at time step t
|
||||
obs_tp1: The observation at time step t+1
|
||||
action: The action
|
||||
rew: The reward
|
||||
done: If the environment is done
|
||||
info: The information from the environment
|
||||
"""
|
||||
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
|
||||
for point, data_series in zip(points, self.data):
|
||||
data_series.append(point)
|
||||
|
@@ -1,7 +1,10 @@
|
||||
"""Set of random number generator functions: seeding, generator, hashing seeds."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import struct
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -9,7 +12,15 @@ from gym import error
|
||||
from gym.logger import deprecation
|
||||
|
||||
|
||||
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
|
||||
def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]:
|
||||
"""Generates a random number generator from the seed and returns the Generator and seed.
|
||||
|
||||
Args:
|
||||
seed: The seed used to create the generator
|
||||
|
||||
Returns:
|
||||
The generator and resulting seed
|
||||
"""
|
||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
||||
|
||||
@@ -22,7 +33,10 @@ def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]
|
||||
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
|
||||
# RandomNumberGenerator = np.random.Generator
|
||||
class RandomNumberGenerator(np.random.Generator):
|
||||
"""Random number generator class that inherits from numpy's random Generator class."""
|
||||
|
||||
def rand(self, *size):
|
||||
"""Deprecated rand function using random."""
|
||||
deprecation(
|
||||
"Function `rng.rand(*size)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
@@ -34,6 +48,7 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
random_sample = rand
|
||||
|
||||
def randn(self, *size):
|
||||
"""Deprecated random standard normal function use standard_normal."""
|
||||
deprecation(
|
||||
"Function `rng.randn(*size)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
@@ -43,6 +58,7 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
return self.standard_normal(size)
|
||||
|
||||
def randint(self, low, high=None, size=None, dtype=int):
|
||||
"""Deprecated random integer function use integers."""
|
||||
deprecation(
|
||||
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
@@ -54,6 +70,7 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
random_integers = randint
|
||||
|
||||
def get_state(self):
|
||||
"""Deprecated get rng state use bit_generator.state."""
|
||||
deprecation(
|
||||
"Function `rng.get_state()` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
@@ -63,6 +80,7 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
return self.bit_generator.state
|
||||
|
||||
def set_state(self, state):
|
||||
"""Deprecated set rng state function use bit_generator.state = state."""
|
||||
deprecation(
|
||||
"Function `rng.set_state(state)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
@@ -72,6 +90,7 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
self.bit_generator.state = state
|
||||
|
||||
def seed(self, seed=None):
|
||||
"""Deprecated seed function use gym.utils.seeding.np_random(seed)."""
|
||||
deprecation(
|
||||
"Function `rng.seed(seed)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
@@ -88,6 +107,7 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
seed.__doc__ = np.random.seed.__doc__
|
||||
|
||||
def __reduce__(self):
|
||||
"""Reduces the Random Number Generator to a RandomNumberGenerator, init_args and additional args."""
|
||||
# np.random.Generator defines __reduce__, but it's hard-coded to
|
||||
# return a Generator instead of its subclass RandomNumberGenerator.
|
||||
# We need to override it here, otherwise sampling from a Space will
|
||||
@@ -119,20 +139,21 @@ RNG = RandomNumberGenerator
|
||||
|
||||
|
||||
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
|
||||
"""Any given evaluation is likely to have many PRNG's active at
|
||||
once. (Most commonly, because the environment is running in
|
||||
multiple processes.) There's literature indicating that having
|
||||
linear correlations between seeds of multiple PRNG's can correlate
|
||||
the outputs:
|
||||
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
||||
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
||||
http://dl.acm.org/citation.cfm?id=1276928
|
||||
Thus, for sanity we hash the seeds before using them. (This scheme
|
||||
is likely not crypto-strength, but it should be good enough to get
|
||||
rid of simple correlations.)
|
||||
"""Any given evaluation is likely to have many PRNG's active at once.
|
||||
|
||||
(Most commonly, because the environment is running in multiple processes.)
|
||||
There's literature indicating that having linear correlations between seeds of multiple PRNG's can correlate the outputs:
|
||||
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
||||
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
||||
http://dl.acm.org/citation.cfm?id=1276928
|
||||
Thus, for sanity we hash the seeds before using them. (This scheme is likely not crypto-strength, but it should be good enough to get rid of simple correlations.)
|
||||
|
||||
Args:
|
||||
seed: None seeds from an operating system specific randomness source.
|
||||
max_bytes: Maximum number of bytes to use in the hashed seed.
|
||||
|
||||
Returns:
|
||||
The hashed seed
|
||||
"""
|
||||
deprecation(
|
||||
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||
@@ -144,12 +165,16 @@ def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
|
||||
|
||||
|
||||
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
|
||||
"""Create a strong random seed. Otherwise, Python 2 would seed using
|
||||
the system time, which might be non-robust especially in the
|
||||
presence of concurrency.
|
||||
"""Create a strong random seed.
|
||||
|
||||
Otherwise, Python 2 would seed using the system time, which might be non-robust especially in the presence of concurrency.
|
||||
|
||||
Args:
|
||||
a: None seeds from an operating system specific randomness source.
|
||||
max_bytes: Maximum number of bytes to use in the seed.
|
||||
|
||||
Returns:
|
||||
A seed
|
||||
"""
|
||||
deprecation(
|
||||
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||
@@ -185,7 +210,7 @@ def _bigint_from_bytes(bt: bytes) -> int:
|
||||
return accum
|
||||
|
||||
|
||||
def _int_list_from_bigint(bigint: int) -> List[int]:
|
||||
def _int_list_from_bigint(bigint: int) -> list[int]:
|
||||
deprecation(
|
||||
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
||||
)
|
||||
@@ -195,7 +220,7 @@ def _int_list_from_bigint(bigint: int) -> List[int]:
|
||||
elif bigint == 0:
|
||||
return [0]
|
||||
|
||||
ints: List[int] = []
|
||||
ints: list[int] = []
|
||||
while bigint > 0:
|
||||
bigint, mod = divmod(bigint, 2**32)
|
||||
ints.append(mod)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
try:
|
||||
from collections.abc import Iterable
|
||||
except ImportError:
|
||||
Iterable = (tuple, list)
|
||||
"""Module for vector environments."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||
@@ -10,40 +10,34 @@ from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
|
||||
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
|
||||
|
||||
|
||||
def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
|
||||
"""Create a vectorized environment from multiple copies of an environment,
|
||||
from its id.
|
||||
def make(
|
||||
id: str,
|
||||
num_envs: int = 1,
|
||||
asynchronous: bool = True,
|
||||
wrappers: Optional[Union[callable, list[callable]]] = None,
|
||||
**kwargs,
|
||||
) -> VectorEnv:
|
||||
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
id : str
|
||||
The environment ID. This must be a valid ID from the registry.
|
||||
Example::
|
||||
|
||||
num_envs : int
|
||||
Number of copies of the environment.
|
||||
>>> import gym
|
||||
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
|
||||
>>> env.reset()
|
||||
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
|
||||
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
|
||||
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
|
||||
dtype=float32)
|
||||
|
||||
asynchronous : bool
|
||||
If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses
|
||||
`multiprocessing`_ to run the environments in parallel). If ``False``,
|
||||
wraps the environments in a :class:`SyncVectorEnv`.
|
||||
Args:
|
||||
id: The environment ID. This must be a valid ID from the registry.
|
||||
num_envs: Number of copies of the environment.
|
||||
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
|
||||
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
|
||||
**kwargs: Keywords arguments applied during gym.make
|
||||
|
||||
wrappers : callable, or iterable of callables, optional
|
||||
If not ``None``, then apply the wrappers to each internal
|
||||
environment during creation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`gym.vector.VectorEnv`
|
||||
Returns:
|
||||
The vectorized environment.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
|
||||
>>> env.reset()
|
||||
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
|
||||
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
|
||||
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
|
||||
dtype=float32)
|
||||
"""
|
||||
from gym.envs import make as make_
|
||||
|
||||
|
@@ -1,13 +1,18 @@
|
||||
"""An async vector environment."""
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym import logger
|
||||
from gym.core import ObsType
|
||||
from gym.error import (
|
||||
AlreadyPendingCallError,
|
||||
ClosedEnvironmentError,
|
||||
@@ -37,69 +42,13 @@ class AsyncState(Enum):
|
||||
|
||||
|
||||
class AsyncVectorEnv(VectorEnv):
|
||||
"""Vectorized environment that runs multiple environments in parallel. It
|
||||
uses `multiprocessing`_ processes, and pipes for communication.
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_fns : iterable of callable
|
||||
Functions that create the environments.
|
||||
It uses ``multiprocessing`` processes, and pipes for communication.
|
||||
|
||||
observation_space : :class:`gym.spaces.Space`, optional
|
||||
Observation space of a single environment. If ``None``, then the
|
||||
observation space of the first environment is taken.
|
||||
|
||||
action_space : :class:`gym.spaces.Space`, optional
|
||||
Action space of a single environment. If ``None``, then the action space
|
||||
of the first environment is taken.
|
||||
|
||||
shared_memory : bool
|
||||
If ``True``, then the observations from the worker processes are
|
||||
communicated back through shared variables. This can improve the
|
||||
efficiency if the observations are large (e.g. images).
|
||||
|
||||
copy : bool
|
||||
If ``True``, then the :meth:`~AsyncVectorEnv.reset` and
|
||||
:meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
|
||||
|
||||
context : str, optional
|
||||
Context for `multiprocessing`_. If ``None``, then the default context is used.
|
||||
|
||||
daemon : bool
|
||||
If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they
|
||||
will quit if the head process quits. However, ``daemon=True`` prevents
|
||||
subprocesses to spawn children, so for some environments you may want
|
||||
to have it set to ``False``.
|
||||
|
||||
worker : callable, optional
|
||||
If set, then use that worker in a subprocess instead of a default one.
|
||||
Can be useful to override some inner vector env logic, for instance,
|
||||
how resets on done are handled.
|
||||
|
||||
Warning
|
||||
-------
|
||||
:attr:`worker` is an advanced mode option. It provides a high degree of
|
||||
flexibility and a high chance to shoot yourself in the foot; thus,
|
||||
if you are writing your own worker, it is recommended to start from the code
|
||||
for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If the observation space of some sub-environment does not match
|
||||
:obj:`observation_space` (or, by default, the observation space of
|
||||
the first sub-environment).
|
||||
|
||||
ValueError
|
||||
If :obj:`observation_space` is a custom space (i.e. not a default
|
||||
space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`,
|
||||
or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block::
|
||||
Example::
|
||||
|
||||
>>> import gym
|
||||
>>> env = gym.vector.AsyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v0", g=9.81),
|
||||
... lambda: gym.make("Pendulum-v0", g=1.62)
|
||||
@@ -111,15 +60,33 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns,
|
||||
observation_space=None,
|
||||
action_space=None,
|
||||
shared_memory=True,
|
||||
copy=True,
|
||||
context=None,
|
||||
daemon=True,
|
||||
worker=None,
|
||||
env_fns: Sequence[callable],
|
||||
observation_space: Optional[gym.Space] = None,
|
||||
action_space: Optional[gym.Space] = None,
|
||||
shared_memory: bool = True,
|
||||
copy: bool = True,
|
||||
context: Optional[str] = None,
|
||||
daemon: bool = True,
|
||||
worker: Optional[callable] = None,
|
||||
):
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
Args:
|
||||
env_fns: Functions that create the environments.
|
||||
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
|
||||
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
|
||||
shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images).
|
||||
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
|
||||
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
|
||||
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``.
|
||||
worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
|
||||
|
||||
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
|
||||
ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True.
|
||||
"""
|
||||
ctx = mp.get_context(context)
|
||||
self.env_fns = env_fns
|
||||
self.shared_memory = shared_memory
|
||||
@@ -192,6 +159,11 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._check_spaces()
|
||||
|
||||
def seed(self, seed=None):
|
||||
"""Seeds the vector environments.
|
||||
|
||||
Args:
|
||||
seed: The seeds use with the environments
|
||||
"""
|
||||
super().seed(seed=seed)
|
||||
self._assert_is_running()
|
||||
if seed is None:
|
||||
@@ -213,22 +185,24 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
seed: Optional[Union[int, list[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Send the calls to :obj:`reset` to each sub-environment.
|
||||
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
||||
|
||||
Raises
|
||||
------
|
||||
ClosedEnvironmentError
|
||||
If the environment was closed (if :meth:`close` was previously called).
|
||||
To get the results of these calls, you may invoke :meth:`reset_wait`.
|
||||
|
||||
AlreadyPendingCallError
|
||||
If the environment is already waiting for a pending call to another
|
||||
method (e.g. :meth:`step_async`). This can be caused by two consecutive
|
||||
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in
|
||||
between.
|
||||
Args:
|
||||
seed: List of seeds for each environment
|
||||
return_info: If to return information
|
||||
options: The reset option
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
|
||||
method (e.g. :meth:`step_async`). This can be caused by two consecutive
|
||||
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
|
||||
@@ -258,37 +232,26 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
def reset_wait(
|
||||
self,
|
||||
timeout=None,
|
||||
timeout: Optional[Union[int, float]] = None,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to `reset_wait` times out. If
|
||||
`None`, the call to `reset_wait` never times out.
|
||||
seed: ignored
|
||||
options: ignored
|
||||
) -> Union[ObsType, tuple[ObsType, list[dict]]]:
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
element of :attr:`~VectorEnv.observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
infos : list of dicts containing metadata
|
||||
Args:
|
||||
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
|
||||
seed: ignored
|
||||
return_info: If to return information
|
||||
options: ignored
|
||||
|
||||
Raises
|
||||
------
|
||||
ClosedEnvironmentError
|
||||
If the environment was closed (if :meth:`close` was previously called).
|
||||
Returns:
|
||||
A tuple of batched observations and list of dictionaries
|
||||
|
||||
NoAsyncCallError
|
||||
If :meth:`reset_wait` was called without any prior call to
|
||||
:meth:`reset_async`.
|
||||
|
||||
TimeoutError
|
||||
If :meth:`reset_wait` timed out.
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
|
||||
TimeoutError: If :meth:`reset_wait` timed out.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_RESET:
|
||||
@@ -327,24 +290,18 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
|
||||
def step_async(self, actions):
|
||||
def step_async(self, actions: np.ndarray):
|
||||
"""Send the calls to :obj:`step` to each sub-environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
actions : element of :attr:`~VectorEnv.action_space`
|
||||
Batch of actions.
|
||||
Args:
|
||||
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
|
||||
|
||||
Raises
|
||||
------
|
||||
ClosedEnvironmentError
|
||||
If the environment was closed (if :meth:`close` was previously called).
|
||||
|
||||
AlreadyPendingCallError
|
||||
If the environment is already waiting for a pending call to another
|
||||
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
|
||||
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
|
||||
between.
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
|
||||
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
|
||||
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
|
||||
between.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
@@ -358,40 +315,21 @@ class AsyncVectorEnv(VectorEnv):
|
||||
pipe.send(("step", action))
|
||||
self._state = AsyncState.WAITING_STEP
|
||||
|
||||
def step_wait(self, timeout=None):
|
||||
def step_wait(
|
||||
self, timeout: Optional[Union[int, float]] = None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[dict]]:
|
||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to :meth:`step_wait` times out. If
|
||||
``None``, the call to :meth:`step_wait` never times out.
|
||||
Args:
|
||||
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : element of :attr:`~VectorEnv.observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
Returns:
|
||||
The batched environment step information, obs, reward, done and info
|
||||
|
||||
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
|
||||
A vector of rewards from the vectorized environment.
|
||||
|
||||
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
|
||||
A vector whose entries indicate whether the episode has ended.
|
||||
|
||||
infos : list of dict
|
||||
A list of auxiliary diagnostic information dicts from sub-environments.
|
||||
|
||||
Raises
|
||||
------
|
||||
ClosedEnvironmentError
|
||||
If the environment was closed (if :meth:`close` was previously called).
|
||||
|
||||
NoAsyncCallError
|
||||
If :meth:`step_wait` was called without any prior call to
|
||||
:meth:`step_async`.
|
||||
|
||||
TimeoutError
|
||||
If :meth:`step_wait` timed out.
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
|
||||
TimeoutError: If :meth:`step_wait` timed out.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_STEP:
|
||||
@@ -425,18 +363,13 @@ class AsyncVectorEnv(VectorEnv):
|
||||
infos,
|
||||
)
|
||||
|
||||
def call_async(self, name, *args, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the method or property to call.
|
||||
def call_async(self, name: str, *args, **kwargs):
|
||||
"""Calls the method with name asynchronously and apply args and kwargs to the method.
|
||||
|
||||
*args
|
||||
Arguments to apply to the method call.
|
||||
|
||||
**kwargs
|
||||
Keywoard arguments to apply to the method call.
|
||||
Args:
|
||||
name: Name of the method or property to call.
|
||||
*args: Arguments to apply to the method call.
|
||||
**kwargs: Keyword arguments to apply to the method call.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
@@ -450,19 +383,14 @@ class AsyncVectorEnv(VectorEnv):
|
||||
pipe.send(("_call", (name, args, kwargs)))
|
||||
self._state = AsyncState.WAITING_CALL
|
||||
|
||||
def call_wait(self, timeout=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to `step_wait` times out. If
|
||||
`None` (default), the call to `step_wait` never times out.
|
||||
def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list:
|
||||
"""Calls all parent pipes and waits for the results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
results : list
|
||||
List of the results of the individual calls to the method or
|
||||
property for each environment.
|
||||
Args:
|
||||
timeout: Number of seconds before the call to `step_wait` times out. If `None` (default), the call to `step_wait` never times out.
|
||||
|
||||
Returns:
|
||||
List of the results of the individual calls to the method or property for each environment.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_CALL:
|
||||
@@ -483,17 +411,14 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
return results
|
||||
|
||||
def set_attr(self, name, values):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the property to be set in each individual environment.
|
||||
def set_attr(self, name: str, values: Union[list, tuple, object]):
|
||||
"""Sets an attribute of the sub-environments.
|
||||
|
||||
values : list, tuple, or object
|
||||
Values of the property to be set to. If `values` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise a single value is set for all environments.
|
||||
Args:
|
||||
name: Name of the property to be set in each individual environment.
|
||||
values: Values of the property to be set to. If ``values`` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise a single value is set for all environments.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if not isinstance(values, (list, tuple)):
|
||||
@@ -517,25 +442,19 @@ class AsyncVectorEnv(VectorEnv):
|
||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
|
||||
def close_extras(self, timeout=None, terminate=False):
|
||||
"""Close the environments & clean up the extra resources
|
||||
(processes and pipes).
|
||||
def close_extras(
|
||||
self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
|
||||
):
|
||||
"""Close the environments & clean up the extra resources (processes and pipes).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to :meth:`close` times out. If ``None``,
|
||||
the call to :meth:`close` never times out. If the call to :meth:`close`
|
||||
times out, then all processes are terminated.
|
||||
Args:
|
||||
timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
|
||||
the call to :meth:`close` never times out. If the call to :meth:`close`
|
||||
times out, then all processes are terminated.
|
||||
terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
|
||||
|
||||
terminate : bool
|
||||
If ``True``, then the :meth:`close` operation is forced and all processes
|
||||
are terminated.
|
||||
|
||||
Raises
|
||||
------
|
||||
TimeoutError
|
||||
If :meth:`close` timed out.
|
||||
Raises:
|
||||
TimeoutError: If :meth:`close` timed out.
|
||||
"""
|
||||
timeout = 0 if terminate else timeout
|
||||
try:
|
||||
@@ -626,6 +545,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
raise exctype(value)
|
||||
|
||||
def __del__(self):
|
||||
"""On deleting the object, checks that the vector environment is closed."""
|
||||
if not getattr(self, "closed", True) and hasattr(self, "_state"):
|
||||
self.close(terminate=True)
|
||||
|
||||
|
@@ -1,8 +1,12 @@
|
||||
"""A synchronous vector environment."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Iterator, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Space
|
||||
from gym.vector.utils import concatenate, create_empty_array, iterate
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
|
||||
@@ -12,35 +16,9 @@ __all__ = ["SyncVectorEnv"]
|
||||
class SyncVectorEnv(VectorEnv):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_fns : iterable of callable
|
||||
Functions that create the environments.
|
||||
|
||||
observation_space : :class:`gym.spaces.Space`, optional
|
||||
Observation space of a single environment. If ``None``, then the
|
||||
observation space of the first environment is taken.
|
||||
|
||||
action_space : :class:`gym.spaces.Space`, optional
|
||||
Action space of a single environment. If ``None``, then the action space
|
||||
of the first environment is taken.
|
||||
|
||||
copy : bool
|
||||
If ``True``, then the :meth:`reset` and :meth:`step` methods return a
|
||||
copy of the observations.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If the observation space of some sub-environment does not match
|
||||
:obj:`observation_space` (or, by default, the observation space of
|
||||
the first sub-environment).
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block::
|
||||
Example::
|
||||
|
||||
>>> import gym
|
||||
>>> env = gym.vector.SyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v0", g=9.81),
|
||||
... lambda: gym.make("Pendulum-v0", g=1.62)
|
||||
@@ -50,7 +28,24 @@ class SyncVectorEnv(VectorEnv):
|
||||
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns, observation_space=None, action_space=None, copy=True):
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: Iterator[callable],
|
||||
observation_space: Space = None,
|
||||
action_space: Space = None,
|
||||
copy: bool = True,
|
||||
):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Args:
|
||||
env_fns: iterable of callable functions that create the environments.
|
||||
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
|
||||
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
|
||||
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
|
||||
"""
|
||||
self.env_fns = env_fns
|
||||
self.envs = [env_fn() for env_fn in env_fns]
|
||||
self.copy = copy
|
||||
@@ -60,7 +55,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
observation_space = observation_space or self.envs[0].observation_space
|
||||
action_space = action_space or self.envs[0].action_space
|
||||
super().__init__(
|
||||
num_envs=len(env_fns),
|
||||
num_envs=len(self.envs),
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
)
|
||||
@@ -73,7 +68,12 @@ class SyncVectorEnv(VectorEnv):
|
||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._actions = None
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
|
||||
"""Sets the seed in all sub-environments.
|
||||
|
||||
Args:
|
||||
seed: The seed
|
||||
"""
|
||||
super().seed(seed=seed)
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
@@ -86,10 +86,20 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
def reset_wait(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
seed: Optional[Union[int, list[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
|
||||
Args:
|
||||
seed: The reset environment seed
|
||||
return_info: If to return information
|
||||
options: Option information for the environment reset
|
||||
|
||||
Returns:
|
||||
The reset observation of the environment and reset information
|
||||
"""
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
@@ -128,9 +138,15 @@ class SyncVectorEnv(VectorEnv):
|
||||
), data_list
|
||||
|
||||
def step_async(self, actions):
|
||||
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
||||
self._actions = iterate(self.action_space, actions)
|
||||
|
||||
def step_wait(self):
|
||||
"""Steps through each of the environments returning the batched results.
|
||||
|
||||
Returns:
|
||||
The batched environment step results
|
||||
"""
|
||||
observations, infos = [], []
|
||||
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
||||
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
||||
@@ -150,7 +166,17 @@ class SyncVectorEnv(VectorEnv):
|
||||
infos,
|
||||
)
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
def call(self, name, *args, **kwargs) -> tuple:
|
||||
"""Calls the method with name and applies args and kwargs.
|
||||
|
||||
Args:
|
||||
name: The method name
|
||||
*args: The method args
|
||||
**kwargs: The method kwargs
|
||||
|
||||
Returns:
|
||||
Tuple of results
|
||||
"""
|
||||
results = []
|
||||
for env in self.envs:
|
||||
function = getattr(env, name)
|
||||
@@ -161,7 +187,15 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
return tuple(results)
|
||||
|
||||
def set_attr(self, name, values):
|
||||
def set_attr(self, name: str, values: Union[list, tuple, Any]):
|
||||
"""Sets an attribute of the sub-environments.
|
||||
|
||||
Args:
|
||||
name: The property name to change
|
||||
values: Values of the property to be set to. If ``values`` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise, a single value is set for all environments.
|
||||
"""
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values for _ in range(self.num_envs)]
|
||||
if len(values) != self.num_envs:
|
||||
@@ -178,7 +212,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
"""Close the environments."""
|
||||
[env.close() for env in self.envs]
|
||||
|
||||
def _check_spaces(self):
|
||||
def _check_spaces(self) -> bool:
|
||||
for env in self.envs:
|
||||
if not (env.observation_space == self.single_observation_space):
|
||||
raise RuntimeError(
|
||||
@@ -194,5 +228,4 @@ class SyncVectorEnv(VectorEnv):
|
||||
"action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
else:
|
||||
return True
|
||||
return True
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""Module for gym vector utils."""
|
||||
from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
|
||||
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||
from gym.vector.utils.shared_memory import (
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""Miscellaneous utilities."""
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
@@ -5,28 +6,35 @@ __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
|
||||
|
||||
|
||||
class CloudpickleWrapper:
|
||||
def __init__(self, fn):
|
||||
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
|
||||
|
||||
def __init__(self, fn: callable):
|
||||
"""Cloudpickle wrapper for a function."""
|
||||
self.fn = fn
|
||||
|
||||
def __getstate__(self):
|
||||
"""Get the state using `cloudpickle.dumps(self.fn)`."""
|
||||
import cloudpickle
|
||||
|
||||
return cloudpickle.dumps(self.fn)
|
||||
|
||||
def __setstate__(self, ob):
|
||||
"""Sets the state with obs."""
|
||||
import pickle
|
||||
|
||||
self.fn = pickle.loads(ob)
|
||||
|
||||
def __call__(self):
|
||||
"""Calls the function `self.fn` with no arguments."""
|
||||
return self.fn()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def clear_mpi_env_vars():
|
||||
"""
|
||||
`from mpi4py import MPI` will call `MPI_Init` by default. If the child
|
||||
process has MPI environment variables, MPI will think that the child process
|
||||
"""Clears the MPI of environment variables.
|
||||
|
||||
`from mpi4py import MPI` will call `MPI_Init` by default.
|
||||
If the child process has MPI environment variables, MPI will think that the child process
|
||||
is an MPI process just like the parent and do bad things such as hang.
|
||||
|
||||
This context manager is a hacky way to clear those environment variables
|
||||
|
@@ -1,5 +1,7 @@
|
||||
"""Numpy utility functions: concatenate space samples and create empty array."""
|
||||
from collections import OrderedDict
|
||||
from functools import singledispatch
|
||||
from typing import Iterable, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -9,36 +11,29 @@ __all__ = ["concatenate", "create_empty_array"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def concatenate(space, items, out):
|
||||
def concatenate(
|
||||
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
|
||||
) -> Union[tuple, dict, np.ndarray]:
|
||||
"""Concatenate multiple samples from space into a single object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
items : iterable of samples of `space`
|
||||
Samples to be concatenated.
|
||||
Example::
|
||||
|
||||
out : tuple, dict, or `np.ndarray`
|
||||
>>> from gym.spaces import Box
|
||||
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
|
||||
>>> out = np.zeros((2, 3), dtype=np.float32)
|
||||
>>> items = [space.sample() for _ in range(2)]
|
||||
>>> concatenate(space, items, out)
|
||||
array([[0.6348213 , 0.28607962, 0.60760117],
|
||||
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
items: Samples to be concatenated.
|
||||
out: The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Returns:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : tuple, dict, or `np.ndarray`
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box
|
||||
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
|
||||
>>> out = np.zeros((2, 3), dtype=np.float32)
|
||||
>>> items = [space.sample() for _ in range(2)]
|
||||
>>> concatenate(items, out, space)
|
||||
array([[0.6348213 , 0.28607962, 0.60760117],
|
||||
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
|
||||
"""
|
||||
assert isinstance(items, (list, tuple))
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
||||
)
|
||||
@@ -76,38 +71,30 @@ def _concatenate_custom(space, items, out):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def create_empty_array(space, n=1, fn=np.zeros):
|
||||
def create_empty_array(
|
||||
space: Space, n: int = 1, fn: callable = np.zeros
|
||||
) -> Union[tuple, dict, np.ndarray]:
|
||||
"""Create an empty (possibly nested) numpy array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
Example::
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment. If `None`, creates
|
||||
an empty sample from `space`.
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> create_empty_array(space, n=2, fn=np.zeros)
|
||||
OrderedDict([('position', array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)),
|
||||
('velocity', array([[0., 0.],
|
||||
[0., 0.]], dtype=float32))])
|
||||
|
||||
fn : callable
|
||||
Function to apply when creating the empty numpy array. Examples of such
|
||||
functions are `np.empty` or `np.zeros`.
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
|
||||
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : tuple, dict, or `np.ndarray`
|
||||
Returns:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> create_empty_array(space, n=2, fn=np.zeros)
|
||||
OrderedDict([('position', array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)),
|
||||
('velocity', array([[0., 0.],
|
||||
[0., 0.]], dtype=float32))])
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
||||
|
@@ -1,44 +1,40 @@
|
||||
"""Utility functions for vector environments to share memory between processes."""
|
||||
import multiprocessing as mp
|
||||
from collections import OrderedDict
|
||||
from ctypes import c_bool
|
||||
from functools import singledispatch
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.error import CustomSpaceError
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
||||
|
||||
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def create_shared_memory(space, n=1, ctx=mp):
|
||||
"""Create a shared memory object, to be shared across processes. This
|
||||
eventually contains the observations from the vectorized environment.
|
||||
def create_shared_memory(
|
||||
space: Space, n: int = 1, ctx=mp
|
||||
) -> Union[dict, tuple, mp.Array]:
|
||||
"""Create a shared memory object, to be shared across processes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
This eventually contains the observations from the vectorized environment.
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment (i.e. the number
|
||||
of processes).
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment (i.e. the number of processes).
|
||||
ctx: The multiprocess module
|
||||
|
||||
ctx : `multiprocessing` context
|
||||
Context for multiprocessing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
||||
Shared object across processes.
|
||||
Returns:
|
||||
shared_memory for the shared object across processes.
|
||||
"""
|
||||
raise CustomSpaceError(
|
||||
"Cannot create a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
f"type `{type(space)}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
"Gym spaces."
|
||||
)
|
||||
|
||||
|
||||
@@ -46,7 +42,7 @@ def create_shared_memory(space, n=1, ctx=mp):
|
||||
@create_shared_memory.register(Discrete)
|
||||
@create_shared_memory.register(MultiDiscrete)
|
||||
@create_shared_memory.register(MultiBinary)
|
||||
def _create_base_shared_memory(space, n=1, ctx=mp):
|
||||
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
|
||||
dtype = space.dtype.char
|
||||
if dtype in "?":
|
||||
dtype = c_bool
|
||||
@@ -54,7 +50,7 @@ def _create_base_shared_memory(space, n=1, ctx=mp):
|
||||
|
||||
|
||||
@create_shared_memory.register(Tuple)
|
||||
def _create_tuple_shared_memory(space, n=1, ctx=mp):
|
||||
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
|
||||
return tuple(
|
||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
||||
)
|
||||
@@ -71,39 +67,32 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def read_from_shared_memory(space, shared_memory, n=1):
|
||||
def read_from_shared_memory(
|
||||
space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1
|
||||
) -> Union[dict, tuple, np.ndarray]:
|
||||
"""Read the batch of observations from shared memory as a numpy array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
||||
Shared object across processes. This contains the observations from the
|
||||
vectorized environment. This object is created with `create_shared_memory`.
|
||||
..notes::
|
||||
The numpy array objects returned by `read_from_shared_memory` shares the
|
||||
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
|
||||
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
|
||||
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
|
||||
This object is created with `create_shared_memory`.
|
||||
n: Number of environments in the vectorized environment (i.e. the number of processes).
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment (i.e. the number
|
||||
of processes).
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : dict, tuple or `np.ndarray` instance
|
||||
Returns:
|
||||
Batch of observations as a (possibly nested) numpy array.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The numpy array objects returned by `read_from_shared_memory` shares the
|
||||
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
|
||||
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
|
||||
"""
|
||||
raise CustomSpaceError(
|
||||
"Cannot read from a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
f"type `{type(space)}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
"Gym spaces."
|
||||
)
|
||||
|
||||
|
||||
@@ -111,14 +100,14 @@ def read_from_shared_memory(space, shared_memory, n=1):
|
||||
@read_from_shared_memory.register(Discrete)
|
||||
@read_from_shared_memory.register(MultiDiscrete)
|
||||
@read_from_shared_memory.register(MultiBinary)
|
||||
def _read_base_from_shared_memory(space, shared_memory, n=1):
|
||||
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
||||
(n,) + space.shape
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Tuple)
|
||||
def _read_tuple_from_shared_memory(space, shared_memory, n=1):
|
||||
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
return tuple(
|
||||
read_from_shared_memory(subspace, memory, n=n)
|
||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||
@@ -126,7 +115,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n=1):
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Dict)
|
||||
def _read_dict_from_shared_memory(space, shared_memory, n=1):
|
||||
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
||||
@@ -136,34 +125,26 @@ def _read_dict_from_shared_memory(space, shared_memory, n=1):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def write_to_shared_memory(space, index, value, shared_memory):
|
||||
def write_to_shared_memory(
|
||||
space: Space,
|
||||
index: int,
|
||||
value: np.ndarray,
|
||||
shared_memory: Union[dict, tuple, mp.Array],
|
||||
):
|
||||
"""Write the observation of a single environment into shared memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int
|
||||
Index of the environment (must be in `[0, num_envs)`).
|
||||
|
||||
value : sample from `space`
|
||||
Observation of the single environment to write to shared memory.
|
||||
|
||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
||||
Shared object across processes. This contains the observations from the
|
||||
vectorized environment. This object is created with `create_shared_memory`.
|
||||
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`None`
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
index: Index of the environment (must be in `[0, num_envs)`).
|
||||
value: Observation of the single environment to write to shared memory.
|
||||
shared_memory: Shared object across processes. This contains the observations from the vectorized environment. This object is created with `create_shared_memory`.
|
||||
"""
|
||||
raise CustomSpaceError(
|
||||
"Cannot write to a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
f"type `{type(space)}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
"Gym spaces."
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
"""Utility functions for gym spaces: batch space and iterator."""
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -12,32 +14,25 @@ __all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def batch_space(space, n=1):
|
||||
def batch_space(space: Space, n: int = 1) -> Space:
|
||||
"""Create a (batched) space, containing multiple copies of a single space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Space (e.g. the observation space) for a single environment in the
|
||||
vectorized environment.
|
||||
Example::
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment.
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
|
||||
... })
|
||||
>>> batch_space(space, n=5)
|
||||
Dict(position:Box(5, 3), velocity:Box(5, 2))
|
||||
|
||||
Returns
|
||||
-------
|
||||
batched_space : `gym.spaces.Space` instance
|
||||
Space (e.g. the observation space) for a batch of environments in the
|
||||
vectorized environment.
|
||||
Args:
|
||||
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> batch_space(space, n=5)
|
||||
Dict(position:Box(5, 3), velocity:Box(5, 2))
|
||||
Returns:
|
||||
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
||||
@@ -126,41 +121,35 @@ def _batch_space_custom(space, n=1):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def iterate(space, items):
|
||||
def iterate(space: Space, items) -> Iterator:
|
||||
"""Iterate over the elements of a (batched) space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Space to which `items` belong to.
|
||||
Example::
|
||||
|
||||
items : samples of `space`
|
||||
Items to be iterated over.
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
|
||||
>>> items = space.sample()
|
||||
>>> it = iterate(space, items)
|
||||
>>> next(it)
|
||||
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
|
||||
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
|
||||
>>> next(it)
|
||||
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
|
||||
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
|
||||
>>> next(it)
|
||||
StopIteration
|
||||
|
||||
Returns
|
||||
-------
|
||||
iterator : `Iterable` instance
|
||||
Args:
|
||||
space: Space to which `items` belong to.
|
||||
items: Items to be iterated over.
|
||||
|
||||
Returns:
|
||||
Iterator over the elements in `items`.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
|
||||
>>> items = space.sample()
|
||||
>>> it = iterate(space, items)
|
||||
>>> next(it)
|
||||
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
|
||||
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
|
||||
>>> next(it)
|
||||
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
|
||||
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
|
||||
>>> next(it)
|
||||
StopIteration
|
||||
"""
|
||||
raise ValueError(
|
||||
"Space of type `{}` is not a valid `gym.Space` " "instance.".format(type(space))
|
||||
f"Space of type `{type(space)}` is not a valid `gym.Space` " "instance."
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,4 +1,7 @@
|
||||
from typing import List, Optional, Union
|
||||
"""Base class for vectorized environments."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import gym
|
||||
from gym.logger import deprecation
|
||||
@@ -8,32 +11,28 @@ __all__ = ["VectorEnv"]
|
||||
|
||||
|
||||
class VectorEnv(gym.Env):
|
||||
r"""Base class for vectorized environments. Runs multiple independent copies of the
|
||||
same environment in parallel. This is not the same as 1 environment that has multiple
|
||||
sub components, but it is many copies of the same base env.
|
||||
"""Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel.
|
||||
|
||||
Each observation returned from vectorized environment is a batch of observations
|
||||
for each parallel environment. And :meth:`step` is also expected to receive a batch of
|
||||
actions for each parallel environment.
|
||||
This is not the same as 1 environment that has multiple subcomponents, but it is many copies of the same base env.
|
||||
|
||||
.. note::
|
||||
Each observation returned from vectorized environment is a batch of observations for each parallel environment.
|
||||
And :meth:`step` is also expected to receive a batch of actions for each parallel environment.
|
||||
|
||||
Notes:
|
||||
All parallel environments should share the identical observation and action spaces.
|
||||
In other words, a vector of multiple different environments is not supported.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_envs : int
|
||||
Number of environments in the vectorized environment.
|
||||
|
||||
observation_space : :class:`gym.spaces.Space`
|
||||
Observation space of a single environment.
|
||||
|
||||
action_space : :class:`gym.spaces.Space`
|
||||
Action space of a single environment.
|
||||
"""
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
def __init__(
|
||||
self, num_envs: int, observation_space: gym.Space, action_space: gym.Space
|
||||
):
|
||||
"""Base class for vectorized environments.
|
||||
|
||||
Args:
|
||||
num_envs: Number of environments in the vectorized environment.
|
||||
observation_space: Observation space of a single environment.
|
||||
action_space: Action space of a single environment.
|
||||
"""
|
||||
self.num_envs = num_envs
|
||||
self.is_vector_env = True
|
||||
self.observation_space = batch_space(observation_space, n=num_envs)
|
||||
@@ -49,141 +48,134 @@ class VectorEnv(gym.Env):
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
seed: Optional[Union[int, list[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Reset the sub-environments asynchronously.
|
||||
|
||||
This method will return ``None``. A call to :meth:`reset_async` should be followed by a call to :meth:`reset_wait` to retrieve the results.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_wait(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
seed: Optional[Union[int, list[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Retrieves the results of a :meth:`reset_async` call.
|
||||
|
||||
A call to this method must always be preceded by a call to :meth:`reset_async`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
seed: Optional[Union[int, list[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
r"""Reset all parallel environments and return a batch of initial observations.
|
||||
"""Reset all parallel environments and return a batch of initial observations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : element of :attr:`observation_space`
|
||||
Args:
|
||||
seed: The environment reset seeds
|
||||
return_info: If to return the info
|
||||
options: If to return the options
|
||||
|
||||
Returns:
|
||||
A batch of observations from the vectorized environment.
|
||||
"""
|
||||
self.reset_async(seed=seed, return_info=return_info, options=options)
|
||||
return self.reset_wait(seed=seed, return_info=return_info, options=options)
|
||||
|
||||
def step_async(self, actions):
|
||||
"""Asynchronously performs steps in the sub-environments.
|
||||
|
||||
The results can be retrieved via a call to :meth:`step_wait`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def step_wait(self, **kwargs):
|
||||
"""Retrieves the results of a :meth:`step_async` call.
|
||||
|
||||
A call to this method must always be preceded by a call to :meth:`step_async`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def step(self, actions):
|
||||
r"""Take an action for each parallel environment.
|
||||
"""Take an action for each parallel environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
actions : element of :attr:`action_space`
|
||||
Batch of actions.
|
||||
Args:
|
||||
actions: element of :attr:`action_space` Batch of actions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : element of :attr:`observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
|
||||
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
|
||||
A vector of rewards from the vectorized environment.
|
||||
|
||||
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
|
||||
A vector whose entries indicate whether the episode has ended.
|
||||
|
||||
infos : list of dict
|
||||
A list of auxiliary diagnostic information dicts from each parallel environment.
|
||||
Returns:
|
||||
Batch of observations, rewards, done and infos
|
||||
"""
|
||||
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def call_async(self, name, *args, **kwargs):
|
||||
"""Calls a method name for each parallel environment asynchronously."""
|
||||
pass
|
||||
|
||||
def call_wait(self, **kwargs):
|
||||
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
def call(self, name: str, *args, **kwargs) -> list[Any]:
|
||||
"""Call a method, or get a property, from each parallel environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the method or property to call.
|
||||
Args:
|
||||
name (str): Name of the method or property to call.
|
||||
*args: Arguments to apply to the method call.
|
||||
**kwargs: Keyword arguments to apply to the method call.
|
||||
|
||||
*args
|
||||
Arguments to apply to the method call.
|
||||
|
||||
**kwargs
|
||||
Keywoard arguments to apply to the method call.
|
||||
|
||||
Returns
|
||||
-------
|
||||
results : list
|
||||
List of the results of the individual calls to the method or
|
||||
property for each environment.
|
||||
Returns:
|
||||
List of the results of the individual calls to the method or property for each environment.
|
||||
"""
|
||||
self.call_async(name, *args, **kwargs)
|
||||
return self.call_wait()
|
||||
|
||||
def get_attr(self, name):
|
||||
def get_attr(self, name: str):
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the property to be get from each individual environment.
|
||||
Args:
|
||||
name (str): Name of the property to be get from each individual environment.
|
||||
|
||||
Returns:
|
||||
The property with name
|
||||
"""
|
||||
return self.call(name)
|
||||
|
||||
def set_attr(self, name, values):
|
||||
"""Set a property in each parallel environment.
|
||||
def set_attr(self, name: str, values: Union[list, tuple, object]):
|
||||
"""Set a property in each sub-environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the property to be set in each individual environment.
|
||||
|
||||
values : list, tuple, or object
|
||||
Values of the property to be set to. If `values` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise a single value is set for all environments.
|
||||
Args:
|
||||
name (str): Name of the property to be set in each individual environment.
|
||||
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
|
||||
tuple, then it corresponds to the values for each individual environment, otherwise a single value
|
||||
is set for all environments.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
r"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||
pass
|
||||
|
||||
def close(self, **kwargs):
|
||||
r"""Close all parallel environments and release resources.
|
||||
"""Close all parallel environments and release resources.
|
||||
|
||||
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
|
||||
:attr:`closed` as ``True``.
|
||||
|
||||
.. warning::
|
||||
|
||||
Warnings:
|
||||
This function itself does not close the environments, it should be handled
|
||||
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
|
||||
vectorized environments.
|
||||
|
||||
.. note::
|
||||
|
||||
Notes:
|
||||
This will be automatically called when garbage collected or program exited.
|
||||
|
||||
"""
|
||||
@@ -197,14 +189,12 @@ class VectorEnv(gym.Env):
|
||||
def seed(self, seed=None):
|
||||
"""Set the random seed in all parallel environments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seed : list of int, or int, optional
|
||||
Random seed for each parallel environment. If ``seed`` is a list of
|
||||
length ``num_envs``, then the items of the list are chosen as random
|
||||
seeds. If ``seed`` is an int, then each parallel environment uses the random
|
||||
seed ``seed + n``, where ``n`` is the index of the parallel environment
|
||||
(between ``0`` and ``num_envs - 1``).
|
||||
Args:
|
||||
seed: Random seed for each parallel environment. If ``seed`` is a list of
|
||||
length ``num_envs``, then the items of the list are chosen as random
|
||||
seeds. If ``seed`` is an int, then each parallel environment uses the random
|
||||
seed ``seed + n``, where ``n`` is the index of the parallel environment
|
||||
(between ``0`` and ``num_envs - 1``).
|
||||
"""
|
||||
deprecation(
|
||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||
@@ -212,10 +202,12 @@ class VectorEnv(gym.Env):
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""Closes the vector environment."""
|
||||
if not getattr(self, "closed", True):
|
||||
self.close()
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns a string representation of the vector environment using the class name, number of environments and environment spec id."""
|
||||
if self.spec is None:
|
||||
return f"{self.__class__.__name__}({self.num_envs})"
|
||||
else:
|
||||
@@ -223,19 +215,17 @@ class VectorEnv(gym.Env):
|
||||
|
||||
|
||||
class VectorEnvWrapper(VectorEnv):
|
||||
r"""Wraps the vectorized environment to allow a modular transformation.
|
||||
"""Wraps the vectorized environment to allow a modular transformation.
|
||||
|
||||
This class is the base class for all wrappers for vectorized environments. The subclass
|
||||
could override some methods to change the behavior of the original vectorized environment
|
||||
without touching the original code.
|
||||
|
||||
.. note::
|
||||
|
||||
Notes:
|
||||
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
def __init__(self, env: VectorEnv):
|
||||
assert isinstance(env, VectorEnv)
|
||||
self.env = env
|
||||
|
||||
|
2
setup.py
2
setup.py
@@ -17,7 +17,7 @@ extras = {
|
||||
"classic_control": ["pygame==2.1.0"],
|
||||
"mujoco": ["mujoco_py>=1.50, <2.0"],
|
||||
"toy_text": ["pygame==2.1.0", "scipy>=1.4.1"],
|
||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0"],
|
||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0"],
|
||||
}
|
||||
|
||||
# Meta dependency groups.
|
||||
|
Reference in New Issue
Block a user