mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Pydocstyle utils vector docstring (#2788)
* Added pydocstyle to pre-commit * Added docstrings for tests and updated the tests for autoreset * Add pydocstyle exclude folder to allow slowly adding new docstrings * Add docstrings for setup.py and gym/__init__.py, core.py, error.py and logger.py * Check that all unwrapped environment are of a particular wrapper type * Reverted back to import gym.spaces.Space to gym.spaces * Fixed the __init__.py docstring * Fixed autoreset autoreset test * Updated gym __init__.py top docstring * Fix examples in docstrings * Add docstrings and type hints where known to all functions and classes in gym/utils and gym/vector * Remove unnecessary import * Removed "unused error" and make APIerror deprecated at gym 1.0 * Add pydocstyle description to CONTRIBUTING.md * Added docstrings section to CONTRIBUTING.md * Added :meth: and :attr: keywords to docstrings * Added :meth: and :attr: keywords to docstrings * Imported annotations from __future__ to fix python 3.7 * Add __future__ import annotations for python 3.7 * isort * Remove utils and vectors for this PR and spaces for previous PR * Update gym/envs/classic_control/acrobot.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/envs/classic_control/acrobot.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/envs/classic_control/acrobot.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/spaces/dict.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/env_checker.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/ezpickle.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/ezpickle.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Update gym/utils/play.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Pre-commit * Updated docstrings with :meth: * Updated docstrings with :meth: * Update gym/utils/play.py * Update gym/utils/play.py * Update gym/utils/play.py * Apply suggestions from code review Co-authored-by: Markus Krimmel <montcyril@gmail.com> * pre-commit * Update gym/utils/play.py Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Updated fps and zoom parameter docstring * Update play docstring * Apply suggestions from code review Added suggested corrections from @markus28 Co-authored-by: Markus Krimmel <montcyril@gmail.com> * Pre-commit magic * Update the `gym.make` docstring with a warning for `env_checker` * Updated and fixed vector docstrings * Update test names for reflect the project filename style Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
@@ -30,7 +30,7 @@ repos:
|
|||||||
rev: 6.1.1 # pick a git hash / tag to point to
|
rev: 6.1.1 # pick a git hash / tag to point to
|
||||||
hooks:
|
hooks:
|
||||||
- id: pydocstyle
|
- id: pydocstyle
|
||||||
exclude: ^(gym/version.py)|(gym/(envs|utils|vector)/)|(tests/)
|
exclude: ^(gym/version.py)|(gym/envs/)|(tests/)
|
||||||
args:
|
args:
|
||||||
- --source
|
- --source
|
||||||
- --explain
|
- --explain
|
||||||
|
@@ -398,27 +398,26 @@ def rk4(derivs, y0, t):
|
|||||||
yourself stranded on a system w/o scipy. Otherwise use
|
yourself stranded on a system w/o scipy. Otherwise use
|
||||||
:func:`scipy.integrate`.
|
:func:`scipy.integrate`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> ### 2D system
|
||||||
|
>>> def derivs(x):
|
||||||
|
... d1 = x[0] + 2*x[1]
|
||||||
|
... d2 = -3*x[0] + 4*x[1]
|
||||||
|
... return (d1, d2)
|
||||||
|
>>> dt = 0.0005
|
||||||
|
>>> t = arange(0.0, 2.0, dt)
|
||||||
|
>>> y0 = (1,2)
|
||||||
|
>>> yout = rk4(derivs, y0, t)
|
||||||
|
|
||||||
|
If you have access to scipy, you should probably be using the
|
||||||
|
:func:`scipy.integrate` tools rather than this function.
|
||||||
|
This would then require re-adding the time variable to the signature of derivs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
derivs: the derivative of the system and has the signature ``dy = derivs(yi)``
|
derivs: the derivative of the system and has the signature ``dy = derivs(yi)``
|
||||||
y0: initial state vector
|
y0: initial state vector
|
||||||
t: sample times
|
t: sample times
|
||||||
args: additional arguments passed to the derivative function
|
|
||||||
kwargs: additional keyword arguments passed to the derivative function
|
|
||||||
|
|
||||||
Example 1 ::
|
|
||||||
### 2D system
|
|
||||||
def derivs(x):
|
|
||||||
d1 = x[0] + 2*x[1]
|
|
||||||
d2 = -3*x[0] + 4*x[1]
|
|
||||||
return (d1, d2)
|
|
||||||
dt = 0.0005
|
|
||||||
t = arange(0.0, 2.0, dt)
|
|
||||||
y0 = (1,2)
|
|
||||||
yout = rk4(derivs6, y0, t)
|
|
||||||
|
|
||||||
If you have access to scipy, you should probably be using the
|
|
||||||
scipy.integrate tools rather than this function.
|
|
||||||
This would then require re-adding the time variable to the signature of derivs.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
yout: Runge-Kutta approximation of the ODE
|
yout: Runge-Kutta approximation of the ODE
|
||||||
|
@@ -499,10 +499,16 @@ def make(
|
|||||||
"""
|
"""
|
||||||
Create an environment according to the given ID.
|
Create an environment according to the given ID.
|
||||||
|
|
||||||
|
Warnings:
|
||||||
|
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
|
||||||
|
This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid
|
||||||
|
if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id: Name of the environment.
|
id: Name of the environment.
|
||||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||||
|
disable_env_checker: If to disable the environment checker
|
||||||
kwargs: Additional arguments to pass to the environment constructor.
|
kwargs: Additional arguments to pass to the environment constructor.
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the environment.
|
An instance of the environment.
|
||||||
|
@@ -17,25 +17,27 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
||||||
|
|
||||||
Example usage::
|
Example usage:
|
||||||
|
|
||||||
>>> observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
|
>>> from gym.spaces import Dict, Discrete
|
||||||
|
>>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||||
>>> observation_space.sample()
|
>>> observation_space.sample()
|
||||||
OrderedDict([('position', 1), ('velocity', 2)])
|
OrderedDict([('position', 1), ('velocity', 2)])
|
||||||
|
|
||||||
Example usage [nested]::
|
Example usage [nested]::
|
||||||
|
|
||||||
>>> spaces.Dict(
|
>>> from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete
|
||||||
|
>>> Dict(
|
||||||
... {
|
... {
|
||||||
... "ext_controller": spaces.MultiDiscrete((5, 2, 2)),
|
... "ext_controller": MultiDiscrete([5, 2, 2]),
|
||||||
... "inner_state": spaces.Dict(
|
... "inner_state": Dict(
|
||||||
... {
|
... {
|
||||||
... "charge": spaces.Discrete(100),
|
... "charge": Discrete(100),
|
||||||
... "system_checks": spaces.MultiBinary(10),
|
... "system_checks": MultiBinary(10),
|
||||||
... "job_status": spaces.Dict(
|
... "job_status": Dict(
|
||||||
... {
|
... {
|
||||||
... "task": spaces.Discrete(5),
|
... "task": Discrete(5),
|
||||||
... "progress": spaces.Box(low=0, high=100, shape=()),
|
... "progress": Box(low=0, high=100, shape=()),
|
||||||
... }
|
... }
|
||||||
... ),
|
... ),
|
||||||
... }
|
... }
|
||||||
@@ -63,9 +65,10 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> spaces.Dict({"position": spaces.Box(-1, 1, shape=(2,)), "color": spaces.Discrete(3)})
|
>>> from gym.spaces import Box, Discrete
|
||||||
|
>>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)})
|
||||||
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
||||||
>>> spaces.Dict(position=spaces.Box(-1, 1, shape=(2,)), color=spaces.Discrete(3))
|
>>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3))
|
||||||
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -16,11 +16,11 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
|
|
||||||
Example Usage::
|
Example Usage::
|
||||||
|
|
||||||
>>> self.observation_space = spaces.MultiBinary(5)
|
>>> observation_space = MultiBinary(5)
|
||||||
>>> self.observation_space.sample()
|
>>> observation_space.sample()
|
||||||
array([0, 1, 0, 1, 0], dtype=int8)
|
array([0, 1, 0, 1, 0], dtype=int8)
|
||||||
>>> self.observation_space = spaces.MultiBinary([3, 2])
|
>>> observation_space = MultiBinary([3, 2])
|
||||||
>>> self.observation_space.sample()
|
>>> observation_space.sample()
|
||||||
array([[0, 0],
|
array([[0, 0],
|
||||||
[0, 1],
|
[0, 1],
|
||||||
[1, 1]], dtype=int8)
|
[1, 1]], dtype=int8)
|
||||||
|
@@ -16,8 +16,9 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
>> observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Box(-1, 1, shape=(2,))))
|
>>> from gym.spaces import Box, Discrete
|
||||||
>> observation_space.sample()
|
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))))
|
||||||
|
>>> observation_space.sample()
|
||||||
(0, array([0.03633198, 0.42370757], dtype=float32))
|
(0, array([0.03633198, 0.42370757], dtype=float32))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@@ -25,8 +25,9 @@ def flatdim(space: Space) -> int:
|
|||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
>>> s = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
|
>>> from gym.spaces import Discrete
|
||||||
>>> spaces.flatdim(s)
|
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||||
|
>>> flatdim(space)
|
||||||
5
|
5
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||||
@@ -195,8 +196,7 @@ def flatten_space(space: Space) -> Box:
|
|||||||
|
|
||||||
Example that recursively flattens a dict::
|
Example that recursively flattens a dict::
|
||||||
|
|
||||||
>>> space = Dict({"position": Discrete(2),
|
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
|
||||||
... "velocity": Box(0, 1, shape=(2, 2))})
|
|
||||||
>>> flatten_space(space)
|
>>> flatten_space(space)
|
||||||
Box(6,)
|
Box(6,)
|
||||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
"""A set of common utilities used within the environments. These are
|
"""A set of common utilities used within the environments.
|
||||||
not intended as API functions, and will not remain stable over time.
|
|
||||||
|
These are not intended as API functions, and will not remain stable over time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
color2num = dict(
|
color2num = dict(
|
||||||
@@ -15,12 +16,20 @@ color2num = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def colorize(string, color, bold=False, highlight=False):
|
def colorize(
|
||||||
"""Return string surrounded by appropriate terminal color codes to
|
string: str, color: str, bold: bool = False, highlight: bool = False
|
||||||
print colorized text. Valid colors: gray, red, green, yellow,
|
) -> str:
|
||||||
blue, magenta, cyan, white, crimson
|
"""Returns string surrounded by appropriate terminal colour codes to print colourised text.
|
||||||
"""
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
string: The message to colourise
|
||||||
|
color: Literal values are gray, red, green, yellow, blue, magenta, cyan, white, crimson
|
||||||
|
bold: If to bold the string
|
||||||
|
highlight: If to highlight the string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Colourised string
|
||||||
|
"""
|
||||||
attr = []
|
attr = []
|
||||||
num = color2num[color]
|
num = color2num[color]
|
||||||
if highlight:
|
if highlight:
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
"""
|
"""A set of functions for checking an environment details.
|
||||||
|
|
||||||
This file is originally from the Stable Baselines3 repository hosted on GitHub
|
This file is originally from the Stable Baselines3 repository hosted on GitHub
|
||||||
(https://github.com/DLR-RM/stable-baselines3/)
|
(https://github.com/DLR-RM/stable-baselines3/)
|
||||||
Original Author: Antonin Raffin
|
Original Author: Antonin Raffin
|
||||||
@@ -16,21 +17,33 @@ from typing import Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import logger, spaces
|
from gym import logger
|
||||||
|
from gym.spaces import Box, Dict, Discrete, Space, Tuple
|
||||||
|
|
||||||
|
|
||||||
def _is_numpy_array_space(space: spaces.Space) -> bool:
|
def _is_numpy_array_space(space: Space) -> bool:
|
||||||
|
"""Checks if a space can be represented as a single numpy array (e.g. Dict and Tuple spaces return False).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: The space to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns False if the provided space is not representable as a single numpy array
|
||||||
"""
|
"""
|
||||||
Returns False if provided space is not representable as a single numpy array
|
return not isinstance(space, (Dict, Tuple))
|
||||||
(e.g. Dict and Tuple spaces return False)
|
|
||||||
"""
|
|
||||||
return not isinstance(space, (spaces.Dict, spaces.Tuple))
|
|
||||||
|
|
||||||
|
|
||||||
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
def _check_image_input(observation_space: Box, key: str = ""):
|
||||||
"""
|
"""Check whether an observation space of type :class:`Box` adheres to general standards for spaces that represent images.
|
||||||
Check that the input adheres to general standards
|
|
||||||
when the observation is apparently an image.
|
It will check that:
|
||||||
|
- The datatype is ``np.uint8``
|
||||||
|
- The lower bound is 0 across all dimensions
|
||||||
|
- The upper bound is 255 across all dimensions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation_space: The observation space to check
|
||||||
|
key: The observation shape key for warning
|
||||||
"""
|
"""
|
||||||
if observation_space.dtype != np.uint8:
|
if observation_space.dtype != np.uint8:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
@@ -49,8 +62,13 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_nan(env: gym.Env, check_inf: bool = True) -> None:
|
def _check_nan(env: gym.Env, check_inf: bool = True):
|
||||||
"""Check for NaN and Inf."""
|
"""Check if the environment observation, reward are NaN and Inf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to check
|
||||||
|
check_inf: Checks if the observation is infinity
|
||||||
|
"""
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
observation, reward, done, _ = env.step(action)
|
observation, reward, done, _ = env.step(action)
|
||||||
@@ -70,19 +88,22 @@ def _check_nan(env: gym.Env, check_inf: bool = True) -> None:
|
|||||||
|
|
||||||
def _check_obs(
|
def _check_obs(
|
||||||
obs: Union[tuple, dict, np.ndarray, int],
|
obs: Union[tuple, dict, np.ndarray, int],
|
||||||
observation_space: spaces.Space,
|
observation_space: Space,
|
||||||
method_name: str,
|
method_name: str,
|
||||||
) -> None:
|
):
|
||||||
|
"""Check that the observation returned by the environment correspond to the declared one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: The observation to check
|
||||||
|
observation_space: The observation space of the observation
|
||||||
|
method_name: The method name that generated the observation
|
||||||
"""
|
"""
|
||||||
Check that the observation returned by the environment
|
if not isinstance(observation_space, Tuple):
|
||||||
correspond to the declared one.
|
|
||||||
"""
|
|
||||||
if not isinstance(observation_space, spaces.Tuple):
|
|
||||||
assert not isinstance(
|
assert not isinstance(
|
||||||
obs, tuple
|
obs, tuple
|
||||||
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
|
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
|
||||||
|
|
||||||
if isinstance(observation_space, spaces.Discrete):
|
if isinstance(observation_space, Discrete):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
obs, int
|
obs, int
|
||||||
), f"The observation returned by `{method_name}()` method must be an int"
|
), f"The observation returned by `{method_name}()` method must be an int"
|
||||||
@@ -96,12 +117,16 @@ def _check_obs(
|
|||||||
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
|
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
|
||||||
|
|
||||||
|
|
||||||
def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
|
def _check_box_obs(observation_space: Box, key: str = ""):
|
||||||
"""
|
"""Check that the observation space is correctly formatted when dealing with a :class:`Box` space.
|
||||||
Check that the observation space is correctly formatted
|
|
||||||
when dealing with a ``Box()`` space. In particular, it checks:
|
In particular, it checks:
|
||||||
- that the dimensions are big enough when it is an image, and that the type matches
|
- that the dimensions are big enough when it is an image, and that the type matches
|
||||||
- that the observation has an expected shape (warn the user if not)
|
- that the observation has an expected shape (warn the user if not)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation_space: Checks if the Box observation space
|
||||||
|
key: The observation key
|
||||||
"""
|
"""
|
||||||
# If image, check the low and high values, the type and the number of channels
|
# If image, check the low and high values, the type and the number of channels
|
||||||
# and the shape (minimal value)
|
# and the shape (minimal value)
|
||||||
@@ -137,14 +162,19 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
|
|||||||
), "Agent's observation_space.high and observation_space have different shapes"
|
), "Agent's observation_space.high and observation_space have different shapes"
|
||||||
|
|
||||||
|
|
||||||
def _check_box_action(action_space: spaces.Box):
|
def _check_box_action(action_space: Box):
|
||||||
|
"""Checks that a :class:`Box` action space is defined in a sensible way.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_space: A box action space
|
||||||
|
"""
|
||||||
if np.any(np.equal(action_space.low, -np.inf)):
|
if np.any(np.equal(action_space.low, -np.inf)):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Agent's minimum action space value is -infinity. This is probably too low."
|
"Agent's minimum action space value is -infinity. This is probably too low."
|
||||||
)
|
)
|
||||||
if np.any(np.equal(action_space.high, np.inf)):
|
if np.any(np.equal(action_space.high, np.inf)):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Agent's maxmimum action space value is infinity. This is probably too high"
|
"Agent's maximum action space value is infinity. This is probably too high"
|
||||||
)
|
)
|
||||||
if np.any(np.equal(action_space.low, action_space.high)):
|
if np.any(np.equal(action_space.low, action_space.high)):
|
||||||
logger.warn("Agent's maximum and minimum action space values are equal")
|
logger.warn("Agent's maximum and minimum action space values are equal")
|
||||||
@@ -156,7 +186,12 @@ def _check_box_action(action_space: spaces.Box):
|
|||||||
assert False, "Agent's action_space.high and action_space have different shapes"
|
assert False, "Agent's action_space.high and action_space have different shapes"
|
||||||
|
|
||||||
|
|
||||||
def _check_normalized_action(action_space: spaces.Box):
|
def _check_normalized_action(action_space: Box):
|
||||||
|
"""Checks that a box action space is normalized.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_space: A box action space
|
||||||
|
"""
|
||||||
if (
|
if (
|
||||||
np.any(np.abs(action_space.low) != np.abs(action_space.high))
|
np.any(np.abs(action_space.low) != np.abs(action_space.high))
|
||||||
or np.any(np.abs(action_space.low) > 1)
|
or np.any(np.abs(action_space.low) > 1)
|
||||||
@@ -168,16 +203,18 @@ def _check_normalized_action(action_space: spaces.Box):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_returned_values(
|
def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space):
|
||||||
env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space
|
"""Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods.
|
||||||
) -> None:
|
|
||||||
"""
|
Args:
|
||||||
Check the returned values by the env when calling `.reset()` or `.step()` methods.
|
env: The environment
|
||||||
|
observation_space: The environment's observation space
|
||||||
|
action_space: The environment's action space
|
||||||
"""
|
"""
|
||||||
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
|
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
if isinstance(observation_space, spaces.Dict):
|
if isinstance(observation_space, Dict):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
obs, dict
|
obs, dict
|
||||||
), "The observation returned by `reset()` must be a dictionary"
|
), "The observation returned by `reset()` must be a dictionary"
|
||||||
@@ -200,7 +237,7 @@ def _check_returned_values(
|
|||||||
# Unpack
|
# Unpack
|
||||||
obs, reward, done, info = data
|
obs, reward, done, info = data
|
||||||
|
|
||||||
if isinstance(observation_space, spaces.Dict):
|
if isinstance(observation_space, Dict):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
obs, dict
|
obs, dict
|
||||||
), "The observation returned by `step()` must be a dictionary"
|
), "The observation returned by `step()` must be a dictionary"
|
||||||
@@ -223,10 +260,11 @@ def _check_returned_values(
|
|||||||
), "The `info` returned by `step()` must be a python dictionary"
|
), "The `info` returned by `step()` must be a python dictionary"
|
||||||
|
|
||||||
|
|
||||||
def _check_spaces(env: gym.Env) -> None:
|
def _check_spaces(env: gym.Env):
|
||||||
"""
|
"""Check that the observation and action spaces are defined and inherit from :class:`gym.spaces.Space`.
|
||||||
Check that the observation and action spaces are defined
|
|
||||||
and inherit from gym.spaces.Space.
|
Args:
|
||||||
|
env: The environment's observation and action space to check
|
||||||
"""
|
"""
|
||||||
# Helper to link to the code, because gym has no proper documentation
|
# Helper to link to the code, because gym has no proper documentation
|
||||||
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
|
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
|
||||||
@@ -238,25 +276,22 @@ def _check_spaces(env: gym.Env) -> None:
|
|||||||
"You must specify an action space (cf gym.spaces)" + gym_spaces
|
"You must specify an action space (cf gym.spaces)" + gym_spaces
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(env.observation_space, spaces.Space), (
|
assert isinstance(env.observation_space, Space), (
|
||||||
"The observation space must inherit from gym.spaces" + gym_spaces
|
"The observation space must inherit from gym.spaces" + gym_spaces
|
||||||
)
|
)
|
||||||
assert isinstance(env.action_space, spaces.Space), (
|
assert isinstance(env.action_space, Space), (
|
||||||
"The action space must inherit from gym.spaces" + gym_spaces
|
"The action space must inherit from gym.spaces" + gym_spaces
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Check render cannot be covered by CI
|
# Check render cannot be covered by CI
|
||||||
def _check_render(
|
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False):
|
||||||
env: gym.Env, warn: bool = True, headless: bool = False
|
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
|
||||||
) -> None: # pragma: no cover
|
|
||||||
"""
|
Args:
|
||||||
Check the declared render modes/fps and the `render()`/`close()`
|
env: The environment to check
|
||||||
method of the environment.
|
warn: Whether to output additional warnings
|
||||||
:param env: The environment to check
|
headless: Whether to disable render modes that require a graphical interface. False by default.
|
||||||
:param warn: Whether to output additional warnings
|
|
||||||
:param headless: Whether to disable render modes
|
|
||||||
that require a graphical interface. False by default.
|
|
||||||
"""
|
"""
|
||||||
render_modes = env.metadata.get("render_modes")
|
render_modes = env.metadata.get("render_modes")
|
||||||
if render_modes is None:
|
if render_modes is None:
|
||||||
@@ -288,9 +323,12 @@ def _check_render(
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
|
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
|
||||||
"""
|
"""Check that the environment can be reset with a seed.
|
||||||
Check that the environment can be reset with a random seed.
|
|
||||||
|
Args:
|
||||||
|
env: The environment to check
|
||||||
|
seed: The optional seed to use
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
assert (
|
assert (
|
||||||
@@ -303,7 +341,7 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
|
|||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` "
|
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` "
|
||||||
"appear in the signature. This should never happen, please report this issue. "
|
"appear in the signature. This should never happen, please report this issue. "
|
||||||
"The error was: " + str(e)
|
f"The error was: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if env.unwrapped.np_random is None:
|
if env.unwrapped.np_random is None:
|
||||||
@@ -322,7 +360,12 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_reset_info(env: gym.Env) -> None:
|
def _check_reset_info(env: gym.Env):
|
||||||
|
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to check
|
||||||
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
assert (
|
assert (
|
||||||
"return_info" in signature.parameters or "kwargs" in signature.parameters
|
"return_info" in signature.parameters or "kwargs" in signature.parameters
|
||||||
@@ -334,7 +377,7 @@ def _check_reset_info(env: gym.Env) -> None:
|
|||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
|
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
|
||||||
"appear in the signature. This should never happen, please report this issue. "
|
"appear in the signature. This should never happen, please report this issue. "
|
||||||
"The error was: " + str(e)
|
f"The error was: {e}"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
len(result) == 2
|
len(result) == 2
|
||||||
@@ -346,9 +389,11 @@ def _check_reset_info(env: gym.Env) -> None:
|
|||||||
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
|
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
|
||||||
|
|
||||||
|
|
||||||
def _check_reset_options(env: gym.Env) -> None:
|
def _check_reset_options(env: gym.Env):
|
||||||
"""
|
"""Check that the environment can be reset with options.
|
||||||
Check that the environment can be reset with options.
|
|
||||||
|
Args:
|
||||||
|
env: The environment to check
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
assert (
|
assert (
|
||||||
@@ -361,22 +406,22 @@ def _check_reset_options(env: gym.Env) -> None:
|
|||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"The environment cannot be reset with options, even though `options` or `kwargs` "
|
"The environment cannot be reset with options, even though `options` or `kwargs` "
|
||||||
"appear in the signature. This should never happen, please report this issue. "
|
"appear in the signature. This should never happen, please report this issue. "
|
||||||
"The error was: " + str(e)
|
f"The error was: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
|
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True):
|
||||||
"""
|
"""Check that an environment follows Gym API.
|
||||||
Check that an environment follows Gym API.
|
|
||||||
This is particularly useful when using a custom environment.
|
This is particularly useful when using a custom environment.
|
||||||
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
|
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
|
||||||
for more information about the API.
|
for more information about the API.
|
||||||
It also optionally check that the environment is compatible with Stable-Baselines.
|
It also optionally check that the environment is compatible with Stable-Baselines.
|
||||||
:param env: The Gym environment that will be checked
|
|
||||||
:param warn: Whether to output additional warnings
|
Args:
|
||||||
mainly related to the interaction with Stable Baselines
|
env: The Gym environment that will be checked
|
||||||
:param skip_render_check: Whether to skip the checks for the render method.
|
warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines
|
||||||
True by default (useful for the CI)
|
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
env, gym.Env
|
env, gym.Env
|
||||||
@@ -393,15 +438,15 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
|
|||||||
if warn:
|
if warn:
|
||||||
obs_spaces = (
|
obs_spaces = (
|
||||||
observation_space.spaces
|
observation_space.spaces
|
||||||
if isinstance(observation_space, spaces.Dict)
|
if isinstance(observation_space, Dict)
|
||||||
else {"": observation_space}
|
else {"": observation_space}
|
||||||
)
|
)
|
||||||
for key, space in obs_spaces.items():
|
for key, space in obs_spaces.items():
|
||||||
if isinstance(space, spaces.Box):
|
if isinstance(space, Box):
|
||||||
_check_box_obs(space, key)
|
_check_box_obs(space, key)
|
||||||
|
|
||||||
# Check for the action space, it may lead to hard-to-debug issues
|
# Check for the action space, it may lead to hard-to-debug issues
|
||||||
if isinstance(action_space, spaces.Box):
|
if isinstance(action_space, Box):
|
||||||
_check_box_action(action_space)
|
_check_box_action(action_space)
|
||||||
_check_normalized_action(action_space)
|
_check_normalized_action(action_space)
|
||||||
|
|
||||||
|
@@ -1,33 +1,35 @@
|
|||||||
|
"""Class for pickling and unpickling objects via their constructor arguments."""
|
||||||
|
|
||||||
|
|
||||||
class EzPickle:
|
class EzPickle:
|
||||||
"""Objects that are pickled and unpickled via their constructor
|
"""Objects that are pickled and unpickled via their constructor arguments.
|
||||||
arguments.
|
|
||||||
|
|
||||||
Example usage:
|
Example::
|
||||||
|
|
||||||
class Dog(Animal, EzPickle):
|
>>> class Dog(Animal, EzPickle):
|
||||||
def __init__(self, furcolor, tailkind="bushy"):
|
... def __init__(self, furcolor, tailkind="bushy"):
|
||||||
Animal.__init__()
|
... Animal.__init__()
|
||||||
EzPickle.__init__(furcolor, tailkind)
|
... EzPickle.__init__(furcolor, tailkind)
|
||||||
...
|
|
||||||
|
|
||||||
When this object is unpickled, a new Dog will be constructed by passing the provided
|
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
|
||||||
furcolor and tailkind into the constructor. However, philosophers are still not sure
|
However, philosophers are still not sure whether it is still the same dog.
|
||||||
whether it is still the same dog.
|
|
||||||
|
|
||||||
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
|
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
|
||||||
and Atari.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
|
||||||
self._ezpickle_args = args
|
self._ezpickle_args = args
|
||||||
self._ezpickle_kwargs = kwargs
|
self._ezpickle_kwargs = kwargs
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
"""Returns the object pickle state with args and kwargs."""
|
||||||
return {
|
return {
|
||||||
"_ezpickle_args": self._ezpickle_args,
|
"_ezpickle_args": self._ezpickle_args,
|
||||||
"_ezpickle_kwargs": self._ezpickle_kwargs,
|
"_ezpickle_kwargs": self._ezpickle_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
|
"""Sets the object pickle state using d."""
|
||||||
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
|
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
|
||||||
self.__dict__.update(out.__dict__)
|
self.__dict__.update(out.__dict__)
|
||||||
|
@@ -1,11 +1,18 @@
|
|||||||
|
"""Utilities of visualising an environment."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
from typing import Callable, Dict, Optional, Tuple, Union
|
from typing import Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pygame
|
import pygame
|
||||||
from numpy.typing import NDArray
|
|
||||||
from pygame import Surface
|
from pygame import Surface
|
||||||
from pygame.event import Event
|
from pygame.event import Event
|
||||||
|
from pygame.locals import VIDEORESIZE
|
||||||
|
|
||||||
from gym import Env, logger
|
from gym import Env, logger
|
||||||
|
from gym.core import ActType, ObsType
|
||||||
|
from gym.error import DependencyNotInstalled
|
||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -13,30 +20,31 @@ try:
|
|||||||
|
|
||||||
matplotlib.use("TkAgg")
|
matplotlib.use("TkAgg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
logger.warn(f"failed to set matplotlib backend, plotting will not work: {str(e)}")
|
logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
|
||||||
plt = None
|
matplotlib, plt = None, None
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from pygame.locals import VIDEORESIZE
|
|
||||||
|
|
||||||
from gym.core import ActType
|
|
||||||
|
|
||||||
|
|
||||||
class MissingKeysToAction(Exception):
|
class MissingKeysToAction(Exception):
|
||||||
"""Raised when the environment does not have
|
"""Raised when the environment does not have a default ``keys_to_action`` mapping."""
|
||||||
a default keys_to_action mapping
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class PlayableGame:
|
class PlayableGame:
|
||||||
|
"""Wraps an environment allowing keyboard inputs to interact with the environment."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: Env,
|
env: Env,
|
||||||
keys_to_action: Optional[Dict[Tuple[int], int]] = None,
|
keys_to_action: Optional[dict[tuple[int], int]] = None,
|
||||||
zoom: Optional[float] = None,
|
zoom: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to play
|
||||||
|
keys_to_action: The dictionary of keyboard tuples and action value
|
||||||
|
zoom: If to zoom in on the environment render
|
||||||
|
"""
|
||||||
self.env = env
|
self.env = env
|
||||||
self.relevant_keys = self._get_relevant_keys(keys_to_action)
|
self.relevant_keys = self._get_relevant_keys(keys_to_action)
|
||||||
self.video_size = self._get_video_size(zoom)
|
self.video_size = self._get_video_size(zoom)
|
||||||
@@ -45,7 +53,7 @@ class PlayableGame:
|
|||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
def _get_relevant_keys(
|
def _get_relevant_keys(
|
||||||
self, keys_to_action: Optional[Dict[Tuple[int], int]] = None
|
self, keys_to_action: Optional[dict[tuple[int], int]] = None
|
||||||
) -> set:
|
) -> set:
|
||||||
if keys_to_action is None:
|
if keys_to_action is None:
|
||||||
if hasattr(self.env, "get_keys_to_action"):
|
if hasattr(self.env, "get_keys_to_action"):
|
||||||
@@ -60,7 +68,7 @@ class PlayableGame:
|
|||||||
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
||||||
return relevant_keys
|
return relevant_keys
|
||||||
|
|
||||||
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
def _get_video_size(self, zoom: Optional[float] = None) -> tuple[int, int]:
|
||||||
# TODO: this needs to be updated when the render API change goes through
|
# TODO: this needs to be updated when the render API change goes through
|
||||||
rendered = self.env.render(mode="rgb_array")
|
rendered = self.env.render(mode="rgb_array")
|
||||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||||
@@ -70,7 +78,14 @@ class PlayableGame:
|
|||||||
|
|
||||||
return video_size
|
return video_size
|
||||||
|
|
||||||
def process_event(self, event: Event) -> None:
|
def process_event(self, event: Event):
|
||||||
|
"""Processes a PyGame event.
|
||||||
|
|
||||||
|
In particular, this function is used to keep track of which buttons are currently pressed and to exit the :func:`play` function when the PyGame window is closed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event to process
|
||||||
|
"""
|
||||||
if event.type == pygame.KEYDOWN:
|
if event.type == pygame.KEYDOWN:
|
||||||
if event.key in self.relevant_keys:
|
if event.key in self.relevant_keys:
|
||||||
self.pressed_keys.append(event.key)
|
self.pressed_keys.append(event.key)
|
||||||
@@ -87,9 +102,17 @@ class PlayableGame:
|
|||||||
|
|
||||||
|
|
||||||
def display_arr(
|
def display_arr(
|
||||||
screen: Surface, arr: NDArray, video_size: Tuple[int, int], transpose: bool
|
screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
|
||||||
):
|
):
|
||||||
arr_min, arr_max = arr.min(), arr.max()
|
"""Displays a numpy array on screen.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
screen: The screen to show the array on
|
||||||
|
arr: The array to show
|
||||||
|
video_size: The video size of the screen
|
||||||
|
transpose: If to transpose the array on the screen
|
||||||
|
"""
|
||||||
|
arr_min, arr_max = np.min(arr), np.max(arr)
|
||||||
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
|
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
|
||||||
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
|
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
|
||||||
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
||||||
@@ -108,60 +131,74 @@ def play(
|
|||||||
):
|
):
|
||||||
"""Allows one to play the game using keyboard.
|
"""Allows one to play the game using keyboard.
|
||||||
|
|
||||||
To simply play the game use:
|
Example::
|
||||||
|
|
||||||
play(gym.make("Pong-v4"))
|
>>> import gym
|
||||||
|
>>> from gym.utils.play import play
|
||||||
|
>>> play(gym.make("CarRacing-v1"), keys_to_action={"w": np.array([0, 0.7, 0]),
|
||||||
|
... "a": np.array([-1, 0, 0]),
|
||||||
|
... "s": np.array([0, 0, 1]),
|
||||||
|
... "d": np.array([1, 0, 0]),
|
||||||
|
... "wa": np.array([-1, 0.7, 0]),
|
||||||
|
... "dw": np.array([1, 0.7, 0]),
|
||||||
|
... "ds": np.array([1, 0, 1]),
|
||||||
|
... "as": np.array([-1, 0, 1]),
|
||||||
|
... }, noop=np.array([0,0,0]))
|
||||||
|
|
||||||
Above code works also if env is wrapped, so it's particularly useful in
|
|
||||||
|
Above code works also if the environment is wrapped, so it's particularly useful in
|
||||||
verifying that the frame-level preprocessing does not render the game
|
verifying that the frame-level preprocessing does not render the game
|
||||||
unplayable.
|
unplayable.
|
||||||
|
|
||||||
If you wish to plot real time statistics as you play, you can use
|
If you wish to plot real time statistics as you play, you can use
|
||||||
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
|
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
|
||||||
for last 5 second of gameplay.
|
for last 150 steps.
|
||||||
|
|
||||||
def callback(obs_t, obs_tp1, action, rew, done, info):
|
>>> def callback(obs_t, obs_tp1, action, rew, done, info):
|
||||||
return [rew,]
|
... return [rew,]
|
||||||
plotter = PlayPlot(callback, 30 * 5, ["reward"])
|
>>> plotter = PlayPlot(callback, 150, ["reward"])
|
||||||
|
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
|
||||||
env = gym.make("Pong-v4")
|
|
||||||
play(env, callback=plotter.callback)
|
|
||||||
|
|
||||||
|
|
||||||
Arguments
|
Args:
|
||||||
---------
|
env: Environment to use for playing.
|
||||||
env: gym.Env
|
transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
|
||||||
Environment to use for playing.
|
fps: Maximum number of steps of the environment executed every second. If ``None`` (the default),
|
||||||
transpose: bool
|
``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
|
||||||
If True the output of observation is transposed.
|
zoom: Zoom the observation in, ``zoom`` amount, should be positive float
|
||||||
Defaults to true.
|
callback: If a callback is provided, it will be executed after every step. It takes the following input:
|
||||||
fps: int
|
|
||||||
Maximum number of steps of the environment to execute every second.
|
|
||||||
Defaults to 30.
|
|
||||||
zoom: float
|
|
||||||
Make screen edge this many times bigger
|
|
||||||
callback: lambda or None
|
|
||||||
Callback if a callback is provided it will be executed after
|
|
||||||
every step. It takes the following input:
|
|
||||||
obs_t: observation before performing action
|
obs_t: observation before performing action
|
||||||
obs_tp1: observation after performing action
|
obs_tp1: observation after performing action
|
||||||
action: action that was executed
|
action: action that was executed
|
||||||
rew: reward that was received
|
rew: reward that was received
|
||||||
done: whether the environment is done or not
|
done: whether the environment is done or not
|
||||||
info: debug info
|
info: debug info
|
||||||
keys_to_action: dict: tuple(int) -> int or None
|
keys_to_action: Mapping from keys pressed to action performed.
|
||||||
Mapping from keys pressed to action performed.
|
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
|
||||||
For example if pressed 'w' and space at the same time is supposed
|
points of the keys, as a tuple of characters, or as a string where each character of the string represents
|
||||||
to trigger action number 2 then key_to_action dict would look like this:
|
one key.
|
||||||
|
For example if pressing 'w' and space at the same time is supposed
|
||||||
{
|
to trigger action number 2 then ``key_to_action`` dict could look like this:
|
||||||
# ...
|
>>> {
|
||||||
sorted(ord('w'), ord(' ')) -> 2
|
... # ...
|
||||||
# ...
|
... (ord('w'), ord(' ')): 2
|
||||||
}
|
... # ...
|
||||||
If None, default key_to_action mapping for that env is used, if provided.
|
... }
|
||||||
seed: bool or None
|
or like this:
|
||||||
Random seed used when resetting the environment. If None, no seed is used.
|
>>> {
|
||||||
|
... # ...
|
||||||
|
... ("w", " "): 2
|
||||||
|
... # ...
|
||||||
|
... }
|
||||||
|
or like this:
|
||||||
|
>>> {
|
||||||
|
... # ...
|
||||||
|
... "w ": 2
|
||||||
|
... # ...
|
||||||
|
... }
|
||||||
|
If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
|
||||||
|
seed: Random seed used when resetting the environment. If None, no seed is used.
|
||||||
|
noop: The action used when no key input has been entered, or the entered key combination is unknown.
|
||||||
"""
|
"""
|
||||||
env.reset(seed=seed)
|
env.reset(seed=seed)
|
||||||
|
|
||||||
@@ -208,7 +245,44 @@ def play(
|
|||||||
|
|
||||||
|
|
||||||
class PlayPlot:
|
class PlayPlot:
|
||||||
def __init__(self, callback, horizon_timesteps, plot_names):
|
"""Provides a callback to create live plots of arbitrary metrics when using :func:`play`.
|
||||||
|
|
||||||
|
This class is instantiated with a function that accepts information about a single environment transition:
|
||||||
|
- obs_t: observation before performing action
|
||||||
|
- obs_tp1: observation after performing action
|
||||||
|
- action: action that was executed
|
||||||
|
- rew: reward that was received
|
||||||
|
- done: whether the environment is done or not
|
||||||
|
- info: debug info
|
||||||
|
|
||||||
|
It should return a list of metrics that are computed from this data.
|
||||||
|
For instance, the function may look like this::
|
||||||
|
|
||||||
|
def compute_metrics(obs_t, obs_tp, action, reward, done, info):
|
||||||
|
return [reward, info["cumulative_reward"], np.linalg.norm(action)]
|
||||||
|
|
||||||
|
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
|
||||||
|
and uses the returned values to update live plots of the metrics.
|
||||||
|
|
||||||
|
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
|
||||||
|
|
||||||
|
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
|
||||||
|
>>> play(your_env, callback=plotter.callback)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, callback: callable, horizon_timesteps: int, plot_names: list[str]
|
||||||
|
):
|
||||||
|
"""Constructor of :class:`PlayPlot`.
|
||||||
|
|
||||||
|
The function ``callback`` that is passed to this constructor should return
|
||||||
|
a list of metrics that is of length ``len(plot_names)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function that computes metrics from environment transitions
|
||||||
|
horizon_timesteps: The time horizon used for the live plots
|
||||||
|
plot_names: List of plot titles
|
||||||
|
"""
|
||||||
deprecation(
|
deprecation(
|
||||||
"`PlayPlot` is marked as deprecated and will be removed in the near future."
|
"`PlayPlot` is marked as deprecated and will be removed in the near future."
|
||||||
)
|
)
|
||||||
@@ -216,7 +290,10 @@ class PlayPlot:
|
|||||||
self.horizon_timesteps = horizon_timesteps
|
self.horizon_timesteps = horizon_timesteps
|
||||||
self.plot_names = plot_names
|
self.plot_names = plot_names
|
||||||
|
|
||||||
assert plt is not None, "matplotlib backend failed, plotting will not work"
|
if plt is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"matplotlib is not installed, run `pip install gym[other]`"
|
||||||
|
)
|
||||||
|
|
||||||
num_plots = len(self.plot_names)
|
num_plots = len(self.plot_names)
|
||||||
self.fig, self.ax = plt.subplots(num_plots)
|
self.fig, self.ax = plt.subplots(num_plots)
|
||||||
@@ -228,7 +305,25 @@ class PlayPlot:
|
|||||||
self.cur_plot = [None for _ in range(num_plots)]
|
self.cur_plot = [None for _ in range(num_plots)]
|
||||||
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
||||||
|
|
||||||
def callback(self, obs_t, obs_tp1, action, rew, done, info):
|
def callback(
|
||||||
|
self,
|
||||||
|
obs_t: ObsType,
|
||||||
|
obs_tp1: ObsType,
|
||||||
|
action: ActType,
|
||||||
|
rew: float,
|
||||||
|
done: bool,
|
||||||
|
info: dict,
|
||||||
|
):
|
||||||
|
"""The callback that calls the provided data callback and adds the data to the plots.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_t: The observation at time step t
|
||||||
|
obs_tp1: The observation at time step t+1
|
||||||
|
action: The action
|
||||||
|
rew: The reward
|
||||||
|
done: If the environment is done
|
||||||
|
info: The information from the environment
|
||||||
|
"""
|
||||||
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
|
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
|
||||||
for point, data_series in zip(points, self.data):
|
for point, data_series in zip(points, self.data):
|
||||||
data_series.append(point)
|
data_series.append(point)
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
|
"""Set of random number generator functions: seeding, generator, hashing seeds."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -9,7 +12,15 @@ from gym import error
|
|||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
|
|
||||||
|
|
||||||
def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]:
|
def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]:
|
||||||
|
"""Generates a random number generator from the seed and returns the Generator and seed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed used to create the generator
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generator and resulting seed
|
||||||
|
"""
|
||||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||||
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
||||||
|
|
||||||
@@ -22,7 +33,10 @@ def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]
|
|||||||
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
|
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
|
||||||
# RandomNumberGenerator = np.random.Generator
|
# RandomNumberGenerator = np.random.Generator
|
||||||
class RandomNumberGenerator(np.random.Generator):
|
class RandomNumberGenerator(np.random.Generator):
|
||||||
|
"""Random number generator class that inherits from numpy's random Generator class."""
|
||||||
|
|
||||||
def rand(self, *size):
|
def rand(self, *size):
|
||||||
|
"""Deprecated rand function using random."""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `rng.rand(*size)` is marked as deprecated "
|
"Function `rng.rand(*size)` is marked as deprecated "
|
||||||
"and will be removed in the future. "
|
"and will be removed in the future. "
|
||||||
@@ -34,6 +48,7 @@ class RandomNumberGenerator(np.random.Generator):
|
|||||||
random_sample = rand
|
random_sample = rand
|
||||||
|
|
||||||
def randn(self, *size):
|
def randn(self, *size):
|
||||||
|
"""Deprecated random standard normal function use standard_normal."""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `rng.randn(*size)` is marked as deprecated "
|
"Function `rng.randn(*size)` is marked as deprecated "
|
||||||
"and will be removed in the future. "
|
"and will be removed in the future. "
|
||||||
@@ -43,6 +58,7 @@ class RandomNumberGenerator(np.random.Generator):
|
|||||||
return self.standard_normal(size)
|
return self.standard_normal(size)
|
||||||
|
|
||||||
def randint(self, low, high=None, size=None, dtype=int):
|
def randint(self, low, high=None, size=None, dtype=int):
|
||||||
|
"""Deprecated random integer function use integers."""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
|
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
|
||||||
"and will be removed in the future. "
|
"and will be removed in the future. "
|
||||||
@@ -54,6 +70,7 @@ class RandomNumberGenerator(np.random.Generator):
|
|||||||
random_integers = randint
|
random_integers = randint
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
|
"""Deprecated get rng state use bit_generator.state."""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `rng.get_state()` is marked as deprecated "
|
"Function `rng.get_state()` is marked as deprecated "
|
||||||
"and will be removed in the future. "
|
"and will be removed in the future. "
|
||||||
@@ -63,6 +80,7 @@ class RandomNumberGenerator(np.random.Generator):
|
|||||||
return self.bit_generator.state
|
return self.bit_generator.state
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
|
"""Deprecated set rng state function use bit_generator.state = state."""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `rng.set_state(state)` is marked as deprecated "
|
"Function `rng.set_state(state)` is marked as deprecated "
|
||||||
"and will be removed in the future. "
|
"and will be removed in the future. "
|
||||||
@@ -72,6 +90,7 @@ class RandomNumberGenerator(np.random.Generator):
|
|||||||
self.bit_generator.state = state
|
self.bit_generator.state = state
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
|
"""Deprecated seed function use gym.utils.seeding.np_random(seed)."""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `rng.seed(seed)` is marked as deprecated "
|
"Function `rng.seed(seed)` is marked as deprecated "
|
||||||
"and will be removed in the future. "
|
"and will be removed in the future. "
|
||||||
@@ -88,6 +107,7 @@ class RandomNumberGenerator(np.random.Generator):
|
|||||||
seed.__doc__ = np.random.seed.__doc__
|
seed.__doc__ = np.random.seed.__doc__
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
|
"""Reduces the Random Number Generator to a RandomNumberGenerator, init_args and additional args."""
|
||||||
# np.random.Generator defines __reduce__, but it's hard-coded to
|
# np.random.Generator defines __reduce__, but it's hard-coded to
|
||||||
# return a Generator instead of its subclass RandomNumberGenerator.
|
# return a Generator instead of its subclass RandomNumberGenerator.
|
||||||
# We need to override it here, otherwise sampling from a Space will
|
# We need to override it here, otherwise sampling from a Space will
|
||||||
@@ -119,20 +139,21 @@ RNG = RandomNumberGenerator
|
|||||||
|
|
||||||
|
|
||||||
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
|
def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
|
||||||
"""Any given evaluation is likely to have many PRNG's active at
|
"""Any given evaluation is likely to have many PRNG's active at once.
|
||||||
once. (Most commonly, because the environment is running in
|
|
||||||
multiple processes.) There's literature indicating that having
|
(Most commonly, because the environment is running in multiple processes.)
|
||||||
linear correlations between seeds of multiple PRNG's can correlate
|
There's literature indicating that having linear correlations between seeds of multiple PRNG's can correlate the outputs:
|
||||||
the outputs:
|
|
||||||
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
||||||
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
||||||
http://dl.acm.org/citation.cfm?id=1276928
|
http://dl.acm.org/citation.cfm?id=1276928
|
||||||
Thus, for sanity we hash the seeds before using them. (This scheme
|
Thus, for sanity we hash the seeds before using them. (This scheme is likely not crypto-strength, but it should be good enough to get rid of simple correlations.)
|
||||||
is likely not crypto-strength, but it should be good enough to get
|
|
||||||
rid of simple correlations.)
|
|
||||||
Args:
|
Args:
|
||||||
seed: None seeds from an operating system specific randomness source.
|
seed: None seeds from an operating system specific randomness source.
|
||||||
max_bytes: Maximum number of bytes to use in the hashed seed.
|
max_bytes: Maximum number of bytes to use in the hashed seed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The hashed seed
|
||||||
"""
|
"""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
|
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||||
@@ -144,12 +165,16 @@ def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
|
def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
|
||||||
"""Create a strong random seed. Otherwise, Python 2 would seed using
|
"""Create a strong random seed.
|
||||||
the system time, which might be non-robust especially in the
|
|
||||||
presence of concurrency.
|
Otherwise, Python 2 would seed using the system time, which might be non-robust especially in the presence of concurrency.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: None seeds from an operating system specific randomness source.
|
a: None seeds from an operating system specific randomness source.
|
||||||
max_bytes: Maximum number of bytes to use in the seed.
|
max_bytes: Maximum number of bytes to use in the seed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A seed
|
||||||
"""
|
"""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||||
@@ -185,7 +210,7 @@ def _bigint_from_bytes(bt: bytes) -> int:
|
|||||||
return accum
|
return accum
|
||||||
|
|
||||||
|
|
||||||
def _int_list_from_bigint(bigint: int) -> List[int]:
|
def _int_list_from_bigint(bigint: int) -> list[int]:
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
||||||
)
|
)
|
||||||
@@ -195,7 +220,7 @@ def _int_list_from_bigint(bigint: int) -> List[int]:
|
|||||||
elif bigint == 0:
|
elif bigint == 0:
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
ints: List[int] = []
|
ints: list[int] = []
|
||||||
while bigint > 0:
|
while bigint > 0:
|
||||||
bigint, mod = divmod(bigint, 2**32)
|
bigint, mod = divmod(bigint, 2**32)
|
||||||
ints.append(mod)
|
ints.append(mod)
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
try:
|
"""Module for vector environments."""
|
||||||
from collections.abc import Iterable
|
from __future__ import annotations
|
||||||
except ImportError:
|
|
||||||
Iterable = (tuple, list)
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
from gym.vector.async_vector_env import AsyncVectorEnv
|
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||||
@@ -10,40 +10,34 @@ from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
|
|||||||
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
|
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
|
||||||
|
|
||||||
|
|
||||||
def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
|
def make(
|
||||||
"""Create a vectorized environment from multiple copies of an environment,
|
id: str,
|
||||||
from its id.
|
num_envs: int = 1,
|
||||||
|
asynchronous: bool = True,
|
||||||
|
wrappers: Optional[Union[callable, list[callable]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> VectorEnv:
|
||||||
|
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||||
|
|
||||||
Parameters
|
Example::
|
||||||
----------
|
|
||||||
id : str
|
|
||||||
The environment ID. This must be a valid ID from the registry.
|
|
||||||
|
|
||||||
num_envs : int
|
>>> import gym
|
||||||
Number of copies of the environment.
|
|
||||||
|
|
||||||
asynchronous : bool
|
|
||||||
If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses
|
|
||||||
`multiprocessing`_ to run the environments in parallel). If ``False``,
|
|
||||||
wraps the environments in a :class:`SyncVectorEnv`.
|
|
||||||
|
|
||||||
wrappers : callable, or iterable of callables, optional
|
|
||||||
If not ``None``, then apply the wrappers to each internal
|
|
||||||
environment during creation.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
:class:`gym.vector.VectorEnv`
|
|
||||||
The vectorized environment.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
|
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
|
||||||
>>> env.reset()
|
>>> env.reset()
|
||||||
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
|
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
|
||||||
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
|
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
|
||||||
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
|
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
|
||||||
dtype=float32)
|
dtype=float32)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The environment ID. This must be a valid ID from the registry.
|
||||||
|
num_envs: Number of copies of the environment.
|
||||||
|
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
|
||||||
|
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
|
||||||
|
**kwargs: Keywords arguments applied during gym.make
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The vectorized environment.
|
||||||
"""
|
"""
|
||||||
from gym.envs import make as make_
|
from gym.envs import make as make_
|
||||||
|
|
||||||
|
@@ -1,13 +1,18 @@
|
|||||||
|
"""An async vector environment."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Union
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import gym
|
||||||
from gym import logger
|
from gym import logger
|
||||||
|
from gym.core import ObsType
|
||||||
from gym.error import (
|
from gym.error import (
|
||||||
AlreadyPendingCallError,
|
AlreadyPendingCallError,
|
||||||
ClosedEnvironmentError,
|
ClosedEnvironmentError,
|
||||||
@@ -37,69 +42,13 @@ class AsyncState(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncVectorEnv(VectorEnv):
|
class AsyncVectorEnv(VectorEnv):
|
||||||
"""Vectorized environment that runs multiple environments in parallel. It
|
"""Vectorized environment that runs multiple environments in parallel.
|
||||||
uses `multiprocessing`_ processes, and pipes for communication.
|
|
||||||
|
|
||||||
Parameters
|
It uses ``multiprocessing`` processes, and pipes for communication.
|
||||||
----------
|
|
||||||
env_fns : iterable of callable
|
|
||||||
Functions that create the environments.
|
|
||||||
|
|
||||||
observation_space : :class:`gym.spaces.Space`, optional
|
Example::
|
||||||
Observation space of a single environment. If ``None``, then the
|
|
||||||
observation space of the first environment is taken.
|
|
||||||
|
|
||||||
action_space : :class:`gym.spaces.Space`, optional
|
|
||||||
Action space of a single environment. If ``None``, then the action space
|
|
||||||
of the first environment is taken.
|
|
||||||
|
|
||||||
shared_memory : bool
|
|
||||||
If ``True``, then the observations from the worker processes are
|
|
||||||
communicated back through shared variables. This can improve the
|
|
||||||
efficiency if the observations are large (e.g. images).
|
|
||||||
|
|
||||||
copy : bool
|
|
||||||
If ``True``, then the :meth:`~AsyncVectorEnv.reset` and
|
|
||||||
:meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
|
|
||||||
|
|
||||||
context : str, optional
|
|
||||||
Context for `multiprocessing`_. If ``None``, then the default context is used.
|
|
||||||
|
|
||||||
daemon : bool
|
|
||||||
If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they
|
|
||||||
will quit if the head process quits. However, ``daemon=True`` prevents
|
|
||||||
subprocesses to spawn children, so for some environments you may want
|
|
||||||
to have it set to ``False``.
|
|
||||||
|
|
||||||
worker : callable, optional
|
|
||||||
If set, then use that worker in a subprocess instead of a default one.
|
|
||||||
Can be useful to override some inner vector env logic, for instance,
|
|
||||||
how resets on done are handled.
|
|
||||||
|
|
||||||
Warning
|
|
||||||
-------
|
|
||||||
:attr:`worker` is an advanced mode option. It provides a high degree of
|
|
||||||
flexibility and a high chance to shoot yourself in the foot; thus,
|
|
||||||
if you are writing your own worker, it is recommended to start from the code
|
|
||||||
for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
RuntimeError
|
|
||||||
If the observation space of some sub-environment does not match
|
|
||||||
:obj:`observation_space` (or, by default, the observation space of
|
|
||||||
the first sub-environment).
|
|
||||||
|
|
||||||
ValueError
|
|
||||||
If :obj:`observation_space` is a custom space (i.e. not a default
|
|
||||||
space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`,
|
|
||||||
or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
|
>>> import gym
|
||||||
>>> env = gym.vector.AsyncVectorEnv([
|
>>> env = gym.vector.AsyncVectorEnv([
|
||||||
... lambda: gym.make("Pendulum-v0", g=9.81),
|
... lambda: gym.make("Pendulum-v0", g=9.81),
|
||||||
... lambda: gym.make("Pendulum-v0", g=1.62)
|
... lambda: gym.make("Pendulum-v0", g=1.62)
|
||||||
@@ -111,15 +60,33 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env_fns,
|
env_fns: Sequence[callable],
|
||||||
observation_space=None,
|
observation_space: Optional[gym.Space] = None,
|
||||||
action_space=None,
|
action_space: Optional[gym.Space] = None,
|
||||||
shared_memory=True,
|
shared_memory: bool = True,
|
||||||
copy=True,
|
copy: bool = True,
|
||||||
context=None,
|
context: Optional[str] = None,
|
||||||
daemon=True,
|
daemon: bool = True,
|
||||||
worker=None,
|
worker: Optional[callable] = None,
|
||||||
):
|
):
|
||||||
|
"""Vectorized environment that runs multiple environments in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_fns: Functions that create the environments.
|
||||||
|
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
|
||||||
|
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
|
||||||
|
shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images).
|
||||||
|
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
|
||||||
|
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
|
||||||
|
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``.
|
||||||
|
worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
|
||||||
|
|
||||||
|
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
|
||||||
|
ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True.
|
||||||
|
"""
|
||||||
ctx = mp.get_context(context)
|
ctx = mp.get_context(context)
|
||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
self.shared_memory = shared_memory
|
self.shared_memory = shared_memory
|
||||||
@@ -192,6 +159,11 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._check_spaces()
|
self._check_spaces()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
|
"""Seeds the vector environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seeds use with the environments
|
||||||
|
"""
|
||||||
super().seed(seed=seed)
|
super().seed(seed=seed)
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if seed is None:
|
if seed is None:
|
||||||
@@ -213,22 +185,24 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, list[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Send the calls to :obj:`reset` to each sub-environment.
|
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
||||||
|
|
||||||
Raises
|
To get the results of these calls, you may invoke :meth:`reset_wait`.
|
||||||
------
|
|
||||||
ClosedEnvironmentError
|
|
||||||
If the environment was closed (if :meth:`close` was previously called).
|
|
||||||
|
|
||||||
AlreadyPendingCallError
|
Args:
|
||||||
If the environment is already waiting for a pending call to another
|
seed: List of seeds for each environment
|
||||||
|
return_info: If to return information
|
||||||
|
options: The reset option
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||||
|
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
|
||||||
method (e.g. :meth:`step_async`). This can be caused by two consecutive
|
method (e.g. :meth:`step_async`). This can be caused by two consecutive
|
||||||
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in
|
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
|
||||||
between.
|
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
|
|
||||||
@@ -258,37 +232,26 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
timeout=None,
|
timeout: Optional[Union[int, float]] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
) -> Union[ObsType, tuple[ObsType, list[dict]]]:
|
||||||
"""
|
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||||
Parameters
|
|
||||||
----------
|
Args:
|
||||||
timeout : int or float, optional
|
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
|
||||||
Number of seconds before the call to `reset_wait` times out. If
|
|
||||||
`None`, the call to `reset_wait` never times out.
|
|
||||||
seed: ignored
|
seed: ignored
|
||||||
|
return_info: If to return information
|
||||||
options: ignored
|
options: ignored
|
||||||
|
|
||||||
Returns
|
Returns:
|
||||||
-------
|
A tuple of batched observations and list of dictionaries
|
||||||
element of :attr:`~VectorEnv.observation_space`
|
|
||||||
A batch of observations from the vectorized environment.
|
|
||||||
infos : list of dicts containing metadata
|
|
||||||
|
|
||||||
Raises
|
Raises:
|
||||||
------
|
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||||
ClosedEnvironmentError
|
NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
|
||||||
If the environment was closed (if :meth:`close` was previously called).
|
TimeoutError: If :meth:`reset_wait` timed out.
|
||||||
|
|
||||||
NoAsyncCallError
|
|
||||||
If :meth:`reset_wait` was called without any prior call to
|
|
||||||
:meth:`reset_async`.
|
|
||||||
|
|
||||||
TimeoutError
|
|
||||||
If :meth:`reset_wait` timed out.
|
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.WAITING_RESET:
|
if self._state != AsyncState.WAITING_RESET:
|
||||||
@@ -327,21 +290,15 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
return deepcopy(self.observations) if self.copy else self.observations
|
return deepcopy(self.observations) if self.copy else self.observations
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions: np.ndarray):
|
||||||
"""Send the calls to :obj:`step` to each sub-environment.
|
"""Send the calls to :obj:`step` to each sub-environment.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
|
||||||
actions : element of :attr:`~VectorEnv.action_space`
|
|
||||||
Batch of actions.
|
|
||||||
|
|
||||||
Raises
|
Raises:
|
||||||
------
|
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||||
ClosedEnvironmentError
|
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
|
||||||
If the environment was closed (if :meth:`close` was previously called).
|
|
||||||
|
|
||||||
AlreadyPendingCallError
|
|
||||||
If the environment is already waiting for a pending call to another
|
|
||||||
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
|
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
|
||||||
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
|
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
|
||||||
between.
|
between.
|
||||||
@@ -358,40 +315,21 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
pipe.send(("step", action))
|
pipe.send(("step", action))
|
||||||
self._state = AsyncState.WAITING_STEP
|
self._state = AsyncState.WAITING_STEP
|
||||||
|
|
||||||
def step_wait(self, timeout=None):
|
def step_wait(
|
||||||
|
self, timeout: Optional[Union[int, float]] = None
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[dict]]:
|
||||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
|
||||||
timeout : int or float, optional
|
|
||||||
Number of seconds before the call to :meth:`step_wait` times out. If
|
|
||||||
``None``, the call to :meth:`step_wait` never times out.
|
|
||||||
|
|
||||||
Returns
|
Returns:
|
||||||
-------
|
The batched environment step information, obs, reward, done and info
|
||||||
observations : element of :attr:`~VectorEnv.observation_space`
|
|
||||||
A batch of observations from the vectorized environment.
|
|
||||||
|
|
||||||
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
|
Raises:
|
||||||
A vector of rewards from the vectorized environment.
|
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||||
|
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
|
||||||
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
|
TimeoutError: If :meth:`step_wait` timed out.
|
||||||
A vector whose entries indicate whether the episode has ended.
|
|
||||||
|
|
||||||
infos : list of dict
|
|
||||||
A list of auxiliary diagnostic information dicts from sub-environments.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ClosedEnvironmentError
|
|
||||||
If the environment was closed (if :meth:`close` was previously called).
|
|
||||||
|
|
||||||
NoAsyncCallError
|
|
||||||
If :meth:`step_wait` was called without any prior call to
|
|
||||||
:meth:`step_async`.
|
|
||||||
|
|
||||||
TimeoutError
|
|
||||||
If :meth:`step_wait` timed out.
|
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.WAITING_STEP:
|
if self._state != AsyncState.WAITING_STEP:
|
||||||
@@ -425,18 +363,13 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
infos,
|
infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_async(self, name, *args, **kwargs):
|
def call_async(self, name: str, *args, **kwargs):
|
||||||
"""
|
"""Calls the method with name asynchronously and apply args and kwargs to the method.
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : string
|
|
||||||
Name of the method or property to call.
|
|
||||||
|
|
||||||
*args
|
Args:
|
||||||
Arguments to apply to the method call.
|
name: Name of the method or property to call.
|
||||||
|
*args: Arguments to apply to the method call.
|
||||||
**kwargs
|
**kwargs: Keyword arguments to apply to the method call.
|
||||||
Keywoard arguments to apply to the method call.
|
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
@@ -450,19 +383,14 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
pipe.send(("_call", (name, args, kwargs)))
|
pipe.send(("_call", (name, args, kwargs)))
|
||||||
self._state = AsyncState.WAITING_CALL
|
self._state = AsyncState.WAITING_CALL
|
||||||
|
|
||||||
def call_wait(self, timeout=None):
|
def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list:
|
||||||
"""
|
"""Calls all parent pipes and waits for the results.
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
timeout : int or float, optional
|
|
||||||
Number of seconds before the call to `step_wait` times out. If
|
|
||||||
`None` (default), the call to `step_wait` never times out.
|
|
||||||
|
|
||||||
Returns
|
Args:
|
||||||
-------
|
timeout: Number of seconds before the call to `step_wait` times out. If `None` (default), the call to `step_wait` never times out.
|
||||||
results : list
|
|
||||||
List of the results of the individual calls to the method or
|
Returns:
|
||||||
property for each environment.
|
List of the results of the individual calls to the method or property for each environment.
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.WAITING_CALL:
|
if self._state != AsyncState.WAITING_CALL:
|
||||||
@@ -483,15 +411,12 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def set_attr(self, name, values):
|
def set_attr(self, name: str, values: Union[list, tuple, object]):
|
||||||
"""
|
"""Sets an attribute of the sub-environments.
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : string
|
|
||||||
Name of the property to be set in each individual environment.
|
|
||||||
|
|
||||||
values : list, tuple, or object
|
Args:
|
||||||
Values of the property to be set to. If `values` is a list or
|
name: Name of the property to be set in each individual environment.
|
||||||
|
values: Values of the property to be set to. If ``values`` is a list or
|
||||||
tuple, then it corresponds to the values for each individual
|
tuple, then it corresponds to the values for each individual
|
||||||
environment, otherwise a single value is set for all environments.
|
environment, otherwise a single value is set for all environments.
|
||||||
"""
|
"""
|
||||||
@@ -517,25 +442,19 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||||
self._raise_if_errors(successes)
|
self._raise_if_errors(successes)
|
||||||
|
|
||||||
def close_extras(self, timeout=None, terminate=False):
|
def close_extras(
|
||||||
"""Close the environments & clean up the extra resources
|
self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
|
||||||
(processes and pipes).
|
):
|
||||||
|
"""Close the environments & clean up the extra resources (processes and pipes).
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
|
||||||
timeout : int or float, optional
|
|
||||||
Number of seconds before the call to :meth:`close` times out. If ``None``,
|
|
||||||
the call to :meth:`close` never times out. If the call to :meth:`close`
|
the call to :meth:`close` never times out. If the call to :meth:`close`
|
||||||
times out, then all processes are terminated.
|
times out, then all processes are terminated.
|
||||||
|
terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
|
||||||
|
|
||||||
terminate : bool
|
Raises:
|
||||||
If ``True``, then the :meth:`close` operation is forced and all processes
|
TimeoutError: If :meth:`close` timed out.
|
||||||
are terminated.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
TimeoutError
|
|
||||||
If :meth:`close` timed out.
|
|
||||||
"""
|
"""
|
||||||
timeout = 0 if terminate else timeout
|
timeout = 0 if terminate else timeout
|
||||||
try:
|
try:
|
||||||
@@ -626,6 +545,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
raise exctype(value)
|
raise exctype(value)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
"""On deleting the object, checks that the vector environment is closed."""
|
||||||
if not getattr(self, "closed", True) and hasattr(self, "_state"):
|
if not getattr(self, "closed", True) and hasattr(self, "_state"):
|
||||||
self.close(terminate=True)
|
self.close(terminate=True)
|
||||||
|
|
||||||
|
@@ -1,8 +1,12 @@
|
|||||||
|
"""A synchronous vector environment."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Iterator, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from gym.spaces import Space
|
||||||
from gym.vector.utils import concatenate, create_empty_array, iterate
|
from gym.vector.utils import concatenate, create_empty_array, iterate
|
||||||
from gym.vector.vector_env import VectorEnv
|
from gym.vector.vector_env import VectorEnv
|
||||||
|
|
||||||
@@ -12,35 +16,9 @@ __all__ = ["SyncVectorEnv"]
|
|||||||
class SyncVectorEnv(VectorEnv):
|
class SyncVectorEnv(VectorEnv):
|
||||||
"""Vectorized environment that serially runs multiple environments.
|
"""Vectorized environment that serially runs multiple environments.
|
||||||
|
|
||||||
Parameters
|
Example::
|
||||||
----------
|
|
||||||
env_fns : iterable of callable
|
|
||||||
Functions that create the environments.
|
|
||||||
|
|
||||||
observation_space : :class:`gym.spaces.Space`, optional
|
|
||||||
Observation space of a single environment. If ``None``, then the
|
|
||||||
observation space of the first environment is taken.
|
|
||||||
|
|
||||||
action_space : :class:`gym.spaces.Space`, optional
|
|
||||||
Action space of a single environment. If ``None``, then the action space
|
|
||||||
of the first environment is taken.
|
|
||||||
|
|
||||||
copy : bool
|
|
||||||
If ``True``, then the :meth:`reset` and :meth:`step` methods return a
|
|
||||||
copy of the observations.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
RuntimeError
|
|
||||||
If the observation space of some sub-environment does not match
|
|
||||||
:obj:`observation_space` (or, by default, the observation space of
|
|
||||||
the first sub-environment).
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
|
>>> import gym
|
||||||
>>> env = gym.vector.SyncVectorEnv([
|
>>> env = gym.vector.SyncVectorEnv([
|
||||||
... lambda: gym.make("Pendulum-v0", g=9.81),
|
... lambda: gym.make("Pendulum-v0", g=9.81),
|
||||||
... lambda: gym.make("Pendulum-v0", g=1.62)
|
... lambda: gym.make("Pendulum-v0", g=1.62)
|
||||||
@@ -50,7 +28,24 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
|
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns, observation_space=None, action_space=None, copy=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_fns: Iterator[callable],
|
||||||
|
observation_space: Space = None,
|
||||||
|
action_space: Space = None,
|
||||||
|
copy: bool = True,
|
||||||
|
):
|
||||||
|
"""Vectorized environment that serially runs multiple environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_fns: iterable of callable functions that create the environments.
|
||||||
|
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
|
||||||
|
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
|
||||||
|
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
|
||||||
|
"""
|
||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
self.envs = [env_fn() for env_fn in env_fns]
|
self.envs = [env_fn() for env_fn in env_fns]
|
||||||
self.copy = copy
|
self.copy = copy
|
||||||
@@ -60,7 +55,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
observation_space = observation_space or self.envs[0].observation_space
|
observation_space = observation_space or self.envs[0].observation_space
|
||||||
action_space = action_space or self.envs[0].action_space
|
action_space = action_space or self.envs[0].action_space
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_envs=len(env_fns),
|
num_envs=len(self.envs),
|
||||||
observation_space=observation_space,
|
observation_space=observation_space,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
)
|
)
|
||||||
@@ -73,7 +68,12 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||||
self._actions = None
|
self._actions = None
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
|
||||||
|
"""Sets the seed in all sub-environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed
|
||||||
|
"""
|
||||||
super().seed(seed=seed)
|
super().seed(seed=seed)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = [None for _ in range(self.num_envs)]
|
seed = [None for _ in range(self.num_envs)]
|
||||||
@@ -86,10 +86,20 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, list[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The reset environment seed
|
||||||
|
return_info: If to return information
|
||||||
|
options: Option information for the environment reset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reset observation of the environment and reset information
|
||||||
|
"""
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = [None for _ in range(self.num_envs)]
|
seed = [None for _ in range(self.num_envs)]
|
||||||
if isinstance(seed, int):
|
if isinstance(seed, int):
|
||||||
@@ -128,9 +138,15 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
), data_list
|
), data_list
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
|
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
||||||
self._actions = iterate(self.action_space, actions)
|
self._actions = iterate(self.action_space, actions)
|
||||||
|
|
||||||
def step_wait(self):
|
def step_wait(self):
|
||||||
|
"""Steps through each of the environments returning the batched results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The batched environment step results
|
||||||
|
"""
|
||||||
observations, infos = [], []
|
observations, infos = [], []
|
||||||
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
||||||
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
||||||
@@ -150,7 +166,17 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
infos,
|
infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call(self, name, *args, **kwargs):
|
def call(self, name, *args, **kwargs) -> tuple:
|
||||||
|
"""Calls the method with name and applies args and kwargs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The method name
|
||||||
|
*args: The method args
|
||||||
|
**kwargs: The method kwargs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of results
|
||||||
|
"""
|
||||||
results = []
|
results = []
|
||||||
for env in self.envs:
|
for env in self.envs:
|
||||||
function = getattr(env, name)
|
function = getattr(env, name)
|
||||||
@@ -161,7 +187,15 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
return tuple(results)
|
return tuple(results)
|
||||||
|
|
||||||
def set_attr(self, name, values):
|
def set_attr(self, name: str, values: Union[list, tuple, Any]):
|
||||||
|
"""Sets an attribute of the sub-environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The property name to change
|
||||||
|
values: Values of the property to be set to. If ``values`` is a list or
|
||||||
|
tuple, then it corresponds to the values for each individual
|
||||||
|
environment, otherwise, a single value is set for all environments.
|
||||||
|
"""
|
||||||
if not isinstance(values, (list, tuple)):
|
if not isinstance(values, (list, tuple)):
|
||||||
values = [values for _ in range(self.num_envs)]
|
values = [values for _ in range(self.num_envs)]
|
||||||
if len(values) != self.num_envs:
|
if len(values) != self.num_envs:
|
||||||
@@ -178,7 +212,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
"""Close the environments."""
|
"""Close the environments."""
|
||||||
[env.close() for env in self.envs]
|
[env.close() for env in self.envs]
|
||||||
|
|
||||||
def _check_spaces(self):
|
def _check_spaces(self) -> bool:
|
||||||
for env in self.envs:
|
for env in self.envs:
|
||||||
if not (env.observation_space == self.single_observation_space):
|
if not (env.observation_space == self.single_observation_space):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -194,5 +228,4 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
"action spaces from all environments must be equal."
|
"action spaces from all environments must be equal."
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
return True
|
return True
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Module for gym vector utils."""
|
||||||
from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
|
from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
|
||||||
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||||
from gym.vector.utils.shared_memory import (
|
from gym.vector.utils.shared_memory import (
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Miscellaneous utilities."""
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -5,28 +6,35 @@ __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
|
|||||||
|
|
||||||
|
|
||||||
class CloudpickleWrapper:
|
class CloudpickleWrapper:
|
||||||
def __init__(self, fn):
|
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
|
||||||
|
|
||||||
|
def __init__(self, fn: callable):
|
||||||
|
"""Cloudpickle wrapper for a function."""
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
"""Get the state using `cloudpickle.dumps(self.fn)`."""
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
|
|
||||||
return cloudpickle.dumps(self.fn)
|
return cloudpickle.dumps(self.fn)
|
||||||
|
|
||||||
def __setstate__(self, ob):
|
def __setstate__(self, ob):
|
||||||
|
"""Sets the state with obs."""
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
self.fn = pickle.loads(ob)
|
self.fn = pickle.loads(ob)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
|
"""Calls the function `self.fn` with no arguments."""
|
||||||
return self.fn()
|
return self.fn()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def clear_mpi_env_vars():
|
def clear_mpi_env_vars():
|
||||||
"""
|
"""Clears the MPI of environment variables.
|
||||||
`from mpi4py import MPI` will call `MPI_Init` by default. If the child
|
|
||||||
process has MPI environment variables, MPI will think that the child process
|
`from mpi4py import MPI` will call `MPI_Init` by default.
|
||||||
|
If the child process has MPI environment variables, MPI will think that the child process
|
||||||
is an MPI process just like the parent and do bad things such as hang.
|
is an MPI process just like the parent and do bad things such as hang.
|
||||||
|
|
||||||
This context manager is a hacky way to clear those environment variables
|
This context manager is a hacky way to clear those environment variables
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
|
"""Numpy utility functions: concatenate space samples and create empty array."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import singledispatch
|
from functools import singledispatch
|
||||||
|
from typing import Iterable, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -9,36 +11,29 @@ __all__ = ["concatenate", "create_empty_array"]
|
|||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def concatenate(space, items, out):
|
def concatenate(
|
||||||
|
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
|
||||||
|
) -> Union[tuple, dict, np.ndarray]:
|
||||||
"""Concatenate multiple samples from space into a single object.
|
"""Concatenate multiple samples from space into a single object.
|
||||||
|
|
||||||
Parameters
|
Example::
|
||||||
----------
|
|
||||||
items : iterable of samples of `space`
|
|
||||||
Samples to be concatenated.
|
|
||||||
|
|
||||||
out : tuple, dict, or `np.ndarray`
|
|
||||||
The output object. This object is a (possibly nested) numpy array.
|
|
||||||
|
|
||||||
space : `gym.spaces.Space` instance
|
|
||||||
Observation space of a single environment in the vectorized environment.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
out : tuple, dict, or `np.ndarray`
|
|
||||||
The output object. This object is a (possibly nested) numpy array.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
>>> from gym.spaces import Box
|
>>> from gym.spaces import Box
|
||||||
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
|
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
|
||||||
>>> out = np.zeros((2, 3), dtype=np.float32)
|
>>> out = np.zeros((2, 3), dtype=np.float32)
|
||||||
>>> items = [space.sample() for _ in range(2)]
|
>>> items = [space.sample() for _ in range(2)]
|
||||||
>>> concatenate(items, out, space)
|
>>> concatenate(space, items, out)
|
||||||
array([[0.6348213 , 0.28607962, 0.60760117],
|
array([[0.6348213 , 0.28607962, 0.60760117],
|
||||||
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
|
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: Observation space of a single environment in the vectorized environment.
|
||||||
|
items: Samples to be concatenated.
|
||||||
|
out: The output object. This object is a (possibly nested) numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output object. This object is a (possibly nested) numpy array.
|
||||||
"""
|
"""
|
||||||
assert isinstance(items, (list, tuple))
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
||||||
)
|
)
|
||||||
@@ -76,29 +71,13 @@ def _concatenate_custom(space, items, out):
|
|||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def create_empty_array(space, n=1, fn=np.zeros):
|
def create_empty_array(
|
||||||
|
space: Space, n: int = 1, fn: callable = np.zeros
|
||||||
|
) -> Union[tuple, dict, np.ndarray]:
|
||||||
"""Create an empty (possibly nested) numpy array.
|
"""Create an empty (possibly nested) numpy array.
|
||||||
|
|
||||||
Parameters
|
Example::
|
||||||
----------
|
|
||||||
space : `gym.spaces.Space` instance
|
|
||||||
Observation space of a single environment in the vectorized environment.
|
|
||||||
|
|
||||||
n : int
|
|
||||||
Number of environments in the vectorized environment. If `None`, creates
|
|
||||||
an empty sample from `space`.
|
|
||||||
|
|
||||||
fn : callable
|
|
||||||
Function to apply when creating the empty numpy array. Examples of such
|
|
||||||
functions are `np.empty` or `np.zeros`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
out : tuple, dict, or `np.ndarray`
|
|
||||||
The output object. This object is a (possibly nested) numpy array.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
>>> from gym.spaces import Box, Dict
|
>>> from gym.spaces import Box, Dict
|
||||||
>>> space = Dict({
|
>>> space = Dict({
|
||||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||||
@@ -108,6 +87,14 @@ def create_empty_array(space, n=1, fn=np.zeros):
|
|||||||
[0., 0., 0.]], dtype=float32)),
|
[0., 0., 0.]], dtype=float32)),
|
||||||
('velocity', array([[0., 0.],
|
('velocity', array([[0., 0.],
|
||||||
[0., 0.]], dtype=float32))])
|
[0., 0.]], dtype=float32))])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: Observation space of a single environment in the vectorized environment.
|
||||||
|
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
|
||||||
|
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output object. This object is a (possibly nested) numpy array.
|
||||||
"""
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
||||||
|
@@ -1,44 +1,40 @@
|
|||||||
|
"""Utility functions for vector environments to share memory between processes."""
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ctypes import c_bool
|
from ctypes import c_bool
|
||||||
from functools import singledispatch
|
from functools import singledispatch
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym.error import CustomSpaceError
|
from gym.error import CustomSpaceError
|
||||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
|
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
||||||
|
|
||||||
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
|
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def create_shared_memory(space, n=1, ctx=mp):
|
def create_shared_memory(
|
||||||
"""Create a shared memory object, to be shared across processes. This
|
space: Space, n: int = 1, ctx=mp
|
||||||
eventually contains the observations from the vectorized environment.
|
) -> Union[dict, tuple, mp.Array]:
|
||||||
|
"""Create a shared memory object, to be shared across processes.
|
||||||
|
|
||||||
Parameters
|
This eventually contains the observations from the vectorized environment.
|
||||||
----------
|
|
||||||
space : `gym.spaces.Space` instance
|
|
||||||
Observation space of a single environment in the vectorized environment.
|
|
||||||
|
|
||||||
n : int
|
Args:
|
||||||
Number of environments in the vectorized environment (i.e. the number
|
space: Observation space of a single environment in the vectorized environment.
|
||||||
of processes).
|
n: Number of environments in the vectorized environment (i.e. the number of processes).
|
||||||
|
ctx: The multiprocess module
|
||||||
|
|
||||||
ctx : `multiprocessing` context
|
Returns:
|
||||||
Context for multiprocessing.
|
shared_memory for the shared object across processes.
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
|
||||||
Shared object across processes.
|
|
||||||
"""
|
"""
|
||||||
raise CustomSpaceError(
|
raise CustomSpaceError(
|
||||||
"Cannot create a shared memory for space with "
|
"Cannot create a shared memory for space with "
|
||||||
"type `{}`. Shared memory only supports "
|
f"type `{type(space)}`. Shared memory only supports "
|
||||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||||
"`Dict`, etc...), and does not support custom "
|
"`Dict`, etc...), and does not support custom "
|
||||||
"Gym spaces.".format(type(space))
|
"Gym spaces."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -46,7 +42,7 @@ def create_shared_memory(space, n=1, ctx=mp):
|
|||||||
@create_shared_memory.register(Discrete)
|
@create_shared_memory.register(Discrete)
|
||||||
@create_shared_memory.register(MultiDiscrete)
|
@create_shared_memory.register(MultiDiscrete)
|
||||||
@create_shared_memory.register(MultiBinary)
|
@create_shared_memory.register(MultiBinary)
|
||||||
def _create_base_shared_memory(space, n=1, ctx=mp):
|
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
|
||||||
dtype = space.dtype.char
|
dtype = space.dtype.char
|
||||||
if dtype in "?":
|
if dtype in "?":
|
||||||
dtype = c_bool
|
dtype = c_bool
|
||||||
@@ -54,7 +50,7 @@ def _create_base_shared_memory(space, n=1, ctx=mp):
|
|||||||
|
|
||||||
|
|
||||||
@create_shared_memory.register(Tuple)
|
@create_shared_memory.register(Tuple)
|
||||||
def _create_tuple_shared_memory(space, n=1, ctx=mp):
|
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
|
||||||
return tuple(
|
return tuple(
|
||||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
||||||
)
|
)
|
||||||
@@ -71,39 +67,32 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
|
|||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def read_from_shared_memory(space, shared_memory, n=1):
|
def read_from_shared_memory(
|
||||||
|
space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1
|
||||||
|
) -> Union[dict, tuple, np.ndarray]:
|
||||||
"""Read the batch of observations from shared memory as a numpy array.
|
"""Read the batch of observations from shared memory as a numpy array.
|
||||||
|
|
||||||
Parameters
|
..notes::
|
||||||
----------
|
|
||||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
|
||||||
Shared object across processes. This contains the observations from the
|
|
||||||
vectorized environment. This object is created with `create_shared_memory`.
|
|
||||||
|
|
||||||
space : `gym.spaces.Space` instance
|
|
||||||
Observation space of a single environment in the vectorized environment.
|
|
||||||
|
|
||||||
n : int
|
|
||||||
Number of environments in the vectorized environment (i.e. the number
|
|
||||||
of processes).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
observations : dict, tuple or `np.ndarray` instance
|
|
||||||
Batch of observations as a (possibly nested) numpy array.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
The numpy array objects returned by `read_from_shared_memory` shares the
|
The numpy array objects returned by `read_from_shared_memory` shares the
|
||||||
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
|
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
|
||||||
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
|
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: Observation space of a single environment in the vectorized environment.
|
||||||
|
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
|
||||||
|
This object is created with `create_shared_memory`.
|
||||||
|
n: Number of environments in the vectorized environment (i.e. the number of processes).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batch of observations as a (possibly nested) numpy array.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise CustomSpaceError(
|
raise CustomSpaceError(
|
||||||
"Cannot read from a shared memory for space with "
|
"Cannot read from a shared memory for space with "
|
||||||
"type `{}`. Shared memory only supports "
|
f"type `{type(space)}`. Shared memory only supports "
|
||||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||||
"`Dict`, etc...), and does not support custom "
|
"`Dict`, etc...), and does not support custom "
|
||||||
"Gym spaces.".format(type(space))
|
"Gym spaces."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -111,14 +100,14 @@ def read_from_shared_memory(space, shared_memory, n=1):
|
|||||||
@read_from_shared_memory.register(Discrete)
|
@read_from_shared_memory.register(Discrete)
|
||||||
@read_from_shared_memory.register(MultiDiscrete)
|
@read_from_shared_memory.register(MultiDiscrete)
|
||||||
@read_from_shared_memory.register(MultiBinary)
|
@read_from_shared_memory.register(MultiBinary)
|
||||||
def _read_base_from_shared_memory(space, shared_memory, n=1):
|
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
|
||||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
||||||
(n,) + space.shape
|
(n,) + space.shape
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@read_from_shared_memory.register(Tuple)
|
@read_from_shared_memory.register(Tuple)
|
||||||
def _read_tuple_from_shared_memory(space, shared_memory, n=1):
|
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
|
||||||
return tuple(
|
return tuple(
|
||||||
read_from_shared_memory(subspace, memory, n=n)
|
read_from_shared_memory(subspace, memory, n=n)
|
||||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||||
@@ -126,7 +115,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n=1):
|
|||||||
|
|
||||||
|
|
||||||
@read_from_shared_memory.register(Dict)
|
@read_from_shared_memory.register(Dict)
|
||||||
def _read_dict_from_shared_memory(space, shared_memory, n=1):
|
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
||||||
@@ -136,34 +125,26 @@ def _read_dict_from_shared_memory(space, shared_memory, n=1):
|
|||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def write_to_shared_memory(space, index, value, shared_memory):
|
def write_to_shared_memory(
|
||||||
|
space: Space,
|
||||||
|
index: int,
|
||||||
|
value: np.ndarray,
|
||||||
|
shared_memory: Union[dict, tuple, mp.Array],
|
||||||
|
):
|
||||||
"""Write the observation of a single environment into shared memory.
|
"""Write the observation of a single environment into shared memory.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
space: Observation space of a single environment in the vectorized environment.
|
||||||
index : int
|
index: Index of the environment (must be in `[0, num_envs)`).
|
||||||
Index of the environment (must be in `[0, num_envs)`).
|
value: Observation of the single environment to write to shared memory.
|
||||||
|
shared_memory: Shared object across processes. This contains the observations from the vectorized environment. This object is created with `create_shared_memory`.
|
||||||
value : sample from `space`
|
|
||||||
Observation of the single environment to write to shared memory.
|
|
||||||
|
|
||||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
|
||||||
Shared object across processes. This contains the observations from the
|
|
||||||
vectorized environment. This object is created with `create_shared_memory`.
|
|
||||||
|
|
||||||
space : `gym.spaces.Space` instance
|
|
||||||
Observation space of a single environment in the vectorized environment.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
`None`
|
|
||||||
"""
|
"""
|
||||||
raise CustomSpaceError(
|
raise CustomSpaceError(
|
||||||
"Cannot write to a shared memory for space with "
|
"Cannot write to a shared memory for space with "
|
||||||
"type `{}`. Shared memory only supports "
|
f"type `{type(space)}`. Shared memory only supports "
|
||||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||||
"`Dict`, etc...), and does not support custom "
|
"`Dict`, etc...), and does not support custom "
|
||||||
"Gym spaces.".format(type(space))
|
"Gym spaces."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
|
"""Utility functions for gym spaces: batch space and iterator."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import singledispatch
|
from functools import singledispatch
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -12,32 +14,25 @@ __all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
|
|||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def batch_space(space, n=1):
|
def batch_space(space: Space, n: int = 1) -> Space:
|
||||||
"""Create a (batched) space, containing multiple copies of a single space.
|
"""Create a (batched) space, containing multiple copies of a single space.
|
||||||
|
|
||||||
Parameters
|
Example::
|
||||||
----------
|
|
||||||
space : `gym.spaces.Space` instance
|
|
||||||
Space (e.g. the observation space) for a single environment in the
|
|
||||||
vectorized environment.
|
|
||||||
|
|
||||||
n : int
|
|
||||||
Number of environments in the vectorized environment.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
batched_space : `gym.spaces.Space` instance
|
|
||||||
Space (e.g. the observation space) for a batch of environments in the
|
|
||||||
vectorized environment.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
>>> from gym.spaces import Box, Dict
|
>>> from gym.spaces import Box, Dict
|
||||||
>>> space = Dict({
|
>>> space = Dict({
|
||||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
|
||||||
|
... })
|
||||||
>>> batch_space(space, n=5)
|
>>> batch_space(space, n=5)
|
||||||
Dict(position:Box(5, 3), velocity:Box(5, 2))
|
Dict(position:Box(5, 3), velocity:Box(5, 2))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
|
||||||
|
n: Number of environments in the vectorized environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
|
||||||
"""
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
||||||
@@ -126,24 +121,11 @@ def _batch_space_custom(space, n=1):
|
|||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def iterate(space, items):
|
def iterate(space: Space, items) -> Iterator:
|
||||||
"""Iterate over the elements of a (batched) space.
|
"""Iterate over the elements of a (batched) space.
|
||||||
|
|
||||||
Parameters
|
Example::
|
||||||
----------
|
|
||||||
space : `gym.spaces.Space` instance
|
|
||||||
Space to which `items` belong to.
|
|
||||||
|
|
||||||
items : samples of `space`
|
|
||||||
Items to be iterated over.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
iterator : `Iterable` instance
|
|
||||||
Iterator over the elements in `items`.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
>>> from gym.spaces import Box, Dict
|
>>> from gym.spaces import Box, Dict
|
||||||
>>> space = Dict({
|
>>> space = Dict({
|
||||||
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
|
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
|
||||||
@@ -158,9 +140,16 @@ def iterate(space, items):
|
|||||||
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
|
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
|
||||||
>>> next(it)
|
>>> next(it)
|
||||||
StopIteration
|
StopIteration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: Space to which `items` belong to.
|
||||||
|
items: Items to be iterated over.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator over the elements in `items`.
|
||||||
"""
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Space of type `{}` is not a valid `gym.Space` " "instance.".format(type(space))
|
f"Space of type `{type(space)}` is not a valid `gym.Space` " "instance."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,4 +1,7 @@
|
|||||||
from typing import List, Optional, Union
|
"""Base class for vectorized environments."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
@@ -8,32 +11,28 @@ __all__ = ["VectorEnv"]
|
|||||||
|
|
||||||
|
|
||||||
class VectorEnv(gym.Env):
|
class VectorEnv(gym.Env):
|
||||||
r"""Base class for vectorized environments. Runs multiple independent copies of the
|
"""Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel.
|
||||||
same environment in parallel. This is not the same as 1 environment that has multiple
|
|
||||||
sub components, but it is many copies of the same base env.
|
|
||||||
|
|
||||||
Each observation returned from vectorized environment is a batch of observations
|
This is not the same as 1 environment that has multiple subcomponents, but it is many copies of the same base env.
|
||||||
for each parallel environment. And :meth:`step` is also expected to receive a batch of
|
|
||||||
actions for each parallel environment.
|
|
||||||
|
|
||||||
.. note::
|
Each observation returned from vectorized environment is a batch of observations for each parallel environment.
|
||||||
|
And :meth:`step` is also expected to receive a batch of actions for each parallel environment.
|
||||||
|
|
||||||
|
Notes:
|
||||||
All parallel environments should share the identical observation and action spaces.
|
All parallel environments should share the identical observation and action spaces.
|
||||||
In other words, a vector of multiple different environments is not supported.
|
In other words, a vector of multiple different environments is not supported.
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
num_envs : int
|
|
||||||
Number of environments in the vectorized environment.
|
|
||||||
|
|
||||||
observation_space : :class:`gym.spaces.Space`
|
|
||||||
Observation space of a single environment.
|
|
||||||
|
|
||||||
action_space : :class:`gym.spaces.Space`
|
|
||||||
Action space of a single environment.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_envs, observation_space, action_space):
|
def __init__(
|
||||||
|
self, num_envs: int, observation_space: gym.Space, action_space: gym.Space
|
||||||
|
):
|
||||||
|
"""Base class for vectorized environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_envs: Number of environments in the vectorized environment.
|
||||||
|
observation_space: Observation space of a single environment.
|
||||||
|
action_space: Action space of a single environment.
|
||||||
|
"""
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.is_vector_env = True
|
self.is_vector_env = True
|
||||||
self.observation_space = batch_space(observation_space, n=num_envs)
|
self.observation_space = batch_space(observation_space, n=num_envs)
|
||||||
@@ -49,141 +48,134 @@ class VectorEnv(gym.Env):
|
|||||||
|
|
||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, list[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
"""Reset the sub-environments asynchronously.
|
||||||
|
|
||||||
|
This method will return ``None``. A call to :meth:`reset_async` should be followed by a call to :meth:`reset_wait` to retrieve the results.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, list[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
"""Retrieves the results of a :meth:`reset_async` call.
|
||||||
|
|
||||||
|
A call to this method must always be preceded by a call to :meth:`reset_async`.
|
||||||
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, list[int]]] = None,
|
||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
r"""Reset all parallel environments and return a batch of initial observations.
|
"""Reset all parallel environments and return a batch of initial observations.
|
||||||
|
|
||||||
Returns
|
Args:
|
||||||
-------
|
seed: The environment reset seeds
|
||||||
observations : element of :attr:`observation_space`
|
return_info: If to return the info
|
||||||
|
options: If to return the options
|
||||||
|
|
||||||
|
Returns:
|
||||||
A batch of observations from the vectorized environment.
|
A batch of observations from the vectorized environment.
|
||||||
"""
|
"""
|
||||||
self.reset_async(seed=seed, return_info=return_info, options=options)
|
self.reset_async(seed=seed, return_info=return_info, options=options)
|
||||||
return self.reset_wait(seed=seed, return_info=return_info, options=options)
|
return self.reset_wait(seed=seed, return_info=return_info, options=options)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
|
"""Asynchronously performs steps in the sub-environments.
|
||||||
|
|
||||||
|
The results can be retrieved via a call to :meth:`step_wait`.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def step_wait(self, **kwargs):
|
def step_wait(self, **kwargs):
|
||||||
|
"""Retrieves the results of a :meth:`step_async` call.
|
||||||
|
|
||||||
|
A call to this method must always be preceded by a call to :meth:`step_async`.
|
||||||
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
r"""Take an action for each parallel environment.
|
"""Take an action for each parallel environment.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
actions: element of :attr:`action_space` Batch of actions.
|
||||||
actions : element of :attr:`action_space`
|
|
||||||
Batch of actions.
|
|
||||||
|
|
||||||
Returns
|
Returns:
|
||||||
-------
|
Batch of observations, rewards, done and infos
|
||||||
observations : element of :attr:`observation_space`
|
|
||||||
A batch of observations from the vectorized environment.
|
|
||||||
|
|
||||||
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
|
|
||||||
A vector of rewards from the vectorized environment.
|
|
||||||
|
|
||||||
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
|
|
||||||
A vector whose entries indicate whether the episode has ended.
|
|
||||||
|
|
||||||
infos : list of dict
|
|
||||||
A list of auxiliary diagnostic information dicts from each parallel environment.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.step_async(actions)
|
self.step_async(actions)
|
||||||
return self.step_wait()
|
return self.step_wait()
|
||||||
|
|
||||||
def call_async(self, name, *args, **kwargs):
|
def call_async(self, name, *args, **kwargs):
|
||||||
|
"""Calls a method name for each parallel environment asynchronously."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def call_wait(self, **kwargs):
|
def call_wait(self, **kwargs):
|
||||||
|
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def call(self, name, *args, **kwargs):
|
def call(self, name: str, *args, **kwargs) -> list[Any]:
|
||||||
"""Call a method, or get a property, from each parallel environment.
|
"""Call a method, or get a property, from each parallel environment.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
name (str): Name of the method or property to call.
|
||||||
name : string
|
*args: Arguments to apply to the method call.
|
||||||
Name of the method or property to call.
|
**kwargs: Keyword arguments to apply to the method call.
|
||||||
|
|
||||||
*args
|
Returns:
|
||||||
Arguments to apply to the method call.
|
List of the results of the individual calls to the method or property for each environment.
|
||||||
|
|
||||||
**kwargs
|
|
||||||
Keywoard arguments to apply to the method call.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
results : list
|
|
||||||
List of the results of the individual calls to the method or
|
|
||||||
property for each environment.
|
|
||||||
"""
|
"""
|
||||||
self.call_async(name, *args, **kwargs)
|
self.call_async(name, *args, **kwargs)
|
||||||
return self.call_wait()
|
return self.call_wait()
|
||||||
|
|
||||||
def get_attr(self, name):
|
def get_attr(self, name: str):
|
||||||
"""Get a property from each parallel environment.
|
"""Get a property from each parallel environment.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
name (str): Name of the property to be get from each individual environment.
|
||||||
name : string
|
|
||||||
Name of the property to be get from each individual environment.
|
Returns:
|
||||||
|
The property with name
|
||||||
"""
|
"""
|
||||||
return self.call(name)
|
return self.call(name)
|
||||||
|
|
||||||
def set_attr(self, name, values):
|
def set_attr(self, name: str, values: Union[list, tuple, object]):
|
||||||
"""Set a property in each parallel environment.
|
"""Set a property in each sub-environment.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
name (str): Name of the property to be set in each individual environment.
|
||||||
name : string
|
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
|
||||||
Name of the property to be set in each individual environment.
|
tuple, then it corresponds to the values for each individual environment, otherwise a single value
|
||||||
|
is set for all environments.
|
||||||
values : list, tuple, or object
|
|
||||||
Values of the property to be set to. If `values` is a list or
|
|
||||||
tuple, then it corresponds to the values for each individual
|
|
||||||
environment, otherwise a single value is set for all environments.
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def close_extras(self, **kwargs):
|
def close_extras(self, **kwargs):
|
||||||
r"""Clean up the extra resources e.g. beyond what's in this base class."""
|
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self, **kwargs):
|
def close(self, **kwargs):
|
||||||
r"""Close all parallel environments and release resources.
|
"""Close all parallel environments and release resources.
|
||||||
|
|
||||||
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
|
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
|
||||||
:attr:`closed` as ``True``.
|
:attr:`closed` as ``True``.
|
||||||
|
|
||||||
.. warning::
|
Warnings:
|
||||||
|
|
||||||
This function itself does not close the environments, it should be handled
|
This function itself does not close the environments, it should be handled
|
||||||
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
|
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
|
||||||
vectorized environments.
|
vectorized environments.
|
||||||
|
|
||||||
.. note::
|
Notes:
|
||||||
|
|
||||||
This will be automatically called when garbage collected or program exited.
|
This will be automatically called when garbage collected or program exited.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -197,10 +189,8 @@ class VectorEnv(gym.Env):
|
|||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
"""Set the random seed in all parallel environments.
|
"""Set the random seed in all parallel environments.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
seed: Random seed for each parallel environment. If ``seed`` is a list of
|
||||||
seed : list of int, or int, optional
|
|
||||||
Random seed for each parallel environment. If ``seed`` is a list of
|
|
||||||
length ``num_envs``, then the items of the list are chosen as random
|
length ``num_envs``, then the items of the list are chosen as random
|
||||||
seeds. If ``seed`` is an int, then each parallel environment uses the random
|
seeds. If ``seed`` is an int, then each parallel environment uses the random
|
||||||
seed ``seed + n``, where ``n`` is the index of the parallel environment
|
seed ``seed + n``, where ``n`` is the index of the parallel environment
|
||||||
@@ -212,10 +202,12 @@ class VectorEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
"""Closes the vector environment."""
|
||||||
if not getattr(self, "closed", True):
|
if not getattr(self, "closed", True):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
"""Returns a string representation of the vector environment using the class name, number of environments and environment spec id."""
|
||||||
if self.spec is None:
|
if self.spec is None:
|
||||||
return f"{self.__class__.__name__}({self.num_envs})"
|
return f"{self.__class__.__name__}({self.num_envs})"
|
||||||
else:
|
else:
|
||||||
@@ -223,19 +215,17 @@ class VectorEnv(gym.Env):
|
|||||||
|
|
||||||
|
|
||||||
class VectorEnvWrapper(VectorEnv):
|
class VectorEnvWrapper(VectorEnv):
|
||||||
r"""Wraps the vectorized environment to allow a modular transformation.
|
"""Wraps the vectorized environment to allow a modular transformation.
|
||||||
|
|
||||||
This class is the base class for all wrappers for vectorized environments. The subclass
|
This class is the base class for all wrappers for vectorized environments. The subclass
|
||||||
could override some methods to change the behavior of the original vectorized environment
|
could override some methods to change the behavior of the original vectorized environment
|
||||||
without touching the original code.
|
without touching the original code.
|
||||||
|
|
||||||
.. note::
|
Notes:
|
||||||
|
|
||||||
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env: VectorEnv):
|
||||||
assert isinstance(env, VectorEnv)
|
assert isinstance(env, VectorEnv)
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@@ -17,7 +17,7 @@ extras = {
|
|||||||
"classic_control": ["pygame==2.1.0"],
|
"classic_control": ["pygame==2.1.0"],
|
||||||
"mujoco": ["mujoco_py>=1.50, <2.0"],
|
"mujoco": ["mujoco_py>=1.50, <2.0"],
|
||||||
"toy_text": ["pygame==2.1.0", "scipy>=1.4.1"],
|
"toy_text": ["pygame==2.1.0", "scipy>=1.4.1"],
|
||||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0"],
|
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Meta dependency groups.
|
# Meta dependency groups.
|
||||||
|
Reference in New Issue
Block a user