diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9b440c4e..3c113707a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: rev: 6.1.1 hooks: - id: pydocstyle - exclude: ^(gymnasium/envs/)|(tests/)|(docs/) + exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/) args: - --source - --explain diff --git a/docs/api/experimental.md b/docs/api/experimental.md index 59e0e6255..764940e38 100644 --- a/docs/api/experimental.md +++ b/docs/api/experimental.md @@ -14,7 +14,9 @@ experimental/vector_wrappers ## Functional Environments +```{eval-rst} The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it. +``` ## Wrappers @@ -36,64 +38,32 @@ Gymnasium already contains a large collection of wrappers, but we believe that t * - Old name - New name - - Vector version - - Tree structure * - :class:`wrappers.TransformObservation` - :class:`experimental.wrappers.LambdaObservationV0` - - VectorLambdaObservation - - No * - :class:`wrappers.FilterObservation` - :class:`experimental.wrappers.FilterObservationV0` - - VectorFilterObservation (*) - - Yes * - :class:`wrappers.FlattenObservation` - :class:`experimental.wrappers.FlattenObservationV0` - - VectorFlattenObservation (*) - - No * - :class:`wrappers.GrayScaleObservation` - :class:`experimental.wrappers.GrayscaleObservationV0` - - VectorGrayscaleObservation (*) - - Yes * - :class:`wrappers.ResizeObservation` - :class:`experimental.wrappers.ResizeObservationV0` - - VectorResizeObservation (*) - - Yes - * - Not Implemented + * - ``supersuit.reshape_v0`` - :class:`experimental.wrappers.ReshapeObservationV0` - - VectorReshapeObservation (*) - - Yes * - Not Implemented - :class:`experimental.wrappers.RescaleObservationV0` - - VectorRescaleObservation (*) - - Yes - * - Not Implemented + * - ``supersuit.dtype_v0`` - :class:`experimental.wrappers.DtypeObservationV0` - - VectorDtypeObservation (*) - - Yes * - :class:`wrappers.PixelObservationWrapper` - - PixelObservation - - VectorPixelObservation - - No + - :class:`experimental.wrappers.PixelObservationV0` * - :class:`wrappers.NormalizeObservation` - - NormalizeObservation - - VectorNormalizeObservation - - No + - :class:`experimental.wrappers.NormalizeObservationV0` * - :class:`wrappers.TimeAwareObservation` - :class:`experimental.wrappers.TimeAwareObservationV0` - - VectorTimeAwareObservation - - No * - :class:`wrappers.FrameStack` - - FrameStackObservation - - VectorFrameStackObservation - - No - * - Not Implemented + - :class:`experimental.wrappers.FrameStackObservationV0` + * - ``supersuit.delay_observations_v0`` - :class:`experimental.wrappers.DelayObservationV0` - - VectorDelayObservation - - No - * - :class:`wrappers.AtariPreprocessing` - - AtariPreprocessing - - Not Implemented - - No ``` ### Action Wrappers @@ -105,24 +75,14 @@ Gymnasium already contains a large collection of wrappers, but we believe that t * - Old name - New name - - Vector version - - Tree structure - * - Not Implemented + * - ``supersuit.action_lambda_v1`` - :class:`experimental.wrappers.LambdaActionV0` - - VectorLambdaAction - - No * - :class:`wrappers.ClipAction` - :class:`experimental.wrappers.ClipActionV0` - - VectorClipAction (*) - - Yes * - :class:`wrappers.RescaleAction` - :class:`experimental.wrappers.RescaleActionV0` - - VectorRescaleAction (*) - - Yes - * - Not Implemented + * - ``supersuit.sticky_actions_v0`` - :class:`experimental.wrappers.StickyActionV0` - - VectorStickyAction - - No ``` ### Reward Wrappers @@ -134,19 +94,12 @@ Gymnasium already contains a large collection of wrappers, but we believe that t * - Old name - New name - - Vector version * - :class:`wrappers.TransformReward` - :class:`experimental.wrappers.LambdaRewardV0` - - VectorLambdaReward - * - Not Implemented + * - ``supersuit.clip_reward_v0`` - :class:`experimental.wrappers.ClipRewardV0` - - VectorClipReward (*) - * - Not Implemented - - RescaleReward - - VectorRescaleReward (*) * - :class:`wrappers.NormalizeReward` - - NormalizeReward - - VectorNormalizeReward + - :class:`experimental.wrappers.NormalizeRewardV0` ``` ### Common Wrappers @@ -159,37 +112,21 @@ Gymnasium already contains a large collection of wrappers, but we believe that t * - Old name - New name - - Vector version * - :class:`wrappers.AutoResetWrapper` - - AutoReset - - VectorAutoReset + - :class:`experimental.wrappers.AutoresetV0` * - :class:`wrappers.PassiveEnvChecker` - - PassiveEnvChecker - - VectorPassiveEnvChecker + - :class:`experimental.wrappers.PassiveEnvCheckerV0` * - :class:`wrappers.OrderEnforcing` - - OrderEnforcing - - VectorOrderEnforcing + - :class:`experimental.wrappers.OrderEnforcingV0` * - :class:`wrappers.EnvCompatibility` - Moved to `shimmy `_ - - Not Implemented * - :class:`wrappers.RecordEpisodeStatistics` - - RecordEpisodeStatistics - - VectorRecordEpisodeStatistics - * - :class:`wrappers.RenderCollection` - - RenderCollection - - VectorRenderCollection - * - :class:`wrappers.HumanRendering` - - HumanRendering - - Not Implemented - * - Not Implemented - - :class:`experimental.wrappers.JaxToNumpyV0` - - VectorJaxToNumpy (*) - * - Not Implemented - - :class:`experimental.wrappers.JaxToTorchV0` - - VectorJaxToTorch (*) + - :class:`experimental.wrappers.RecordEpisodeStatisticsV0` + * - :class:`wrappers.AtariPreprocessing` + - :class:`experimental.wrappers.AtariPreprocessingV0` ``` -### Vector Only Wrappers +### Rendering Wrappers ```{eval-rst} .. py:currentmodule:: gymnasium @@ -199,8 +136,22 @@ Gymnasium already contains a large collection of wrappers, but we believe that t * - Old name - New name - * - :class:`wrappers.VectorListInfo` - - VectorListInfo + * - :class:`wrapper.RecordVideo` + - :class:`experimental.wrappers.RecordVideoV0` + * - :class:`wrappers.HumanRendering` + - :class:`experimental.wrappers.HumanRenderingV0` + * - :class:`wrappers.RenderCollection` + - :class:`experimental.wrappers.RenderCollectionV0` +``` + +### Environment data conversion + +```{eval-rst} +.. py:currentmodule:: gymnasium + +* :class:`experimental.wrappers.JaxToNumpyV0` +* :class:`experimental.wrappers.JaxToTorchV0` +* :class:`experimental.wrappers.NumpyToTorchV0` ``` ## Vector Environment diff --git a/docs/api/experimental/wrappers.md b/docs/api/experimental/wrappers.md index 7acb7a276..f3c2e9e18 100644 --- a/docs/api/experimental/wrappers.md +++ b/docs/api/experimental/wrappers.md @@ -11,8 +11,12 @@ .. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0 .. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0 .. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0 +.. autoclass:: gymnasium.experimental.wrappers.PixelObservationV0 +.. autoclass:: gymnasium.experimental.wrappers.NormalizeObservationV0 .. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0 +.. autoclass:: gymnasium.experimental.wrappers.FrameStackObservationV0 .. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0 +.. autoclass:: gymnasium.experimental.wrappers.AtariPreprocessingV0 ``` ## Action Wrappers @@ -24,16 +28,34 @@ .. autoclass:: gymnasium.experimental.wrappers.StickyActionV0 ``` -# Reward Wrappers +## Reward Wrappers ```{eval-rst} .. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0 .. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 +.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV0 ``` -## Common Wrappers +## Other Wrappers +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.AutoresetV0 +.. autoclass:: gymnasium.experimental.wrappers.PassiveEnvCheckerV0 +.. autoclass:: gymnasium.experimental.wrappers.OrderEnforcingV0 +.. autoclass:: gymnasium.experimental.wrappers.RecordEpisodeStatisticsV0 +``` + +## Rendering Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.RecordVideoV0 +.. autoclass:: gymnasium.experimental.wrappers.HumanRenderingV0 +.. autoclass:: gymnasium.experimental.wrappers.RenderCollectionV0 +``` + +## Environment data conversion ```{eval-rst} .. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0 .. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0 +.. autoclass:: gymnasium.experimental.wrappers.NumpyToTorchV0 ``` diff --git a/gymnasium/core.py b/gymnasium/core.py index e15137e07..d5ad5f1b4 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -356,7 +356,7 @@ class Wrapper(Env[WrapperObsType, WrapperActType]): def step( self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.""" return self.env.step(action) diff --git a/gymnasium/envs/phys2d/__init__.py b/gymnasium/envs/phys2d/__init__.py index 8ff4b205c..d85a43918 100644 --- a/gymnasium/envs/phys2d/__init__.py +++ b/gymnasium/envs/phys2d/__init__.py @@ -1,2 +1,3 @@ +"""Module for 2d physics environments with functional and environment implementations.""" from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 7e03ef4e0..dfe4d3a70 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -1,8 +1,7 @@ -""" -Implementation of a Jax-accelerated cartpole environment. -""" +"""Implementation of a Jax-accelerated cartpole environment.""" +from __future__ import annotations -from typing import Optional, Tuple, Union +from typing import Any, Tuple import jax import jax.numpy as jnp @@ -74,7 +73,7 @@ class CartPoleFunctional( ) def transition( - self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None + self, state: jnp.ndarray, action: int | jnp.ndarray, rng: None = None ) -> StateType: """Cartpole transition.""" x, x_dot, theta, theta_dot = state @@ -106,6 +105,7 @@ class CartPoleFunctional( return state def terminal(self, state: jnp.ndarray) -> jnp.ndarray: + """Checks if the state is terminal.""" x, _, theta, _ = state terminated = ( @@ -120,6 +120,7 @@ class CartPoleFunctional( def reward( self, state: StateType, action: ActType, next_state: StateType ) -> jnp.ndarray: + """Computes the reward for the state transition using the action.""" x, _, theta, _ = state terminated = ( @@ -136,8 +137,8 @@ class CartPoleFunctional( self, state: StateType, render_state: RenderStateType, - ) -> Tuple[RenderStateType, np.ndarray]: - + ) -> tuple[RenderStateType, np.ndarray]: + """Renders an image of the state using the render state.""" try: import pygame from pygame import gfxdraw @@ -210,6 +211,7 @@ class CartPoleFunctional( def render_init( self, screen_width: int = 600, screen_height: int = 400 ) -> RenderStateType: + """Initialises the render state for a screen width and height.""" try: import pygame except ImportError as e: @@ -224,6 +226,7 @@ class CartPoleFunctional( return screen, clock def render_close(self, render_state: RenderStateType) -> None: + """Closes the render state.""" try: import pygame except ImportError as e: @@ -235,20 +238,24 @@ class CartPoleFunctional( class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle): + """Jax-based implementation of the CartPole environment.""" metadata = {"render_modes": ["rgb_array"], "render_fps": 50} - def __init__(self, render_mode: Optional[str] = None, **kwargs): + def __init__(self, render_mode: str | None = None, **kwargs: Any): + """Constructor for the CartPole where the kwargs are applied to the functional environment.""" EzPickle.__init__(self, render_mode=render_mode, **kwargs) + env = CartPoleFunctional(**kwargs) env.transform(jax.jit) + action_space = env.action_space observation_space = env.observation_space - metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + super().__init__( env, observation_space=observation_space, action_space=action_space, - metadata=metadata, + metadata=self.metadata, render_mode=render_mode, ) diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index d530b4313..9a1f1f8db 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -1,8 +1,8 @@ -""" -Implementation of a Jax-accelerated pendulum environment. -""" +"""Implementation of a Jax-accelerated pendulum environment.""" +from __future__ import annotations + from os import path -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple import jax import jax.numpy as jnp @@ -22,7 +22,7 @@ RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] class PendulumFunctional( FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType] ): - """Pendulum but in jax and functional.""" + """Pendulum but in jax and functional structure.""" max_speed = 8 max_torque = 2.0 @@ -44,7 +44,7 @@ class PendulumFunctional( return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape) def transition( - self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None + self, state: jnp.ndarray, action: int | jnp.ndarray, rng: None = None ) -> jnp.ndarray: """Pendulum transition.""" th, thdot = state # th := theta @@ -65,10 +65,12 @@ class PendulumFunctional( return new_state def observation(self, state: jnp.ndarray) -> jnp.ndarray: + """Generates an observation based on the state.""" theta, thetadot = state return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot]) def reward(self, state: StateType, action: ActType, next_state: StateType) -> float: + """Generates the reward based on the state, action and next state.""" th, thdot = state # th := theta u = action @@ -80,13 +82,15 @@ class PendulumFunctional( return -costs def terminal(self, state: StateType) -> bool: + """Determines if the state is a terminal state.""" return False def render_image( self, state: StateType, - render_state: Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]], # type: ignore # noqa: F821 - ) -> Tuple[RenderStateType, np.ndarray]: + render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821 + ) -> tuple[RenderStateType, np.ndarray]: + """Renders an RGB image.""" try: import pygame from pygame import gfxdraw @@ -159,6 +163,7 @@ class PendulumFunctional( def render_init( self, screen_width: int = 600, screen_height: int = 400 ) -> RenderStateType: + """Initialises the render state.""" try: import pygame except ImportError as e: @@ -172,7 +177,8 @@ class PendulumFunctional( return screen, clock, None - def render_close(self, render_state: RenderStateType) -> None: + def render_close(self, render_state: RenderStateType): + """Closes the render state.""" try: import pygame except ImportError as e: @@ -184,21 +190,24 @@ class PendulumFunctional( class PendulumJaxEnv(FunctionalJaxEnv, EzPickle): + """Jax-based pendulum environment using the functional version as base.""" metadata = {"render_modes": ["rgb_array"], "render_fps": 30} - def __init__(self, render_mode: Optional[str] = None, **kwargs): + def __init__(self, render_mode: str | None = None, **kwargs: Any): + """Constructor where the kwargs are passed to the base environment to modify the parameters.""" EzPickle.__init__(self, render_mode=render_mode, **kwargs) + env = PendulumFunctional(**kwargs) env.transform(jax.jit) + action_space = env.action_space observation_space = env.observation_space - metadata = {"render_modes": ["rgb_array"], "render_fps": 30} super().__init__( env, observation_space=observation_space, action_space=action_space, - metadata=metadata, + metadata=self.metadata, render_mode=render_mode, ) diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index ceaed04e8..a4a995e20 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -1,3 +1,6 @@ +"""Functions for registering environments within gymnasium using public functions ``make``, ``register`` and ``spec``.""" +from __future__ import annotations + import contextlib import copy import difflib @@ -8,18 +11,7 @@ import sys import warnings from collections import defaultdict from dataclasses import dataclass, field -from typing import ( - Callable, - Dict, - Iterable, - List, - Optional, - Sequence, - SupportsFloat, - Tuple, - Union, - overload, -) +from typing import Any, Callable, Iterable, Sequence, SupportsFloat, overload import numpy as np @@ -53,7 +45,7 @@ ENV_ID_RE = re.compile( def load(name: str) -> callable: - """Loads an environment with name and returns an environment creation function + """Loads an environment with name and returns an environment creation function. Args: name: The environment name @@ -67,7 +59,7 @@ def load(name: str) -> callable: return fn -def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]: +def parse_env_id(id: str) -> tuple[str | None, str, int | None]: """Parse environment ID string format. This format is true today, but it's *not* an official spec. @@ -98,7 +90,7 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]: return namespace, name, version -def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str: +def get_env_id(ns: str | None, name: str, version: int | None) -> str: """Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`. Args: @@ -109,7 +101,6 @@ def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str: Returns: The environment id """ - full_name = name if version is not None: full_name += f"-v{version}" @@ -134,14 +125,14 @@ class EnvSpec: """ id: str - entry_point: Union[Callable, str] + entry_point: Callable | str # Environment attributes - reward_threshold: Optional[float] = field(default=None) + reward_threshold: float | None = field(default=None) nondeterministic: bool = field(default=False) # Wrappers - max_episode_steps: Optional[int] = field(default=None) + max_episode_steps: int | None = field(default=None) order_enforce: bool = field(default=True) autoreset: bool = field(default=False) disable_env_checker: bool = field(default=False) @@ -151,20 +142,22 @@ class EnvSpec: kwargs: dict = field(default_factory=dict) # post-init attributes - namespace: Optional[str] = field(init=False) + namespace: str | None = field(init=False) name: str = field(init=False) - version: Optional[int] = field(init=False) + version: int | None = field(init=False) def __post_init__(self): + """Calls after the spec is created to extract the namespace, name and version from the id.""" # Initialize namespace, name, version self.namespace, self.name, self.version = parse_env_id(self.id) - def make(self, **kwargs) -> Env: + def make(self, **kwargs: Any) -> Env: + """Calls ``make`` using the environment spec and any keyword arguments.""" # For compatibility purposes return make(self, **kwargs) -def _check_namespace_exists(ns: Optional[str]): +def _check_namespace_exists(ns: str | None): """Check if a namespace exists. If it doesn't, print a helpful error message.""" if ns is None: return @@ -186,7 +179,7 @@ def _check_namespace_exists(ns: Optional[str]): raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}") -def _check_name_exists(ns: Optional[str], name: str): +def _check_name_exists(ns: str | None, name: str): """Check if an env exists in a namespace. If it doesn't, print a helpful error message.""" _check_namespace_exists(ns) names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns} @@ -203,8 +196,9 @@ def _check_name_exists(ns: Optional[str], name: str): ) -def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]): +def _check_version_exists(ns: str | None, name: str, version: int | None): """Check if an env version exists in a namespace. If it doesn't, print a helpful error message. + This is a complete test whether an environment identifier is valid, and will provide the best available hints. Args: @@ -258,8 +252,9 @@ def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]): ) -def find_highest_version(ns: Optional[str], name: str) -> Optional[int]: - version: List[int] = [ +def find_highest_version(ns: str | None, name: str) -> int | None: + """Finds the highest registered version of the environment in the registry.""" + version: list[int] = [ spec_.version for spec_ in registry.values() if spec_.namespace == ns and spec_.name == name and spec_.version is not None @@ -268,6 +263,11 @@ def find_highest_version(ns: Optional[str], name: str) -> Optional[int]: def load_env_plugins(entry_point: str = "gymnasium.envs") -> None: + """Load modules (plugins) using the gymnasium entry points == to `entry_points`. + + Args: + entry_point: The string for the entry point. + """ # Load third-party environments for plugin in metadata.entry_points(group=entry_point): # Python 3.8 doesn't support plugin.module, plugin.attr @@ -323,37 +323,37 @@ def make(id: EnvSpec, **kwargs) -> Env: ... # Classic control # ---------------------------------------- @overload -def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... @overload -def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... @overload -def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... +def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ... @overload -def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... +def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ... @overload -def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... # Box2d # ---------------------------------------- @overload -def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... @overload -def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... +def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ... @overload -def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... +def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ... # Toy Text # ---------------------------------------- @overload -def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... @overload -def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... @overload -def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... @overload -def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... +def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... # Mujoco @@ -376,8 +376,8 @@ def make(id: Literal[ # Global registry of environments. Meant to be accessed through `register` and `make` -registry: Dict[str, EnvSpec] = {} -current_namespace: Optional[str] = None +registry: dict[str, EnvSpec] = {} +current_namespace: str | None = None def _check_spec_register(spec: EnvSpec): @@ -445,6 +445,7 @@ def _check_metadata(metadata_: dict): @contextlib.contextmanager def namespace(ns: str): + """Context manager for modifying the current namespace.""" global current_namespace old_namespace = current_namespace current_namespace = ns @@ -454,10 +455,10 @@ def namespace(ns: str): def register( id: str, - entry_point: Union[Callable, str], - reward_threshold: Optional[float] = None, + entry_point: Callable | str, + reward_threshold: float | None = None, nondeterministic: bool = False, - max_episode_steps: Optional[int] = None, + max_episode_steps: int | None = None, order_enforce: bool = True, autoreset: bool = False, disable_env_checker: bool = False, @@ -521,11 +522,11 @@ def register( def make( - id: Union[str, EnvSpec], - max_episode_steps: Optional[int] = None, + id: str | EnvSpec, + max_episode_steps: int | None = None, autoreset: bool = False, - apply_api_compatibility: Optional[bool] = None, - disable_env_checker: Optional[bool] = None, + apply_api_compatibility: bool | None = None, + disable_env_checker: bool | None = None, **kwargs, ) -> Env: """Create an environment according to the given ID. @@ -706,9 +707,9 @@ def spec(env_id: str) -> EnvSpec: def pprint_registry( _registry: dict = registry, num_cols: int = 3, - exclude_namespaces: Optional[List[str]] = None, + exclude_namespaces: list[str] | None = None, disable_print: bool = False, -) -> Optional[str]: +) -> str | None: """Pretty print the environments in the registry. Args: @@ -718,7 +719,6 @@ def pprint_registry( disable_print: Whether to return a string of all the namespaces and environment IDs instead of printing it to console. """ - # Defaultdict to store environment names according to namespace. namespace_envs = defaultdict(lambda: []) max_justify = float("-inf") diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 3a2a52b61..483952d60 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -19,14 +19,34 @@ from gymnasium.experimental.wrappers.lambda_observations import ( ReshapeObservationV0, RescaleObservationV0, DtypeObservationV0, + PixelObservationV0, + NormalizeObservationV0, ) -from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 -from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0 -from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0 +from gymnasium.experimental.wrappers.lambda_reward import ( + ClipRewardV0, + LambdaRewardV0, + NormalizeRewardV0, +) +from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0 +from gymnasium.experimental.wrappers.jax_to_torch import JaxToTorchV0 +from gymnasium.experimental.wrappers.numpy_to_torch import NumpyToTorchV0 from gymnasium.experimental.wrappers.stateful_action import StickyActionV0 from gymnasium.experimental.wrappers.stateful_observation import ( TimeAwareObservationV0, DelayObservationV0, + FrameStackObservationV0, +) +from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0 +from gymnasium.experimental.wrappers.common import ( + PassiveEnvCheckerV0, + OrderEnforcingV0, + AutoresetV0, + RecordEpisodeStatisticsV0, +) +from gymnasium.experimental.wrappers.rendering import ( + RenderCollectionV0, + RecordVideoV0, + HumanRenderingV0, ) __all__ = [ @@ -39,12 +59,12 @@ __all__ = [ "ReshapeObservationV0", "RescaleObservationV0", "DtypeObservationV0", - # "PixelObservationV0", - # "NormalizeObservationV0", + "PixelObservationV0", + "NormalizeObservationV0", "TimeAwareObservationV0", - # "FrameStackV0", + "FrameStackObservationV0", "DelayObservationV0", - # "AtariPreprocessingV0" + "AtariPreprocessingV0", # --- Action Wrappers --- "LambdaActionV0", "ClipActionV0", @@ -54,15 +74,18 @@ __all__ = [ # --- Reward wrappers --- "LambdaRewardV0", "ClipRewardV0", - # "RescaleRewardV0", - # "NormalizeRewardV0", + "NormalizeRewardV0", # --- Common --- - # "AutoReset", - # "PassiveEnvChecker", - # "OrderEnforcing", - # "RecordEpisodeStatistics", - # "RenderCollection", - # "HumanRendering", + "AutoresetV0", + "PassiveEnvCheckerV0", + "OrderEnforcingV0", + "RecordEpisodeStatisticsV0", + # --- Rendering --- + "RenderCollectionV0", + "RecordVideoV0", + "HumanRenderingV0", + # --- Data Conversion --- "JaxToNumpyV0", "JaxToTorchV0", + "NumpyToTorchV0", ] diff --git a/gymnasium/experimental/wrappers/atari_preprocessing.py b/gymnasium/experimental/wrappers/atari_preprocessing.py new file mode 100644 index 000000000..7dfeaf1a8 --- /dev/null +++ b/gymnasium/experimental/wrappers/atari_preprocessing.py @@ -0,0 +1,193 @@ +"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018.""" +import numpy as np + +import gymnasium as gym +from gymnasium.spaces import Box + + +try: + import cv2 +except ImportError: + cv2 = None + + +class AtariPreprocessingV0(gym.Wrapper): + """Atari 2600 preprocessing wrapper. + + This class follows the guidelines in Machado et al. (2018), + "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents". + + Specifically, the following preprocess stages applies to the atari environment: + + - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops. + - Frame skipping: The number of frames skipped between steps, 4 by default + - Max-pooling: Pools over the most recent two observations from the frame skips + - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated. + Turned off by default. Not recommended by Machado et al. (2018). + - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default + - Grayscale observation: If the observation is colour or greyscale, by default, greyscale. + - Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled. + """ + + def __init__( + self, + env: gym.Env, + noop_max: int = 30, + frame_skip: int = 4, + screen_size: int = 84, + terminal_on_life_loss: bool = False, + grayscale_obs: bool = True, + grayscale_newaxis: bool = False, + scale_obs: bool = False, + ): + """Wrapper for Atari 2600 preprocessing. + + Args: + env (Env): The environment to apply the preprocessing + noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. + frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. + screen_size (int): resize Atari frame + terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a + life is lost. + grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation + is returned. + grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to + grayscale observations to make them 3-dimensional. + scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory + optimization benefits of FrameStack Wrapper. + + Raises: + DependencyNotInstalled: opencv-python package not installed + ValueError: Disable frame-skipping in the original env + """ + super().__init__(env) + if cv2 is None: + raise gym.error.DependencyNotInstalled( + "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" + ) + assert frame_skip > 0 + assert screen_size > 0 + assert noop_max >= 0 + if frame_skip > 1: + if ( + env.spec is not None + and "NoFrameskip" not in env.spec.id + and getattr(env.unwrapped, "_frameskip", None) != 1 + ): + raise ValueError( + "Disable frame-skipping in the original env. Otherwise, more than one " + "frame-skip will happen as through this wrapper" + ) + self.noop_max = noop_max + assert env.unwrapped.get_action_meanings()[0] == "NOOP" + + self.frame_skip = frame_skip + self.screen_size = screen_size + self.terminal_on_life_loss = terminal_on_life_loss + self.grayscale_obs = grayscale_obs + self.grayscale_newaxis = grayscale_newaxis + self.scale_obs = scale_obs + + # buffer of most recent two observations for max pooling + assert isinstance(env.observation_space, Box) + if grayscale_obs: + self.obs_buffer = [ + np.empty(env.observation_space.shape[:2], dtype=np.uint8), + np.empty(env.observation_space.shape[:2], dtype=np.uint8), + ] + else: + self.obs_buffer = [ + np.empty(env.observation_space.shape, dtype=np.uint8), + np.empty(env.observation_space.shape, dtype=np.uint8), + ] + + self.lives = 0 + self.game_over = False + + _low, _high, _obs_dtype = ( + (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32) + ) + _shape = (screen_size, screen_size, 1 if grayscale_obs else 3) + if grayscale_obs and not grayscale_newaxis: + _shape = _shape[:-1] # Remove channel axis + self.observation_space = Box( + low=_low, high=_high, shape=_shape, dtype=_obs_dtype + ) + + @property + def ale(self): + """Make ale as a class property to avoid serialization error.""" + return self.env.unwrapped.ale + + def step(self, action): + """Applies the preprocessing for an :meth:`env.step`.""" + total_reward, terminated, truncated, info = 0.0, False, False, {} + + for t in range(self.frame_skip): + _, reward, terminated, truncated, info = self.env.step(action) + total_reward += reward + self.game_over = terminated + + if self.terminal_on_life_loss: + new_lives = self.ale.lives() + terminated = terminated or new_lives < self.lives + self.game_over = terminated + self.lives = new_lives + + if terminated or truncated: + break + if t == self.frame_skip - 2: + if self.grayscale_obs: + self.ale.getScreenGrayscale(self.obs_buffer[1]) + else: + self.ale.getScreenRGB(self.obs_buffer[1]) + elif t == self.frame_skip - 1: + if self.grayscale_obs: + self.ale.getScreenGrayscale(self.obs_buffer[0]) + else: + self.ale.getScreenRGB(self.obs_buffer[0]) + return self._get_obs(), total_reward, terminated, truncated, info + + def reset(self, **kwargs): + """Resets the environment using preprocessing.""" + # NoopReset + _, reset_info = self.env.reset(**kwargs) + + noops = ( + self.env.unwrapped.np_random.integers(1, self.noop_max + 1) + if self.noop_max > 0 + else 0 + ) + for _ in range(noops): + _, _, terminated, truncated, step_info = self.env.step(0) + reset_info.update(step_info) + if terminated or truncated: + _, reset_info = self.env.reset(**kwargs) + + self.lives = self.ale.lives() + if self.grayscale_obs: + self.ale.getScreenGrayscale(self.obs_buffer[0]) + else: + self.ale.getScreenRGB(self.obs_buffer[0]) + self.obs_buffer[1].fill(0) + + return self._get_obs(), reset_info + + def _get_obs(self): + if self.frame_skip > 1: # more efficient in-place pooling + np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0]) + assert cv2 is not None + obs = cv2.resize( + self.obs_buffer[0], + (self.screen_size, self.screen_size), + interpolation=cv2.INTER_AREA, + ) + + if self.scale_obs: + obs = np.asarray(obs, dtype=np.float32) / 255.0 + else: + obs = np.asarray(obs, dtype=np.uint8) + + if self.grayscale_obs and self.grayscale_newaxis: + obs = np.expand_dims(obs, axis=-1) # Add a channel axis + return obs diff --git a/gymnasium/experimental/wrappers/common.py b/gymnasium/experimental/wrappers/common.py new file mode 100644 index 000000000..214aff835 --- /dev/null +++ b/gymnasium/experimental/wrappers/common.py @@ -0,0 +1,281 @@ +"""A collection of common wrappers. + +* ``AutoresetV0`` - Auto-resets the environment +* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data +* ``OrderEnforcingV0`` - Enforces the order of function calls to environments +* ``RecordEpisodeStatisticsV0`` - Records the episode statistics +""" +from __future__ import annotations + +import time +from collections import deque +from typing import Any, SupportsFloat + +import numpy as np + +import gymnasium as gym +from gymnasium import Env +from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType +from gymnasium.error import ResetNeeded +from gymnasium.utils.passive_env_checker import ( + check_action_space, + check_observation_space, + env_render_passive_checker, + env_reset_passive_checker, + env_step_passive_checker, +) + + +class AutoresetV0(gym.Wrapper): + """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.""" + + def __init__(self, env: gym.Env): + """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. + + Args: + env (gym.Env): The environment to apply the wrapper + """ + super().__init__(env) + self._episode_ended: bool = False + self._reset_options: dict[str, Any] | None = None + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step. + + Args: + action: The action to take + + Returns: + The autoreset environment :meth:`step` + """ + if self._episode_ended: + obs, info = super().reset(options=self._reset_options) + self._episode_ended = True + return obs, 0, False, False, info + else: + obs, reward, terminated, truncated, info = super().step(action) + self._episode_ended = terminated or truncated + return obs, reward, terminated, truncated, info + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment, saving the options used.""" + self._episode_ended = False + self._reset_options = options + return super().reset(seed=seed, options=self._reset_options) + + +class PassiveEnvCheckerV0(gym.Wrapper): + """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" + + def __init__(self, env: Env[ObsType, ActType]): + """Initialises the wrapper with the environments, run the observation and action space tests.""" + super().__init__(env) + + assert hasattr( + env, "action_space" + ), "The environment must specify an action space. https://gymnasium.farama.org/content/environment_creation/" + check_action_space(env.action_space) + assert hasattr( + env, "observation_space" + ), "The environment must specify an observation space. https://gymnasium.farama.org/content/environment_creation/" + check_observation_space(env.observation_space) + + self._checked_reset: bool = False + self._checked_step: bool = False + self._checked_render: bool = False + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment that on the first call will run the `passive_env_step_check`.""" + if self._checked_step is False: + self._checked_step = True + return env_step_passive_checker(self.env, action) + else: + return self.env.step(action) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment that on the first call will run the `passive_env_reset_check`.""" + if self._checked_reset is False: + self._checked_reset = True + return env_reset_passive_checker(self.env, seed=seed, options=options) + else: + return self.env.reset(seed=seed, options=options) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Renders the environment that on the first call will run the `passive_env_render_check`.""" + if self._checked_render is False: + self._checked_render = True + return env_render_passive_checker(self.env) + else: + return self.env.render() + + +class OrderEnforcingV0(gym.Wrapper): + """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. + + Example: + >>> from gymnasium.envs.classic_control import CartPoleEnv + >>> env = CartPoleEnv() + >>> env = OrderEnforcingV0(env) + >>> env.step(0) + ResetNeeded: Cannot call env.step() before calling env.reset() + >>> env.render() + ResetNeeded: Cannot call env.render() before calling env.reset() + >>> env.reset() + >>> env.render() + >>> env.step(0) + """ + + def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): + """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. + + Args: + env: The environment to wrap + disable_render_order_enforcing: If to disable render order enforcing + """ + super().__init__(env) + self._has_reset: bool = False + self._disable_render_order_enforcing: bool = disable_render_order_enforcing + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Steps through the environment with `kwargs`.""" + if not self._has_reset: + raise ResetNeeded("Cannot call env.step() before calling env.reset()") + return super().step(action) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment with `kwargs`.""" + self._has_reset = True + return super().reset(seed=seed, options=options) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Renders the environment with `kwargs`.""" + if not self._disable_render_order_enforcing and not self._has_reset: + raise ResetNeeded( + "Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, " + "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper." + ) + return super().render() + + @property + def has_reset(self): + """Returns if the environment has been reset before.""" + return self._has_reset + + +class RecordEpisodeStatisticsV0(gym.Wrapper): + """This wrapper will keep track of cumulative rewards and episode lengths. + + At the end of an episode, the statistics of the episode will be added to ``info`` + using the key ``episode``. If using a vectorized environment also the key + ``_episode`` is used which indicates whether the env at the respective index has + the episode statistics. + + After the completion of an episode, ``info`` will look like this:: + + >>> info = { + ... ... + ... "episode": { + ... "r": "", + ... "l": "", + ... "t": "" + ... }, + ... } + + For a vectorized environments the output will be in the form of:: + + >>> infos = { + ... ... + ... "episode": { + ... "r": "", + ... "l": "", + ... "t": "" + ... }, + ... "_episode": "" + ... } + + Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via + :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. + + Attributes: + episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes + episode_length_buffer: The lengths of the last ``deque_size``-many episodes + """ + + def __init__( + self, + env: Env[ObsType, ActType], + buffer_length: int | None = 100, + stats_key: str = "episode", + ): + """This wrapper will keep track of cumulative rewards and episode lengths. + + Args: + env (Env): The environment to apply the wrapper + buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue` + stats_key: The info key for the episode statistics + """ + super().__init__(env) + + self._stats_key = stats_key + + self.episode_count = 0 + self.episode_start_time: float = -1 + self.episode_reward: float = -1 + self.episode_length: int = -1 + + self.episode_time_length_buffer = deque(maxlen=buffer_length) + self.episode_reward_buffer = deque(maxlen=buffer_length) + self.episode_length_buffer = deque(maxlen=buffer_length) + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment, recording the episode statistics.""" + obs, reward, terminated, truncated, info = super().step(action) + + self.episode_reward += reward + self.episode_length += 1 + + if terminated or truncated: + assert self._stats_key not in info + + episode_time_length = np.round( + time.perf_counter() - self.episode_start_time, 6 + ) + info[self._stats_key] = { + "r": self.episode_reward, + "l": self.episode_length, + "t": episode_time_length, + } + + self.episode_time_length_buffer.append(episode_time_length) + self.episode_reward_buffer.append(self.episode_reward) + self.episode_length_buffer.append(self.episode_length) + + self.episode_count += 1 + + return obs, reward, terminated, truncated, info + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment using seed and options and resets the episode rewards and lengths.""" + obs, info = super().reset(seed=seed, options=options) + + self.episode_start_time = time.perf_counter() + self.episode_reward = 0 + self.episode_length = 0 + + return obs, info diff --git a/gymnasium/experimental/wrappers/numpy_to_jax.py b/gymnasium/experimental/wrappers/jax_to_numpy.py similarity index 100% rename from gymnasium/experimental/wrappers/numpy_to_jax.py rename to gymnasium/experimental/wrappers/jax_to_numpy.py diff --git a/gymnasium/experimental/wrappers/torch_to_jax.py b/gymnasium/experimental/wrappers/jax_to_torch.py similarity index 99% rename from gymnasium/experimental/wrappers/torch_to_jax.py rename to gymnasium/experimental/wrappers/jax_to_torch.py index 36686e217..961549a3b 100644 --- a/gymnasium/experimental/wrappers/torch_to_jax.py +++ b/gymnasium/experimental/wrappers/jax_to_torch.py @@ -17,7 +17,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union from gymnasium import Env, Wrapper from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy +from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy try: diff --git a/gymnasium/experimental/wrappers/lambda_observations.py b/gymnasium/experimental/wrappers/lambda_observations.py index 2b30a4a8b..0ab46fea8 100644 --- a/gymnasium/experimental/wrappers/lambda_observations.py +++ b/gymnasium/experimental/wrappers/lambda_observations.py @@ -1,13 +1,15 @@ """A collection of observation wrappers using a lambda function. -* ``LambdaObservation`` - Transforms the observation with a function -* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys -* ``FlattenObservation`` - Flattens the observations -* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation -* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation) -* ``ReshapeObservation`` - Reshapes an array-based observation -* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value -* ``DtypeObservation`` - Convert a observation dtype +* ``LambdaObservationV0`` - Transforms the observation with a function +* ``FilterObservationV0`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys +* ``FlattenObservationV0`` - Flattens the observations +* ``GrayscaleObservationV0`` - Converts a RGB observation to a grayscale observation +* ``ResizeObservationV0`` - Resizes an array-based observation (normally a RGB observation) +* ``ReshapeObservationV0`` - Reshapes an array-based observation +* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value +* ``DtypeObservationV0`` - Convert an observation to a dtype +* ``PixelObservationV0`` - Allows the observation to the rendered frame +* ``NormalizeObservationV0`` - Normalized the observations to a mean and """ from __future__ import annotations @@ -18,10 +20,11 @@ import jumpy as jp import numpy as np import gymnasium as gym -from gymnasium import spaces -from gymnasium.core import ObsType +from gymnasium import Env, spaces +from gymnasium.core import ActType, ObservationWrapper, ObsType, WrapperObsType from gymnasium.error import DependencyNotInstalled -from gymnasium.spaces import Box, utils +from gymnasium.experimental.wrappers.utils import RunningMeanStd +from gymnasium.spaces import Box, Dict, utils class LambdaObservationV0(gym.ObservationWrapper): @@ -407,3 +410,83 @@ class DtypeObservationV0(LambdaObservationV0): ) super().__init__(env, lambda obs: dtype(obs), new_observation_space) + + +class PixelObservationV0(LambdaObservationV0): + """Augment observations by pixel values. + + Observations of this wrapper will be dictionaries of images. + You can also choose to add the observation of the base environment to this dictionary. + In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary + of rendered images will be updated with the base environment's observation. If, however, the observation + space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box` + space) will be added to the dictionary under the key "state". + """ + + def __init__( + self, + env: Env[ObsType, ActType], + pixels_only: bool = True, + pixels_key: str = "pixels", + obs_key: str = "state", + ): + """Initializes a new pixel Wrapper. + + Args: + env: The environment to wrap. + pixels_only (bool): If `True` (default), the original observation returned + by the wrapped environment will be discarded, and a dictionary + observation will only include pixels. If `False`, the + observation dictionary will contain both the original + observations and the pixel observations. + pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels" + obs_key: Optional custom string specifying the obs key. Defaults to "state" + """ + assert env.render_mode is not None and env.render_mode != "human" + env.reset() + pixels = env.render() + assert pixels is not None and isinstance(pixels, np.ndarray) + pixel_space = Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8) + + if pixels_only: + obs_space = pixel_space + super().__init__(env, lambda _: self.render(), obs_space) + elif isinstance(env.observation_space, Dict): + assert pixels_key not in env.observation_space.spaces.keys() + + obs_space = Dict({pixels_key: pixel_space, **env.observation_space.spaces}) + super().__init__( + env, lambda obs: {pixels_key: self.render(), **obs_space}, obs_space + ) + else: + obs_space = Dict({obs_key: env.observation_space, pixels_key: pixel_space}) + super().__init__( + env, lambda obs: {obs_key: obs, pixels_key: self.render()}, obs_space + ) + + +class NormalizeObservationV0(ObservationWrapper): + """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + + Note: + The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was + newly instantiated or the policy was changed recently. + """ + + def __init__(self, env: gym.Env, epsilon: float = 1e-8): + """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + + Args: + env (Env): The environment to apply the wrapper + epsilon: A stability parameter that is used when scaling the observations. + """ + super().__init__(env) + self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) + self.epsilon = epsilon + + def observation(self, observation: ObsType) -> WrapperObsType: + """Normalises the observation using the running mean and variance of the observations.""" + self.obs_rms.update(observation) + return (observation - self.obs_rms.mean) / np.sqrt( + self.obs_rms.var + self.epsilon + ) diff --git a/gymnasium/experimental/wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py index 111f1157c..19717a81f 100644 --- a/gymnasium/experimental/wrappers/lambda_reward.py +++ b/gymnasium/experimental/wrappers/lambda_reward.py @@ -6,12 +6,14 @@ from __future__ import annotations -from typing import Callable, SupportsFloat +from typing import Any, Callable, SupportsFloat import numpy as np import gymnasium as gym +from gymnasium.core import WrapperActType, WrapperObsType from gymnasium.error import InvalidBound +from gymnasium.experimental.wrappers.utils import RunningMeanStd class LambdaRewardV0(gym.RewardWrapper): @@ -89,3 +91,48 @@ class ClipRewardV0(LambdaRewardV0): ) super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)) + + +class NormalizeRewardV0(gym.Wrapper): + r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. + + The exponential moving average will have variance :math:`(1 - \gamma)^2`. + + Note: + The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly + instantiated or the policy was changed recently. + """ + + def __init__( + self, + env: gym.Env, + gamma: float = 0.99, + epsilon: float = 1e-8, + ): + """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. + + Args: + env (env): The environment to apply the wrapper + epsilon (float): A stability parameter + gamma (float): The discount factor that is used in the exponential moving average. + """ + super().__init__(env) + self.rewards_running_means = RunningMeanStd(shape=()) + self.discounted_reward: float = 0.0 + self.gamma = gamma + self.epsilon = epsilon + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment, normalizing the reward returned.""" + obs, reward, terminated, truncated, info = super().step(action) + self.discounted_reward = self.discounted_reward * self.gamma * ( + 1 - terminated + ) + float(reward) + return obs, self.normalize(float(reward)), terminated, truncated, info + + def normalize(self, reward): + """Normalizes the rewards with the running mean rewards and their variance.""" + self.rewards_running_means.update(self.discounted_reward) + return reward / np.sqrt(self.rewards_running_means.var + self.epsilon) diff --git a/gymnasium/experimental/wrappers/numpy_to_torch.py b/gymnasium/experimental/wrappers/numpy_to_torch.py new file mode 100644 index 000000000..f123455d2 --- /dev/null +++ b/gymnasium/experimental/wrappers/numpy_to_torch.py @@ -0,0 +1,158 @@ +"""Helper functions and wrapper class for converting between PyTorch and NumPy.""" +from __future__ import annotations + +import functools +import numbers +from collections import abc +from typing import Any, Iterable, Mapping, SupportsFloat, Union + +import numpy as np + +from gymnasium import Env, Wrapper +from gymnasium.core import WrapperActType, WrapperObsType +from gymnasium.error import DependencyNotInstalled + + +try: + import torch + + Device = Union[str, torch.device] +except ImportError: + torch, Device = None, None + + +@functools.singledispatch +def torch_to_numpy(value: Any) -> Any: + """Converts a PyTorch Tensor into a NumPy Array.""" + if torch is None: + raise DependencyNotInstalled( + "Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`" + ) + else: + raise Exception( + f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github." + ) + + +if torch is not None: + + @torch_to_numpy.register(numbers.Number) + @torch_to_numpy.register(torch.Tensor) + def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any: + """Convert a python number (int, float, complex) and torch.Tensor to a numpy array.""" + return np.array(value) + + @torch_to_numpy.register(abc.Mapping) + def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" + return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()}) + + @torch_to_numpy.register(abc.Iterable) + def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]: + """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" + return type(value)(torch_to_numpy(v) for v in value) + + +@functools.singledispatch +def numpy_to_torch(value: Any, device: Device | None = None) -> Any: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + if torch is None: + raise DependencyNotInstalled( + "Torch is not installed therefore cannot call `numpy_to_torch`, run `pip install torch`" + ) + else: + raise Exception( + f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github." + ) + + +if torch is not None: + + @numpy_to_torch.register(np.ndarray) + def _numpy_to_torch( + value: np.ndarray, device: Device | None = None + ) -> torch.Tensor: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + assert torch is not None + tensor = torch.tensor(value) + if device: + return tensor.to(device=device) + return tensor + + @numpy_to_torch.register(abc.Mapping) + def _numpy_mapping_to_torch( + value: Mapping[str, Any], device: Device | None = None + ) -> Mapping[str, Any]: + """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" + return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()}) + + @numpy_to_torch.register(abc.Iterable) + def _numpy_iterable_to_torch( + value: Iterable[Any], device: Device | None = None + ) -> Iterable[Any]: + """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" + return type(value)(numpy_to_torch(v, device) for v in value) + + +class NumpyToTorchV0(Wrapper): + """Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors. + + Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. + + Note: + For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. + """ + + def __init__(self, env: Env, device: Device | None = None): + """Wrapper class to change inputs and outputs of environment to PyTorch tensors. + + Args: + env: The Jax-based environment to wrap + device: The device the torch Tensors should be moved to + """ + if torch is None: + raise DependencyNotInstalled( + "Torch is not installed, run `pip install torch`" + ) + + super().__init__(env) + self.device: Device | None = device + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Performs the given action within the environment. + + Args: + action: The action to perform as a PyTorch Tensor + + Returns: + The next observation, reward, termination, truncation, and extra info + """ + jax_action = torch_to_numpy(action) + obs, reward, terminated, truncated, info = self.env.step(jax_action) + + return ( + numpy_to_torch(obs, self.device), + float(reward), + bool(terminated), + bool(truncated), + numpy_to_torch(info, self.device), + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment returning PyTorch-based observation and info. + + Args: + seed: The seed for resetting the environment + options: The options for resetting the environment, these are converted to jax arrays. + + Returns: + PyTorch-based observations and info + """ + if options: + options = torch_to_numpy(options) + + return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device) diff --git a/gymnasium/experimental/wrappers/rendering.py b/gymnasium/experimental/wrappers/rendering.py new file mode 100644 index 000000000..860a21e25 --- /dev/null +++ b/gymnasium/experimental/wrappers/rendering.py @@ -0,0 +1,219 @@ +"""A collections of rendering-based wrappers. + +* ``RenderCollectionV0`` - Collects rendered frames into a list +* ``RecordVideoV0`` - Records a video of the environments +* ``HumanRenderingV0`` - Provides human rendering of environments with ``"rgb_array"`` +""" +from __future__ import annotations + +from copy import deepcopy +from typing import Any, SupportsFloat + +import numpy as np + +import gymnasium as gym +from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType +from gymnasium.error import DependencyNotInstalled + + +class RenderCollectionV0(gym.Wrapper): + """Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.""" + + def __init__( + self, + env: gym.Env[ObsType, ActType], + pop_frames: bool = True, + reset_clean: bool = True, + ): + """Initialize a :class:`RenderCollection` instance. + + Args: + env: The environment that is being wrapped + pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``. + reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``. + """ + super().__init__(env) + assert env.render_mode is not None + assert not env.render_mode.endswith("_list") + + self.frame_list: list[RenderFrame] = [] + self.pop_frames = pop_frames + self.reset_clean = reset_clean + + self.metadata = deepcopy(self.env.metadata) + if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]: + self.metadata["render_modes"].append(f"{self.env.render_mode}_list") + + @property + def render_mode(self): + """Returns the collection render_mode name.""" + return f"{self.env.render_mode}_list" + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Perform a step in the base environment and collect a frame.""" + output = super().step(action) + self.frame_list.append(super().render()) + return output + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Reset the base environment, eventually clear the frame_list, and collect a frame.""" + output = super().reset(seed=seed, options=options) + + if self.reset_clean: + self.frame_list = [] + self.frame_list.append(super().render()) + + return output + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Returns the collection of frames and, if pop_frames = True, clears it.""" + frames = self.frame_list + if self.pop_frames: + self.frame_list = [] + + return frames + + +class RecordVideoV0(gym.Wrapper): + """Record a video of an environment.""" + + pass + + +class HumanRenderingV0(gym.Wrapper): + """Performs human rendering for an environment that only supports "rgb_array"rendering. + + This wrapper is particularly useful when you have implemented an environment that can produce + RGB images but haven't implemented any code to render the images to the screen. + If you want to use this wrapper with your environments, remember to specify ``"render_fps"`` + in the metadata of your environment. + + The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``. + + Example: + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") + >>> wrapped = HumanRenderingV0(env) + >>> wrapped.reset() # This will start rendering to the screen + + The wrapper can also be applied directly when the environment is instantiated, simply by passing + ``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not + implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``). + + Example: + >>> env = gym.make("NoNativeRendering-v2", render_mode="human") # NoNativeRendering-v0 doesn't implement human-rendering natively + >>> env.reset() # This will start rendering to the screen + + Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method + will always return an empty list: + + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list") + >>> wrapped = HumanRenderingV0(env) + >>> wrapped.reset() + >>> env.render() + [] # env.render() will always return an empty list! + + """ + + def __init__(self, env): + """Initialize a :class:`HumanRendering` instance. + + Args: + env: The environment that is being wrapped + """ + super().__init__(env) + assert env.render_mode in [ + "rgb_array", + "rgb_array_list", + ], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'" + assert ( + "render_fps" in env.metadata + ), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper" + + self.screen_size = None + self.window = None + self.clock = None + + if "human" not in self.metadata["render_modes"]: + self.metadata = deepcopy(self.env.metadata) + self.metadata["render_modes"].append("human") + + @property + def render_mode(self): + """Always returns ``'human'``.""" + return "human" + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Perform a step in the base environment and render a frame to the screen.""" + result = super().step(action) + self._render_frame() + return result + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Reset the base environment and render a frame to the screen.""" + result = super().reset(seed=seed, options=options) + self._render_frame() + return result + + def render(self): + """This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`.""" + return None + + def _render_frame(self): + """Fetch the last frame from the base environment and render it to the screen.""" + try: + import pygame + except ImportError: + raise DependencyNotInstalled( + "pygame is not installed, run `pip install gymnasium[box2d]`" + ) + if self.env.render_mode == "rgb_array_list": + last_rgb_array = self.env.render() + assert isinstance(last_rgb_array, list) + last_rgb_array = last_rgb_array[-1] + elif self.env.render_mode == "rgb_array": + last_rgb_array = self.env.render() + else: + raise Exception( + f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}" + ) + assert isinstance(last_rgb_array, np.ndarray) + + rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2)) + + if self.screen_size is None: + self.screen_size = rgb_array.shape[:2] + + assert ( + self.screen_size == rgb_array.shape[:2] + ), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}" + + if self.window is None: + pygame.init() + pygame.display.init() + self.window = pygame.display.set_mode(self.screen_size) + + if self.clock is None: + self.clock = pygame.time.Clock() + + surf = pygame.surfarray.make_surface(rgb_array) + self.window.blit(surf, (0, 0)) + pygame.event.pump() + self.clock.tick(self.metadata["render_fps"]) + pygame.display.flip() + + def close(self): + """Close the rendering window.""" + if self.window is not None: + import pygame + + pygame.display.quit() + pygame.quit() + super().close() diff --git a/gymnasium/experimental/wrappers/stateful_action.py b/gymnasium/experimental/wrappers/stateful_action.py index 7527c12a4..5d3a821c5 100644 --- a/gymnasium/experimental/wrappers/stateful_action.py +++ b/gymnasium/experimental/wrappers/stateful_action.py @@ -1,17 +1,14 @@ -"""A collection of stateful action wrappers. - -* StickyAction - There is a probability that the action is taken again -""" +"""``StickyAction`` wrapper - There is a probability that the action is taken again.""" from __future__ import annotations -from typing import Any, SupportsFloat +from typing import Any import gymnasium as gym -from gymnasium.core import WrapperActType, WrapperObsType +from gymnasium.core import ActionWrapper, ActType, WrapperActType, WrapperObsType from gymnasium.error import InvalidProbability -class StickyActionV0(gym.Wrapper): +class StickyActionV0(ActionWrapper): """Wrapper which adds a probability of repeating the previous action. This wrapper follows the implementation proposed by `Machado et al., 2018 `_ @@ -42,9 +39,7 @@ class StickyActionV0(gym.Wrapper): return super().reset(seed=seed, options=options) - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + def action(self, action: WrapperActType) -> ActType: """Execute the action.""" if ( self.last_action is not None diff --git a/gymnasium/experimental/wrappers/stateful_observation.py b/gymnasium/experimental/wrappers/stateful_observation.py index ffc96de89..f6f31537b 100644 --- a/gymnasium/experimental/wrappers/stateful_observation.py +++ b/gymnasium/experimental/wrappers/stateful_observation.py @@ -1,7 +1,9 @@ """A collection of stateful observation wrappers. -* DelayObservation - A wrapper for delaying the returned observation -* TimeAwareObservation - A wrapper for adding time aware observations to environment observation +* ``DelayObservationV0`` - A wrapper for delaying the returned observation +* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation +* ``FrameStackObservationV0`` - Frame stack the observations +* ``AtariPreprocessingV0`` - Preprocessing wrapper for atari environments """ from __future__ import annotations @@ -14,8 +16,10 @@ import numpy as np import gymnasium as gym import gymnasium.spaces as spaces -from gymnasium.core import ActType, ObsType, WrapperObsType +from gymnasium import Env +from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple +from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate class DelayObservationV0(gym.ObservationWrapper): @@ -31,7 +35,9 @@ class DelayObservationV0(gym.ObservationWrapper): returned observation is an array of zeros with the same shape of the observation space. """ - assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete)) + assert isinstance( + env.observation_space, (Box, MultiBinary, MultiDiscrete) + ), type(env.observation_space) assert 0 < delay self.delay: Final[int] = delay @@ -134,9 +140,9 @@ class TimeAwareObservationV0(gym.ObservationWrapper): if isinstance(env.observation_space, Dict): assert dict_time_key not in env.observation_space.keys() observation_space = Dict( - {dict_time_key: time_space}, **env.observation_space.spaces + {dict_time_key: time_space, **env.observation_space.spaces} ) - self._append_data_func = lambda obs, time: {**obs, dict_time_key: time} + self._append_data_func = lambda obs, time: {dict_time_key: time, **obs} elif isinstance(env.observation_space, Tuple): observation_space = Tuple(env.observation_space.spaces + (time_space,)) self._append_data_func = lambda obs, time: obs + (time,) @@ -198,3 +204,101 @@ class TimeAwareObservationV0(gym.ObservationWrapper): self.timesteps = 0 return super().reset(seed=seed, options=options) + + +class FrameStackObservationV0(gym.Wrapper): + """Observation wrapper that stacks the observations in a rolling manner. + + For example, if the number of stacks is 4, then the returned observation contains + the most recent 4 observations. For environment 'Pendulum-v1', the original observation + is an array with shape [3], so if we stack 4 observations, the processed observation + has shape [4, 3]. + + Note: + - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. + I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames. + + Example: + >>> import gymnasium as gym + >>> env = gym.make('CarRacing-v1') + >>> env = FrameStack(env, 4) + >>> env.observation_space + Box(4, 96, 96, 3) + >>> obs = env.reset() + >>> obs.shape + (4, 96, 96, 3) + """ + + def __init__(self, env: Env[ObsType, ActType], stack_size: int): + """Observation wrapper that stacks the observations in a rolling manner. + + Args: + env: The environment to apply the wrapper + stack_size: The number of frames to stack + """ + assert np.issubdtype(type(stack_size), np.integer) + assert stack_size > 0 + + super().__init__(env) + + self.observation_space = batch_space(env.observation_space, n=stack_size) + self.stack_size = stack_size + + self.stacked_obs_array = create_empty_array(env.observation_space, n=stack_size) + self.stacked_obs = self._init_stacked_obs() + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment, appending the observation to the frame buffer. + + Args: + action: The action to step through the environment with + + Returns: + Stacked observations, reward, terminated, truncated, and info from the environment + """ + obs, reward, terminated, truncated, info = super().step(action) + self.stacked_obs.rotate(1) + self.stacked_obs[0] = obs + + return ( + concatenate( + self.observation_space, self.stacked_obs, self.stacked_obs_array + ), + reward, + terminated, + truncated, + info, + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Reset the environment, returning the stacked observation and info. + + Args: + seed: The environment seed + options: The reset options + + Returns: + The stacked observations and info + """ + obs, info = super().reset(seed=seed, options=options) + self.stacked_obs = self._init_stacked_obs() + self.stacked_obs[0] = obs + + return ( + concatenate( + self.observation_space, self.stacked_obs, self.stacked_obs_array + ), + info, + ) + + def _init_stacked_obs(self) -> deque: + return deque( + iterate( + self.observation_space, + create_empty_array(self.env.observation_space, n=self.stack_size), + ) + ) diff --git a/gymnasium/experimental/wrappers/utils.py b/gymnasium/experimental/wrappers/utils.py new file mode 100644 index 000000000..08485fc12 --- /dev/null +++ b/gymnasium/experimental/wrappers/utils.py @@ -0,0 +1,43 @@ +"""Utility functions for the wrappers.""" +import numpy as np + + +class RunningMeanStd: + """Tracks the mean, variance and count of values.""" + + # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + def __init__(self, epsilon=1e-4, shape=()): + """Tracks the mean, variance and count of values.""" + self.mean = np.zeros(shape, "float64") + self.var = np.ones(shape, "float64") + self.count = epsilon + + def update(self, x): + """Updates the mean, var and count from a batch of samples.""" + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + batch_count = x.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean, batch_var, batch_count): + """Updates from batch mean, variance and count moments.""" + self.mean, self.var, self.count = update_mean_var_count_from_moments( + self.mean, self.var, self.count, batch_mean, batch_var, batch_count + ) + + +def update_mean_var_count_from_moments( + mean, var, count, batch_mean, batch_var, batch_count +): + """Updates the mean, var and count using the previous mean, var, count and batch values.""" + delta = batch_mean - mean + tot_count = count + batch_count + + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + return new_mean, new_var, new_count diff --git a/gymnasium/wrappers/normalize.py b/gymnasium/wrappers/normalize.py index ab0b4645e..4231717de 100644 --- a/gymnasium/wrappers/normalize.py +++ b/gymnasium/wrappers/normalize.py @@ -93,7 +93,7 @@ class NormalizeObservation(gym.Wrapper): return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon) -class NormalizeReward(gym.Wrapper): +class NormalizeReward(gym.core.Wrapper): r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. The exponential moving average will have variance :math:`(1 - \gamma)^2`. @@ -129,10 +129,8 @@ class NormalizeReward(gym.Wrapper): obs, rews, terminateds, truncateds, infos = self.env.step(action) if not self.is_vector_env: rews = np.array([rews]) - self.returns = self.returns * self.gamma + rews + self.returns = self.returns * self.gamma * (1 - terminateds) + rews rews = self.normalize(rews) - dones = np.logical_or(terminateds, truncateds) - self.returns[dones] = 0.0 if not self.is_vector_env: rews = rews[0] return obs, rews, terminateds, truncateds, infos diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..36676323d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Testing for Gymnasium.""" diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py index e69de29bb..203b13d74 100644 --- a/tests/experimental/__init__.py +++ b/tests/experimental/__init__.py @@ -0,0 +1 @@ +"""Testing suite for ``gymnasium.experimental``.""" diff --git a/tests/experimental/functional/__init__.py b/tests/experimental/functional/__init__.py index e69de29bb..ea8350ae9 100644 --- a/tests/experimental/functional/__init__.py +++ b/tests/experimental/functional/__init__.py @@ -0,0 +1 @@ +"""Module for functional environment API.""" diff --git a/tests/experimental/functional/test_jax.py b/tests/experimental/functional/test_func_jax_env.py similarity index 90% rename from tests/experimental/functional/test_jax.py rename to tests/experimental/functional/test_func_jax_env.py index 284b1a748..c934fa5b4 100644 --- a/tests/experimental/functional/test_jax.py +++ b/tests/experimental/functional/test_func_jax_env.py @@ -1,7 +1,8 @@ +"""Test the functional jax environment.""" + import jax import jax.numpy as jnp import jax.random as jrng -import numpy as np import pytest from gymnasium.envs.phys2d.cartpole import CartPoleFunctional @@ -9,7 +10,8 @@ from gymnasium.envs.phys2d.pendulum import PendulumFunctional @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) -def test_normal(env_class): +def test_without_transform(env_class): + """Tests the environment without transforming the environment.""" env = env_class() rng = jrng.PRNGKey(0) @@ -42,6 +44,7 @@ def test_normal(env_class): @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) def test_jit(env_class): + """Tests jitting the functional instance functions.""" env = env_class() rng = jrng.PRNGKey(0) @@ -75,6 +78,7 @@ def test_jit(env_class): @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) def test_vmap(env_class): + """Tests vmap of functional instance functions with transform.""" env = env_class() num_envs = 10 rng = jrng.split(jrng.PRNGKey(0), num_envs) @@ -98,7 +102,7 @@ def test_vmap(env_class): assert reward.shape == (num_envs,) assert reward.dtype == jnp.float32 assert terminal.shape == (num_envs,) - assert terminal.dtype == np.bool + assert terminal.dtype == bool assert isinstance(obs, jnp.ndarray) assert obs.dtype == jnp.float32 diff --git a/tests/experimental/functional/test_core.py b/tests/experimental/functional/test_functional.py similarity index 73% rename from tests/experimental/functional/test_core.py rename to tests/experimental/functional/test_functional.py index d1282bd4e..800441874 100644 --- a/tests/experimental/functional/test_core.py +++ b/tests/experimental/functional/test_functional.py @@ -1,4 +1,7 @@ -from typing import Any, Dict, Optional +"""Tests the functional api.""" +from __future__ import annotations + +from typing import Any import numpy as np @@ -6,29 +9,41 @@ from gymnasium.experimental import FuncEnv class GenericTestFuncEnv(FuncEnv): - def __init__(self, options: Optional[Dict[str, Any]] = None): + """Generic testing functional environment.""" + + def __init__(self, options: dict[str, Any] | None = None): + """Constructor that allows generic options to be set on the environment.""" super().__init__(options) def initial(self, rng: Any) -> np.ndarray: + """Testing initial function.""" return np.array([0, 0], dtype=np.float32) def observation(self, state: np.ndarray) -> np.ndarray: + """Testing observation function.""" return state def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray: + """Testing transition function.""" return state + np.array([0, action], dtype=np.float32) def reward(self, state: np.ndarray, action: int, next_state: np.ndarray) -> float: + """Testing reward function.""" return 1.0 if next_state[1] > 0 else 0.0 def terminal(self, state: np.ndarray) -> bool: + """Testing terminal function.""" return state[1] > 0 -def test_api(): +def test_functional_api(): + """Tests the core functional api specification using a generic testing environment.""" env = GenericTestFuncEnv() + state = env.initial(None) + obs = env.observation(state) + assert state.shape == (2,) assert state.dtype == np.float32 assert obs.shape == (2,) diff --git a/tests/experimental/wrappers/__init__.py b/tests/experimental/wrappers/__init__.py index e69de29bb..a100571ed 100644 --- a/tests/experimental/wrappers/__init__.py +++ b/tests/experimental/wrappers/__init__.py @@ -0,0 +1 @@ +"""Experimental wrapper module.""" diff --git a/tests/experimental/wrappers/human_rendering.py b/tests/experimental/wrappers/human_rendering.py new file mode 100644 index 000000000..2cce4d70f --- /dev/null +++ b/tests/experimental/wrappers/human_rendering.py @@ -0,0 +1 @@ +"""Test suite for HumanRenderingV0.""" diff --git a/tests/experimental/wrappers/test_atari_preprocessing.py b/tests/experimental/wrappers/test_atari_preprocessing.py new file mode 100644 index 000000000..fa0aac8d2 --- /dev/null +++ b/tests/experimental/wrappers/test_atari_preprocessing.py @@ -0,0 +1 @@ +"""Test suite for AtariPreprocessingV0.""" diff --git a/tests/experimental/wrappers/test_autoreset.py b/tests/experimental/wrappers/test_autoreset.py new file mode 100644 index 000000000..948815732 --- /dev/null +++ b/tests/experimental/wrappers/test_autoreset.py @@ -0,0 +1 @@ +"""Test suite for AutoresetV0.""" diff --git a/tests/experimental/wrappers/test_clip_action.py b/tests/experimental/wrappers/test_clip_action.py new file mode 100644 index 000000000..2f80a1183 --- /dev/null +++ b/tests/experimental/wrappers/test_clip_action.py @@ -0,0 +1,25 @@ +"""Test suite for ClipActionV0.""" + +import numpy as np + +from gymnasium.experimental.wrappers import ClipActionV0 +from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import record_action_step +from tests.testing_env import GenericTestEnv + + +def test_clip_action_wrapper(): + """Test that the action is correctly clipped to the base environment action space.""" + env = GenericTestEnv( + action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])), + step_func=record_action_step, + ) + wrapped_env = ClipActionV0(env) + + sampled_action = np.array([-1, 5, 3.5], dtype=np.float32) + assert sampled_action not in env.action_space + assert sampled_action in wrapped_env.action_space + + _, _, _, _, info = wrapped_env.step(sampled_action) + assert np.all(info["action"] in env.action_space) + assert np.all(info["action"] == np.array([0, 2, 3.5])) diff --git a/tests/experimental/wrappers/test_clip_reward.py b/tests/experimental/wrappers/test_clip_reward.py new file mode 100644 index 000000000..b88290daf --- /dev/null +++ b/tests/experimental/wrappers/test_clip_reward.py @@ -0,0 +1,67 @@ +"""Test suite for ClipRewardV0.""" +import numpy as np +import pytest + +import gymnasium as gym +from gymnasium.error import InvalidBound +from gymnasium.experimental.wrappers import ClipRewardV0 +from tests.envs.test_envs import SEED +from tests.experimental.wrappers.test_lambda_rewards import ( + DISCRETE_ACTION, + ENV_ID, + NUM_ENVS, +) + + +@pytest.mark.parametrize( + ("lower_bound", "upper_bound", "expected_reward"), + [(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)], +) +def test_clip_reward(lower_bound, upper_bound, expected_reward): + """Test reward clipping. + + Test if reward is correctly clipped accordingly to the input args. + """ + env = gym.make(ENV_ID) + env = ClipRewardV0(env, lower_bound, upper_bound) + env.reset(seed=SEED) + _, rew, _, _, _ = env.step(DISCRETE_ACTION) + + assert rew == expected_reward + + +@pytest.mark.parametrize( + ("lower_bound", "upper_bound", "expected_reward"), + [(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)], +) +def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward): + """Test reward clipping in vectorized environment. + + Test if reward is correctly clipped accordingly to the input args in a vectorized environment. + """ + actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)] + + env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS) + env = ClipRewardV0(env, lower_bound, upper_bound) + env.reset(seed=SEED) + + _, rew, _, _, _ = env.step(actions) + + assert np.alltrue(rew == expected_reward) + + +@pytest.mark.parametrize( + ("lower_bound", "upper_bound"), + [(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))], +) +def test_clip_reward_incorrect_params(lower_bound, upper_bound): + """Test reward clipping with incorrect params. + + Test whether passing wrong params to clip_rewards correctly raise an exception. + clip_rewards should raise an exception if, both low and upper bound of reward are `None` + or if upper bound is lower than lower bound. + """ + env = gym.make(ENV_ID) + + with pytest.raises(InvalidBound): + ClipRewardV0(env, lower_bound, upper_bound) diff --git a/tests/experimental/wrappers/test_delay_observation.py b/tests/experimental/wrappers/test_delay_observation.py new file mode 100644 index 000000000..6fce4e7ff --- /dev/null +++ b/tests/experimental/wrappers/test_delay_observation.py @@ -0,0 +1,34 @@ +"""Test suite for DelayObservationV0.""" +import numpy as np + +import gymnasium as gym +from gymnasium.experimental.wrappers import DelayObservationV0 +from tests.experimental.wrappers.utils import DELAY, NUM_STEPS, SEED + + +def test_delay_observation_wrapper(): + """Tests the delay observation wrapper.""" + env = gym.make("CartPole-v1") + env.action_space.seed(SEED) + env.reset(seed=SEED) + + undelayed_observations = [] + for _ in range(NUM_STEPS): + obs, _, _, _, _ = env.step(env.action_space.sample()) + undelayed_observations.append(obs) + + env = DelayObservationV0(env, delay=DELAY) + env.action_space.seed(SEED) + env.reset(seed=SEED) + + delayed_observations = [] + for i in range(NUM_STEPS): + obs, _, _, _, _ = env.step(env.action_space.sample()) + delayed_observations.append(obs) + if i < DELAY - 1: + assert np.all(obs == 0) + + undelayed_observations = np.array(undelayed_observations) + delayed_observations = np.array(delayed_observations) + + assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY]) diff --git a/tests/experimental/wrappers/test_dtype_observation.py b/tests/experimental/wrappers/test_dtype_observation.py new file mode 100644 index 000000000..1d6233638 --- /dev/null +++ b/tests/experimental/wrappers/test_dtype_observation.py @@ -0,0 +1,25 @@ +"""Test suite for DtypeObservationV0.""" +import numpy as np + +from gymnasium.experimental.wrappers import DtypeObservationV0 +from tests.experimental.wrappers.utils import ( + record_random_obs_reset, + record_random_obs_step, +) +from tests.testing_env import GenericTestEnv + + +def test_dtype_observation(): + """Test ``DtypeObservation`` that the dtype is corrected modified.""" + env = GenericTestEnv( + reset_func=record_random_obs_reset, step_func=record_random_obs_step + ) + wrapped_env = DtypeObservationV0(env, dtype=np.uint8) + + obs, info = wrapped_env.reset() + assert obs.dtype != info["obs"].dtype + assert obs.dtype == np.uint8 + + obs, _, _, _, info = wrapped_env.step(None) + assert obs.dtype != info["obs"].dtype + assert obs.dtype == np.uint8 diff --git a/tests/experimental/wrappers/test_filter_observation.py b/tests/experimental/wrappers/test_filter_observation.py new file mode 100644 index 000000000..5232cb840 --- /dev/null +++ b/tests/experimental/wrappers/test_filter_observation.py @@ -0,0 +1,45 @@ +"""Test suite for FilterObservationV0.""" +from gymnasium.experimental.wrappers import FilterObservationV0 +from gymnasium.spaces import Box, Dict, Tuple +from tests.experimental.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) +from tests.testing_env import GenericTestEnv + + +def test_filter_observation_wrapper(): + """Tests ``FilterObservation`` that the right keys are filtered.""" + dict_env = GenericTestEnv( + observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + + wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3")) + obs, info = wrapped_env.reset() + assert list(obs.keys()) == ["arm_1", "arm_3"] + assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] + check_obs(dict_env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + assert list(obs.keys()) == ["arm_1", "arm_3"] + assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] + check_obs(dict_env, wrapped_env, obs, info["obs"]) + + # Test tuple environments + tuple_env = GenericTestEnv( + observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + wrapped_env = FilterObservationV0(tuple_env, (2,)) + + obs, info = wrapped_env.reset() + assert len(obs) == 1 and len(info["obs"]) == 3 + check_obs(tuple_env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + assert len(obs) == 1 and len(info["obs"]) == 3 + check_obs(tuple_env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_flatten_observation.py b/tests/experimental/wrappers/test_flatten_observation.py new file mode 100644 index 000000000..7212261bc --- /dev/null +++ b/tests/experimental/wrappers/test_flatten_observation.py @@ -0,0 +1,27 @@ +"""Test suite for FlattenObservationV0.""" +from gymnasium.experimental.wrappers import FlattenObservationV0 +from gymnasium.spaces import Box, Dict +from tests.experimental.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) +from tests.testing_env import GenericTestEnv + + +def test_flatten_observation_wrapper(): + """Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly.""" + env = GenericTestEnv( + observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + print(env.observation_space) + wrapped_env = FlattenObservationV0(env) + print(wrapped_env.observation_space) + + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_frame_stack_observation.py b/tests/experimental/wrappers/test_frame_stack_observation.py new file mode 100644 index 000000000..ea42bc448 --- /dev/null +++ b/tests/experimental/wrappers/test_frame_stack_observation.py @@ -0,0 +1 @@ +"""Test suite for FrameStackObservationV0.""" diff --git a/tests/experimental/wrappers/test_grayscale_observation.py b/tests/experimental/wrappers/test_grayscale_observation.py new file mode 100644 index 000000000..74a7a6a17 --- /dev/null +++ b/tests/experimental/wrappers/test_grayscale_observation.py @@ -0,0 +1,38 @@ +"""Test suite for GrayscaleObservationV0.""" +import numpy as np + +from gymnasium.experimental.wrappers import GrayscaleObservationV0 +from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) +from tests.testing_env import GenericTestEnv + + +def test_grayscale_observation_wrapper(): + """Tests the ``GrayscaleObservation`` that the observation is grayscale.""" + env = GenericTestEnv( + observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + wrapped_env = GrayscaleObservationV0(env) + + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (25, 25) + + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) + + # Keep_dim + wrapped_env = GrayscaleObservationV0(env, keep_dim=True) + + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (25, 25, 1) + + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_numpy_to_jax.py b/tests/experimental/wrappers/test_jax_to_numpy.py similarity index 91% rename from tests/experimental/wrappers/test_numpy_to_jax.py rename to tests/experimental/wrappers/test_jax_to_numpy.py index d5abaa116..25c5ee62a 100644 --- a/tests/experimental/wrappers/test_numpy_to_jax.py +++ b/tests/experimental/wrappers/test_jax_to_numpy.py @@ -1,9 +1,11 @@ +"""Test suite for JaxToNumpyV0.""" + import jax.numpy as jnp import numpy as np import pytest from gymnasium.experimental.wrappers import JaxToNumpyV0 -from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy, numpy_to_jax +from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax from gymnasium.utils.env_checker import data_equivalence from tests.testing_env import GenericTestEnv @@ -40,10 +42,12 @@ def test_roundtripping(value, expected_value): def jax_reset_func(self, seed=None, options=None): + """A jax-based reset function.""" return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])} def jax_step_func(self, action): + """A jax-based step function.""" assert isinstance(action, jnp.DeviceArray), type(action) return ( jnp.array([1, 2, 3]), @@ -54,7 +58,8 @@ def jax_step_func(self, action): ) -def test_jax_to_numpy(): +def test_jax_to_numpy_wrapper(): + """Tests the ``JaxToNumpyV0`` wrapper.""" jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func) # Check that the reset and step for jax environment are as expected diff --git a/tests/experimental/wrappers/test_torch_to_jax.py b/tests/experimental/wrappers/test_jax_to_torch.py similarity index 95% rename from tests/experimental/wrappers/test_torch_to_jax.py rename to tests/experimental/wrappers/test_jax_to_torch.py index c2e524dc5..a2313ae90 100644 --- a/tests/experimental/wrappers/test_torch_to_jax.py +++ b/tests/experimental/wrappers/test_jax_to_torch.py @@ -1,10 +1,12 @@ +"""Test suite for TorchToJaxV0.""" + import jax.numpy as jnp import numpy as np import pytest import torch from gymnasium.experimental.wrappers import JaxToTorchV0 -from gymnasium.experimental.wrappers.torch_to_jax import jax_to_torch, torch_to_jax +from gymnasium.experimental.wrappers.jax_to_torch import jax_to_torch, torch_to_jax from tests.testing_env import GenericTestEnv @@ -76,7 +78,8 @@ def _jax_step_func(self, action): ) -def test_jax_to_torch(): +def test_jax_to_torch_wrapper(): + """Tests the `JaxToTorchV0` wrapper.""" env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func) # Check that the reset and step for jax environment are as expected diff --git a/tests/experimental/wrappers/test_lambda_action.py b/tests/experimental/wrappers/test_lambda_action.py index e70a63de8..81429ee3e 100644 --- a/tests/experimental/wrappers/test_lambda_action.py +++ b/tests/experimental/wrappers/test_lambda_action.py @@ -1,25 +1,14 @@ -"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction.""" -import numpy as np +"""Test suite for LambdaActionV0.""" -from gymnasium.experimental.wrappers import ( - ClipActionV0, - LambdaActionV0, - RescaleActionV0, -) +from gymnasium.experimental.wrappers import LambdaActionV0 from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import record_action_step from tests.testing_env import GenericTestEnv -SEED = 42 - - -def _record_action_step_func(self, action): - return 0, 0, False, False, {"action": action} - - def test_lambda_action_wrapper(): """Tests LambdaAction through checking that the action taken is transformed by function.""" - env = GenericTestEnv(step_func=_record_action_step_func) + env = GenericTestEnv(step_func=record_action_step) wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3)) sampled_action = wrapped_env.action_space.sample() @@ -28,51 +17,3 @@ def test_lambda_action_wrapper(): _, _, _, _, info = wrapped_env.step(sampled_action) assert info["action"] in env.action_space assert sampled_action - 2 == info["action"] - - -def test_clip_action_wrapper(): - """Test that the action is correctly clipped to the base environment action space.""" - env = GenericTestEnv( - action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])), - step_func=_record_action_step_func, - ) - wrapped_env = ClipActionV0(env) - - sampled_action = np.array([-1, 5, 3.5], dtype=np.float32) - assert sampled_action not in env.action_space - assert sampled_action in wrapped_env.action_space - - _, _, _, _, info = wrapped_env.step(sampled_action) - assert np.all(info["action"] in env.action_space) - assert np.all(info["action"] == np.array([0, 2, 3.5])) - - -def test_rescale_action_wrapper(): - """Test that the action is rescale within a min / max bound.""" - env = GenericTestEnv( - step_func=_record_action_step_func, - action_space=Box(np.array([0, 1]), np.array([1, 3])), - ) - wrapped_env = RescaleActionV0( - env, min_action=np.array([-5, 0]), max_action=np.array([5, 1]) - ) - assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1])) - - for sample_action, expected_action in ( - ( - np.array([0.0, 0.5], dtype=np.float32), - np.array([0.5, 2.0], dtype=np.float32), - ), - ( - np.array([-5.0, 0.0], dtype=np.float32), - np.array([0.0, 1.0], dtype=np.float32), - ), - ( - np.array([5.0, 1.0], dtype=np.float32), - np.array([1.0, 3.0], dtype=np.float32), - ), - ): - assert sample_action in wrapped_env.action_space - - _, _, _, _, info = wrapped_env.step(sample_action) - assert np.all(info["action"] == expected_action) diff --git a/tests/experimental/wrappers/test_lambda_observation.py b/tests/experimental/wrappers/test_lambda_observation.py index 430748935..af26e7b89 100644 --- a/tests/experimental/wrappers/test_lambda_observation.py +++ b/tests/experimental/wrappers/test_lambda_observation.py @@ -1,250 +1,26 @@ -"""Test suite for lambda observation wrappers: """ +"""Test suite for lambda observation wrappers.""" import numpy as np -import gymnasium as gym -from gymnasium.experimental.wrappers import ( - DtypeObservationV0, - FilterObservationV0, - FlattenObservationV0, - GrayscaleObservationV0, - LambdaObservationV0, - RescaleObservationV0, - ReshapeObservationV0, - ResizeObservationV0, +from gymnasium.experimental.wrappers import LambdaObservationV0 +from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import ( + check_obs, + record_action_as_obs_step, + record_obs_reset, ) -from gymnasium.spaces import Box, Dict, Tuple from tests.testing_env import GenericTestEnv -SEED = 42 - - -def _record_random_obs_reset(self: gym.Env, seed=None, options=None): - obs = self.observation_space.sample() - return obs, {"obs": obs} - - -def _record_random_obs_step(self: gym.Env, action): - obs = self.observation_space.sample() - return obs, 0, False, False, {"obs": obs} - - -def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}): - return options["obs"], {"obs": options["obs"]} - - -def _record_action_obs_step(self: gym.Env, action): - return action, 0, False, False, {"obs": action} - - -def _check_obs( - env: gym.Env, - wrapped_env: gym.Wrapper, - transformed_obs, - original_obs, - strict: bool = True, -): - assert ( - transformed_obs in wrapped_env.observation_space - ), f"{transformed_obs}, {wrapped_env.observation_space}" - assert ( - original_obs in env.observation_space - ), f"{original_obs}, {env.observation_space}" - - if strict: - assert ( - transformed_obs not in env.observation_space - ), f"{transformed_obs}, {env.observation_space}" - assert ( - original_obs not in wrapped_env.observation_space - ), f"{original_obs}, {wrapped_env.observation_space}" - - def test_lambda_observation_wrapper(): """Tests lambda observation that the function is applied to both the reset and step observation.""" env = GenericTestEnv( - reset_func=_record_action_obs_reset, step_func=_record_action_obs_step + reset_func=record_obs_reset, step_func=record_action_as_obs_step ) - wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3)) + wrapped_env = LambdaObservationV0(env, lambda _obs: _obs + 2, Box(2, 3)) obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)}) - _check_obs(env, wrapped_env, obs, info["obs"]) + check_obs(env, wrapped_env, obs, info["obs"]) obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32)) - _check_obs(env, wrapped_env, obs, info["obs"]) - - -def test_filter_observation_wrapper(): - """Tests ``FilterObservation`` that the right keys are filtered.""" - dict_env = GenericTestEnv( - observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)), - reset_func=_record_random_obs_reset, - step_func=_record_random_obs_step, - ) - - wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3")) - obs, info = wrapped_env.reset() - assert list(obs.keys()) == ["arm_1", "arm_3"] - assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] - _check_obs(dict_env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - assert list(obs.keys()) == ["arm_1", "arm_3"] - assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] - _check_obs(dict_env, wrapped_env, obs, info["obs"]) - - # Test tuple environments - tuple_env = GenericTestEnv( - observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))), - reset_func=_record_random_obs_reset, - step_func=_record_random_obs_step, - ) - wrapped_env = FilterObservationV0(tuple_env, (2,)) - - obs, info = wrapped_env.reset() - assert len(obs) == 1 and len(info["obs"]) == 3 - _check_obs(tuple_env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - assert len(obs) == 1 and len(info["obs"]) == 3 - _check_obs(tuple_env, wrapped_env, obs, info["obs"]) - - -def test_flatten_observation_wrapper(): - """Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly.""" - env = GenericTestEnv( - observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)), - reset_func=_record_random_obs_reset, - step_func=_record_random_obs_step, - ) - print(env.observation_space) - wrapped_env = FlattenObservationV0(env) - print(wrapped_env.observation_space) - - obs, info = wrapped_env.reset() - _check_obs(env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - _check_obs(env, wrapped_env, obs, info["obs"]) - - -def test_grayscale_observation_wrapper(): - """Tests the ``GrayscaleObservation`` that the observation is grayscale.""" - env = GenericTestEnv( - observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8), - reset_func=_record_random_obs_reset, - step_func=_record_random_obs_step, - ) - wrapped_env = GrayscaleObservationV0(env) - - obs, info = wrapped_env.reset() - _check_obs(env, wrapped_env, obs, info["obs"]) - assert obs.shape == (25, 25) - - obs, _, _, _, info = wrapped_env.step(None) - _check_obs(env, wrapped_env, obs, info["obs"]) - - # Keep_dim - wrapped_env = GrayscaleObservationV0(env, keep_dim=True) - - obs, info = wrapped_env.reset() - _check_obs(env, wrapped_env, obs, info["obs"]) - assert obs.shape == (25, 25, 1) - - obs, _, _, _, info = wrapped_env.step(None) - _check_obs(env, wrapped_env, obs, info["obs"]) - - -def test_resize_observation_wrapper(): - """Test the ``ResizeObservation`` that the observation has changed size""" - env = GenericTestEnv( - observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8), - reset_func=_record_random_obs_reset, - step_func=_record_random_obs_step, - ) - wrapped_env = ResizeObservationV0(env, (25, 25)) - - obs, info = wrapped_env.reset() - _check_obs(env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - _check_obs(env, wrapped_env, obs, info["obs"]) - - -def test_reshape_observation_wrapper(): - """Test the ``ReshapeObservation`` wrapper.""" - env = GenericTestEnv( - observation_space=Box(0, 1, shape=(2, 3, 2)), - reset_func=_record_random_obs_reset, - step_func=_record_random_obs_step, - ) - wrapped_env = ReshapeObservationV0(env, (6, 2)) - - obs, info = wrapped_env.reset() - _check_obs(env, wrapped_env, obs, info["obs"]) - assert obs.shape == (6, 2) - - obs, _, _, _, info = wrapped_env.step(None) - _check_obs(env, wrapped_env, obs, info["obs"]) - assert obs.shape == (6, 2) - - -def test_rescale_observation(): - """Test the ``RescaleObservation`` wrapper""" - env = GenericTestEnv( - observation_space=Box( - np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32) - ), - reset_func=_record_action_obs_reset, - step_func=_record_action_obs_step, - ) - wrapped_env = RescaleObservationV0( - env, - min_obs=np.array([-5, 0], dtype=np.float32), - max_obs=np.array([5, 1], dtype=np.float32), - ) - assert wrapped_env.observation_space == Box( - np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32) - ) - - for sample_obs, expected_obs in ( - ( - np.array([0.5, 2.0], dtype=np.float32), - np.array([0.0, 0.5], dtype=np.float32), - ), - ( - np.array([0.0, 1.0], dtype=np.float32), - np.array([-5.0, 0.0], dtype=np.float32), - ), - ( - np.array([1.0, 3.0], dtype=np.float32), - np.array([5.0, 1.0], dtype=np.float32), - ), - ): - assert sample_obs in env.observation_space - assert expected_obs in wrapped_env.observation_space - - obs, info = wrapped_env.reset(options={"obs": sample_obs}) - assert np.all(obs == expected_obs) - _check_obs(env, wrapped_env, obs, info["obs"], strict=False) - - obs, _, _, _, info = wrapped_env.step(sample_obs) - assert np.all(obs == expected_obs) - _check_obs(env, wrapped_env, obs, info["obs"], strict=False) - - -def test_dtype_observation(): - """Test ``DtypeObservation`` that the""" - env = GenericTestEnv( - reset_func=_record_random_obs_reset, step_func=_record_random_obs_step - ) - wrapped_env = DtypeObservationV0(env, dtype=np.uint8) - - obs, info = wrapped_env.reset() - assert obs.dtype != info["obs"].dtype - assert obs.dtype == np.uint8 - - obs, _, _, _, info = wrapped_env.step(None) - assert obs.dtype != info["obs"].dtype - assert obs.dtype == np.uint8 + check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_lambda_rewards.py b/tests/experimental/wrappers/test_lambda_rewards.py index bc29da6c6..697c30c76 100644 --- a/tests/experimental/wrappers/test_lambda_rewards.py +++ b/tests/experimental/wrappers/test_lambda_rewards.py @@ -4,14 +4,8 @@ import numpy as np import pytest import gymnasium as gym -from gymnasium.error import InvalidBound -from gymnasium.experimental.wrappers import ClipRewardV0, LambdaRewardV0 - - -ENV_ID = "CartPole-v1" -DISCRETE_ACTION = 0 -NUM_ENVS = 3 -SEED = 0 +from gymnasium.experimental.wrappers import LambdaRewardV0 +from tests.experimental.wrappers.utils import DISCRETE_ACTION, ENV_ID, NUM_ENVS, SEED @pytest.mark.parametrize( @@ -54,57 +48,3 @@ def test_lambda_reward_within_vector(reward_fn, expected_reward): _, rew, _, _, _ = env.step(actions) assert np.alltrue(rew == expected_reward) - - -@pytest.mark.parametrize( - ("lower_bound", "upper_bound", "expected_reward"), - [(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)], -) -def test_clip_reward(lower_bound, upper_bound, expected_reward): - """Test reward clipping. - Test if reward is correctly clipped - accordingly to the input args. - """ - env = gym.make(ENV_ID) - env = ClipRewardV0(env, lower_bound, upper_bound) - env.reset(seed=SEED) - _, rew, _, _, _ = env.step(DISCRETE_ACTION) - - assert rew == expected_reward - - -@pytest.mark.parametrize( - ("lower_bound", "upper_bound", "expected_reward"), - [(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)], -) -def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward): - """Test reward clipping in vectorized environment. - Test if reward is correctly clipped - accordingly to the input args in a vectorized environment. - """ - actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)] - - env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS) - env = ClipRewardV0(env, lower_bound, upper_bound) - env.reset(seed=SEED) - - _, rew, _, _, _ = env.step(actions) - - assert np.alltrue(rew == expected_reward) - - -@pytest.mark.parametrize( - ("lower_bound", "upper_bound"), - [(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))], -) -def test_clip_reward_incorrect_params(lower_bound, upper_bound): - """Test reward clipping with incorrect params. - Test whether passing wrong params to clip_rewards - correctly raise an exception. - clip_rewards should raise an exception if, both low and upper - bound of reward are `None` or if upper bound is lower than lower bound. - """ - env = gym.make(ENV_ID) - - with pytest.raises(InvalidBound): - env = ClipRewardV0(env, lower_bound, upper_bound) diff --git a/tests/experimental/wrappers/test_normalize_observation.py b/tests/experimental/wrappers/test_normalize_observation.py new file mode 100644 index 000000000..6e4b8c7c4 --- /dev/null +++ b/tests/experimental/wrappers/test_normalize_observation.py @@ -0,0 +1 @@ +"""Test suite for NormalizeObservationV0.""" diff --git a/tests/experimental/wrappers/test_normalize_reward.py b/tests/experimental/wrappers/test_normalize_reward.py new file mode 100644 index 000000000..845cd9059 --- /dev/null +++ b/tests/experimental/wrappers/test_normalize_reward.py @@ -0,0 +1 @@ +"""Test suite for NormalizeRewardV0.""" diff --git a/tests/experimental/wrappers/test_numpy_to_torch.py b/tests/experimental/wrappers/test_numpy_to_torch.py new file mode 100644 index 000000000..6cd9680ec --- /dev/null +++ b/tests/experimental/wrappers/test_numpy_to_torch.py @@ -0,0 +1 @@ +"""Test suite for NumpyToTorchV0.""" diff --git a/tests/experimental/wrappers/test_order_enforcing.py b/tests/experimental/wrappers/test_order_enforcing.py new file mode 100644 index 000000000..d513dc779 --- /dev/null +++ b/tests/experimental/wrappers/test_order_enforcing.py @@ -0,0 +1 @@ +"""Test suite for OrderEnforcingV0.""" diff --git a/tests/experimental/wrappers/test_passive_env_checker.py b/tests/experimental/wrappers/test_passive_env_checker.py new file mode 100644 index 000000000..cd10b83c9 --- /dev/null +++ b/tests/experimental/wrappers/test_passive_env_checker.py @@ -0,0 +1 @@ +"""Test suite for PassiveEnvCheckerV0.""" diff --git a/tests/experimental/wrappers/test_pixel_observation.py b/tests/experimental/wrappers/test_pixel_observation.py new file mode 100644 index 000000000..4df32ed93 --- /dev/null +++ b/tests/experimental/wrappers/test_pixel_observation.py @@ -0,0 +1 @@ +"""Test suite for PixelObservationV0.""" diff --git a/tests/experimental/wrappers/test_record_episode_statistics.py b/tests/experimental/wrappers/test_record_episode_statistics.py new file mode 100644 index 000000000..1f0ede6ae --- /dev/null +++ b/tests/experimental/wrappers/test_record_episode_statistics.py @@ -0,0 +1 @@ +"""Test suite for RecordEpisodeStatisticsV0.""" diff --git a/tests/experimental/wrappers/test_record_video.py b/tests/experimental/wrappers/test_record_video.py new file mode 100644 index 000000000..d79f672b6 --- /dev/null +++ b/tests/experimental/wrappers/test_record_video.py @@ -0,0 +1 @@ +"""Test suite for RecordVideoV0.""" diff --git a/tests/experimental/wrappers/test_render_collection.py b/tests/experimental/wrappers/test_render_collection.py new file mode 100644 index 000000000..9a2687519 --- /dev/null +++ b/tests/experimental/wrappers/test_render_collection.py @@ -0,0 +1 @@ +"""Test suite for RenderCollectionV0.""" diff --git a/tests/experimental/wrappers/test_rescale_action.py b/tests/experimental/wrappers/test_rescale_action.py new file mode 100644 index 000000000..efd40559d --- /dev/null +++ b/tests/experimental/wrappers/test_rescale_action.py @@ -0,0 +1,38 @@ +"""Test suite for RescaleActionV0.""" +import numpy as np + +from gymnasium.experimental.wrappers import RescaleActionV0 +from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import record_action_step +from tests.testing_env import GenericTestEnv + + +def test_rescale_action_wrapper(): + """Test that the action is rescale within a min / max bound.""" + env = GenericTestEnv( + step_func=record_action_step, + action_space=Box(np.array([0, 1]), np.array([1, 3])), + ) + wrapped_env = RescaleActionV0( + env, min_action=np.array([-5, 0]), max_action=np.array([5, 1]) + ) + assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1])) + + for sample_action, expected_action in ( + ( + np.array([0.0, 0.5], dtype=np.float32), + np.array([0.5, 2.0], dtype=np.float32), + ), + ( + np.array([-5.0, 0.0], dtype=np.float32), + np.array([0.0, 1.0], dtype=np.float32), + ), + ( + np.array([5.0, 1.0], dtype=np.float32), + np.array([1.0, 3.0], dtype=np.float32), + ), + ): + assert sample_action in wrapped_env.action_space + + _, _, _, _, info = wrapped_env.step(sample_action) + assert np.all(info["action"] == expected_action) diff --git a/tests/experimental/wrappers/test_rescale_observation.py b/tests/experimental/wrappers/test_rescale_observation.py new file mode 100644 index 000000000..fffb4a9b8 --- /dev/null +++ b/tests/experimental/wrappers/test_rescale_observation.py @@ -0,0 +1,55 @@ +"""Test suite for RescaleObservationV0.""" +import numpy as np + +from gymnasium.experimental.wrappers import RescaleObservationV0 +from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import ( + check_obs, + record_action_as_obs_step, + record_obs_reset, +) +from tests.testing_env import GenericTestEnv + + +def test_rescale_observation(): + """Test the ``RescaleObservation`` wrapper.""" + env = GenericTestEnv( + observation_space=Box( + np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32) + ), + reset_func=record_obs_reset, + step_func=record_action_as_obs_step, + ) + wrapped_env = RescaleObservationV0( + env, + min_obs=np.array([-5, 0], dtype=np.float32), + max_obs=np.array([5, 1], dtype=np.float32), + ) + assert wrapped_env.observation_space == Box( + np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32) + ) + + for sample_obs, expected_obs in ( + ( + np.array([0.5, 2.0], dtype=np.float32), + np.array([0.0, 0.5], dtype=np.float32), + ), + ( + np.array([0.0, 1.0], dtype=np.float32), + np.array([-5.0, 0.0], dtype=np.float32), + ), + ( + np.array([1.0, 3.0], dtype=np.float32), + np.array([5.0, 1.0], dtype=np.float32), + ), + ): + assert sample_obs in env.observation_space + assert expected_obs in wrapped_env.observation_space + + obs, info = wrapped_env.reset(options={"obs": sample_obs}) + assert np.all(obs == expected_obs) + check_obs(env, wrapped_env, obs, info["obs"], strict=False) + + obs, _, _, _, info = wrapped_env.step(sample_obs) + assert np.all(obs == expected_obs) + check_obs(env, wrapped_env, obs, info["obs"], strict=False) diff --git a/tests/experimental/wrappers/test_reshape_observation.py b/tests/experimental/wrappers/test_reshape_observation.py new file mode 100644 index 000000000..f42f759bd --- /dev/null +++ b/tests/experimental/wrappers/test_reshape_observation.py @@ -0,0 +1,27 @@ +"""Test suite for ReshapeObservationv0.""" +from gymnasium.experimental.wrappers import ReshapeObservationV0 +from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) +from tests.testing_env import GenericTestEnv + + +def test_reshape_observation_wrapper(): + """Test the ``ReshapeObservation`` wrapper.""" + env = GenericTestEnv( + observation_space=Box(0, 1, shape=(2, 3, 2)), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + wrapped_env = ReshapeObservationV0(env, (6, 2)) + + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (6, 2) + + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (6, 2) diff --git a/tests/experimental/wrappers/test_resize_observation.py b/tests/experimental/wrappers/test_resize_observation.py new file mode 100644 index 000000000..ecf9c8f7a --- /dev/null +++ b/tests/experimental/wrappers/test_resize_observation.py @@ -0,0 +1,27 @@ +"""Test suite for ResizeObservationV0.""" +import numpy as np + +from gymnasium.experimental.wrappers import ResizeObservationV0 +from gymnasium.spaces import Box +from tests.experimental.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) +from tests.testing_env import GenericTestEnv + + +def test_resize_observation_wrapper(): + """Test the ``ResizeObservation`` that the observation has changed size.""" + env = GenericTestEnv( + observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + wrapped_env = ResizeObservationV0(env, (25, 25)) + + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_stateful_action.py b/tests/experimental/wrappers/test_sticky_action.py similarity index 54% rename from tests/experimental/wrappers/test_stateful_action.py rename to tests/experimental/wrappers/test_sticky_action.py index 3bc1254e8..efa51a0eb 100644 --- a/tests/experimental/wrappers/test_stateful_action.py +++ b/tests/experimental/wrappers/test_sticky_action.py @@ -1,43 +1,34 @@ """Test suite for StickyActionV0.""" +import numpy as np import pytest from gymnasium.error import InvalidProbability from gymnasium.experimental.wrappers import StickyActionV0 +from tests.experimental.wrappers.utils import NUM_STEPS, record_action_as_obs_step from tests.testing_env import GenericTestEnv -SEED = 42 - -DELAY = 3 -NUM_STEPS = 10 - - -def step_fn(self, action): - return action - - def test_sticky_action(): + """Tests the sticky action wrapper.""" env = StickyActionV0( - GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5 + GenericTestEnv(step_func=record_action_as_obs_step), + repeat_action_probability=0.5, ) - env.reset(seed=SEED) - env.action_space.seed(SEED) previous_action = None for _ in range(NUM_STEPS): input_action = env.action_space.sample() - executed_action = env.step(input_action) + executed_action, _, _, _, _ = env.step(input_action) - if executed_action != input_action: - assert executed_action == previous_action - else: - assert executed_action == input_action - - previous_action = input_action + assert np.all(executed_action == input_action) or np.all( + executed_action == previous_action + ) + previous_action = executed_action @pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5]) def test_sticky_action_raise(repeat_action_probability): + """Tests the stick action wrapper with probabilities that should raise an error.""" with pytest.raises(InvalidProbability): StickyActionV0( GenericTestEnv(), repeat_action_probability=repeat_action_probability diff --git a/tests/experimental/wrappers/test_stateful_observation.py b/tests/experimental/wrappers/test_time_aware_observation.py similarity index 67% rename from tests/experimental/wrappers/test_stateful_observation.py rename to tests/experimental/wrappers/test_time_aware_observation.py index fedbf19c6..2584d5793 100644 --- a/tests/experimental/wrappers/test_stateful_observation.py +++ b/tests/experimental/wrappers/test_time_aware_observation.py @@ -1,19 +1,10 @@ -"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation.""" +"""Test suite for TimeAwareObservationV0.""" -import numpy as np - -import gymnasium as gym -from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0 +from gymnasium.experimental.wrappers import TimeAwareObservationV0 from gymnasium.spaces import Box, Dict, Tuple from tests.testing_env import GenericTestEnv -NUM_STEPS = 20 -SEED = 0 - -DELAY = 3 - - def test_time_aware_observation_wrapper(): """Tests the time aware observation wrapper.""" # Test the environment observation space with Dict, Tuple and other @@ -60,30 +51,3 @@ def test_time_aware_observation_wrapper(): reset_obs, _ = wrapped_env.reset() step_obs, _, _, _, _ = wrapped_env.step(None) assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01 - - -def test_delay_observation_wrapper(): - env = gym.make("CartPole-v1") - env.action_space.seed(SEED) - env.reset(seed=SEED) - - undelayed_observations = [] - for _ in range(NUM_STEPS): - obs, _, _, _, _ = env.step(env.action_space.sample()) - undelayed_observations.append(obs) - - env = DelayObservationV0(env, delay=DELAY) - env.action_space.seed(SEED) - env.reset(seed=SEED) - - delayed_observations = [] - for i in range(NUM_STEPS): - obs, _, _, _, _ = env.step(env.action_space.sample()) - delayed_observations.append(obs) - if i < DELAY - 1: - assert np.all(obs == 0) - - undelayed_observations = np.array(undelayed_observations) - delayed_observations = np.array(delayed_observations) - - assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY]) diff --git a/tests/experimental/wrappers/utils.py b/tests/experimental/wrappers/utils.py new file mode 100644 index 000000000..7bbcfa42e --- /dev/null +++ b/tests/experimental/wrappers/utils.py @@ -0,0 +1,69 @@ +"""Utility functions for testing the experimental wrappers.""" +import gymnasium as gym + + +SEED = 42 +ENV_ID = "CartPole-v1" +DISCRETE_ACTION = 0 +NUM_ENVS = 3 +NUM_STEPS = 20 +DELAY = 3 + + +def record_obs_reset(self: gym.Env, seed=None, options: dict = None): + """Records and uses an observation passed through options.""" + return options["obs"], {"obs": options["obs"]} + + +def record_random_obs_reset(self: gym.Env, seed=None, options=None): + """Records random observation generated by the environment.""" + obs = self.observation_space.sample() + return obs, {"obs": obs} + + +def record_action_step(self: gym.Env, action): + """Records the actions passed to the environment.""" + return 0, 0, False, False, {"action": action} + + +def record_random_obs_step(self: gym.Env, action): + """Records the observation generated by the environment.""" + obs = self.observation_space.sample() + return obs, 0, False, False, {"obs": obs} + + +def record_action_as_obs_step(self: gym.Env, action): + """Uses the action as the observation.""" + return action, 0, False, False, {"obs": action} + + +def check_obs( + env: gym.Env, + wrapped_env: gym.Wrapper, + transformed_obs, + original_obs, + strict: bool = True, +): + """Checks that the original and transformed observations using the environment and wrapped environment. + + Args: + env: The base environment + wrapped_env: The wrapped environment + transformed_obs: The transformed observation by the wrapped environment + original_obs: The original observation by the base environment. + strict: If to check that the observations aren't contained in the other environment. + """ + assert ( + transformed_obs in wrapped_env.observation_space + ), f"{transformed_obs}, {wrapped_env.observation_space}" + assert ( + original_obs in env.observation_space + ), f"{original_obs}, {env.observation_space}" + + if strict: + assert ( + transformed_obs not in env.observation_space + ), f"{transformed_obs}, {env.observation_space}" + assert ( + original_obs not in wrapped_env.observation_space + ), f"{original_obs}, {wrapped_env.observation_space}" diff --git a/tests/test_core.py b/tests/test_core.py index 432fde48f..da2aa0e4c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -23,36 +23,44 @@ from tests.testing_env import GenericTestEnv class ArgumentEnv(Env): + """Testing environment that records the number of times the environment is created.""" + observation_space = spaces.Box(low=0, high=1, shape=(1,)) action_space = spaces.Box(low=0, high=1, shape=(1,)) calls = 0 - def __init__(self, arg): + def __init__(self, arg: Any): + """Constructor.""" self.calls += 1 self.arg = arg class UnittestEnv(Env): + """Example testing environment.""" + observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8) action_space = spaces.Discrete(3) def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + """Resets the environment.""" super().reset(seed=seed) return self.observation_space.sample(), {"info": "dummy"} def step(self, action): + """Steps through the environment.""" observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, {}) + return observation, 0.0, False, {} class UnknownSpacesEnv(Env): - """This environment defines its observation & action spaces only - after the first call to reset. Although this pattern is sometimes - necessary when implementing a new environment (e.g. if it depends + """This environment defines its observation & action spaces only after the first call to reset. + + Although this pattern is sometimes necessary when implementing a new environment (e.g. if it depends on external resources), it is not encouraged. """ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + """Resets the environment.""" super().reset(seed=seed) self.observation_space = spaces.Box( low=0, high=255, shape=(64, 64, 3), dtype=np.uint8 @@ -61,25 +69,27 @@ class UnknownSpacesEnv(Env): return self.observation_space.sample(), {} # Dummy observation with info def step(self, action): + """Steps through the environment.""" observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, {}) + return observation, 0.0, False, {} class OldStyleEnv(Env): - """This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)""" - - def __init__(self): - pass + """This environment doesn't accept any arguments in reset, ideally we want to support this too (for now).""" def reset(self): + """Resets the environment.""" super().reset() return 0 def step(self, action): + """Steps through the environment.""" return 0, 0, False, {} class NewPropertyWrapper(Wrapper): + """Wrapper that tests setting a property.""" + def __init__( self, env, @@ -88,6 +98,15 @@ class NewPropertyWrapper(Wrapper): reward_range=None, metadata=None, ): + """New property wrapper. + + Args: + env: The environment to wrap + observation_space: The observation space + action_space: The action space + reward_range: The reward range + metadata: The environment metadata + """ super().__init__(env) if observation_space is not None: # Only set the observation space if not None to test property forwarding @@ -101,6 +120,7 @@ class NewPropertyWrapper(Wrapper): def test_env_instantiation(): + """Tests the environment instantiation using ArgumentEnv.""" # This looks like a pretty trivial, but given our usage of # __new__, it's worth having. env = ArgumentEnv("arg") @@ -129,6 +149,7 @@ properties = [ @pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv]) @pytest.mark.parametrize("props", properties) def test_wrapper_property_forwarding(class_, props): + """Tests wrapper property forwarding.""" env = class_() env = NewPropertyWrapper(env, **props) @@ -147,6 +168,7 @@ def test_wrapper_property_forwarding(class_, props): def test_compatibility_with_old_style_env(): + """Test compatibility with old style environment.""" env = OldStyleEnv() env = OrderEnforcing(env) env = TimeLimit(env) @@ -158,13 +180,17 @@ def test_compatibility_with_old_style_env(): class ExampleEnv(Env): + """Example testing environment.""" + def __init__(self): + """Constructor for example environment.""" self.observation_space = Box(0, 1) self.action_space = Box(0, 1) def step( self, action: ActType ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: + """Steps through the environment.""" return 0, 0, False, False, {} def reset( @@ -173,10 +199,12 @@ class ExampleEnv(Env): seed: Optional[int] = None, options: Optional[dict] = None, ) -> Tuple[ObsType, dict]: + """Resets the environment.""" return 0, {} def test_gymnasium_env(): + """Tests a gymnasium environment.""" env = ExampleEnv() assert env.metadata == {"render_modes": []} @@ -187,7 +215,10 @@ def test_gymnasium_env(): class ExampleWrapper(Wrapper): + """An example testing wrapper.""" + def __init__(self, env: Env[ObsType, ActType]): + """Constructor that sets the reward.""" super().__init__(env) self.new_reward = 3 @@ -195,11 +226,13 @@ class ExampleWrapper(Wrapper): def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[WrapperObsType, Dict[str, Any]]: + """Resets the environment .""" return super().reset(seed=seed, options=options) def step( self, action: WrapperActType ) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: + """Steps through the environment.""" obs, reward, termination, truncation, info = self.env.step(action) return obs, self.new_reward, termination, truncation, info @@ -209,6 +242,7 @@ class ExampleWrapper(Wrapper): def test_gymnasium_wrapper(): + """Tests the gymnasium wrapper works as expected.""" env = ExampleEnv() wrapper_env = ExampleWrapper(env) @@ -250,21 +284,31 @@ def test_gymnasium_wrapper(): class ExampleRewardWrapper(RewardWrapper): + """Example reward wrapper for testing.""" + def reward(self, reward: SupportsFloat) -> SupportsFloat: + """Reward function.""" return 1 class ExampleObservationWrapper(ObservationWrapper): + """Example observation wrapper for testing.""" + def observation(self, observation: ObsType) -> ObsType: + """Observation function.""" return np.array([1]) class ExampleActionWrapper(ActionWrapper): + """Example action wrapper for testing.""" + def action(self, action: ActType) -> ActType: + """Action function.""" return np.array([1]) def test_wrapper_types(): + """Tests the observation, action and reward wrapper examples.""" env = GenericTestEnv() reward_env = ExampleRewardWrapper(env) diff --git a/tests/testing_env.py b/tests/testing_env.py index dfebba114..a066f4306 100644 --- a/tests/testing_env.py +++ b/tests/testing_env.py @@ -1,6 +1,8 @@ """Provides a generic testing environment for use in tests with custom reset, step and render functions.""" +from __future__ import annotations + import types -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import gymnasium as gym from gymnasium import spaces @@ -11,21 +13,21 @@ from gymnasium.envs.registration import EnvSpec def basic_reset_func( self, *, - seed: Optional[int] = None, - options: Optional[dict] = None, -) -> Union[ObsType, Tuple[ObsType, dict]]: + seed: int | None = None, + options: dict | None = None, +) -> ObsType | tuple[ObsType, dict]: """A basic reset function that will pass the environment check using random actions from the observation space.""" super(GenericTestEnv, self).reset(seed=seed) self.observation_space.seed(seed) return self.observation_space.sample(), {"options": options} -def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: +def new_step_func(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]: """A step function that follows the new step api that will pass the environment check using random actions from the observation space.""" return self.observation_space.sample(), 0, False, False, {} -def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: +def old_step_func(self, action: ActType) -> tuple[ObsType, float, bool, dict]: """A step function that follows the old step api that will pass the environment check using random actions from the observation space.""" return self.observation_space.sample(), 0, False, {} @@ -46,12 +48,24 @@ class GenericTestEnv(gym.Env): reset_func: callable = basic_reset_func, step_func: callable = new_step_func, render_func: callable = basic_render_func, - metadata: Dict[str, Any] = {"render_modes": []}, - render_mode: Optional[str] = None, + metadata: dict[str, Any] = {"render_modes": []}, + render_mode: str | None = None, spec: EnvSpec = EnvSpec( "TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100 ), ): + """Generic testing environment constructor. + + Args: + action_space: The environment action space + observation_space: The environment observation space + reset_func: The environment reset function + step_func: The environment step function + render_func: The environment render function + metadata: The environment metadata + render_mode: The render mode of the environment + spec: The environment spec + """ self.metadata = metadata self.render_mode = render_mode self.spec = spec @@ -71,14 +85,17 @@ class GenericTestEnv(gym.Env): def reset( self, *, - seed: Optional[int] = None, - options: Optional[dict] = None, - ) -> Union[ObsType, Tuple[ObsType, dict]]: + seed: int | None = None, + options: dict | None = None, + ) -> ObsType | tuple[ObsType, dict]: + """Resets the environment.""" # If you need a default working reset function, use `basic_reset_fn` above raise NotImplementedError("TestingEnv reset_fn is not set.") - def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + def step(self, action: ActType) -> tuple[ObsType, float, bool, dict[str, Any]]: + """Steps through the environment.""" raise NotImplementedError("TestingEnv step_fn is not set.") def render(self): + """Renders the environment.""" raise NotImplementedError("testingEnv render_fn is not set.")