diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f408c99ae..ea16d0eae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: rev: 6.1.1 # pick a git hash / tag to point to hooks: - id: pydocstyle - exclude: ^(gym/version.py)|(gym/(envs|utils|vector)/)|(tests/) + exclude: ^(gym/version.py)|(gym/envs/)|(tests/) args: - --source - --explain diff --git a/gym/envs/classic_control/acrobot.py b/gym/envs/classic_control/acrobot.py index a3d598c43..059aea7d2 100644 --- a/gym/envs/classic_control/acrobot.py +++ b/gym/envs/classic_control/acrobot.py @@ -398,27 +398,26 @@ def rk4(derivs, y0, t): yourself stranded on a system w/o scipy. Otherwise use :func:`scipy.integrate`. + Example: + + >>> ### 2D system + >>> def derivs(x): + ... d1 = x[0] + 2*x[1] + ... d2 = -3*x[0] + 4*x[1] + ... return (d1, d2) + >>> dt = 0.0005 + >>> t = arange(0.0, 2.0, dt) + >>> y0 = (1,2) + >>> yout = rk4(derivs, y0, t) + + If you have access to scipy, you should probably be using the + :func:`scipy.integrate` tools rather than this function. + This would then require re-adding the time variable to the signature of derivs. + Args: derivs: the derivative of the system and has the signature ``dy = derivs(yi)`` y0: initial state vector t: sample times - args: additional arguments passed to the derivative function - kwargs: additional keyword arguments passed to the derivative function - - Example 1 :: - ### 2D system - def derivs(x): - d1 = x[0] + 2*x[1] - d2 = -3*x[0] + 4*x[1] - return (d1, d2) - dt = 0.0005 - t = arange(0.0, 2.0, dt) - y0 = (1,2) - yout = rk4(derivs6, y0, t) - - If you have access to scipy, you should probably be using the - scipy.integrate tools rather than this function. - This would then require re-adding the time variable to the signature of derivs. Returns: yout: Runge-Kutta approximation of the ODE diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 118cc6d06..5f5d33d32 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -499,10 +499,16 @@ def make( """ Create an environment according to the given ID. + Warnings: + In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment. + This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid + if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`. + Args: id: Name of the environment. max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). + disable_env_checker: If to disable the environment checker kwargs: Additional arguments to pass to the environment constructor. Returns: An instance of the environment. diff --git a/gym/spaces/dict.py b/gym/spaces/dict.py index fd1e82e58..40cb69f68 100644 --- a/gym/spaces/dict.py +++ b/gym/spaces/dict.py @@ -17,25 +17,27 @@ class Dict(Space[TypingDict[str, Space]], Mapping): Elements of this space are (ordered) dictionaries of elements from the constituent spaces. - Example usage:: + Example usage: - >>> observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)}) + >>> from gym.spaces import Dict, Discrete + >>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)}) >>> observation_space.sample() OrderedDict([('position', 1), ('velocity', 2)]) Example usage [nested]:: - >>> spaces.Dict( + >>> from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete + >>> Dict( ... { - ... "ext_controller": spaces.MultiDiscrete((5, 2, 2)), - ... "inner_state": spaces.Dict( + ... "ext_controller": MultiDiscrete([5, 2, 2]), + ... "inner_state": Dict( ... { - ... "charge": spaces.Discrete(100), - ... "system_checks": spaces.MultiBinary(10), - ... "job_status": spaces.Dict( + ... "charge": Discrete(100), + ... "system_checks": MultiBinary(10), + ... "job_status": Dict( ... { - ... "task": spaces.Discrete(5), - ... "progress": spaces.Box(low=0, high=100, shape=()), + ... "task": Discrete(5), + ... "progress": Box(low=0, high=100, shape=()), ... } ... ), ... } @@ -63,9 +65,10 @@ class Dict(Space[TypingDict[str, Space]], Mapping): Example:: - >>> spaces.Dict({"position": spaces.Box(-1, 1, shape=(2,)), "color": spaces.Discrete(3)}) + >>> from gym.spaces import Box, Discrete + >>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)}) Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32)) - >>> spaces.Dict(position=spaces.Box(-1, 1, shape=(2,)), color=spaces.Discrete(3)) + >>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3)) Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32)) Args: diff --git a/gym/spaces/multi_binary.py b/gym/spaces/multi_binary.py index 7ca18924d..36abef52b 100644 --- a/gym/spaces/multi_binary.py +++ b/gym/spaces/multi_binary.py @@ -16,11 +16,11 @@ class MultiBinary(Space[np.ndarray]): Example Usage:: - >>> self.observation_space = spaces.MultiBinary(5) - >>> self.observation_space.sample() + >>> observation_space = MultiBinary(5) + >>> observation_space.sample() array([0, 1, 0, 1, 0], dtype=int8) - >>> self.observation_space = spaces.MultiBinary([3, 2]) - >>> self.observation_space.sample() + >>> observation_space = MultiBinary([3, 2]) + >>> observation_space.sample() array([[0, 0], [0, 1], [1, 1]], dtype=int8) diff --git a/gym/spaces/tuple.py b/gym/spaces/tuple.py index 51c9a187a..000a619b2 100644 --- a/gym/spaces/tuple.py +++ b/gym/spaces/tuple.py @@ -16,8 +16,9 @@ class Tuple(Space[tuple], Sequence): Example usage:: - >> observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Box(-1, 1, shape=(2,)))) - >> observation_space.sample() + >>> from gym.spaces import Box, Discrete + >>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,)))) + >>> observation_space.sample() (0, array([0.03633198, 0.42370757], dtype=float32)) """ diff --git a/gym/spaces/utils.py b/gym/spaces/utils.py index 25b2b5480..f7a238b3a 100644 --- a/gym/spaces/utils.py +++ b/gym/spaces/utils.py @@ -25,8 +25,9 @@ def flatdim(space: Space) -> int: Example usage:: - >>> s = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)}) - >>> spaces.flatdim(s) + >>> from gym.spaces import Discrete + >>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)}) + >>> flatdim(space) 5 """ raise NotImplementedError(f"Unknown space: `{space}`") @@ -195,8 +196,7 @@ def flatten_space(space: Space) -> Box: Example that recursively flattens a dict:: - >>> space = Dict({"position": Discrete(2), - ... "velocity": Box(0, 1, shape=(2, 2))}) + >>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))}) >>> flatten_space(space) Box(6,) >>> flatten(space, space.sample()) in flatten_space(space) diff --git a/gym/utils/colorize.py b/gym/utils/colorize.py index e25f45f68..6674ded5d 100644 --- a/gym/utils/colorize.py +++ b/gym/utils/colorize.py @@ -1,5 +1,6 @@ -"""A set of common utilities used within the environments. These are -not intended as API functions, and will not remain stable over time. +"""A set of common utilities used within the environments. + +These are not intended as API functions, and will not remain stable over time. """ color2num = dict( @@ -15,12 +16,20 @@ color2num = dict( ) -def colorize(string, color, bold=False, highlight=False): - """Return string surrounded by appropriate terminal color codes to - print colorized text. Valid colors: gray, red, green, yellow, - blue, magenta, cyan, white, crimson - """ +def colorize( + string: str, color: str, bold: bool = False, highlight: bool = False +) -> str: + """Returns string surrounded by appropriate terminal colour codes to print colourised text. + Args: + string: The message to colourise + color: Literal values are gray, red, green, yellow, blue, magenta, cyan, white, crimson + bold: If to bold the string + highlight: If to highlight the string + + Returns: + Colourised string + """ attr = [] num = color2num[color] if highlight: diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 399c5b10e..ff7fe7d70 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -1,4 +1,5 @@ -""" +"""A set of functions for checking an environment details. + This file is originally from the Stable Baselines3 repository hosted on GitHub (https://github.com/DLR-RM/stable-baselines3/) Original Author: Antonin Raffin @@ -16,21 +17,33 @@ from typing import Optional, Union import numpy as np import gym -from gym import logger, spaces +from gym import logger +from gym.spaces import Box, Dict, Discrete, Space, Tuple -def _is_numpy_array_space(space: spaces.Space) -> bool: +def _is_numpy_array_space(space: Space) -> bool: + """Checks if a space can be represented as a single numpy array (e.g. Dict and Tuple spaces return False). + + Args: + space: The space to check + + Returns: + Returns False if the provided space is not representable as a single numpy array """ - Returns False if provided space is not representable as a single numpy array - (e.g. Dict and Tuple spaces return False) - """ - return not isinstance(space, (spaces.Dict, spaces.Tuple)) + return not isinstance(space, (Dict, Tuple)) -def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: - """ - Check that the input adheres to general standards - when the observation is apparently an image. +def _check_image_input(observation_space: Box, key: str = ""): + """Check whether an observation space of type :class:`Box` adheres to general standards for spaces that represent images. + + It will check that: + - The datatype is ``np.uint8`` + - The lower bound is 0 across all dimensions + - The upper bound is 255 across all dimensions + + Args: + observation_space: The observation space to check + key: The observation shape key for warning """ if observation_space.dtype != np.uint8: logger.warn( @@ -49,8 +62,13 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: ) -def _check_nan(env: gym.Env, check_inf: bool = True) -> None: - """Check for NaN and Inf.""" +def _check_nan(env: gym.Env, check_inf: bool = True): + """Check if the environment observation, reward are NaN and Inf. + + Args: + env: The environment to check + check_inf: Checks if the observation is infinity + """ for _ in range(10): action = env.action_space.sample() observation, reward, done, _ = env.step(action) @@ -70,19 +88,22 @@ def _check_nan(env: gym.Env, check_inf: bool = True) -> None: def _check_obs( obs: Union[tuple, dict, np.ndarray, int], - observation_space: spaces.Space, + observation_space: Space, method_name: str, -) -> None: +): + """Check that the observation returned by the environment correspond to the declared one. + + Args: + obs: The observation to check + observation_space: The observation space of the observation + method_name: The method name that generated the observation """ - Check that the observation returned by the environment - correspond to the declared one. - """ - if not isinstance(observation_space, spaces.Tuple): + if not isinstance(observation_space, Tuple): assert not isinstance( obs, tuple ), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple" - if isinstance(observation_space, spaces.Discrete): + if isinstance(observation_space, Discrete): assert isinstance( obs, int ), f"The observation returned by `{method_name}()` method must be an int" @@ -96,12 +117,16 @@ def _check_obs( ), f"The observation returned by the `{method_name}()` method does not match the given observation space" -def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: - """ - Check that the observation space is correctly formatted - when dealing with a ``Box()`` space. In particular, it checks: +def _check_box_obs(observation_space: Box, key: str = ""): + """Check that the observation space is correctly formatted when dealing with a :class:`Box` space. + + In particular, it checks: - that the dimensions are big enough when it is an image, and that the type matches - that the observation has an expected shape (warn the user if not) + + Args: + observation_space: Checks if the Box observation space + key: The observation key """ # If image, check the low and high values, the type and the number of channels # and the shape (minimal value) @@ -137,14 +162,19 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: ), "Agent's observation_space.high and observation_space have different shapes" -def _check_box_action(action_space: spaces.Box): +def _check_box_action(action_space: Box): + """Checks that a :class:`Box` action space is defined in a sensible way. + + Args: + action_space: A box action space + """ if np.any(np.equal(action_space.low, -np.inf)): logger.warn( "Agent's minimum action space value is -infinity. This is probably too low." ) if np.any(np.equal(action_space.high, np.inf)): logger.warn( - "Agent's maxmimum action space value is infinity. This is probably too high" + "Agent's maximum action space value is infinity. This is probably too high" ) if np.any(np.equal(action_space.low, action_space.high)): logger.warn("Agent's maximum and minimum action space values are equal") @@ -156,7 +186,12 @@ def _check_box_action(action_space: spaces.Box): assert False, "Agent's action_space.high and action_space have different shapes" -def _check_normalized_action(action_space: spaces.Box): +def _check_normalized_action(action_space: Box): + """Checks that a box action space is normalized. + + Args: + action_space: A box action space + """ if ( np.any(np.abs(action_space.low) != np.abs(action_space.high)) or np.any(np.abs(action_space.low) > 1) @@ -168,16 +203,18 @@ def _check_normalized_action(action_space: spaces.Box): ) -def _check_returned_values( - env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space -) -> None: - """ - Check the returned values by the env when calling `.reset()` or `.step()` methods. +def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space): + """Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods. + + Args: + env: The environment + observation_space: The environment's observation space + action_space: The environment's action space """ # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists obs = env.reset() - if isinstance(observation_space, spaces.Dict): + if isinstance(observation_space, Dict): assert isinstance( obs, dict ), "The observation returned by `reset()` must be a dictionary" @@ -200,7 +237,7 @@ def _check_returned_values( # Unpack obs, reward, done, info = data - if isinstance(observation_space, spaces.Dict): + if isinstance(observation_space, Dict): assert isinstance( obs, dict ), "The observation returned by `step()` must be a dictionary" @@ -223,10 +260,11 @@ def _check_returned_values( ), "The `info` returned by `step()` must be a python dictionary" -def _check_spaces(env: gym.Env) -> None: - """ - Check that the observation and action spaces are defined - and inherit from gym.spaces.Space. +def _check_spaces(env: gym.Env): + """Check that the observation and action spaces are defined and inherit from :class:`gym.spaces.Space`. + + Args: + env: The environment's observation and action space to check """ # Helper to link to the code, because gym has no proper documentation gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/" @@ -238,25 +276,22 @@ def _check_spaces(env: gym.Env) -> None: "You must specify an action space (cf gym.spaces)" + gym_spaces ) - assert isinstance(env.observation_space, spaces.Space), ( + assert isinstance(env.observation_space, Space), ( "The observation space must inherit from gym.spaces" + gym_spaces ) - assert isinstance(env.action_space, spaces.Space), ( + assert isinstance(env.action_space, Space), ( "The action space must inherit from gym.spaces" + gym_spaces ) # Check render cannot be covered by CI -def _check_render( - env: gym.Env, warn: bool = True, headless: bool = False -) -> None: # pragma: no cover - """ - Check the declared render modes/fps and the `render()`/`close()` - method of the environment. - :param env: The environment to check - :param warn: Whether to output additional warnings - :param headless: Whether to disable render modes - that require a graphical interface. False by default. +def _check_render(env: gym.Env, warn: bool = True, headless: bool = False): + """Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment. + + Args: + env: The environment to check + warn: Whether to output additional warnings + headless: Whether to disable render modes that require a graphical interface. False by default. """ render_modes = env.metadata.get("render_modes") if render_modes is None: @@ -288,9 +323,12 @@ def _check_render( env.close() -def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None: - """ - Check that the environment can be reset with a random seed. +def _check_reset_seed(env: gym.Env, seed: Optional[int] = None): + """Check that the environment can be reset with a seed. + + Args: + env: The environment to check + seed: The optional seed to use """ signature = inspect.signature(env.reset) assert ( @@ -303,7 +341,7 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None: raise AssertionError( "The environment cannot be reset with a random seed, even though `seed` or `kwargs` " "appear in the signature. This should never happen, please report this issue. " - "The error was: " + str(e) + f"The error was: {e}" ) if env.unwrapped.np_random is None: @@ -322,7 +360,12 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None) -> None: ) -def _check_reset_info(env: gym.Env) -> None: +def _check_reset_info(env: gym.Env): + """Checks that :meth:`reset` supports the ``return_info`` keyword. + + Args: + env: The environment to check + """ signature = inspect.signature(env.reset) assert ( "return_info" in signature.parameters or "kwargs" in signature.parameters @@ -334,7 +377,7 @@ def _check_reset_info(env: gym.Env) -> None: raise AssertionError( "The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` " "appear in the signature. This should never happen, please report this issue. " - "The error was: " + str(e) + f"The error was: {e}" ) assert ( len(result) == 2 @@ -346,9 +389,11 @@ def _check_reset_info(env: gym.Env) -> None: ), "The second element returned by `env.reset(return_info=True)` was not a dictionary" -def _check_reset_options(env: gym.Env) -> None: - """ - Check that the environment can be reset with options. +def _check_reset_options(env: gym.Env): + """Check that the environment can be reset with options. + + Args: + env: The environment to check """ signature = inspect.signature(env.reset) assert ( @@ -361,22 +406,22 @@ def _check_reset_options(env: gym.Env) -> None: raise AssertionError( "The environment cannot be reset with options, even though `options` or `kwargs` " "appear in the signature. This should never happen, please report this issue. " - "The error was: " + str(e) + f"The error was: {e}" ) -def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: - """ - Check that an environment follows Gym API. +def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True): + """Check that an environment follows Gym API. + This is particularly useful when using a custom environment. Please take a look at https://github.com/openai/gym/blob/master/gym/core.py for more information about the API. It also optionally check that the environment is compatible with Stable-Baselines. - :param env: The Gym environment that will be checked - :param warn: Whether to output additional warnings - mainly related to the interaction with Stable Baselines - :param skip_render_check: Whether to skip the checks for the render method. - True by default (useful for the CI) + + Args: + env: The Gym environment that will be checked + warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines + skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI) """ assert isinstance( env, gym.Env @@ -393,15 +438,15 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - if warn: obs_spaces = ( observation_space.spaces - if isinstance(observation_space, spaces.Dict) + if isinstance(observation_space, Dict) else {"": observation_space} ) for key, space in obs_spaces.items(): - if isinstance(space, spaces.Box): + if isinstance(space, Box): _check_box_obs(space, key) # Check for the action space, it may lead to hard-to-debug issues - if isinstance(action_space, spaces.Box): + if isinstance(action_space, Box): _check_box_action(action_space) _check_normalized_action(action_space) diff --git a/gym/utils/ezpickle.py b/gym/utils/ezpickle.py index f72127d99..9a601dba6 100644 --- a/gym/utils/ezpickle.py +++ b/gym/utils/ezpickle.py @@ -1,33 +1,35 @@ +"""Class for pickling and unpickling objects via their constructor arguments.""" + + class EzPickle: - """Objects that are pickled and unpickled via their constructor - arguments. + """Objects that are pickled and unpickled via their constructor arguments. - Example usage: + Example:: - class Dog(Animal, EzPickle): - def __init__(self, furcolor, tailkind="bushy"): - Animal.__init__() - EzPickle.__init__(furcolor, tailkind) - ... + >>> class Dog(Animal, EzPickle): + ... def __init__(self, furcolor, tailkind="bushy"): + ... Animal.__init__() + ... EzPickle.__init__(furcolor, tailkind) - When this object is unpickled, a new Dog will be constructed by passing the provided - furcolor and tailkind into the constructor. However, philosophers are still not sure - whether it is still the same dog. + When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor. + However, philosophers are still not sure whether it is still the same dog. - This is generally needed only for environments which wrap C/C++ code, such as MuJoCo - and Atari. + This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari. """ def __init__(self, *args, **kwargs): + """Uses the ``args`` and ``kwargs`` from the object's constructor for pickling.""" self._ezpickle_args = args self._ezpickle_kwargs = kwargs def __getstate__(self): + """Returns the object pickle state with args and kwargs.""" return { "_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs, } def __setstate__(self, d): + """Sets the object pickle state using d.""" out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"]) self.__dict__.update(out.__dict__) diff --git a/gym/utils/play.py b/gym/utils/play.py index 0823e1a5b..d58fc15f2 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -1,11 +1,18 @@ +"""Utilities of visualising an environment.""" +from __future__ import annotations + +from collections import deque from typing import Callable, Dict, Optional, Tuple, Union +import numpy as np import pygame -from numpy.typing import NDArray from pygame import Surface from pygame.event import Event +from pygame.locals import VIDEORESIZE from gym import Env, logger +from gym.core import ActType, ObsType +from gym.error import DependencyNotInstalled from gym.logger import deprecation try: @@ -13,30 +20,31 @@ try: matplotlib.use("TkAgg") import matplotlib.pyplot as plt -except ImportError as e: - logger.warn(f"failed to set matplotlib backend, plotting will not work: {str(e)}") - plt = None - -from collections import deque - -from pygame.locals import VIDEORESIZE - -from gym.core import ActType +except ImportError: + logger.warn("Matplotlib is not installed, run `pip install gym[other]`") + matplotlib, plt = None, None class MissingKeysToAction(Exception): - """Raised when the environment does not have - a default keys_to_action mapping - """ + """Raised when the environment does not have a default ``keys_to_action`` mapping.""" class PlayableGame: + """Wraps an environment allowing keyboard inputs to interact with the environment.""" + def __init__( self, env: Env, - keys_to_action: Optional[Dict[Tuple[int], int]] = None, + keys_to_action: Optional[dict[tuple[int], int]] = None, zoom: Optional[float] = None, ): + """Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment. + + Args: + env: The environment to play + keys_to_action: The dictionary of keyboard tuples and action value + zoom: If to zoom in on the environment render + """ self.env = env self.relevant_keys = self._get_relevant_keys(keys_to_action) self.video_size = self._get_video_size(zoom) @@ -45,7 +53,7 @@ class PlayableGame: self.running = True def _get_relevant_keys( - self, keys_to_action: Optional[Dict[Tuple[int], int]] = None + self, keys_to_action: Optional[dict[tuple[int], int]] = None ) -> set: if keys_to_action is None: if hasattr(self.env, "get_keys_to_action"): @@ -60,7 +68,7 @@ class PlayableGame: relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) return relevant_keys - def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]: + def _get_video_size(self, zoom: Optional[float] = None) -> tuple[int, int]: # TODO: this needs to be updated when the render API change goes through rendered = self.env.render(mode="rgb_array") video_size = [rendered.shape[1], rendered.shape[0]] @@ -70,7 +78,14 @@ class PlayableGame: return video_size - def process_event(self, event: Event) -> None: + def process_event(self, event: Event): + """Processes a PyGame event. + + In particular, this function is used to keep track of which buttons are currently pressed and to exit the :func:`play` function when the PyGame window is closed. + + Args: + event: The event to process + """ if event.type == pygame.KEYDOWN: if event.key in self.relevant_keys: self.pressed_keys.append(event.key) @@ -87,9 +102,17 @@ class PlayableGame: def display_arr( - screen: Surface, arr: NDArray, video_size: Tuple[int, int], transpose: bool + screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool ): - arr_min, arr_max = arr.min(), arr.max() + """Displays a numpy array on screen. + + Args: + screen: The screen to show the array on + arr: The array to show + video_size: The video size of the screen + transpose: If to transpose the array on the screen + """ + arr_min, arr_max = np.min(arr), np.max(arr) arr = 255.0 * (arr - arr_min) / (arr_max - arr_min) pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr) pyg_img = pygame.transform.scale(pyg_img, video_size) @@ -108,60 +131,74 @@ def play( ): """Allows one to play the game using keyboard. - To simply play the game use: + Example:: - play(gym.make("Pong-v4")) + >>> import gym + >>> from gym.utils.play import play + >>> play(gym.make("CarRacing-v1"), keys_to_action={"w": np.array([0, 0.7, 0]), + ... "a": np.array([-1, 0, 0]), + ... "s": np.array([0, 0, 1]), + ... "d": np.array([1, 0, 0]), + ... "wa": np.array([-1, 0.7, 0]), + ... "dw": np.array([1, 0.7, 0]), + ... "ds": np.array([1, 0, 1]), + ... "as": np.array([-1, 0, 1]), + ... }, noop=np.array([0,0,0])) - Above code works also if env is wrapped, so it's particularly useful in + + Above code works also if the environment is wrapped, so it's particularly useful in verifying that the frame-level preprocessing does not render the game unplayable. If you wish to plot real time statistics as you play, you can use - gym.utils.play.PlayPlot. Here's a sample code for plotting the reward - for last 5 second of gameplay. + :class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward + for last 150 steps. - def callback(obs_t, obs_tp1, action, rew, done, info): - return [rew,] - plotter = PlayPlot(callback, 30 * 5, ["reward"]) - - env = gym.make("Pong-v4") - play(env, callback=plotter.callback) + >>> def callback(obs_t, obs_tp1, action, rew, done, info): + ... return [rew,] + >>> plotter = PlayPlot(callback, 150, ["reward"]) + >>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback) - Arguments - --------- - env: gym.Env - Environment to use for playing. - transpose: bool - If True the output of observation is transposed. - Defaults to true. - fps: int - Maximum number of steps of the environment to execute every second. - Defaults to 30. - zoom: float - Make screen edge this many times bigger - callback: lambda or None - Callback if a callback is provided it will be executed after - every step. It takes the following input: - obs_t: observation before performing action - obs_tp1: observation after performing action - action: action that was executed - rew: reward that was received - done: whether the environment is done or not - info: debug info - keys_to_action: dict: tuple(int) -> int or None - Mapping from keys pressed to action performed. - For example if pressed 'w' and space at the same time is supposed - to trigger action number 2 then key_to_action dict would look like this: - - { - # ... - sorted(ord('w'), ord(' ')) -> 2 - # ... - } - If None, default key_to_action mapping for that env is used, if provided. - seed: bool or None - Random seed used when resetting the environment. If None, no seed is used. + Args: + env: Environment to use for playing. + transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``. + fps: Maximum number of steps of the environment executed every second. If ``None`` (the default), + ``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used. + zoom: Zoom the observation in, ``zoom`` amount, should be positive float + callback: If a callback is provided, it will be executed after every step. It takes the following input: + obs_t: observation before performing action + obs_tp1: observation after performing action + action: action that was executed + rew: reward that was received + done: whether the environment is done or not + info: debug info + keys_to_action: Mapping from keys pressed to action performed. + Different formats are supported: Key combinations can either be expressed as a tuple of unicode code + points of the keys, as a tuple of characters, or as a string where each character of the string represents + one key. + For example if pressing 'w' and space at the same time is supposed + to trigger action number 2 then ``key_to_action`` dict could look like this: + >>> { + ... # ... + ... (ord('w'), ord(' ')): 2 + ... # ... + ... } + or like this: + >>> { + ... # ... + ... ("w", " "): 2 + ... # ... + ... } + or like this: + >>> { + ... # ... + ... "w ": 2 + ... # ... + ... } + If ``None``, default ``key_to_action`` mapping for that environment is used, if provided. + seed: Random seed used when resetting the environment. If None, no seed is used. + noop: The action used when no key input has been entered, or the entered key combination is unknown. """ env.reset(seed=seed) @@ -208,7 +245,44 @@ def play( class PlayPlot: - def __init__(self, callback, horizon_timesteps, plot_names): + """Provides a callback to create live plots of arbitrary metrics when using :func:`play`. + + This class is instantiated with a function that accepts information about a single environment transition: + - obs_t: observation before performing action + - obs_tp1: observation after performing action + - action: action that was executed + - rew: reward that was received + - done: whether the environment is done or not + - info: debug info + + It should return a list of metrics that are computed from this data. + For instance, the function may look like this:: + + def compute_metrics(obs_t, obs_tp, action, reward, done, info): + return [reward, info["cumulative_reward"], np.linalg.norm(action)] + + :class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function + and uses the returned values to update live plots of the metrics. + + Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play:: + + >>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"]) + >>> play(your_env, callback=plotter.callback) + """ + + def __init__( + self, callback: callable, horizon_timesteps: int, plot_names: list[str] + ): + """Constructor of :class:`PlayPlot`. + + The function ``callback`` that is passed to this constructor should return + a list of metrics that is of length ``len(plot_names)``. + + Args: + callback: Function that computes metrics from environment transitions + horizon_timesteps: The time horizon used for the live plots + plot_names: List of plot titles + """ deprecation( "`PlayPlot` is marked as deprecated and will be removed in the near future." ) @@ -216,7 +290,10 @@ class PlayPlot: self.horizon_timesteps = horizon_timesteps self.plot_names = plot_names - assert plt is not None, "matplotlib backend failed, plotting will not work" + if plt is None: + raise DependencyNotInstalled( + "matplotlib is not installed, run `pip install gym[other]`" + ) num_plots = len(self.plot_names) self.fig, self.ax = plt.subplots(num_plots) @@ -228,7 +305,25 @@ class PlayPlot: self.cur_plot = [None for _ in range(num_plots)] self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] - def callback(self, obs_t, obs_tp1, action, rew, done, info): + def callback( + self, + obs_t: ObsType, + obs_tp1: ObsType, + action: ActType, + rew: float, + done: bool, + info: dict, + ): + """The callback that calls the provided data callback and adds the data to the plots. + + Args: + obs_t: The observation at time step t + obs_tp1: The observation at time step t+1 + action: The action + rew: The reward + done: If the environment is done + info: The information from the environment + """ points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) for point, data_series in zip(points, self.data): data_series.append(point) diff --git a/gym/utils/seeding.py b/gym/utils/seeding.py index c1986ba60..a6ee786d4 100644 --- a/gym/utils/seeding.py +++ b/gym/utils/seeding.py @@ -1,7 +1,10 @@ +"""Set of random number generator functions: seeding, generator, hashing seeds.""" +from __future__ import annotations + import hashlib import os import struct -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np @@ -9,7 +12,15 @@ from gym import error from gym.logger import deprecation -def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any]: +def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]: + """Generates a random number generator from the seed and returns the Generator and seed. + + Args: + seed: The seed used to create the generator + + Returns: + The generator and resulting seed + """ if seed is not None and not (isinstance(seed, int) and 0 <= seed): raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}") @@ -22,7 +33,10 @@ def np_random(seed: Optional[int] = None) -> Tuple["RandomNumberGenerator", Any] # TODO: Remove this class and make it alias to `Generator` in a future Gym release # RandomNumberGenerator = np.random.Generator class RandomNumberGenerator(np.random.Generator): + """Random number generator class that inherits from numpy's random Generator class.""" + def rand(self, *size): + """Deprecated rand function using random.""" deprecation( "Function `rng.rand(*size)` is marked as deprecated " "and will be removed in the future. " @@ -34,6 +48,7 @@ class RandomNumberGenerator(np.random.Generator): random_sample = rand def randn(self, *size): + """Deprecated random standard normal function use standard_normal.""" deprecation( "Function `rng.randn(*size)` is marked as deprecated " "and will be removed in the future. " @@ -43,6 +58,7 @@ class RandomNumberGenerator(np.random.Generator): return self.standard_normal(size) def randint(self, low, high=None, size=None, dtype=int): + """Deprecated random integer function use integers.""" deprecation( "Function `rng.randint(low, [high, size, dtype])` is marked as deprecated " "and will be removed in the future. " @@ -54,6 +70,7 @@ class RandomNumberGenerator(np.random.Generator): random_integers = randint def get_state(self): + """Deprecated get rng state use bit_generator.state.""" deprecation( "Function `rng.get_state()` is marked as deprecated " "and will be removed in the future. " @@ -63,6 +80,7 @@ class RandomNumberGenerator(np.random.Generator): return self.bit_generator.state def set_state(self, state): + """Deprecated set rng state function use bit_generator.state = state.""" deprecation( "Function `rng.set_state(state)` is marked as deprecated " "and will be removed in the future. " @@ -72,6 +90,7 @@ class RandomNumberGenerator(np.random.Generator): self.bit_generator.state = state def seed(self, seed=None): + """Deprecated seed function use gym.utils.seeding.np_random(seed).""" deprecation( "Function `rng.seed(seed)` is marked as deprecated " "and will be removed in the future. " @@ -88,6 +107,7 @@ class RandomNumberGenerator(np.random.Generator): seed.__doc__ = np.random.seed.__doc__ def __reduce__(self): + """Reduces the Random Number Generator to a RandomNumberGenerator, init_args and additional args.""" # np.random.Generator defines __reduce__, but it's hard-coded to # return a Generator instead of its subclass RandomNumberGenerator. # We need to override it here, otherwise sampling from a Space will @@ -119,20 +139,21 @@ RNG = RandomNumberGenerator def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int: - """Any given evaluation is likely to have many PRNG's active at - once. (Most commonly, because the environment is running in - multiple processes.) There's literature indicating that having - linear correlations between seeds of multiple PRNG's can correlate - the outputs: - http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/ - http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be - http://dl.acm.org/citation.cfm?id=1276928 - Thus, for sanity we hash the seeds before using them. (This scheme - is likely not crypto-strength, but it should be good enough to get - rid of simple correlations.) + """Any given evaluation is likely to have many PRNG's active at once. + + (Most commonly, because the environment is running in multiple processes.) + There's literature indicating that having linear correlations between seeds of multiple PRNG's can correlate the outputs: + http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/ + http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be + http://dl.acm.org/citation.cfm?id=1276928 + Thus, for sanity we hash the seeds before using them. (This scheme is likely not crypto-strength, but it should be good enough to get rid of simple correlations.) + Args: seed: None seeds from an operating system specific randomness source. max_bytes: Maximum number of bytes to use in the hashed seed. + + Returns: + The hashed seed """ deprecation( "Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. " @@ -144,12 +165,16 @@ def hash_seed(seed: Optional[int] = None, max_bytes: int = 8) -> int: def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int: - """Create a strong random seed. Otherwise, Python 2 would seed using - the system time, which might be non-robust especially in the - presence of concurrency. + """Create a strong random seed. + + Otherwise, Python 2 would seed using the system time, which might be non-robust especially in the presence of concurrency. + Args: a: None seeds from an operating system specific randomness source. max_bytes: Maximum number of bytes to use in the seed. + + Returns: + A seed """ deprecation( "Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. " @@ -185,7 +210,7 @@ def _bigint_from_bytes(bt: bytes) -> int: return accum -def _int_list_from_bigint(bigint: int) -> List[int]: +def _int_list_from_bigint(bigint: int) -> list[int]: deprecation( "Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. " ) @@ -195,7 +220,7 @@ def _int_list_from_bigint(bigint: int) -> List[int]: elif bigint == 0: return [0] - ints: List[int] = [] + ints: list[int] = [] while bigint > 0: bigint, mod = divmod(bigint, 2**32) ints.append(mod) diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index 4f3b94814..45b0be061 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -1,7 +1,7 @@ -try: - from collections.abc import Iterable -except ImportError: - Iterable = (tuple, list) +"""Module for vector environments.""" +from __future__ import annotations + +from typing import Iterable, Optional, Union from gym.vector.async_vector_env import AsyncVectorEnv from gym.vector.sync_vector_env import SyncVectorEnv @@ -10,40 +10,34 @@ from gym.vector.vector_env import VectorEnv, VectorEnvWrapper __all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] -def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs): - """Create a vectorized environment from multiple copies of an environment, - from its id. +def make( + id: str, + num_envs: int = 1, + asynchronous: bool = True, + wrappers: Optional[Union[callable, list[callable]]] = None, + **kwargs, +) -> VectorEnv: + """Create a vectorized environment from multiple copies of an environment, from its id. - Parameters - ---------- - id : str - The environment ID. This must be a valid ID from the registry. + Example:: - num_envs : int - Number of copies of the environment. + >>> import gym + >>> env = gym.vector.make('CartPole-v1', num_envs=3) + >>> env.reset() + array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827], + [ 0.03073904, 0.00145001, -0.03088818, -0.03131252], + [ 0.03468829, 0.01500225, 0.01230312, 0.01825218]], + dtype=float32) - asynchronous : bool - If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses - `multiprocessing`_ to run the environments in parallel). If ``False``, - wraps the environments in a :class:`SyncVectorEnv`. + Args: + id: The environment ID. This must be a valid ID from the registry. + num_envs: Number of copies of the environment. + asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`. + wrappers: If not ``None``, then apply the wrappers to each internal environment during creation. + **kwargs: Keywords arguments applied during gym.make - wrappers : callable, or iterable of callables, optional - If not ``None``, then apply the wrappers to each internal - environment during creation. - - Returns - ------- - :class:`gym.vector.VectorEnv` + Returns: The vectorized environment. - - Example - ------- - >>> env = gym.vector.make('CartPole-v1', num_envs=3) - >>> env.reset() - array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827], - [ 0.03073904, 0.00145001, -0.03088818, -0.03131252], - [ 0.03468829, 0.01500225, 0.01230312, 0.01825218]], - dtype=float32) """ from gym.envs import make as make_ diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 565c75137..6d593bdfd 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -1,13 +1,18 @@ +"""An async vector environment.""" +from __future__ import annotations + import multiprocessing as mp import sys import time from copy import deepcopy from enum import Enum -from typing import List, Optional, Union +from typing import Optional, Sequence, Union import numpy as np +import gym from gym import logger +from gym.core import ObsType from gym.error import ( AlreadyPendingCallError, ClosedEnvironmentError, @@ -37,69 +42,13 @@ class AsyncState(Enum): class AsyncVectorEnv(VectorEnv): - """Vectorized environment that runs multiple environments in parallel. It - uses `multiprocessing`_ processes, and pipes for communication. + """Vectorized environment that runs multiple environments in parallel. - Parameters - ---------- - env_fns : iterable of callable - Functions that create the environments. + It uses ``multiprocessing`` processes, and pipes for communication. - observation_space : :class:`gym.spaces.Space`, optional - Observation space of a single environment. If ``None``, then the - observation space of the first environment is taken. - - action_space : :class:`gym.spaces.Space`, optional - Action space of a single environment. If ``None``, then the action space - of the first environment is taken. - - shared_memory : bool - If ``True``, then the observations from the worker processes are - communicated back through shared variables. This can improve the - efficiency if the observations are large (e.g. images). - - copy : bool - If ``True``, then the :meth:`~AsyncVectorEnv.reset` and - :meth:`~AsyncVectorEnv.step` methods return a copy of the observations. - - context : str, optional - Context for `multiprocessing`_. If ``None``, then the default context is used. - - daemon : bool - If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they - will quit if the head process quits. However, ``daemon=True`` prevents - subprocesses to spawn children, so for some environments you may want - to have it set to ``False``. - - worker : callable, optional - If set, then use that worker in a subprocess instead of a default one. - Can be useful to override some inner vector env logic, for instance, - how resets on done are handled. - - Warning - ------- - :attr:`worker` is an advanced mode option. It provides a high degree of - flexibility and a high chance to shoot yourself in the foot; thus, - if you are writing your own worker, it is recommended to start from the code - for ``_worker`` (or ``_worker_shared_memory``) method, and add changes. - - Raises - ------ - RuntimeError - If the observation space of some sub-environment does not match - :obj:`observation_space` (or, by default, the observation space of - the first sub-environment). - - ValueError - If :obj:`observation_space` is a custom space (i.e. not a default - space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`, - or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``. - - Example - ------- - - .. code-block:: + Example:: + >>> import gym >>> env = gym.vector.AsyncVectorEnv([ ... lambda: gym.make("Pendulum-v0", g=9.81), ... lambda: gym.make("Pendulum-v0", g=1.62) @@ -111,15 +60,33 @@ class AsyncVectorEnv(VectorEnv): def __init__( self, - env_fns, - observation_space=None, - action_space=None, - shared_memory=True, - copy=True, - context=None, - daemon=True, - worker=None, + env_fns: Sequence[callable], + observation_space: Optional[gym.Space] = None, + action_space: Optional[gym.Space] = None, + shared_memory: bool = True, + copy: bool = True, + context: Optional[str] = None, + daemon: bool = True, + worker: Optional[callable] = None, ): + """Vectorized environment that runs multiple environments in parallel. + + Args: + env_fns: Functions that create the environments. + observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken. + action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken. + shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images). + copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods return a copy of the observations. + context: Context for `multiprocessing`_. If ``None``, then the default context is used. + daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``. + worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled. + + Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes. + + Raises: + RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment). + ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True. + """ ctx = mp.get_context(context) self.env_fns = env_fns self.shared_memory = shared_memory @@ -192,6 +159,11 @@ class AsyncVectorEnv(VectorEnv): self._check_spaces() def seed(self, seed=None): + """Seeds the vector environments. + + Args: + seed: The seeds use with the environments + """ super().seed(seed=seed) self._assert_is_running() if seed is None: @@ -213,22 +185,24 @@ class AsyncVectorEnv(VectorEnv): def reset_async( self, - seed: Optional[Union[int, List[int]]] = None, + seed: Optional[Union[int, list[int]]] = None, return_info: bool = False, options: Optional[dict] = None, ): - """Send the calls to :obj:`reset` to each sub-environment. + """Send calls to the :obj:`reset` methods of the sub-environments. - Raises - ------ - ClosedEnvironmentError - If the environment was closed (if :meth:`close` was previously called). + To get the results of these calls, you may invoke :meth:`reset_wait`. - AlreadyPendingCallError - If the environment is already waiting for a pending call to another - method (e.g. :meth:`step_async`). This can be caused by two consecutive - calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in - between. + Args: + seed: List of seeds for each environment + return_info: If to return information + options: The reset option + + Raises: + ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). + AlreadyPendingCallError: If the environment is already waiting for a pending call to another + method (e.g. :meth:`step_async`). This can be caused by two consecutive + calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between. """ self._assert_is_running() @@ -258,37 +232,26 @@ class AsyncVectorEnv(VectorEnv): def reset_wait( self, - timeout=None, + timeout: Optional[Union[int, float]] = None, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None, - ): - """ - Parameters - ---------- - timeout : int or float, optional - Number of seconds before the call to `reset_wait` times out. If - `None`, the call to `reset_wait` never times out. - seed: ignored - options: ignored + ) -> Union[ObsType, tuple[ObsType, list[dict]]]: + """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. - Returns - ------- - element of :attr:`~VectorEnv.observation_space` - A batch of observations from the vectorized environment. - infos : list of dicts containing metadata + Args: + timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out. + seed: ignored + return_info: If to return information + options: ignored - Raises - ------ - ClosedEnvironmentError - If the environment was closed (if :meth:`close` was previously called). + Returns: + A tuple of batched observations and list of dictionaries - NoAsyncCallError - If :meth:`reset_wait` was called without any prior call to - :meth:`reset_async`. - - TimeoutError - If :meth:`reset_wait` timed out. + Raises: + ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). + NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`. + TimeoutError: If :meth:`reset_wait` timed out. """ self._assert_is_running() if self._state != AsyncState.WAITING_RESET: @@ -327,24 +290,18 @@ class AsyncVectorEnv(VectorEnv): return deepcopy(self.observations) if self.copy else self.observations - def step_async(self, actions): + def step_async(self, actions: np.ndarray): """Send the calls to :obj:`step` to each sub-environment. - Parameters - ---------- - actions : element of :attr:`~VectorEnv.action_space` - Batch of actions. + Args: + actions: Batch of actions. element of :attr:`~VectorEnv.action_space` - Raises - ------ - ClosedEnvironmentError - If the environment was closed (if :meth:`close` was previously called). - - AlreadyPendingCallError - If the environment is already waiting for a pending call to another - method (e.g. :meth:`reset_async`). This can be caused by two consecutive - calls to :meth:`step_async`, with no call to :meth:`step_wait` in - between. + Raises: + ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). + AlreadyPendingCallError: If the environment is already waiting for a pending call to another + method (e.g. :meth:`reset_async`). This can be caused by two consecutive + calls to :meth:`step_async`, with no call to :meth:`step_wait` in + between. """ self._assert_is_running() if self._state != AsyncState.DEFAULT: @@ -358,40 +315,21 @@ class AsyncVectorEnv(VectorEnv): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP - def step_wait(self, timeout=None): + def step_wait( + self, timeout: Optional[Union[int, float]] = None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, list[dict]]: """Wait for the calls to :obj:`step` in each sub-environment to finish. - Parameters - ---------- - timeout : int or float, optional - Number of seconds before the call to :meth:`step_wait` times out. If - ``None``, the call to :meth:`step_wait` never times out. + Args: + timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out. - Returns - ------- - observations : element of :attr:`~VectorEnv.observation_space` - A batch of observations from the vectorized environment. + Returns: + The batched environment step information, obs, reward, done and info - rewards : :obj:`np.ndarray`, dtype :obj:`np.float_` - A vector of rewards from the vectorized environment. - - dones : :obj:`np.ndarray`, dtype :obj:`np.bool_` - A vector whose entries indicate whether the episode has ended. - - infos : list of dict - A list of auxiliary diagnostic information dicts from sub-environments. - - Raises - ------ - ClosedEnvironmentError - If the environment was closed (if :meth:`close` was previously called). - - NoAsyncCallError - If :meth:`step_wait` was called without any prior call to - :meth:`step_async`. - - TimeoutError - If :meth:`step_wait` timed out. + Raises: + ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). + NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`. + TimeoutError: If :meth:`step_wait` timed out. """ self._assert_is_running() if self._state != AsyncState.WAITING_STEP: @@ -425,18 +363,13 @@ class AsyncVectorEnv(VectorEnv): infos, ) - def call_async(self, name, *args, **kwargs): - """ - Parameters - ---------- - name : string - Name of the method or property to call. + def call_async(self, name: str, *args, **kwargs): + """Calls the method with name asynchronously and apply args and kwargs to the method. - *args - Arguments to apply to the method call. - - **kwargs - Keywoard arguments to apply to the method call. + Args: + name: Name of the method or property to call. + *args: Arguments to apply to the method call. + **kwargs: Keyword arguments to apply to the method call. """ self._assert_is_running() if self._state != AsyncState.DEFAULT: @@ -450,19 +383,14 @@ class AsyncVectorEnv(VectorEnv): pipe.send(("_call", (name, args, kwargs))) self._state = AsyncState.WAITING_CALL - def call_wait(self, timeout=None): - """ - Parameters - ---------- - timeout : int or float, optional - Number of seconds before the call to `step_wait` times out. If - `None` (default), the call to `step_wait` never times out. + def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list: + """Calls all parent pipes and waits for the results. - Returns - ------- - results : list - List of the results of the individual calls to the method or - property for each environment. + Args: + timeout: Number of seconds before the call to `step_wait` times out. If `None` (default), the call to `step_wait` never times out. + + Returns: + List of the results of the individual calls to the method or property for each environment. """ self._assert_is_running() if self._state != AsyncState.WAITING_CALL: @@ -483,17 +411,14 @@ class AsyncVectorEnv(VectorEnv): return results - def set_attr(self, name, values): - """ - Parameters - ---------- - name : string - Name of the property to be set in each individual environment. + def set_attr(self, name: str, values: Union[list, tuple, object]): + """Sets an attribute of the sub-environments. - values : list, tuple, or object - Values of the property to be set to. If `values` is a list or - tuple, then it corresponds to the values for each individual - environment, otherwise a single value is set for all environments. + Args: + name: Name of the property to be set in each individual environment. + values: Values of the property to be set to. If ``values`` is a list or + tuple, then it corresponds to the values for each individual + environment, otherwise a single value is set for all environments. """ self._assert_is_running() if not isinstance(values, (list, tuple)): @@ -517,25 +442,19 @@ class AsyncVectorEnv(VectorEnv): _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) - def close_extras(self, timeout=None, terminate=False): - """Close the environments & clean up the extra resources - (processes and pipes). + def close_extras( + self, timeout: Optional[Union[int, float]] = None, terminate: bool = False + ): + """Close the environments & clean up the extra resources (processes and pipes). - Parameters - ---------- - timeout : int or float, optional - Number of seconds before the call to :meth:`close` times out. If ``None``, - the call to :meth:`close` never times out. If the call to :meth:`close` - times out, then all processes are terminated. + Args: + timeout: Number of seconds before the call to :meth:`close` times out. If ``None``, + the call to :meth:`close` never times out. If the call to :meth:`close` + times out, then all processes are terminated. + terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated. - terminate : bool - If ``True``, then the :meth:`close` operation is forced and all processes - are terminated. - - Raises - ------ - TimeoutError - If :meth:`close` timed out. + Raises: + TimeoutError: If :meth:`close` timed out. """ timeout = 0 if terminate else timeout try: @@ -626,6 +545,7 @@ class AsyncVectorEnv(VectorEnv): raise exctype(value) def __del__(self): + """On deleting the object, checks that the vector environment is closed.""" if not getattr(self, "closed", True) and hasattr(self, "_state"): self.close(terminate=True) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index dc2e8c345..61629bcaf 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -1,8 +1,12 @@ +"""A synchronous vector environment.""" +from __future__ import annotations + from copy import deepcopy -from typing import List, Optional, Union +from typing import Any, Iterator, Optional, Sequence, Union import numpy as np +from gym.spaces import Space from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.vector_env import VectorEnv @@ -12,35 +16,9 @@ __all__ = ["SyncVectorEnv"] class SyncVectorEnv(VectorEnv): """Vectorized environment that serially runs multiple environments. - Parameters - ---------- - env_fns : iterable of callable - Functions that create the environments. - - observation_space : :class:`gym.spaces.Space`, optional - Observation space of a single environment. If ``None``, then the - observation space of the first environment is taken. - - action_space : :class:`gym.spaces.Space`, optional - Action space of a single environment. If ``None``, then the action space - of the first environment is taken. - - copy : bool - If ``True``, then the :meth:`reset` and :meth:`step` methods return a - copy of the observations. - - Raises - ------ - RuntimeError - If the observation space of some sub-environment does not match - :obj:`observation_space` (or, by default, the observation space of - the first sub-environment). - - Example - ------- - - .. code-block:: + Example:: + >>> import gym >>> env = gym.vector.SyncVectorEnv([ ... lambda: gym.make("Pendulum-v0", g=9.81), ... lambda: gym.make("Pendulum-v0", g=1.62) @@ -50,7 +28,24 @@ class SyncVectorEnv(VectorEnv): [-0.85009176, 0.5266346 , 0.60007906]], dtype=float32) """ - def __init__(self, env_fns, observation_space=None, action_space=None, copy=True): + def __init__( + self, + env_fns: Iterator[callable], + observation_space: Space = None, + action_space: Space = None, + copy: bool = True, + ): + """Vectorized environment that serially runs multiple environments. + + Args: + env_fns: iterable of callable functions that create the environments. + observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken. + action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken. + copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations. + + Raises: + RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment). + """ self.env_fns = env_fns self.envs = [env_fn() for env_fn in env_fns] self.copy = copy @@ -60,7 +55,7 @@ class SyncVectorEnv(VectorEnv): observation_space = observation_space or self.envs[0].observation_space action_space = action_space or self.envs[0].action_space super().__init__( - num_envs=len(env_fns), + num_envs=len(self.envs), observation_space=observation_space, action_space=action_space, ) @@ -73,7 +68,12 @@ class SyncVectorEnv(VectorEnv): self._dones = np.zeros((self.num_envs,), dtype=np.bool_) self._actions = None - def seed(self, seed=None): + def seed(self, seed: Optional[Union[int, Sequence[int]]] = None): + """Sets the seed in all sub-environments. + + Args: + seed: The seed + """ super().seed(seed=seed) if seed is None: seed = [None for _ in range(self.num_envs)] @@ -86,10 +86,20 @@ class SyncVectorEnv(VectorEnv): def reset_wait( self, - seed: Optional[Union[int, List[int]]] = None, + seed: Optional[Union[int, list[int]]] = None, return_info: bool = False, options: Optional[dict] = None, ): + """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. + + Args: + seed: The reset environment seed + return_info: If to return information + options: Option information for the environment reset + + Returns: + The reset observation of the environment and reset information + """ if seed is None: seed = [None for _ in range(self.num_envs)] if isinstance(seed, int): @@ -128,9 +138,15 @@ class SyncVectorEnv(VectorEnv): ), data_list def step_async(self, actions): + """Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version.""" self._actions = iterate(self.action_space, actions) def step_wait(self): + """Steps through each of the environments returning the batched results. + + Returns: + The batched environment step results + """ observations, infos = [], [] for i, (env, action) in enumerate(zip(self.envs, self._actions)): observation, self._rewards[i], self._dones[i], info = env.step(action) @@ -150,7 +166,17 @@ class SyncVectorEnv(VectorEnv): infos, ) - def call(self, name, *args, **kwargs): + def call(self, name, *args, **kwargs) -> tuple: + """Calls the method with name and applies args and kwargs. + + Args: + name: The method name + *args: The method args + **kwargs: The method kwargs + + Returns: + Tuple of results + """ results = [] for env in self.envs: function = getattr(env, name) @@ -161,7 +187,15 @@ class SyncVectorEnv(VectorEnv): return tuple(results) - def set_attr(self, name, values): + def set_attr(self, name: str, values: Union[list, tuple, Any]): + """Sets an attribute of the sub-environments. + + Args: + name: The property name to change + values: Values of the property to be set to. If ``values`` is a list or + tuple, then it corresponds to the values for each individual + environment, otherwise, a single value is set for all environments. + """ if not isinstance(values, (list, tuple)): values = [values for _ in range(self.num_envs)] if len(values) != self.num_envs: @@ -178,7 +212,7 @@ class SyncVectorEnv(VectorEnv): """Close the environments.""" [env.close() for env in self.envs] - def _check_spaces(self): + def _check_spaces(self) -> bool: for env in self.envs: if not (env.observation_space == self.single_observation_space): raise RuntimeError( @@ -194,5 +228,4 @@ class SyncVectorEnv(VectorEnv): "action spaces from all environments must be equal." ) - else: - return True + return True diff --git a/gym/vector/utils/__init__.py b/gym/vector/utils/__init__.py index a3420f394..01a58445d 100644 --- a/gym/vector/utils/__init__.py +++ b/gym/vector/utils/__init__.py @@ -1,3 +1,4 @@ +"""Module for gym vector utils.""" from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars from gym.vector.utils.numpy_utils import concatenate, create_empty_array from gym.vector.utils.shared_memory import ( diff --git a/gym/vector/utils/misc.py b/gym/vector/utils/misc.py index 30af0cfa0..61cd1b0cc 100644 --- a/gym/vector/utils/misc.py +++ b/gym/vector/utils/misc.py @@ -1,3 +1,4 @@ +"""Miscellaneous utilities.""" import contextlib import os @@ -5,28 +6,35 @@ __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"] class CloudpickleWrapper: - def __init__(self, fn): + """Wrapper that uses cloudpickle to pickle and unpickle the result.""" + + def __init__(self, fn: callable): + """Cloudpickle wrapper for a function.""" self.fn = fn def __getstate__(self): + """Get the state using `cloudpickle.dumps(self.fn)`.""" import cloudpickle return cloudpickle.dumps(self.fn) def __setstate__(self, ob): + """Sets the state with obs.""" import pickle self.fn = pickle.loads(ob) def __call__(self): + """Calls the function `self.fn` with no arguments.""" return self.fn() @contextlib.contextmanager def clear_mpi_env_vars(): - """ - `from mpi4py import MPI` will call `MPI_Init` by default. If the child - process has MPI environment variables, MPI will think that the child process + """Clears the MPI of environment variables. + + `from mpi4py import MPI` will call `MPI_Init` by default. + If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. This context manager is a hacky way to clear those environment variables diff --git a/gym/vector/utils/numpy_utils.py b/gym/vector/utils/numpy_utils.py index af6295cf5..8cdc46434 100644 --- a/gym/vector/utils/numpy_utils.py +++ b/gym/vector/utils/numpy_utils.py @@ -1,5 +1,7 @@ +"""Numpy utility functions: concatenate space samples and create empty array.""" from collections import OrderedDict from functools import singledispatch +from typing import Iterable, Union import numpy as np @@ -9,36 +11,29 @@ __all__ = ["concatenate", "create_empty_array"] @singledispatch -def concatenate(space, items, out): +def concatenate( + space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray] +) -> Union[tuple, dict, np.ndarray]: """Concatenate multiple samples from space into a single object. - Parameters - ---------- - items : iterable of samples of `space` - Samples to be concatenated. + Example:: - out : tuple, dict, or `np.ndarray` + >>> from gym.spaces import Box + >>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32) + >>> out = np.zeros((2, 3), dtype=np.float32) + >>> items = [space.sample() for _ in range(2)] + >>> concatenate(space, items, out) + array([[0.6348213 , 0.28607962, 0.60760117], + [0.87383074, 0.192658 , 0.2148103 ]], dtype=float32) + + Args: + space: Observation space of a single environment in the vectorized environment. + items: Samples to be concatenated. + out: The output object. This object is a (possibly nested) numpy array. + + Returns: The output object. This object is a (possibly nested) numpy array. - - space : `gym.spaces.Space` instance - Observation space of a single environment in the vectorized environment. - - Returns - ------- - out : tuple, dict, or `np.ndarray` - The output object. This object is a (possibly nested) numpy array. - - Example - ------- - >>> from gym.spaces import Box - >>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32) - >>> out = np.zeros((2, 3), dtype=np.float32) - >>> items = [space.sample() for _ in range(2)] - >>> concatenate(items, out, space) - array([[0.6348213 , 0.28607962, 0.60760117], - [0.87383074, 0.192658 , 0.2148103 ]], dtype=float32) """ - assert isinstance(items, (list, tuple)) raise ValueError( f"Space of type `{type(space)}` is not a valid `gym.Space` instance." ) @@ -76,38 +71,30 @@ def _concatenate_custom(space, items, out): @singledispatch -def create_empty_array(space, n=1, fn=np.zeros): +def create_empty_array( + space: Space, n: int = 1, fn: callable = np.zeros +) -> Union[tuple, dict, np.ndarray]: """Create an empty (possibly nested) numpy array. - Parameters - ---------- - space : `gym.spaces.Space` instance - Observation space of a single environment in the vectorized environment. + Example:: - n : int - Number of environments in the vectorized environment. If `None`, creates - an empty sample from `space`. + >>> from gym.spaces import Box, Dict + >>> space = Dict({ + ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), + ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)}) + >>> create_empty_array(space, n=2, fn=np.zeros) + OrderedDict([('position', array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32)), + ('velocity', array([[0., 0.], + [0., 0.]], dtype=float32))]) - fn : callable - Function to apply when creating the empty numpy array. Examples of such - functions are `np.empty` or `np.zeros`. + Args: + space: Observation space of a single environment in the vectorized environment. + n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`. + fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`. - Returns - ------- - out : tuple, dict, or `np.ndarray` + Returns: The output object. This object is a (possibly nested) numpy array. - - Example - ------- - >>> from gym.spaces import Box, Dict - >>> space = Dict({ - ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), - ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)}) - >>> create_empty_array(space, n=2, fn=np.zeros) - OrderedDict([('position', array([[0., 0., 0.], - [0., 0., 0.]], dtype=float32)), - ('velocity', array([[0., 0.], - [0., 0.]], dtype=float32))]) """ raise ValueError( f"Space of type `{type(space)}` is not a valid `gym.Space` instance." diff --git a/gym/vector/utils/shared_memory.py b/gym/vector/utils/shared_memory.py index dd0c9beb6..a47abcb00 100644 --- a/gym/vector/utils/shared_memory.py +++ b/gym/vector/utils/shared_memory.py @@ -1,44 +1,40 @@ +"""Utility functions for vector environments to share memory between processes.""" import multiprocessing as mp from collections import OrderedDict from ctypes import c_bool from functools import singledispatch +from typing import Union import numpy as np from gym.error import CustomSpaceError -from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple +from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"] @singledispatch -def create_shared_memory(space, n=1, ctx=mp): - """Create a shared memory object, to be shared across processes. This - eventually contains the observations from the vectorized environment. +def create_shared_memory( + space: Space, n: int = 1, ctx=mp +) -> Union[dict, tuple, mp.Array]: + """Create a shared memory object, to be shared across processes. - Parameters - ---------- - space : `gym.spaces.Space` instance - Observation space of a single environment in the vectorized environment. + This eventually contains the observations from the vectorized environment. - n : int - Number of environments in the vectorized environment (i.e. the number - of processes). + Args: + space: Observation space of a single environment in the vectorized environment. + n: Number of environments in the vectorized environment (i.e. the number of processes). + ctx: The multiprocess module - ctx : `multiprocessing` context - Context for multiprocessing. - - Returns - ------- - shared_memory : dict, tuple, or `multiprocessing.Array` instance - Shared object across processes. + Returns: + shared_memory for the shared object across processes. """ raise CustomSpaceError( "Cannot create a shared memory for space with " - "type `{}`. Shared memory only supports " + f"type `{type(space)}`. Shared memory only supports " "default Gym spaces (e.g. `Box`, `Tuple`, " "`Dict`, etc...), and does not support custom " - "Gym spaces.".format(type(space)) + "Gym spaces." ) @@ -46,7 +42,7 @@ def create_shared_memory(space, n=1, ctx=mp): @create_shared_memory.register(Discrete) @create_shared_memory.register(MultiDiscrete) @create_shared_memory.register(MultiBinary) -def _create_base_shared_memory(space, n=1, ctx=mp): +def _create_base_shared_memory(space, n: int = 1, ctx=mp): dtype = space.dtype.char if dtype in "?": dtype = c_bool @@ -54,7 +50,7 @@ def _create_base_shared_memory(space, n=1, ctx=mp): @create_shared_memory.register(Tuple) -def _create_tuple_shared_memory(space, n=1, ctx=mp): +def _create_tuple_shared_memory(space, n: int = 1, ctx=mp): return tuple( create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces ) @@ -71,39 +67,32 @@ def _create_dict_shared_memory(space, n=1, ctx=mp): @singledispatch -def read_from_shared_memory(space, shared_memory, n=1): +def read_from_shared_memory( + space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1 +) -> Union[dict, tuple, np.ndarray]: """Read the batch of observations from shared memory as a numpy array. - Parameters - ---------- - shared_memory : dict, tuple, or `multiprocessing.Array` instance - Shared object across processes. This contains the observations from the - vectorized environment. This object is created with `create_shared_memory`. + ..notes:: + The numpy array objects returned by `read_from_shared_memory` shares the + memory of `shared_memory`. Any changes to `shared_memory` are forwarded + to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`. - space : `gym.spaces.Space` instance - Observation space of a single environment in the vectorized environment. + Args: + space: Observation space of a single environment in the vectorized environment. + shared_memory: Shared object across processes. This contains the observations from the vectorized environment. + This object is created with `create_shared_memory`. + n: Number of environments in the vectorized environment (i.e. the number of processes). - n : int - Number of environments in the vectorized environment (i.e. the number - of processes). - - Returns - ------- - observations : dict, tuple or `np.ndarray` instance + Returns: Batch of observations as a (possibly nested) numpy array. - Notes - ----- - The numpy array objects returned by `read_from_shared_memory` shares the - memory of `shared_memory`. Any changes to `shared_memory` are forwarded - to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`. """ raise CustomSpaceError( "Cannot read from a shared memory for space with " - "type `{}`. Shared memory only supports " + f"type `{type(space)}`. Shared memory only supports " "default Gym spaces (e.g. `Box`, `Tuple`, " "`Dict`, etc...), and does not support custom " - "Gym spaces.".format(type(space)) + "Gym spaces." ) @@ -111,14 +100,14 @@ def read_from_shared_memory(space, shared_memory, n=1): @read_from_shared_memory.register(Discrete) @read_from_shared_memory.register(MultiDiscrete) @read_from_shared_memory.register(MultiBinary) -def _read_base_from_shared_memory(space, shared_memory, n=1): +def _read_base_from_shared_memory(space, shared_memory, n: int = 1): return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape( (n,) + space.shape ) @read_from_shared_memory.register(Tuple) -def _read_tuple_from_shared_memory(space, shared_memory, n=1): +def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1): return tuple( read_from_shared_memory(subspace, memory, n=n) for (memory, subspace) in zip(shared_memory, space.spaces) @@ -126,7 +115,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n=1): @read_from_shared_memory.register(Dict) -def _read_dict_from_shared_memory(space, shared_memory, n=1): +def _read_dict_from_shared_memory(space, shared_memory, n: int = 1): return OrderedDict( [ (key, read_from_shared_memory(subspace, shared_memory[key], n=n)) @@ -136,34 +125,26 @@ def _read_dict_from_shared_memory(space, shared_memory, n=1): @singledispatch -def write_to_shared_memory(space, index, value, shared_memory): +def write_to_shared_memory( + space: Space, + index: int, + value: np.ndarray, + shared_memory: Union[dict, tuple, mp.Array], +): """Write the observation of a single environment into shared memory. - Parameters - ---------- - index : int - Index of the environment (must be in `[0, num_envs)`). - - value : sample from `space` - Observation of the single environment to write to shared memory. - - shared_memory : dict, tuple, or `multiprocessing.Array` instance - Shared object across processes. This contains the observations from the - vectorized environment. This object is created with `create_shared_memory`. - - space : `gym.spaces.Space` instance - Observation space of a single environment in the vectorized environment. - - Returns - ------- - `None` + Args: + space: Observation space of a single environment in the vectorized environment. + index: Index of the environment (must be in `[0, num_envs)`). + value: Observation of the single environment to write to shared memory. + shared_memory: Shared object across processes. This contains the observations from the vectorized environment. This object is created with `create_shared_memory`. """ raise CustomSpaceError( "Cannot write to a shared memory for space with " - "type `{}`. Shared memory only supports " + f"type `{type(space)}`. Shared memory only supports " "default Gym spaces (e.g. `Box`, `Tuple`, " "`Dict`, etc...), and does not support custom " - "Gym spaces.".format(type(space)) + "Gym spaces." ) diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index b5de8b963..e7e114b35 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -1,6 +1,8 @@ +"""Utility functions for gym spaces: batch space and iterator.""" from collections import OrderedDict from copy import deepcopy from functools import singledispatch +from typing import Iterator import numpy as np @@ -12,32 +14,25 @@ __all__ = ["_BaseGymSpaces", "batch_space", "iterate"] @singledispatch -def batch_space(space, n=1): +def batch_space(space: Space, n: int = 1) -> Space: """Create a (batched) space, containing multiple copies of a single space. - Parameters - ---------- - space : `gym.spaces.Space` instance - Space (e.g. the observation space) for a single environment in the - vectorized environment. + Example:: - n : int - Number of environments in the vectorized environment. + >>> from gym.spaces import Box, Dict + >>> space = Dict({ + ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), + ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32) + ... }) + >>> batch_space(space, n=5) + Dict(position:Box(5, 3), velocity:Box(5, 2)) - Returns - ------- - batched_space : `gym.spaces.Space` instance - Space (e.g. the observation space) for a batch of environments in the - vectorized environment. + Args: + space: Space (e.g. the observation space) for a single environment in the vectorized environment. + n: Number of environments in the vectorized environment. - Example - ------- - >>> from gym.spaces import Box, Dict - >>> space = Dict({ - ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), - ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)}) - >>> batch_space(space, n=5) - Dict(position:Box(5, 3), velocity:Box(5, 2)) + Returns: + Space (e.g. the observation space) for a batch of environments in the vectorized environment. """ raise ValueError( f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance." @@ -126,41 +121,35 @@ def _batch_space_custom(space, n=1): @singledispatch -def iterate(space, items): +def iterate(space: Space, items) -> Iterator: """Iterate over the elements of a (batched) space. - Parameters - ---------- - space : `gym.spaces.Space` instance - Space to which `items` belong to. + Example:: - items : samples of `space` - Items to be iterated over. + >>> from gym.spaces import Box, Dict + >>> space = Dict({ + ... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32), + ... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)}) + >>> items = space.sample() + >>> it = iterate(space, items) + >>> next(it) + {'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32), + 'velocity': array([0.35848552, 0.1533453 ], dtype=float32)} + >>> next(it) + {'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32), + 'velocity': array([0.7975036 , 0.93317133], dtype=float32)} + >>> next(it) + StopIteration - Returns - ------- - iterator : `Iterable` instance + Args: + space: Space to which `items` belong to. + items: Items to be iterated over. + + Returns: Iterator over the elements in `items`. - - Example - ------- - >>> from gym.spaces import Box, Dict - >>> space = Dict({ - ... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32), - ... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)}) - >>> items = space.sample() - >>> it = iterate(space, items) - >>> next(it) - {'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32), - 'velocity': array([0.35848552, 0.1533453 ], dtype=float32)} - >>> next(it) - {'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32), - 'velocity': array([0.7975036 , 0.93317133], dtype=float32)} - >>> next(it) - StopIteration """ raise ValueError( - "Space of type `{}` is not a valid `gym.Space` " "instance.".format(type(space)) + f"Space of type `{type(space)}` is not a valid `gym.Space` " "instance." ) diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index a404e76f4..96fe4495f 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Union +"""Base class for vectorized environments.""" +from __future__ import annotations + +from typing import Any, Optional, Union import gym from gym.logger import deprecation @@ -8,32 +11,28 @@ __all__ = ["VectorEnv"] class VectorEnv(gym.Env): - r"""Base class for vectorized environments. Runs multiple independent copies of the - same environment in parallel. This is not the same as 1 environment that has multiple - sub components, but it is many copies of the same base env. + """Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel. - Each observation returned from vectorized environment is a batch of observations - for each parallel environment. And :meth:`step` is also expected to receive a batch of - actions for each parallel environment. + This is not the same as 1 environment that has multiple subcomponents, but it is many copies of the same base env. - .. note:: + Each observation returned from vectorized environment is a batch of observations for each parallel environment. + And :meth:`step` is also expected to receive a batch of actions for each parallel environment. + Notes: All parallel environments should share the identical observation and action spaces. In other words, a vector of multiple different environments is not supported. - - Parameters - ---------- - num_envs : int - Number of environments in the vectorized environment. - - observation_space : :class:`gym.spaces.Space` - Observation space of a single environment. - - action_space : :class:`gym.spaces.Space` - Action space of a single environment. """ - def __init__(self, num_envs, observation_space, action_space): + def __init__( + self, num_envs: int, observation_space: gym.Space, action_space: gym.Space + ): + """Base class for vectorized environments. + + Args: + num_envs: Number of environments in the vectorized environment. + observation_space: Observation space of a single environment. + action_space: Action space of a single environment. + """ self.num_envs = num_envs self.is_vector_env = True self.observation_space = batch_space(observation_space, n=num_envs) @@ -49,141 +48,134 @@ class VectorEnv(gym.Env): def reset_async( self, - seed: Optional[Union[int, List[int]]] = None, + seed: Optional[Union[int, list[int]]] = None, return_info: bool = False, options: Optional[dict] = None, ): + """Reset the sub-environments asynchronously. + + This method will return ``None``. A call to :meth:`reset_async` should be followed by a call to :meth:`reset_wait` to retrieve the results. + """ pass def reset_wait( self, - seed: Optional[Union[int, List[int]]] = None, + seed: Optional[Union[int, list[int]]] = None, return_info: bool = False, options: Optional[dict] = None, ): + """Retrieves the results of a :meth:`reset_async` call. + + A call to this method must always be preceded by a call to :meth:`reset_async`. + """ raise NotImplementedError() def reset( self, *, - seed: Optional[Union[int, List[int]]] = None, + seed: Optional[Union[int, list[int]]] = None, return_info: bool = False, options: Optional[dict] = None, ): - r"""Reset all parallel environments and return a batch of initial observations. + """Reset all parallel environments and return a batch of initial observations. - Returns - ------- - observations : element of :attr:`observation_space` + Args: + seed: The environment reset seeds + return_info: If to return the info + options: If to return the options + + Returns: A batch of observations from the vectorized environment. """ self.reset_async(seed=seed, return_info=return_info, options=options) return self.reset_wait(seed=seed, return_info=return_info, options=options) def step_async(self, actions): + """Asynchronously performs steps in the sub-environments. + + The results can be retrieved via a call to :meth:`step_wait`. + """ pass def step_wait(self, **kwargs): + """Retrieves the results of a :meth:`step_async` call. + + A call to this method must always be preceded by a call to :meth:`step_async`. + """ raise NotImplementedError() def step(self, actions): - r"""Take an action for each parallel environment. + """Take an action for each parallel environment. - Parameters - ---------- - actions : element of :attr:`action_space` - Batch of actions. + Args: + actions: element of :attr:`action_space` Batch of actions. - Returns - ------- - observations : element of :attr:`observation_space` - A batch of observations from the vectorized environment. - - rewards : :obj:`np.ndarray`, dtype :obj:`np.float_` - A vector of rewards from the vectorized environment. - - dones : :obj:`np.ndarray`, dtype :obj:`np.bool_` - A vector whose entries indicate whether the episode has ended. - - infos : list of dict - A list of auxiliary diagnostic information dicts from each parallel environment. + Returns: + Batch of observations, rewards, done and infos """ - self.step_async(actions) return self.step_wait() def call_async(self, name, *args, **kwargs): + """Calls a method name for each parallel environment asynchronously.""" pass def call_wait(self, **kwargs): + """After calling a method in :meth:`call_async`, this function collects the results.""" raise NotImplementedError() - def call(self, name, *args, **kwargs): + def call(self, name: str, *args, **kwargs) -> list[Any]: """Call a method, or get a property, from each parallel environment. - Parameters - ---------- - name : string - Name of the method or property to call. + Args: + name (str): Name of the method or property to call. + *args: Arguments to apply to the method call. + **kwargs: Keyword arguments to apply to the method call. - *args - Arguments to apply to the method call. - - **kwargs - Keywoard arguments to apply to the method call. - - Returns - ------- - results : list - List of the results of the individual calls to the method or - property for each environment. + Returns: + List of the results of the individual calls to the method or property for each environment. """ self.call_async(name, *args, **kwargs) return self.call_wait() - def get_attr(self, name): + def get_attr(self, name: str): """Get a property from each parallel environment. - Parameters - ---------- - name : string - Name of the property to be get from each individual environment. + Args: + name (str): Name of the property to be get from each individual environment. + + Returns: + The property with name """ return self.call(name) - def set_attr(self, name, values): - """Set a property in each parallel environment. + def set_attr(self, name: str, values: Union[list, tuple, object]): + """Set a property in each sub-environment. - Parameters - ---------- - name : string - Name of the property to be set in each individual environment. - - values : list, tuple, or object - Values of the property to be set to. If `values` is a list or - tuple, then it corresponds to the values for each individual - environment, otherwise a single value is set for all environments. + Args: + name (str): Name of the property to be set in each individual environment. + values (list, tuple, or object): Values of the property to be set to. If `values` is a list or + tuple, then it corresponds to the values for each individual environment, otherwise a single value + is set for all environments. """ raise NotImplementedError() def close_extras(self, **kwargs): - r"""Clean up the extra resources e.g. beyond what's in this base class.""" + """Clean up the extra resources e.g. beyond what's in this base class.""" pass def close(self, **kwargs): - r"""Close all parallel environments and release resources. + """Close all parallel environments and release resources. It also closes all the existing image viewers, then calls :meth:`close_extras` and set :attr:`closed` as ``True``. - .. warning:: - + Warnings: This function itself does not close the environments, it should be handled in :meth:`close_extras`. This is generic for both synchronous and asynchronous vectorized environments. - .. note:: - + Notes: This will be automatically called when garbage collected or program exited. """ @@ -197,14 +189,12 @@ class VectorEnv(gym.Env): def seed(self, seed=None): """Set the random seed in all parallel environments. - Parameters - ---------- - seed : list of int, or int, optional - Random seed for each parallel environment. If ``seed`` is a list of - length ``num_envs``, then the items of the list are chosen as random - seeds. If ``seed`` is an int, then each parallel environment uses the random - seed ``seed + n``, where ``n`` is the index of the parallel environment - (between ``0`` and ``num_envs - 1``). + Args: + seed: Random seed for each parallel environment. If ``seed`` is a list of + length ``num_envs``, then the items of the list are chosen as random + seeds. If ``seed`` is an int, then each parallel environment uses the random + seed ``seed + n``, where ``n`` is the index of the parallel environment + (between ``0`` and ``num_envs - 1``). """ deprecation( "Function `env.seed(seed)` is marked as deprecated and will be removed in the future. " @@ -212,10 +202,12 @@ class VectorEnv(gym.Env): ) def __del__(self): + """Closes the vector environment.""" if not getattr(self, "closed", True): self.close() def __repr__(self): + """Returns a string representation of the vector environment using the class name, number of environments and environment spec id.""" if self.spec is None: return f"{self.__class__.__name__}({self.num_envs})" else: @@ -223,19 +215,17 @@ class VectorEnv(gym.Env): class VectorEnvWrapper(VectorEnv): - r"""Wraps the vectorized environment to allow a modular transformation. + """Wraps the vectorized environment to allow a modular transformation. This class is the base class for all wrappers for vectorized environments. The subclass could override some methods to change the behavior of the original vectorized environment without touching the original code. - .. note:: - + Notes: Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. - """ - def __init__(self, env): + def __init__(self, env: VectorEnv): assert isinstance(env, VectorEnv) self.env = env diff --git a/setup.py b/setup.py index fb44e5d4d..7cd9d6321 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ extras = { "classic_control": ["pygame==2.1.0"], "mujoco": ["mujoco_py>=1.50, <2.0"], "toy_text": ["pygame==2.1.0", "scipy>=1.4.1"], - "other": ["lz4>=3.1.0", "opencv-python>=3.0"], + "other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0"], } # Meta dependency groups. diff --git a/tests/wrappers/flatten_test.py b/tests/wrappers/test_flatten.py similarity index 100% rename from tests/wrappers/flatten_test.py rename to tests/wrappers/test_flatten.py diff --git a/tests/wrappers/nested_dict_test.py b/tests/wrappers/test_nested_dict.py similarity index 100% rename from tests/wrappers/nested_dict_test.py rename to tests/wrappers/test_nested_dict.py