From 848b7097bfbb7088bcb39e725363ae76f7b07890 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 5 Dec 2022 19:14:56 +0000 Subject: [PATCH] Update experimental wrappers (#176) --- .pre-commit-config.yaml | 4 +- docs/api/experimental.md | 51 ++-- docs/api/experimental/functional.md | 7 +- docs/api/experimental/wrappers.md | 33 +-- gymnasium/__init__.py | 2 +- gymnasium/envs/phys2d/__init__.py | 4 +- gymnasium/experimental/__init__.py | 5 +- gymnasium/experimental/wrappers/__init__.py | 54 +++- .../wrappers/delay_observation.py | 35 --- .../experimental/wrappers/lambda_action.py | 95 ++++-- .../wrappers/lambda_observations.py | 182 ++++++++---- .../experimental/wrappers/lambda_reward.py | 19 +- .../experimental/wrappers/numpy_to_jax.py | 104 ++++--- .../experimental/wrappers/stateful_action.py | 56 ++++ .../wrappers/stateful_observation.py | 200 +++++++++++++ .../experimental/wrappers/sticky_action.py | 40 --- .../wrappers/time_aware_observation.py | 113 -------- .../experimental/wrappers/torch_to_jax.py | 151 ++++++---- gymnasium/spaces/sequence.py | 1 + pyproject.toml | 2 +- tests/envs/test_make.py | 4 +- .../experimental/wrappers/test_clip_action.py | 48 --- .../wrappers/test_delay_observation.py | 37 --- .../wrappers/test_lambda_action.py | 102 ++++--- .../wrappers/test_lambda_observation.py | 273 +++++++++++++++--- .../wrappers/test_numpy_to_jax.py | 2 +- .../wrappers/test_rescale_action.py | 52 ---- ...icky_action.py => test_stateful_action.py} | 6 +- .../wrappers/test_stateful_observation.py | 89 ++++++ .../wrappers/test_time_aware_observation.py | 99 ------- .../wrappers/test_torch_to_jax.py | 10 +- tests/test_core.py | 2 +- tests/testing_env.py | 30 +- tests/utils/test_env_checker.py | 10 +- tests/utils/test_passive_env_checker.py | 10 +- tests/utils/test_play.py | 2 +- tests/vector/test_vector_env.py | 4 +- tests/wrappers/test_atari_preprocessing.py | 4 +- tests/wrappers/test_passive_env_checker.py | 4 +- 39 files changed, 1140 insertions(+), 806 deletions(-) delete mode 100644 gymnasium/experimental/wrappers/delay_observation.py create mode 100644 gymnasium/experimental/wrappers/stateful_action.py create mode 100644 gymnasium/experimental/wrappers/stateful_observation.py delete mode 100644 gymnasium/experimental/wrappers/sticky_action.py delete mode 100644 gymnasium/experimental/wrappers/time_aware_observation.py delete mode 100644 tests/experimental/wrappers/test_clip_action.py delete mode 100644 tests/experimental/wrappers/test_delay_observation.py delete mode 100644 tests/experimental/wrappers/test_rescale_action.py rename tests/experimental/wrappers/{test_sticky_action.py => test_stateful_action.py} (84%) create mode 100644 tests/experimental/wrappers/test_stateful_observation.py delete mode 100644 tests/experimental/wrappers/test_time_aware_observation.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 59b2529f0..b9b440c4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,9 +22,9 @@ repos: hooks: - id: codespell args: - - --ignore-words-list=nd,reacher,thist,ths, ure, referenc,wile + - --ignore-words-list=nd,reacher,thist,ths,ure,referenc,wile - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 5.0.4 hooks: - id: flake8 args: diff --git a/docs/api/experimental.md b/docs/api/experimental.md index 9ff8c199d..59e0e6255 100644 --- a/docs/api/experimental.md +++ b/docs/api/experimental.md @@ -27,8 +27,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t * In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided. -### Lambda Observation Wrappers - +### Observation Wrappers ```{eval-rst} .. py:currentmodule:: gymnasium @@ -44,61 +43,60 @@ Gymnasium already contains a large collection of wrappers, but we believe that t - VectorLambdaObservation - No * - :class:`wrappers.FilterObservation` - - :class:`experimental.wrappers.FilterObservation` + - :class:`experimental.wrappers.FilterObservationV0` - VectorFilterObservation (*) - Yes * - :class:`wrappers.FlattenObservation` - - `:class:`experimental.wrappers.FlattenObservation` + - :class:`experimental.wrappers.FlattenObservationV0` - VectorFlattenObservation (*) - No * - :class:`wrappers.GrayScaleObservation` - - `:class:`experimental.wrappers.GrayscaleObservation` + - :class:`experimental.wrappers.GrayscaleObservationV0` - VectorGrayscaleObservation (*) - Yes * - :class:`wrappers.ResizeObservation` - - :class:`experimental.wrappers.ResizeObservation` + - :class:`experimental.wrappers.ResizeObservationV0` - VectorResizeObservation (*) - Yes * - Not Implemented - - :class:`experimental.wrappers.ReshapeObservation` + - :class:`experimental.wrappers.ReshapeObservationV0` - VectorReshapeObservation (*) - Yes * - Not Implemented - - :class:`experimental.wrappers.RescaleObservation` + - :class:`experimental.wrappers.RescaleObservationV0` - VectorRescaleObservation (*) - Yes * - Not Implemented - - :class:`experimental.wrappers.DtypeObservation` + - :class:`experimental.wrappers.DtypeObservationV0` - VectorDtypeObservation (*) - Yes * - :class:`wrappers.PixelObservationWrapper` - PixelObservation - VectorPixelObservation - No - * - :class:`NormalizeObservation` + * - :class:`wrappers.NormalizeObservation` - NormalizeObservation - VectorNormalizeObservation - No - * - :class:`TimeAwareObservation` - - TimeAwareObservation + * - :class:`wrappers.TimeAwareObservation` + - :class:`experimental.wrappers.TimeAwareObservationV0` - VectorTimeAwareObservation - No - * - :class:`FrameStack` + * - :class:`wrappers.FrameStack` - FrameStackObservation - VectorFrameStackObservation - No * - Not Implemented - - DelayObservation + - :class:`experimental.wrappers.DelayObservationV0` - VectorDelayObservation - No - * - :class:`AtariPreprocessing` + * - :class:`wrappers.AtariPreprocessing` - AtariPreprocessing - Not Implemented - No ``` -### Lambda Action Wrappers - +### Action Wrappers ```{eval-rst} .. py:currentmodule:: gymnasium @@ -114,25 +112,20 @@ Gymnasium already contains a large collection of wrappers, but we believe that t - VectorLambdaAction - No * - :class:`wrappers.ClipAction` - - ClipAction + - :class:`experimental.wrappers.ClipActionV0` - VectorClipAction (*) - Yes * - :class:`wrappers.RescaleAction` - - RescaleAction + - :class:`experimental.wrappers.RescaleActionV0` - VectorRescaleAction (*) - Yes * - Not Implemented - - NanAction - - VectorNanAction (*) - - Yes - * - Not Implemented - - StickyAction + - :class:`experimental.wrappers.StickyActionV0` - VectorStickyAction - No ``` -### Lambda Reward Wrappers - +### Reward Wrappers ```{eval-rst} .. py:currentmodule:: gymnasium @@ -175,7 +168,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t - VectorPassiveEnvChecker * - :class:`wrappers.OrderEnforcing` - OrderEnforcing - - VectorOrderEnforcing (*) + - VectorOrderEnforcing * - :class:`wrappers.EnvCompatibility` - Moved to `shimmy `_ - Not Implemented @@ -189,10 +182,10 @@ Gymnasium already contains a large collection of wrappers, but we believe that t - HumanRendering - Not Implemented * - Not Implemented - - :class:`experimental.wrappers.JaxToNumpy` + - :class:`experimental.wrappers.JaxToNumpyV0` - VectorJaxToNumpy (*) * - Not Implemented - - :class:`experimental.wrappers.JaxToTorch` + - :class:`experimental.wrappers.JaxToTorchV0` - VectorJaxToTorch (*) ``` diff --git a/docs/api/experimental/functional.md b/docs/api/experimental/functional.md index 974ce01ea..b3dd2569f 100644 --- a/docs/api/experimental/functional.md +++ b/docs/api/experimental/functional.md @@ -12,9 +12,6 @@ title: Functional .. autofunction:: gymnasium.experimental.FuncEnv.initial .. autofunction:: gymnasium.experimental.FuncEnv.transition -.. autofunction:: gymnasium.experimental.FuncEnv.observation -.. autofunction:: gymnasium.experimental.FuncEnv.initial - .. autofunction:: gymnasium.experimental.FuncEnv.observation .. autofunction:: gymnasium.experimental.FuncEnv.reward .. autofunction:: gymnasium.experimental.FuncEnv.terminal @@ -33,4 +30,8 @@ title: Functional ```{eval-rst} ... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv + +... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.reset +... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.step +... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.render ``` diff --git a/docs/api/experimental/wrappers.md b/docs/api/experimental/wrappers.md index 865afbc09..7acb7a276 100644 --- a/docs/api/experimental/wrappers.md +++ b/docs/api/experimental/wrappers.md @@ -1,6 +1,6 @@ # Wrappers -## Lambda Observation Wrappers +## Observation Wrappers ```{eval-rst} .. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0 @@ -11,24 +11,6 @@ .. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0 .. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0 .. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0 -``` - -## Lambda Action Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0 -``` - -## Lambda Reward Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0 -.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 -``` - -## Observation Wrappers - -```{eval-rst} .. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0 .. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0 ``` @@ -36,11 +18,22 @@ ## Action Wrappers ```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0 +.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0 +.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0 .. autoclass:: gymnasium.experimental.wrappers.StickyActionV0 ``` +# Reward Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0 +.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 +``` + ## Common Wrappers ```{eval-rst} - +.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0 +.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0 ``` diff --git a/gymnasium/__init__.py b/gymnasium/__init__.py index 4241023dd..36d60c80f 100644 --- a/gymnasium/__init__.py +++ b/gymnasium/__init__.py @@ -27,7 +27,7 @@ __all__ = [ "register", "registry", "pprint_registry", - # root files + # module folders "envs", "spaces", "utils", diff --git a/gymnasium/envs/phys2d/__init__.py b/gymnasium/envs/phys2d/__init__.py index f00c65601..8ff4b205c 100644 --- a/gymnasium/envs/phys2d/__init__.py +++ b/gymnasium/envs/phys2d/__init__.py @@ -1,2 +1,2 @@ -from gymnasium.envs.phys2d.cartpole import CartPoleFunctional -from gymnasium.envs.phys2d.pendulum import PendulumFunctional +from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv +from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv diff --git a/gymnasium/experimental/__init__.py b/gymnasium/experimental/__init__.py index e348e26d7..dbbdf267d 100644 --- a/gymnasium/experimental/__init__.py +++ b/gymnasium/experimental/__init__.py @@ -1,6 +1,7 @@ """Root __init__ of the gym experimental wrappers.""" +from gymnasium.experimental import functional, wrappers from gymnasium.experimental.functional import FuncEnv @@ -8,6 +9,8 @@ __all__ = [ # Functional "FuncEnv", "functional", - # Wrapper + # Wrappers "wrappers", + # Vector + # "vector", ] diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 988247db6..3a2a52b61 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -10,31 +10,59 @@ from gymnasium.experimental.wrappers.lambda_action import ( ClipActionV0, RescaleActionV0, ) -from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0 +from gymnasium.experimental.wrappers.lambda_observations import ( + LambdaObservationV0, + FilterObservationV0, + FlattenObservationV0, + GrayscaleObservationV0, + ResizeObservationV0, + ReshapeObservationV0, + RescaleObservationV0, + DtypeObservationV0, +) from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0 from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0 -from gymnasium.experimental.wrappers.sticky_action import StickyActionV0 -from gymnasium.experimental.wrappers.time_aware_observation import ( +from gymnasium.experimental.wrappers.stateful_action import StickyActionV0 +from gymnasium.experimental.wrappers.stateful_observation import ( TimeAwareObservationV0, + DelayObservationV0, ) -from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0 __all__ = [ - "ArgType", - # Lambda Action + # --- Observation wrappers --- + "LambdaObservationV0", + "FilterObservationV0", + "FlattenObservationV0", + "GrayscaleObservationV0", + "ResizeObservationV0", + "ReshapeObservationV0", + "RescaleObservationV0", + "DtypeObservationV0", + # "PixelObservationV0", + # "NormalizeObservationV0", + "TimeAwareObservationV0", + # "FrameStackV0", + "DelayObservationV0", + # "AtariPreprocessingV0" + # --- Action Wrappers --- "LambdaActionV0", - "StickyActionV0", "ClipActionV0", "RescaleActionV0", - # Lambda Observation - "LambdaObservationV0", - "DelayObservationV0", - "TimeAwareObservationV0", - # Lambda Reward + # "NanAction", + "StickyActionV0", + # --- Reward wrappers --- "LambdaRewardV0", "ClipRewardV0", - # Jax conversion wrappers + # "RescaleRewardV0", + # "NormalizeRewardV0", + # --- Common --- + # "AutoReset", + # "PassiveEnvChecker", + # "OrderEnforcing", + # "RecordEpisodeStatistics", + # "RenderCollection", + # "HumanRendering", "JaxToNumpyV0", "JaxToTorchV0", ] diff --git a/gymnasium/experimental/wrappers/delay_observation.py b/gymnasium/experimental/wrappers/delay_observation.py deleted file mode 100644 index 2903d1154..000000000 --- a/gymnasium/experimental/wrappers/delay_observation.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Wrapper for delaying the returned observation.""" - -from collections import deque - -import jumpy as jp - -import gymnasium as gym -from gymnasium.core import ObsType - - -class DelayObservationV0(gym.ObservationWrapper): - """Wrapper which adds a delay to the returned observation.""" - - def __init__(self, env: gym.Env, delay: int): - """Initialize the DelayObservation wrapper. - - Args: - env (Env): the wrapped environment - delay (int): number of timesteps for delaying the observation. - Before reaching the `delay` number of timesteps, - returned observation is an array of zeros with the - same shape of the observation space. - """ - super().__init__(env) - self.delay = delay - self.observation_queue = deque() - - def observation(self, observation: ObsType) -> ObsType: - """Return the delayed observation.""" - self.observation_queue.append(observation) - - if len(self.observation_queue) > self.delay: - return self.observation_queue.popleft() - - return jp.zeros_like(observation) diff --git a/gymnasium/experimental/wrappers/lambda_action.py b/gymnasium/experimental/wrappers/lambda_action.py index 9cb7c4a30..181129807 100644 --- a/gymnasium/experimental/wrappers/lambda_action.py +++ b/gymnasium/experimental/wrappers/lambda_action.py @@ -1,13 +1,19 @@ -"""Lambda action wrapper which apply a function to the provided action.""" -from typing import Any, Callable, Union +"""A collection of wrappers that all use the LambdaAction class. + +* ``LambdaAction`` - Transforms the actions based on a function +* ``ClipAction`` - Clips the action within a bounds +* ``RescaleAction`` - Rescales the action within a minimum and maximum actions +""" +from __future__ import annotations + +from typing import Callable import jumpy as jp import numpy as np import gymnasium as gym -from gymnasium import spaces -from gymnasium.core import ActType -from gymnasium.experimental.wrappers import ArgType +from gymnasium.core import ActType, WrapperActType +from gymnasium.spaces import Box, Space class LambdaActionV0(gym.ActionWrapper): @@ -16,19 +22,23 @@ class LambdaActionV0(gym.ActionWrapper): def __init__( self, env: gym.Env, - func: Callable[[ArgType], Any], + func: Callable[[WrapperActType], ActType], + action_space: Space | None, ): """Initialize LambdaAction. Args: - env (Env): The gymnasium environment - func (Callable): function to apply to action + env: The gymnasium environment + func: Function to apply to ``step`` ``action`` + action_space: The updated action space of the wrapper given the function. """ super().__init__(env) + if action_space is not None: + self.action_space = action_space self.func = func - def action(self, action: ActType) -> Any: + def action(self, action: WrapperActType) -> ActType: """Apply function to action.""" return self.func(action) @@ -53,14 +63,19 @@ class ClipActionV0(LambdaActionV0): Args: env: The environment to apply the wrapper """ - assert isinstance(env.action_space, spaces.Box) + assert isinstance(env.action_space, Box) + super().__init__( env, lambda action: jp.clip(action, env.action_space.low, env.action_space.high), + Box( + -np.inf, + np.inf, + shape=env.action_space.shape, + dtype=env.action_space.dtype, + ), ) - self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape) - class RescaleActionV0(LambdaActionV0): """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. @@ -86,8 +101,8 @@ class RescaleActionV0(LambdaActionV0): def __init__( self, env: gym.Env, - min_action: Union[float, int, np.ndarray], - max_action: Union[float, int, np.ndarray], + min_action: float | int | np.ndarray, + max_action: float | int | np.ndarray, ): """Initializes the :class:`RescaleAction` wrapper. @@ -96,28 +111,44 @@ class RescaleActionV0(LambdaActionV0): min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. """ - assert isinstance( - env.action_space, spaces.Box - ), f"expected Box action space, got {type(env.action_space)}" - assert np.less_equal(min_action, max_action).all(), (min_action, max_action) - - low = env.action_space.low - high = env.action_space.high - - self.min_action = np.full( - env.action_space.shape, min_action, dtype=env.action_space.dtype + assert isinstance(env.action_space, Box) + assert not np.any(env.action_space.low == np.inf) and not np.any( + env.action_space.high == np.inf ) - self.max_action = np.full( - env.action_space.shape, max_action, dtype=env.action_space.dtype + + if not isinstance(min_action, np.ndarray): + assert np.issubdtype(type(min_action), np.integer) or np.issubdtype( + type(max_action), np.floating + ) + min_action = np.full(env.action_space.shape, min_action) + + assert min_action.shape == env.action_space.shape + assert not np.any(min_action == np.inf) + + if not isinstance(max_action, np.ndarray): + assert np.issubdtype(type(max_action), np.integer) or np.issubdtype( + type(max_action), np.floating + ) + max_action = np.full(env.action_space.shape, max_action) + assert max_action.shape == env.action_space.shape + assert not np.any(max_action == np.inf) + + assert isinstance(env.action_space, Box) + assert np.all(np.less_equal(min_action, max_action)) + + # Imagine the x-axis between the old Box and the y-axis being the new Box + gradient = (env.action_space.high - env.action_space.low) / ( + max_action - min_action ) + intercept = gradient * -min_action + env.action_space.low super().__init__( env, - lambda action: jp.clip( - low - + (high - low) - * ((action - self.min_action) / (self.max_action - self.min_action)), - low, - high, + lambda action: gradient * action + intercept, + Box( + low=min_action, + high=max_action, + shape=env.action_space.shape, + dtype=env.action_space.dtype, ), ) diff --git a/gymnasium/experimental/wrappers/lambda_observations.py b/gymnasium/experimental/wrappers/lambda_observations.py index 311ac3cd1..c7bc77218 100644 --- a/gymnasium/experimental/wrappers/lambda_observations.py +++ b/gymnasium/experimental/wrappers/lambda_observations.py @@ -1,17 +1,27 @@ -"""Lambda observation wrappers which apply a function to the observation.""" +"""A collection of observation wrappers using a lambda function. + +* ``LambdaObservation`` - Transforms the observation with a function +* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys +* ``FlattenObservation`` - Flattens the observations +* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation +* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation) +* ``ReshapeObservation`` - Reshapes an array-based observation +* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value +* ``DtypeObservation`` - Convert a observation dtype +""" from __future__ import annotations from typing import Any, Callable, Sequence +from typing_extensions import Final import jumpy as jp import numpy as np -import numpy.typing as npt import gymnasium as gym from gymnasium import spaces from gymnasium.core import ObsType from gymnasium.error import DependencyNotInstalled -from gymnasium.spaces import utils +from gymnasium.spaces import Box, utils class LambdaObservationV0(gym.ObservationWrapper): @@ -71,32 +81,82 @@ class FilterObservationV0(LambdaObservationV0): ({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {}) """ - def __init__(self, env: gym.Env, filter_keys: Sequence[str]): + def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]): """Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys.""" - if not isinstance(env.observation_space, spaces.Dict): + assert isinstance(filter_keys, Sequence) + + # Filters for dictionary space + if isinstance(env.observation_space, spaces.Dict): + assert all(isinstance(key, str) for key in filter_keys) + + if any( + key not in env.observation_space.spaces.keys() for key in filter_keys + ): + missing_keys = [ + key + for key in filter_keys + if key not in env.observation_space.spaces.keys() + ] + raise ValueError( + "All the `filter_keys` must be included in the observation space.\n" + f"Filter keys: {filter_keys}\n" + f"Observation keys: {list(env.observation_space.spaces.keys())}\n" + f"Missing keys: {missing_keys}" + ) + + new_observation_space = spaces.Dict( + {key: env.observation_space[key] for key in filter_keys} + ) + if len(new_observation_space) == 0: + raise ValueError( + "The observation space is empty due to filtering all keys." + ) + + super().__init__( + env, + lambda obs: {key: obs[key] for key in filter_keys}, + new_observation_space, + ) + # Filter for tuple observation + elif isinstance(env.observation_space, spaces.Tuple): + assert all(isinstance(key, int) for key in filter_keys) + assert len(set(filter_keys)) == len( + filter_keys + ), f"Duplicate keys exist, filter_keys: {filter_keys}" + + if any( + 0 < key and key >= len(env.observation_space) for key in filter_keys + ): + missing_index = [ + key + for key in filter_keys + if 0 < key and key >= len(env.observation_space) + ] + raise ValueError( + "All the `filter_keys` must be included in the length of the observation space.\n" + f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, " + f"missing indexes: {missing_index}" + ) + + new_observation_spaces = spaces.Tuple( + env.observation_space[key] for key in filter_keys + ) + if len(new_observation_spaces) == 0: + raise ValueError( + "The observation space is empty due to filtering all keys." + ) + + super().__init__( + env, + lambda obs: tuple(obs[key] for key in filter_keys), + new_observation_spaces, + ) + else: raise ValueError( - f"FilterObservation wrapper is only usable with dict observations, actual type: {type(env.observation_space)}" + f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}" ) - if any(key not in env.observation_space.keys() for key in filter_keys): - missing_keys = [ - key for key in filter_keys if key not in env.observation_space.keys() - ] - raise ValueError( - "All the filter_keys must be included in the original observation space.\n" - f"Filter keys: {filter_keys}\n" - f"Observation keys: {list(env.observation_space.keys())}\n" - f"Missing keys: {missing_keys}" - ) - - new_observation_space = spaces.Dict( - {key: env.observation_space[key] for key in filter_keys} - ) - super().__init__( - env, - lambda obs: {key: obs[key] for key in filter_keys}, - new_observation_space, - ) + self.filter_keys: Final[Sequence[str | int]] = filter_keys class FlattenObservationV0(LambdaObservationV0): @@ -117,9 +177,10 @@ class FlattenObservationV0(LambdaObservationV0): def __init__(self, env: gym.Env): """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.""" - flattened_space = utils.flatten_space(env.observation_space) super().__init__( - env, lambda obs: utils.flatten(flattened_space, obs), flattened_space + env, + lambda obs: utils.flatten(env.observation_space, obs), + utils.flatten_space(env.observation_space), ) @@ -154,7 +215,7 @@ class GrayscaleObservationV0(LambdaObservationV0): and env.observation_space.dtype == np.uint8 ) - self.keep_dim = keep_dim + self.keep_dim: Final[bool] = keep_dim if keep_dim: new_observation_space = spaces.Box( low=0, @@ -167,7 +228,8 @@ class GrayscaleObservationV0(LambdaObservationV0): lambda obs: jp.expand_dims( jp.sum( jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1 - ) + ).astype(np.uint8), + axis=-1, ), new_observation_space, ) @@ -179,7 +241,7 @@ class GrayscaleObservationV0(LambdaObservationV0): env, lambda obs: jp.sum( jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1 - ), + ).astype(np.uint8), new_observation_space, ) @@ -215,7 +277,7 @@ class ResizeObservationV0(LambdaObservationV0): "opencv is not install, run `pip install gymnasium[other]`" ) - self.shape = tuple(shape) + self.shape: Final[tuple[int, ...]] = tuple(shape) new_observation_space = spaces.Box( low=0, high=255, shape=self.shape + env.observation_space.shape[2:] @@ -237,7 +299,7 @@ class ReshapeObservationV0(LambdaObservationV0): assert isinstance(shape, tuple) assert all(np.issubdtype(type(elem), np.integer) for elem in shape) - assert all(x > 0 for x in shape) + assert all(x > 0 or x == -1 for x in shape) new_observation_space = spaces.Box( low=np.reshape(np.ravel(env.observation_space.low), shape), @@ -245,9 +307,8 @@ class ReshapeObservationV0(LambdaObservationV0): shape=shape, dtype=env.observation_space.dtype, ) - super().__init__( - env, lambda obs: jp.reshape(obs, self.shape), new_observation_space - ) + self.shape = shape + super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space) class RescaleObservationV0(LambdaObservationV0): @@ -256,18 +317,23 @@ class RescaleObservationV0(LambdaObservationV0): def __init__( self, env: gym.Env, - min_obs: tuple[np.floating, np.integer, np.ndarray], - max_obs: tuple[np.floating, np.integer, np.ndarray], + min_obs: np.floating | np.integer | np.ndarray, + max_obs: np.floating | np.integer | np.ndarray, ): """Constructor that requires the env observation spaces to be a :class:`Box`.""" assert isinstance(env.observation_space, spaces.Box) + assert not np.any(env.observation_space.low == np.inf) and not np.any( + env.observation_space.high == np.inf + ) if not isinstance(min_obs, np.ndarray): assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype( type(max_obs), np.floating ) min_obs = np.full(env.observation_space.shape, min_obs) - assert min_obs.shape == env.observation_space.shape + assert ( + min_obs.shape == env.observation_space.shape + ), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}" assert not np.any(min_obs == np.inf) if not isinstance(max_obs, np.ndarray): @@ -278,52 +344,66 @@ class RescaleObservationV0(LambdaObservationV0): assert max_obs.shape == env.observation_space.shape assert not np.any(max_obs == np.inf) - env_low = env.observation_space.low - env_high = env.observation_space.high + self.min_obs = min_obs + self.max_obs = max_obs + + # Imagine the x-axis between the old Box and the y-axis being the new Box + gradient = (max_obs - min_obs) / ( + env.observation_space.high - env.observation_space.low + ) + intercept = gradient * -env.observation_space.low + min_obs - new_observation_space = spaces.Box(low=min_obs, high=max_obs) super().__init__( env, - lambda obs: env_low - + (env_high - env_low) * ((obs - min_obs) / (max_obs - min_obs)), - new_observation_space, + lambda obs: gradient * obs + intercept, + Box( + low=min_obs, + high=max_obs, + shape=env.observation_space.shape, + dtype=env.observation_space.dtype, + ), ) class DtypeObservationV0(LambdaObservationV0): """Observation wrapper for transforming the dtype of an observation.""" - def __init__(self, env: gym.Env, dtype: npt.DTypeLike): + def __init__(self, env: gym.Env, dtype: Any): """Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces.""" assert isinstance( env.observation_space, (spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary), ) - dtype = np.dtype(dtype) + self.dtype = dtype if isinstance(env.observation_space, spaces.Box): new_observation_space = spaces.Box( low=env.observation_space.low, high=env.observation_space.high, shape=env.observation_space.shape, - dtype=dtype.__name__, + dtype=self.dtype, ) elif isinstance(env.observation_space, spaces.Discrete): new_observation_space = spaces.Box( low=env.observation_space.start, high=env.observation_space.start + env.observation_space.n, shape=(), - dtype=dtype.__name__, + dtype=self.dtype, ) elif isinstance(env.observation_space, spaces.MultiDiscrete): new_observation_space = spaces.MultiDiscrete( - env.observation_space.nvec, dtype=dtype.__name__ + env.observation_space.nvec, dtype=dtype ) elif isinstance(env.observation_space, spaces.MultiBinary): new_observation_space = spaces.Box( - low=0, high=1, shape=env.observation_space.shape, dtype=dtype.__name__ + low=0, + high=1, + shape=env.observation_space.shape, + dtype=self.dtype, ) else: - raise TypeError + raise TypeError( + "DtypeObservation is only compatible with value / array-based observations." + ) super().__init__(env, lambda obs: dtype(obs), new_observation_space) diff --git a/gymnasium/experimental/wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py index 2a0393812..111f1157c 100644 --- a/gymnasium/experimental/wrappers/lambda_reward.py +++ b/gymnasium/experimental/wrappers/lambda_reward.py @@ -1,12 +1,17 @@ -"""Lambda reward wrappers which apply a function to the reward.""" +"""A collection of wrappers for modifying the reward. -from typing import Any, Callable, Optional, Union +* ``LambdaReward`` - Transforms the reward by a function +* ``ClipReward`` - Clips the reward between a minimum and maximum value +""" + +from __future__ import annotations + +from typing import Callable, SupportsFloat import numpy as np import gymnasium as gym from gymnasium.error import InvalidBound -from gymnasium.experimental.wrappers import ArgType class LambdaRewardV0(gym.RewardWrapper): @@ -26,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper): def __init__( self, env: gym.Env, - func: Callable[[ArgType], Any], + func: Callable[[SupportsFloat], SupportsFloat], ): """Initialize LambdaRewardV0 wrapper. @@ -38,7 +43,7 @@ class LambdaRewardV0(gym.RewardWrapper): self.func = func - def reward(self, reward: Union[float, int, np.ndarray]) -> Any: + def reward(self, reward: SupportsFloat) -> SupportsFloat: """Apply function to reward. Args: @@ -64,8 +69,8 @@ class ClipRewardV0(LambdaRewardV0): def __init__( self, env: gym.Env, - min_reward: Optional[Union[float, np.ndarray]] = None, - max_reward: Optional[Union[float, np.ndarray]] = None, + min_reward: float | np.ndarray | None = None, + max_reward: float | np.ndarray | None = None, ): """Initialize ClipRewardsV0 wrapper. diff --git a/gymnasium/experimental/wrappers/numpy_to_jax.py b/gymnasium/experimental/wrappers/numpy_to_jax.py index f352766be..fbcbd0ebf 100644 --- a/gymnasium/experimental/wrappers/numpy_to_jax.py +++ b/gymnasium/experimental/wrappers/numpy_to_jax.py @@ -6,70 +6,90 @@ import numbers from collections import abc from typing import Any, Iterable, Mapping, SupportsFloat -import jax.numpy as jnp import numpy as np from gymnasium import Env, Wrapper from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType +from gymnasium.error import DependencyNotInstalled + + +try: + import jax.numpy as jnp +except ImportError: + # We handle the error internal to the relative functions + jnp = None @functools.singledispatch def numpy_to_jax(value: Any) -> Any: """Converts a value to a Jax DeviceArray.""" - raise Exception( - f"No conversion for Numpy to Jax registered for type: {type(value)}" - ) + if jnp is None: + raise DependencyNotInstalled( + "Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`" + ) + else: + raise Exception( + f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github." + ) -@numpy_to_jax.register(numbers.Number) -@numpy_to_jax.register(np.ndarray) -def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray: - """Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray.""" - return jnp.array(value) +if jnp is not None: + @numpy_to_jax.register(numbers.Number) + @numpy_to_jax.register(np.ndarray) + def _number_ndarray_numpy_to_jax( + value: np.ndarray | numbers.Number, + ) -> jnp.DeviceArray: + """Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray.""" + assert jnp is not None + return jnp.array(value) -@numpy_to_jax.register(abc.Mapping) -def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays.""" - return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) + @numpy_to_jax.register(abc.Mapping) + def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays.""" + return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) - -@numpy_to_jax.register(abc.Iterable) -def _iterable_numpy_to_jax( - value: Iterable[np.ndarray | Any], -) -> Iterable[jnp.DeviceArray | Any]: - """Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays.""" - return type(value)(numpy_to_jax(v) for v in value) + @numpy_to_jax.register(abc.Iterable) + def _iterable_numpy_to_jax( + value: Iterable[np.ndarray | Any], + ) -> Iterable[jnp.DeviceArray | Any]: + """Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays.""" + return type(value)(numpy_to_jax(v) for v in value) @functools.singledispatch def jax_to_numpy(value: Any) -> Any: """Converts a value to a numpy array.""" - raise Exception( - f"No conversion for Jax to Numpy registered for type: {type(value)}" - ) + if jnp is None: + raise DependencyNotInstalled( + "Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`" + ) + else: + raise Exception( + f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github." + ) -@jax_to_numpy.register(jnp.DeviceArray) -def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray: - """Converts a Jax DeviceArray to a numpy array.""" - return np.array(value) +if jnp is not None: + @jax_to_numpy.register(jnp.DeviceArray) + def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray: + """Converts a Jax DeviceArray to a numpy array.""" + return np.array(value) -@jax_to_numpy.register(abc.Mapping) -def _mapping_jax_to_numpy( - value: Mapping[str, jnp.DeviceArray | Any] -) -> Mapping[str, np.ndarray | Any]: - """Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays.""" - return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) + @jax_to_numpy.register(abc.Mapping) + def _mapping_jax_to_numpy( + value: Mapping[str, jnp.DeviceArray | Any] + ) -> Mapping[str, np.ndarray | Any]: + """Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays.""" + return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) - -@jax_to_numpy.register(abc.Iterable) -def _iterable_jax_to_numpy( - value: Iterable[np.ndarray | Any], -) -> Iterable[jnp.DeviceArray | Any]: - """Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays.""" - return type(value)(jax_to_numpy(v) for v in value) + @jax_to_numpy.register(abc.Iterable) + def _iterable_jax_to_numpy( + value: Iterable[np.ndarray | Any], + ) -> Iterable[jnp.DeviceArray | Any]: + """Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays.""" + return type(value)(jax_to_numpy(v) for v in value) class JaxToNumpyV0(Wrapper): @@ -88,6 +108,10 @@ class JaxToNumpyV0(Wrapper): Args: env: the environment to wrap """ + if jnp is None: + raise DependencyNotInstalled( + "Jax is not installed, run `pip install gymnasium[jax]`" + ) super().__init__(env) def step( diff --git a/gymnasium/experimental/wrappers/stateful_action.py b/gymnasium/experimental/wrappers/stateful_action.py new file mode 100644 index 000000000..7527c12a4 --- /dev/null +++ b/gymnasium/experimental/wrappers/stateful_action.py @@ -0,0 +1,56 @@ +"""A collection of stateful action wrappers. + +* StickyAction - There is a probability that the action is taken again +""" +from __future__ import annotations + +from typing import Any, SupportsFloat + +import gymnasium as gym +from gymnasium.core import WrapperActType, WrapperObsType +from gymnasium.error import InvalidProbability + + +class StickyActionV0(gym.Wrapper): + """Wrapper which adds a probability of repeating the previous action. + + This wrapper follows the implementation proposed by `Machado et al., 2018 `_ + in Section 5.2 on page 12. + """ + + def __init__(self, env: gym.Env, repeat_action_probability: float): + """Initialize StickyAction wrapper. + + Args: + env (Env): the wrapped environment + repeat_action_probability (int | float): a probability of repeating the old action. + """ + if not 0 <= repeat_action_probability < 1: + raise InvalidProbability( + f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}" + ) + + super().__init__(env) + self.repeat_action_probability = repeat_action_probability + self.last_action: WrapperActType | None = None + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Reset the environment.""" + self.last_action = None + + return super().reset(seed=seed, options=options) + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Execute the action.""" + if ( + self.last_action is not None + and self.np_random.uniform() < self.repeat_action_probability + ): + action = self.last_action + + self.last_action = action + return action diff --git a/gymnasium/experimental/wrappers/stateful_observation.py b/gymnasium/experimental/wrappers/stateful_observation.py new file mode 100644 index 000000000..ffc96de89 --- /dev/null +++ b/gymnasium/experimental/wrappers/stateful_observation.py @@ -0,0 +1,200 @@ +"""A collection of stateful observation wrappers. + +* DelayObservation - A wrapper for delaying the returned observation +* TimeAwareObservation - A wrapper for adding time aware observations to environment observation +""" +from __future__ import annotations + +from collections import deque +from typing import Any, SupportsFloat +from typing_extensions import Final + +import jumpy as jp +import numpy as np + +import gymnasium as gym +import gymnasium.spaces as spaces +from gymnasium.core import ActType, ObsType, WrapperObsType +from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple + + +class DelayObservationV0(gym.ObservationWrapper): + """Wrapper which adds a delay to the returned observation.""" + + def __init__(self, env: gym.Env, delay: int): + """Initialize the DelayObservation wrapper. + + Args: + env (Env): the wrapped environment + delay (int): number of timesteps for delaying the observation. + Before reaching the `delay` number of timesteps, + returned observation is an array of zeros with the + same shape of the observation space. + """ + assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete)) + assert 0 < delay + + self.delay: Final[int] = delay + self.observation_queue: Final[deque] = deque() + + super().__init__(env) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment, clearing the observation queue.""" + self.observation_queue.clear() + + return super().reset(seed=seed, options=options) + + def observation(self, observation: ObsType) -> ObsType: + """Return the delayed observation.""" + self.observation_queue.append(observation) + + if len(self.observation_queue) > self.delay: + return self.observation_queue.popleft() + + return jp.zeros_like(observation) + + +class TimeAwareObservationV0(gym.ObservationWrapper): + """Augment the observation with time information of the episode. + + Time can be represented as a normalized value between [0,1] + or by the number of timesteps remaining before truncation occurs. + + For environments with ``Dict`` or ``Tuple`` observation spaces, by default, + the time information is automatically added in the key `"time"` and + as the final element in the tuple. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.experimental.wrappers import TimeAwareObservationV0 + >>> env = gym.make('CartPole-v1') + >>> env = TimeAwareObservationV0(env) + >>> env.observation_space + Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32)) + >>> _ = env.reset() + >>> env.step(env.action_space.sample())[0] + OrderedDict([('obs', + ... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)), + ... ('time', array([0.002]))]) + + Flatten observation space example: + >>> env = gym.make('CartPole-v1') + >>> env = TimeAwareObservationV0(env, flatten=True) + >>> env.observation_space + Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32) + >>> _ = env.reset() + >>> env.step(env.action_space.sample())[0] + array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32) + """ + + def __init__( + self, + env: gym.Env, + flatten: bool = False, + normalize_time: bool = True, + *, + dict_time_key: str = "time", + ): + """Initialize :class:`TimeAwareObservationV0`. + + Args: + env: The environment to apply the wrapper + flatten: Flatten the observation to a `Box` of a single dimension + normalize_time: if `True` return time in the range [0,1] + otherwise return time as remaining timesteps before truncation + dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`. + """ + super().__init__(env) + self.flatten: Final[bool] = flatten + self.normalize_time: Final[bool] = normalize_time + + if hasattr(env, "_max_episode_steps"): + self.max_timesteps = getattr(env, "_max_episode_steps") + elif env.spec is not None and env.spec.max_episode_steps is not None: + self.max_timesteps = env.spec.max_episode_steps + else: + raise ValueError( + "The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`." + ) + + self.timesteps: int = 0 + + # Find the normalized time space + if self.normalize_time: + self._time_preprocess_func = lambda time: time / self.max_timesteps + time_space = Box(0.0, 1.0) + else: + self._time_preprocess_func = lambda time: self.max_timesteps - time + time_space = Box(0, self.max_timesteps, dtype=np.int32) + + # Find the observation space + if isinstance(env.observation_space, Dict): + assert dict_time_key not in env.observation_space.keys() + observation_space = Dict( + {dict_time_key: time_space}, **env.observation_space.spaces + ) + self._append_data_func = lambda obs, time: {**obs, dict_time_key: time} + elif isinstance(env.observation_space, Tuple): + observation_space = Tuple(env.observation_space.spaces + (time_space,)) + self._append_data_func = lambda obs, time: obs + (time,) + else: + observation_space = Dict(obs=env.observation_space, time=time_space) + self._append_data_func = lambda obs, time: {"obs": obs, "time": time} + + # If to flatten the observation space + if self.flatten: + self.observation_space = spaces.flatten_space(observation_space) + self._obs_postprocess_func = lambda obs: spaces.flatten( + observation_space, obs + ) + else: + self.observation_space = observation_space + self._obs_postprocess_func = lambda obs: obs + + def observation(self, observation: ObsType) -> WrapperObsType: + """Adds to the observation with the current time information. + + Args: + observation: The observation to add the time step to + + Returns: + The observation with the time information appended to + """ + return self._obs_postprocess_func( + self._append_data_func( + observation, self._time_preprocess_func(self.timesteps) + ) + ) + + def step( + self, action: ActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment, incrementing the time step. + + Args: + action: The action to take + + Returns: + The environment's step using the action. + """ + self.timesteps += 1 + return super().step(action) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Reset the environment setting the time to zero. + + Args: + seed: The seed to reset the environment + options: The options used to reset the environment + + Returns: + The reset environment + """ + self.timesteps = 0 + + return super().reset(seed=seed, options=options) diff --git a/gymnasium/experimental/wrappers/sticky_action.py b/gymnasium/experimental/wrappers/sticky_action.py deleted file mode 100644 index 586e2ff1b..000000000 --- a/gymnasium/experimental/wrappers/sticky_action.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Wrapper which adds a probability of repeating the previous executed action.""" -from typing import Union - -import gymnasium as gym -from gymnasium.core import ActType -from gymnasium.error import InvalidProbability - - -class StickyActionV0(gym.ActionWrapper): - """Wrapper which adds a probability of repeating the previous action.""" - - def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]): - """Initialize StickyAction wrapper. - - Args: - env (Env): the wrapped environment - repeat_action_probability (int | float): a proability of repeating the old action. - """ - if not 0 <= repeat_action_probability < 1: - raise InvalidProbability( - f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}" - ) - super().__init__(env) - self.repeat_action_probability = repeat_action_probability - self.old_action = None - - def action(self, action: ActType): - """Execute the action.""" - if ( - self.old_action is not None - and self.np_random.uniform() < self.repeat_action_probability - ): - action = self.old_action - self.old_action = action - return action - - def reset(self, **kwargs): - """Reset the environment.""" - self.old_action = None - return super().reset(**kwargs) diff --git a/gymnasium/experimental/wrappers/time_aware_observation.py b/gymnasium/experimental/wrappers/time_aware_observation.py deleted file mode 100644 index 021d55468..000000000 --- a/gymnasium/experimental/wrappers/time_aware_observation.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Wrapper for adding time aware observations to environment observation.""" -from collections import OrderedDict - -import gymnasium as gym -import gymnasium.spaces as spaces -from gymnasium.core import ActType, ObsType -from gymnasium.spaces import Box, Dict - - -class TimeAwareObservationV0(gym.ObservationWrapper): - """Augment the observation with time information of the episode. - - Time can be represented as a normalized value between [0,1] - or by the number of timesteps remaining before truncation occurs. - - Example: - >>> import gym - >>> from gym.wrappers import TimeAwareObservationV0 - >>> env = gym.make('CartPole-v1') - >>> env = TimeAwareObservationV0(env) - >>> env.observation_space - Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32)) - >>> _ = env.reset() - >>> env.step(env.action_space.sample())[0] - OrderedDict([('obs', - ... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)), - ... ('time', array([0.002]))]) - - Flatten observation space example: - >>> env = gym.make('CartPole-v1') - >>> env = TimeAwareObservationV0(env, flatten=True) - >>> env.observation_space - Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32) - >>> _ = env.reset() - >>> env.step(env.action_space.sample())[0] - array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32) - """ - - def __init__(self, env: gym.Env, flatten=False, normalize_time=True): - """Initialize :class:`TimeAwareObservationV0`. - - Args: - env: The environment to apply the wrapper - flatten: Flatten the observation to a `Box` of a single dimension - normalize_time: if `True` return time in the range [0,1] - otherwise return time as remaining timesteps before truncation - """ - super().__init__(env) - self.flatten = flatten - self.normalize_time = normalize_time - self.max_timesteps = getattr(env, "_max_episode_steps") - - if self.normalize_time: - self._get_time_observation = lambda: self.timesteps / self.max_timesteps - time_space = Box(0, 1) - else: - self._get_time_observation = lambda: self.max_timesteps - self.timesteps - time_space = Box(0, self.max_timesteps) - - self.time_aware_observation_space = Dict( - obs=env.observation_space, time=time_space - ) - - if self.flatten: - self.observation_space = spaces.flatten_space( - self.time_aware_observation_space - ) - self._observation_postprocess = lambda observation: spaces.flatten( - self.time_aware_observation_space, observation - ) - else: - self.observation_space = self.time_aware_observation_space - self._observation_postprocess = lambda observation: observation - - def observation(self, observation: ObsType): - """Adds to the observation with the current time information. - - Args: - observation: The observation to add the time step to - - Returns: - The observation with the time information appended to - """ - time_observation = self._get_time_observation() - observation = OrderedDict(obs=observation, time=time_observation) - - return self._observation_postprocess(observation) - - def step(self, action: ActType): - """Steps through the environment, incrementing the time step. - - Args: - action: The action to take - - Returns: - The environment's step using the action. - """ - self.timesteps += 1 - observation, reward, terminated, truncated, info = super().step(action) - - return observation, reward, terminated, truncated, info - - def reset(self, **kwargs): - """Reset the environment setting the time to zero. - - Args: - **kwargs: Kwargs to apply to env.reset() - - Returns: - The reset environment - """ - self.timesteps = 0 - return super().reset(**kwargs) diff --git a/gymnasium/experimental/wrappers/torch_to_jax.py b/gymnasium/experimental/wrappers/torch_to_jax.py index 2a449cf7d..36686e217 100644 --- a/gymnasium/experimental/wrappers/torch_to_jax.py +++ b/gymnasium/experimental/wrappers/torch_to_jax.py @@ -14,92 +14,121 @@ import numbers from collections import abc from typing import Any, Iterable, Mapping, SupportsFloat, Union -import jax.numpy as jnp -from jax import dlpack as jax_dlpack - from gymnasium import Env, Wrapper from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy +try: + import jax.numpy as jnp + from jax import dlpack as jax_dlpack +except ImportError: + jnp, jax_dlpack = None, None + try: import torch from torch.utils import dlpack as torch_dlpack + + Device = Union[str, torch.device] except ImportError: - raise DependencyNotInstalled("torch is not installed, run `pip install torch`") - - -Device = Union[str, torch.device] + torch, torch_dlpack, Device = None, None, None @functools.singledispatch def torch_to_jax(value: Any) -> Any: """Converts a PyTorch Tensor into a Jax DeviceArray.""" - raise Exception( - f"No conversion for PyTorch to Jax registered for type: {type(value)}" - ) + if torch is None: + raise DependencyNotInstalled( + "Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`" + ) + elif jnp is None: + raise DependencyNotInstalled( + "Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`" + ) + else: + raise Exception( + f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github." + ) -@torch_to_jax.register(numbers.Number) -def _number_torch_to_jax(value: numbers.Number) -> Any: - return jnp.array(value) +if torch is not None and jnp is not None: + @torch_to_jax.register(numbers.Number) + def _number_torch_to_jax(value: numbers.Number) -> Any: + """Convert a python number (int, float, complex) to a jax array.""" + assert jnp is not None + return jnp.array(value) -@torch_to_jax.register(torch.Tensor) -def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray: - """Converts a PyTorch Tensor into a Jax DeviceArray.""" - tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage] - tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage] - return tensor + @torch_to_jax.register(torch.Tensor) + def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray: + """Converts a PyTorch Tensor into a Jax DeviceArray.""" + assert torch_dlpack is not None and jax_dlpack is not None + tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage] + value + ) + tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage] + tensor + ) + return tensor + @torch_to_jax.register(abc.Mapping) + def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" + return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) -@torch_to_jax.register(abc.Mapping) -def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" - return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) - - -@torch_to_jax.register(abc.Iterable) -def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: - """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" - return type(value)(torch_to_jax(v) for v in value) + @torch_to_jax.register(abc.Iterable) + def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: + """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" + return type(value)(torch_to_jax(v) for v in value) @functools.singledispatch def jax_to_torch(value: Any, device: Device | None = None) -> Any: """Converts a Jax DeviceArray into a PyTorch Tensor.""" - raise Exception( - f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}" - ) + if torch is None: + raise DependencyNotInstalled( + "Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`" + ) + elif jnp is None: + raise DependencyNotInstalled( + "Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`" + ) + else: + raise Exception( + f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github." + ) -@jax_to_torch.register(jnp.DeviceArray) -def _devicearray_jax_to_torch( - value: jnp.DeviceArray, device: Device | None = None -) -> torch.Tensor: - """Converts a Jax DeviceArray into a PyTorch Tensor.""" - dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage] - tensor = torch_dlpack.from_dlpack(dlpack) - if device: - return tensor.to(device=device) - return tensor +if torch is not None and jnp is not None: + @jax_to_torch.register(jnp.DeviceArray) + def _devicearray_jax_to_torch( + value: jnp.DeviceArray, device: Device | None = None + ) -> torch.Tensor: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + assert jax_dlpack is not None and torch_dlpack is not None + dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage] + value + ) + tensor = torch_dlpack.from_dlpack(dlpack) + if device: + return tensor.to(device=device) + return tensor -@jax_to_torch.register(abc.Mapping) -def _jax_mapping_to_torch( - value: Mapping[str, Any], device: Device | None = None -) -> Mapping[str, Any]: - """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" - return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()}) + @jax_to_torch.register(abc.Mapping) + def _jax_mapping_to_torch( + value: Mapping[str, Any], device: Device | None = None + ) -> Mapping[str, Any]: + """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" + return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()}) - -@jax_to_torch.register(abc.Iterable) -def _jax_iterable_to_torch( - value: Iterable[Any], device: Device | None = None -) -> Iterable[Any]: - """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" - return type(value)(jax_to_torch(v, device) for v in value) + @jax_to_torch.register(abc.Iterable) + def _jax_iterable_to_torch( + value: Iterable[Any], device: Device | None = None + ) -> Iterable[Any]: + """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" + return type(value)(jax_to_torch(v, device) for v in value) class JaxToTorchV0(Wrapper): @@ -107,7 +136,8 @@ class JaxToTorchV0(Wrapper): Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. - For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. + Note: + For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. """ def __init__(self, env: Env, device: Device | None = None): @@ -117,6 +147,15 @@ class JaxToTorchV0(Wrapper): env: The Jax-based environment to wrap device: The device the torch Tensors should be moved to """ + if torch is None: + raise DependencyNotInstalled( + "Torch is not installed, run `pip install torch`" + ) + elif jnp is None: + raise DependencyNotInstalled( + "Jax is not installed, run `pip install gymnasium[jax]`" + ) + super().__init__(env) self.device: Device | None = device diff --git a/gymnasium/spaces/sequence.py b/gymnasium/spaces/sequence.py index dfb74f99f..9c19942f6 100644 --- a/gymnasium/spaces/sequence.py +++ b/gymnasium/spaces/sequence.py @@ -76,6 +76,7 @@ class Sequence(Space[typing.Tuple[Any, ...]]): * ``None`` The length will be randomly drawn from a geometric distribution * ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array. * ``int`` for a fixed length sample + The second element of the mask tuple `sample` mask specifies a mask that is applied when sampling elements from the base space. The mask is applied for each feature space sample. diff --git a/pyproject.toml b/pyproject.toml index ae86c1c19..cb791f12a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "jax-jumpy >=0.2.0", "cloudpickle >=1.2.0", "importlib-metadata >=4.8.0; python_version < '3.10'", - "typing-extensions >=4.3.0; python_version == '3.7'", + "typing-extensions >=4.3.0", "gymnasium-notices >=0.0.1", "shimmy >=0.1.0,<1.0", ] diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 8ce67babb..8957f72ed 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -19,7 +19,7 @@ from gymnasium.wrappers.env_checker import PassiveEnvChecker from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.utils import all_testing_env_specs from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv -from tests.testing_env import GenericTestEnv, old_step_fn +from tests.testing_env import GenericTestEnv, old_step_func from tests.wrappers.utils import has_wrapper @@ -155,7 +155,7 @@ def test_make_disable_env_checker(): def test_apply_api_compatibility(): gym.register( "testing-old-env", - lambda: GenericTestEnv(step_fn=old_step_fn), + lambda: GenericTestEnv(step_func=old_step_func), apply_api_compatibility=True, max_episode_steps=3, ) diff --git a/tests/experimental/wrappers/test_clip_action.py b/tests/experimental/wrappers/test_clip_action.py deleted file mode 100644 index 3c2efcd11..000000000 --- a/tests/experimental/wrappers/test_clip_action.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Test suite for LambdaActionV0.""" -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.experimental.wrappers import ClipActionV0 - - -SEED = 42 - - -@pytest.mark.parametrize( - ("env", "action_unclipped_env", "action_clipped_env"), - ( - [ - # MountainCar action space: Box(-1.0, 1.0, (1,), float32) - gym.make("MountainCarContinuous-v0"), - np.array([1]), - np.array([1.5]), - ], - [ - # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) - gym.make("BipedalWalker-v3"), - np.array([1, 1, 1, 1]), - np.array([10, 10, 10, 10]), - ], - [ - # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) - gym.make("BipedalWalker-v3"), - np.array([0.5, 0.5, 1, 1]), - np.array([0.5, 0.5, 10, 10]), - ], - ), -) -def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env): - """Tests if actions out of bound are correctly clipped. - - Tests whether out of bound actions for the wrapped - environments are correctly clipped. - """ - env.reset(seed=SEED) - obs, _, _, _, _ = env.step(action_unclipped_env) - - env.reset(seed=SEED) - wrapped_env = ClipActionV0(env) - wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env) - - assert np.alltrue(obs == wrapped_obs) diff --git a/tests/experimental/wrappers/test_delay_observation.py b/tests/experimental/wrappers/test_delay_observation.py deleted file mode 100644 index 4f8058612..000000000 --- a/tests/experimental/wrappers/test_delay_observation.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np - -import gymnasium as gym -from gymnasium.experimental.wrappers import DelayObservationV0 - - -SEED = 42 - -DELAY = 3 -NUM_STEPS = 5 - - -def test_delay_observation(): - env = gym.make("CartPole-v1") - env.action_space.seed(SEED) - env.reset(seed=SEED) - - undelayed_observations = [] - for _ in range(NUM_STEPS): - obs, _, _, _, _ = env.step(env.action_space.sample()) - undelayed_observations.append(obs) - - env.action_space.seed(SEED) - env.reset(seed=SEED) - env = DelayObservationV0(env, delay=DELAY) - - delayed_observations = [] - for i in range(NUM_STEPS): - obs, _, _, _, _ = env.step(env.action_space.sample()) - if i < DELAY - 1: - assert np.all(obs == 0) - delayed_observations.append(obs) - - assert np.alltrue( - np.array(delayed_observations[DELAY:]) - == np.array(undelayed_observations[: DELAY - 1]) - ) diff --git a/tests/experimental/wrappers/test_lambda_action.py b/tests/experimental/wrappers/test_lambda_action.py index a3ffa4d57..e70a63de8 100644 --- a/tests/experimental/wrappers/test_lambda_action.py +++ b/tests/experimental/wrappers/test_lambda_action.py @@ -1,60 +1,78 @@ -"""Test suite for LambdaActionV0.""" +"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction.""" import numpy as np -import pytest -import gymnasium as gym -from gymnasium.error import InvalidAction -from gymnasium.experimental.wrappers import LambdaActionV0 +from gymnasium.experimental.wrappers import ( + ClipActionV0, + LambdaActionV0, + RescaleActionV0, +) from gymnasium.spaces import Box from tests.testing_env import GenericTestEnv -NUM_ENVS = 3 -BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64) +SEED = 42 -def generic_step_fn(self, action): +def _record_action_step_func(self, action): return 0, 0, False, False, {"action": action} -@pytest.mark.parametrize( - ("env", "func", "action", "expected"), - [ - ( - GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn), - lambda action: action + 2, - 1, - 3, - ), - ], -) -def test_lambda_action_v0(env, func, action, expected): - """Tests lambda action. - Tests if function is correctly applied to environment's action. - """ - wrapped_env = LambdaActionV0(env, func) - _, _, _, _, info = wrapped_env.step(action) - executed_action = info["action"] +def test_lambda_action_wrapper(): + """Tests LambdaAction through checking that the action taken is transformed by function.""" + env = GenericTestEnv(step_func=_record_action_step_func) + wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3)) - assert executed_action == expected + sampled_action = wrapped_env.action_space.sample() + assert sampled_action not in env.action_space + + _, _, _, _, info = wrapped_env.step(sampled_action) + assert info["action"] in env.action_space + assert sampled_action - 2 == info["action"] -def test_lambda_action_v0_within_vector(): - """Tests lambda action in vectorized environments. - Tests if function is correctly applied to environment's action - in vectorized environment. - """ - env = gym.vector.make( - "CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False +def test_clip_action_wrapper(): + """Test that the action is correctly clipped to the base environment action space.""" + env = GenericTestEnv( + action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])), + step_func=_record_action_step_func, ) - action = np.ones(NUM_ENVS, dtype=np.float64) + wrapped_env = ClipActionV0(env) - wrapped_env = LambdaActionV0(env, lambda action: action.astype(int)) - wrapped_env.reset() + sampled_action = np.array([-1, 5, 3.5], dtype=np.float32) + assert sampled_action not in env.action_space + assert sampled_action in wrapped_env.action_space - wrapped_env.step(action) + _, _, _, _, info = wrapped_env.step(sampled_action) + assert np.all(info["action"] in env.action_space) + assert np.all(info["action"] == np.array([0, 2, 3.5])) - # unwrapped env should raise exception because it does not - # support float actions - with pytest.raises(InvalidAction): - env.step(action) + +def test_rescale_action_wrapper(): + """Test that the action is rescale within a min / max bound.""" + env = GenericTestEnv( + step_func=_record_action_step_func, + action_space=Box(np.array([0, 1]), np.array([1, 3])), + ) + wrapped_env = RescaleActionV0( + env, min_action=np.array([-5, 0]), max_action=np.array([5, 1]) + ) + assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1])) + + for sample_action, expected_action in ( + ( + np.array([0.0, 0.5], dtype=np.float32), + np.array([0.5, 2.0], dtype=np.float32), + ), + ( + np.array([-5.0, 0.0], dtype=np.float32), + np.array([0.0, 1.0], dtype=np.float32), + ), + ( + np.array([5.0, 1.0], dtype=np.float32), + np.array([1.0, 3.0], dtype=np.float32), + ), + ): + assert sample_action in wrapped_env.action_space + + _, _, _, _, info = wrapped_env.step(sample_action) + assert np.all(info["action"] == expected_action) diff --git a/tests/experimental/wrappers/test_lambda_observation.py b/tests/experimental/wrappers/test_lambda_observation.py index ac861f801..430748935 100644 --- a/tests/experimental/wrappers/test_lambda_observation.py +++ b/tests/experimental/wrappers/test_lambda_observation.py @@ -1,59 +1,250 @@ -"""Test suite for LambdaObservationV0.""" +"""Test suite for lambda observation wrappers: """ import numpy as np import gymnasium as gym -from gymnasium.experimental.wrappers import LambdaObservationV0 -from gymnasium.spaces import Box +from gymnasium.experimental.wrappers import ( + DtypeObservationV0, + FilterObservationV0, + FlattenObservationV0, + GrayscaleObservationV0, + LambdaObservationV0, + RescaleObservationV0, + ReshapeObservationV0, + ResizeObservationV0, +) +from gymnasium.spaces import Box, Dict, Tuple +from tests.testing_env import GenericTestEnv -NUM_ENVS = 3 -BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64) - SEED = 42 -DISCRETE_ACTION = 1 -def test_lambda_observation_v0(): - """Tests lambda observation. +def _record_random_obs_reset(self: gym.Env, seed=None, options=None): + obs = self.observation_space.sample() + return obs, {"obs": obs} - Tests if function is correctly applied to environment's observation. - """ - env = gym.make("CartPole-v1") - env.reset(seed=SEED) - obs, _, _, _, _ = env.step(DISCRETE_ACTION) - observation_shift = 1 +def _record_random_obs_step(self: gym.Env, action): + obs = self.observation_space.sample() + return obs, 0, False, False, {"obs": obs} - env.reset(seed=SEED) - wrapped_env = LambdaObservationV0( - env, lambda observation: observation + observation_shift, None + +def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}): + return options["obs"], {"obs": options["obs"]} + + +def _record_action_obs_step(self: gym.Env, action): + return action, 0, False, False, {"obs": action} + + +def _check_obs( + env: gym.Env, + wrapped_env: gym.Wrapper, + transformed_obs, + original_obs, + strict: bool = True, +): + assert ( + transformed_obs in wrapped_env.observation_space + ), f"{transformed_obs}, {wrapped_env.observation_space}" + assert ( + original_obs in env.observation_space + ), f"{original_obs}, {env.observation_space}" + + if strict: + assert ( + transformed_obs not in env.observation_space + ), f"{transformed_obs}, {env.observation_space}" + assert ( + original_obs not in wrapped_env.observation_space + ), f"{original_obs}, {wrapped_env.observation_space}" + + +def test_lambda_observation_wrapper(): + """Tests lambda observation that the function is applied to both the reset and step observation.""" + env = GenericTestEnv( + reset_func=_record_action_obs_reset, step_func=_record_action_obs_step ) - wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION) + wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3)) - assert np.alltrue(wrapped_obs == obs + observation_shift) + obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)}) + _check_obs(env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32)) + _check_obs(env, wrapped_env, obs, info["obs"]) -def test_lambda_observation_v0_within_vector(): - """Tests lambda observation in vectorized environments. - - Tests if function is correctly applied to environment's observation - in vectorized environment. - """ - env = gym.vector.make( - "CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False - ) - env.reset(seed=SEED) - obs, _, _, _, _ = env.step(np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)])) - - observation_shift = 1 - - env.reset(seed=SEED) - wrapped_env = LambdaObservationV0( - env, lambda observation: observation + observation_shift, None - ) - wrapped_obs, _, _, _, _ = wrapped_env.step( - np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)]) +def test_filter_observation_wrapper(): + """Tests ``FilterObservation`` that the right keys are filtered.""" + dict_env = GenericTestEnv( + observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)), + reset_func=_record_random_obs_reset, + step_func=_record_random_obs_step, ) - assert np.alltrue(wrapped_obs == obs + observation_shift) + wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3")) + obs, info = wrapped_env.reset() + assert list(obs.keys()) == ["arm_1", "arm_3"] + assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] + _check_obs(dict_env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + assert list(obs.keys()) == ["arm_1", "arm_3"] + assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] + _check_obs(dict_env, wrapped_env, obs, info["obs"]) + + # Test tuple environments + tuple_env = GenericTestEnv( + observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))), + reset_func=_record_random_obs_reset, + step_func=_record_random_obs_step, + ) + wrapped_env = FilterObservationV0(tuple_env, (2,)) + + obs, info = wrapped_env.reset() + assert len(obs) == 1 and len(info["obs"]) == 3 + _check_obs(tuple_env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + assert len(obs) == 1 and len(info["obs"]) == 3 + _check_obs(tuple_env, wrapped_env, obs, info["obs"]) + + +def test_flatten_observation_wrapper(): + """Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly.""" + env = GenericTestEnv( + observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)), + reset_func=_record_random_obs_reset, + step_func=_record_random_obs_step, + ) + print(env.observation_space) + wrapped_env = FlattenObservationV0(env) + print(wrapped_env.observation_space) + + obs, info = wrapped_env.reset() + _check_obs(env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + _check_obs(env, wrapped_env, obs, info["obs"]) + + +def test_grayscale_observation_wrapper(): + """Tests the ``GrayscaleObservation`` that the observation is grayscale.""" + env = GenericTestEnv( + observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8), + reset_func=_record_random_obs_reset, + step_func=_record_random_obs_step, + ) + wrapped_env = GrayscaleObservationV0(env) + + obs, info = wrapped_env.reset() + _check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (25, 25) + + obs, _, _, _, info = wrapped_env.step(None) + _check_obs(env, wrapped_env, obs, info["obs"]) + + # Keep_dim + wrapped_env = GrayscaleObservationV0(env, keep_dim=True) + + obs, info = wrapped_env.reset() + _check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (25, 25, 1) + + obs, _, _, _, info = wrapped_env.step(None) + _check_obs(env, wrapped_env, obs, info["obs"]) + + +def test_resize_observation_wrapper(): + """Test the ``ResizeObservation`` that the observation has changed size""" + env = GenericTestEnv( + observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8), + reset_func=_record_random_obs_reset, + step_func=_record_random_obs_step, + ) + wrapped_env = ResizeObservationV0(env, (25, 25)) + + obs, info = wrapped_env.reset() + _check_obs(env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + _check_obs(env, wrapped_env, obs, info["obs"]) + + +def test_reshape_observation_wrapper(): + """Test the ``ReshapeObservation`` wrapper.""" + env = GenericTestEnv( + observation_space=Box(0, 1, shape=(2, 3, 2)), + reset_func=_record_random_obs_reset, + step_func=_record_random_obs_step, + ) + wrapped_env = ReshapeObservationV0(env, (6, 2)) + + obs, info = wrapped_env.reset() + _check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (6, 2) + + obs, _, _, _, info = wrapped_env.step(None) + _check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (6, 2) + + +def test_rescale_observation(): + """Test the ``RescaleObservation`` wrapper""" + env = GenericTestEnv( + observation_space=Box( + np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32) + ), + reset_func=_record_action_obs_reset, + step_func=_record_action_obs_step, + ) + wrapped_env = RescaleObservationV0( + env, + min_obs=np.array([-5, 0], dtype=np.float32), + max_obs=np.array([5, 1], dtype=np.float32), + ) + assert wrapped_env.observation_space == Box( + np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32) + ) + + for sample_obs, expected_obs in ( + ( + np.array([0.5, 2.0], dtype=np.float32), + np.array([0.0, 0.5], dtype=np.float32), + ), + ( + np.array([0.0, 1.0], dtype=np.float32), + np.array([-5.0, 0.0], dtype=np.float32), + ), + ( + np.array([1.0, 3.0], dtype=np.float32), + np.array([5.0, 1.0], dtype=np.float32), + ), + ): + assert sample_obs in env.observation_space + assert expected_obs in wrapped_env.observation_space + + obs, info = wrapped_env.reset(options={"obs": sample_obs}) + assert np.all(obs == expected_obs) + _check_obs(env, wrapped_env, obs, info["obs"], strict=False) + + obs, _, _, _, info = wrapped_env.step(sample_obs) + assert np.all(obs == expected_obs) + _check_obs(env, wrapped_env, obs, info["obs"], strict=False) + + +def test_dtype_observation(): + """Test ``DtypeObservation`` that the""" + env = GenericTestEnv( + reset_func=_record_random_obs_reset, step_func=_record_random_obs_step + ) + wrapped_env = DtypeObservationV0(env, dtype=np.uint8) + + obs, info = wrapped_env.reset() + assert obs.dtype != info["obs"].dtype + assert obs.dtype == np.uint8 + + obs, _, _, _, info = wrapped_env.step(None) + assert obs.dtype != info["obs"].dtype + assert obs.dtype == np.uint8 diff --git a/tests/experimental/wrappers/test_numpy_to_jax.py b/tests/experimental/wrappers/test_numpy_to_jax.py index f687b2177..d5abaa116 100644 --- a/tests/experimental/wrappers/test_numpy_to_jax.py +++ b/tests/experimental/wrappers/test_numpy_to_jax.py @@ -55,7 +55,7 @@ def jax_step_func(self, action): def test_jax_to_numpy(): - jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func) + jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func) # Check that the reset and step for jax environment are as expected obs, info = jax_env.reset() diff --git a/tests/experimental/wrappers/test_rescale_action.py b/tests/experimental/wrappers/test_rescale_action.py deleted file mode 100644 index d689340e8..000000000 --- a/tests/experimental/wrappers/test_rescale_action.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Test suite for RescaleActionV0.""" -import jax -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.experimental.wrappers import RescaleActionV0 - - -SEED = 42 - - -@pytest.mark.parametrize( - ("env", "low", "high", "action", "scaled_action"), - [ - ( - # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) - gym.make("BipedalWalker-v3"), - -0.5, - 0.5, - np.array([1, 1, 1, 1]), - np.array([0.5, 0.5, 0.5, 0.5]), - ), - ( - # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) - gym.make("BipedalWalker-v3"), - -0.5, - 0.5, - jax.numpy.array([1, 1, 1, 1]), - jax.numpy.array([0.5, 0.5, 0.5, 0.5]), - ), - ( - # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) - gym.make("BipedalWalker-v3"), - np.array([-0.5, -0.5, -1, -1], dtype=np.float32), - np.array([0.5, 0.5, 1, 1], dtype=np.float32), - jax.numpy.array([1, 1, 1, 1]), - jax.numpy.array([0.5, 0.5, 1, 1]), - ), - ], -) -def test_rescale_actions_v0_box(env, low, high, action, scaled_action): - """Test action rescaling.""" - env.reset(seed=SEED) - obs, _, _, _, _ = env.step(action) - - env.reset(seed=SEED) - wrapped_env = RescaleActionV0(env, low, high) - - obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action) - - assert np.alltrue(obs == obs_scaled) diff --git a/tests/experimental/wrappers/test_sticky_action.py b/tests/experimental/wrappers/test_stateful_action.py similarity index 84% rename from tests/experimental/wrappers/test_sticky_action.py rename to tests/experimental/wrappers/test_stateful_action.py index d04667296..3bc1254e8 100644 --- a/tests/experimental/wrappers/test_sticky_action.py +++ b/tests/experimental/wrappers/test_stateful_action.py @@ -17,7 +17,9 @@ def step_fn(self, action): def test_sticky_action(): - env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5) + env = StickyActionV0( + GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5 + ) env.reset(seed=SEED) env.action_space.seed(SEED) @@ -34,7 +36,7 @@ def test_sticky_action(): previous_action = input_action -@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5]) +@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5]) def test_sticky_action_raise(repeat_action_probability): with pytest.raises(InvalidProbability): StickyActionV0( diff --git a/tests/experimental/wrappers/test_stateful_observation.py b/tests/experimental/wrappers/test_stateful_observation.py new file mode 100644 index 000000000..fedbf19c6 --- /dev/null +++ b/tests/experimental/wrappers/test_stateful_observation.py @@ -0,0 +1,89 @@ +"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation.""" + +import numpy as np + +import gymnasium as gym +from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0 +from gymnasium.spaces import Box, Dict, Tuple +from tests.testing_env import GenericTestEnv + + +NUM_STEPS = 20 +SEED = 0 + +DELAY = 3 + + +def test_time_aware_observation_wrapper(): + """Tests the time aware observation wrapper.""" + # Test the environment observation space with Dict, Tuple and other + env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3))) + wrapped_env = TimeAwareObservationV0(env) + assert isinstance(wrapped_env.observation_space, Dict) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}" + + env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3)))) + wrapped_env = TimeAwareObservationV0(env) + assert isinstance(wrapped_env.observation_space, Tuple) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert len(reset_obs) == 3 and len(step_obs) == 3 + + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservationV0(env) + assert isinstance(wrapped_env.observation_space, Dict) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert isinstance(reset_obs, dict) and isinstance(step_obs, dict) + assert "obs" in reset_obs and "obs" in step_obs + assert "time" in reset_obs and "time" in step_obs + + # Tests the flatten parameter + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservationV0(env, flatten=True) + assert isinstance(wrapped_env.observation_space, Box) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert reset_obs.shape == (2,) and step_obs.shape == (2,) + + # Tests the normalize_time parameter + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservationV0(env, normalize_time=False) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert reset_obs["time"] == 100 and step_obs["time"] == 99 + + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservationV0(env, normalize_time=True) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01 + + +def test_delay_observation_wrapper(): + env = gym.make("CartPole-v1") + env.action_space.seed(SEED) + env.reset(seed=SEED) + + undelayed_observations = [] + for _ in range(NUM_STEPS): + obs, _, _, _, _ = env.step(env.action_space.sample()) + undelayed_observations.append(obs) + + env = DelayObservationV0(env, delay=DELAY) + env.action_space.seed(SEED) + env.reset(seed=SEED) + + delayed_observations = [] + for i in range(NUM_STEPS): + obs, _, _, _, _ = env.step(env.action_space.sample()) + delayed_observations.append(obs) + if i < DELAY - 1: + assert np.all(obs == 0) + + undelayed_observations = np.array(undelayed_observations) + delayed_observations = np.array(delayed_observations) + + assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY]) diff --git a/tests/experimental/wrappers/test_time_aware_observation.py b/tests/experimental/wrappers/test_time_aware_observation.py deleted file mode 100644 index d63210eea..000000000 --- a/tests/experimental/wrappers/test_time_aware_observation.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Test suite for TimeAwareobservationV0.""" - -from collections import OrderedDict - -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.experimental.wrappers import TimeAwareObservationV0 -from gymnasium.spaces import Box, Dict - - -NUM_STEPS = 20 -SEED = 0 - - -@pytest.mark.parametrize( - "env", - [ - gym.make("CartPole-v1", disable_env_checker=True), - gym.make("CarRacing-v2", disable_env_checker=True), - ], -) -def test_time_aware_observation_creation(env): - """Test TimeAwareObservationV0 wrapper creation. - - This test checks if wrapped env with TimeAwareObservationV0 - is correctly created. - """ - wrapped_env = TimeAwareObservationV0(env) - obs, _ = wrapped_env.reset() - - assert isinstance(wrapped_env.observation_space, Dict) - assert isinstance(obs, OrderedDict) - assert np.all(obs["time"] == 0) - assert env.observation_space == wrapped_env.observation_space["obs"] - - -@pytest.mark.parametrize("normalize_time", [True, False]) -@pytest.mark.parametrize("flatten", [False, True]) -@pytest.mark.parametrize( - "env", - [ - gym.make("CartPole-v1", disable_env_checker=True), - gym.make("CarRacing-v2", disable_env_checker=True, continuous=False), - ], -) -def test_time_aware_observation_step(env, flatten, normalize_time): - """Test TimeAwareObservationV0 step. - - This test checks if wrapped env with TimeAwareObservationV0 - steps correctly. - """ - env.action_space.seed(SEED) - max_timesteps = env._max_episode_steps - - wrapped_env = TimeAwareObservationV0( - env, flatten=flatten, normalize_time=normalize_time - ) - wrapped_env.reset(seed=SEED) - - for timestep in range(1, NUM_STEPS): - action = env.action_space.sample() - observation, _, terminated, _, _ = wrapped_env.step(action) - - expected_time_obs = ( - timestep / max_timesteps if normalize_time else max_timesteps - timestep - ) - - if flatten: - assert np.allclose(observation[-1], expected_time_obs) - else: - assert np.allclose(observation["time"], expected_time_obs) - - if terminated: - break - - -@pytest.mark.parametrize( - "env", - [ - gym.make("CartPole-v1", disable_env_checker=True), - gym.make("CarRacing-v2", disable_env_checker=True), - ], -) -def test_time_aware_observation_creation_flatten(env): - """Test TimeAwareObservationV0 wrapper creation with `flatten=True`. - - This test checks if wrapped env with TimeAwareObservationV0 - is correctly created when the `flatten` parameter is set to `True`. - When flattened, the observation space should be a 1 dimension `Box` - with time appended to the end. - """ - wrapped_env = TimeAwareObservationV0(env, flatten=True) - obs, _ = wrapped_env.reset() - - assert isinstance(wrapped_env.observation_space, Box) - assert isinstance(obs, np.ndarray) - assert env.observation_space == wrapped_env.time_aware_observation_space["obs"] diff --git a/tests/experimental/wrappers/test_torch_to_jax.py b/tests/experimental/wrappers/test_torch_to_jax.py index cb6cdf030..c2e524dc5 100644 --- a/tests/experimental/wrappers/test_torch_to_jax.py +++ b/tests/experimental/wrappers/test_torch_to_jax.py @@ -9,6 +9,7 @@ from tests.testing_env import GenericTestEnv def torch_data_equivalence(data_1, data_2) -> bool: + """Return if two variables are equivalent that might contain ``torch.Tensor``.""" if type(data_1) == type(data_2): if isinstance(data_1, dict): return data_1.keys() == data_2.keys() and all( @@ -56,14 +57,15 @@ def torch_data_equivalence(data_1, data_2) -> bool: ) def test_roundtripping(value, expected_value): """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" - assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value) + roundtripped_value = jax_to_torch(torch_to_jax(value)) + assert torch_data_equivalence(roundtripped_value, expected_value) -def jax_reset_func(self, seed=None, options=None): +def _jax_reset_func(self, seed=None, options=None): return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])} -def jax_step_func(self, action): +def _jax_step_func(self, action): assert isinstance(action, jnp.DeviceArray), type(action) return ( jnp.array([1, 2, 3]), @@ -75,7 +77,7 @@ def jax_step_func(self, action): def test_jax_to_torch(): - env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func) + env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func) # Check that the reset and step for jax environment are as expected obs, info = env.reset() diff --git a/tests/test_core.py b/tests/test_core.py index 92ad59df8..432fde48f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -278,7 +278,7 @@ def test_wrapper_types(): obs, _, _, _, _ = observation_env.step(0) assert obs == np.array([1]) - env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {})) + env = GenericTestEnv(step_func=lambda self, action: (action, 0, False, False, {})) action_env = ExampleActionWrapper(env) obs, _, _, _, _ = action_env.step(0) assert obs == np.array([1]) diff --git a/tests/testing_env.py b/tests/testing_env.py index 8e7458168..dfebba114 100644 --- a/tests/testing_env.py +++ b/tests/testing_env.py @@ -8,7 +8,7 @@ from gymnasium.core import ActType, ObsType from gymnasium.envs.registration import EnvSpec -def basic_reset_fn( +def basic_reset_func( self, *, seed: Optional[int] = None, @@ -20,17 +20,17 @@ def basic_reset_fn( return self.observation_space.sample(), {"options": options} -def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: +def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """A step function that follows the new step api that will pass the environment check using random actions from the observation space.""" return self.observation_space.sample(), 0, False, False, {} -def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: +def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: """A step function that follows the old step api that will pass the environment check using random actions from the observation space.""" return self.observation_space.sample(), 0, False, {} -def basic_render_fn(self): +def basic_render_func(self): """Basic render fn that does nothing.""" pass @@ -43,12 +43,14 @@ class GenericTestEnv(gym.Env): self, action_space: spaces.Space = spaces.Box(0, 1, (1,)), observation_space: spaces.Space = spaces.Box(0, 1, (1,)), - reset_fn: callable = basic_reset_fn, - step_fn: callable = new_step_fn, - render_fn: callable = basic_render_fn, + reset_func: callable = basic_reset_func, + step_func: callable = new_step_func, + render_func: callable = basic_render_func, metadata: Dict[str, Any] = {"render_modes": []}, render_mode: Optional[str] = None, - spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"), + spec: EnvSpec = EnvSpec( + "TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100 + ), ): self.metadata = metadata self.render_mode = render_mode @@ -59,12 +61,12 @@ class GenericTestEnv(gym.Env): if action_space is not None: self.action_space = action_space - if reset_fn is not None: - self.reset = types.MethodType(reset_fn, self) - if step_fn is not None: - self.step = types.MethodType(step_fn, self) - if render_fn is not None: - self.render = types.MethodType(render_fn, self) + if reset_func is not None: + self.reset = types.MethodType(reset_func, self) + if step_func is not None: + self.step = types.MethodType(step_func, self) + if render_func is not None: + self.render = types.MethodType(render_func, self) def reset( self, diff --git a/tests/utils/test_env_checker.py b/tests/utils/test_env_checker.py index eaed92273..498ede073 100644 --- a/tests/utils/test_env_checker.py +++ b/tests/utils/test_env_checker.py @@ -112,10 +112,10 @@ def test_check_reset_seed(test, func: callable, message: str): with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): - check_reset_seed(GenericTestEnv(reset_fn=func)) + check_reset_seed(GenericTestEnv(reset_func=func)) else: with pytest.raises(test, match=f"^{re.escape(message)}$"): - check_reset_seed(GenericTestEnv(reset_fn=func)) + check_reset_seed(GenericTestEnv(reset_func=func)) def _deprecated_return_info( @@ -179,7 +179,7 @@ def test_check_reset_return_type(test, func: callable, message: str): """Tests the check `env.reset()` function has a correct return type.""" with pytest.raises(test, match=f"^{re.escape(message)}$"): - check_reset_return_type(GenericTestEnv(reset_fn=func)) + check_reset_return_type(GenericTestEnv(reset_func=func)) @pytest.mark.parametrize( @@ -198,7 +198,7 @@ def test_check_reset_return_info_deprecation(test, func: callable, message: str) """Tests that return_info has been correct deprecated as an argument to `env.reset()`.""" with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"): - check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func)) + check_reset_return_info_deprecation(GenericTestEnv(reset_func=func)) def test_check_seed_deprecation(): @@ -236,7 +236,7 @@ def test_check_reset_options(): "The `reset` method does not provide an `options` or `**kwargs` keyword argument" ), ): - check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {}))) + check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {}))) @pytest.mark.parametrize( diff --git a/tests/utils/test_passive_env_checker.py b/tests/utils/test_passive_env_checker.py index 5c60ddcce..96710ca40 100644 --- a/tests/utils/test_passive_env_checker.py +++ b/tests/utils/test_passive_env_checker.py @@ -303,11 +303,11 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): - env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs) + env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs) else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): - env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs) + env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs) assert len(caught_warnings) == 0 @@ -383,11 +383,11 @@ def test_passive_env_step_checker( with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): - env_step_passive_checker(GenericTestEnv(step_fn=func), 0) + env_step_passive_checker(GenericTestEnv(step_func=func), 0) else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): - env_step_passive_checker(GenericTestEnv(step_fn=func), 0) + env_step_passive_checker(GenericTestEnv(step_func=func), 0) assert len(caught_warnings) == 0, caught_warnings @@ -416,7 +416,7 @@ def test_passive_env_step_checker( GenericTestEnv( metadata={"render_modes": ["Testing mode"], "render_fps": None}, render_mode="Testing mode", - render_fn=lambda self: 0, + render_func=lambda self: 0, ), "No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.", ], diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py index b8e5bef4d..39a10c249 100644 --- a/tests/utils/test_play.py +++ b/tests/utils/test_play.py @@ -21,7 +21,7 @@ IRRELEVANT_KEY = 1 PlayableEnv = partial( GenericTestEnv, metadata={"render_modes": ["rgb_array"]}, - render_fn=lambda self: np.ones((10, 10, 3)), + render_func=lambda self: np.ones((10, 10, 3)), ) diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 801ccb98e..16ad5eba6 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -82,8 +82,8 @@ def test_final_obs_info(vectoriser): return GenericTestEnv( action_space=Discrete(4), observation_space=Discrete(4), - reset_fn=reset_fn, - step_fn=lambda self, action: ( + reset_func=reset_fn, + step_func=lambda self, action: ( action if action < 3 else 0, 0, action >= 3, diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index b451f528f..3c8121588 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -3,7 +3,7 @@ import pytest from gymnasium.spaces import Box, Discrete from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility -from tests.testing_env import GenericTestEnv, old_step_fn +from tests.testing_env import GenericTestEnv, old_step_func class AleTesting: @@ -34,7 +34,7 @@ class AtariTestingEnv(GenericTestEnv): low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1 ), action_space=Discrete(3, seed=1), - step_fn=old_step_fn, + step_func=old_step_func, ) self.ale = AleTesting() diff --git a/tests/wrappers/test_passive_env_checker.py b/tests/wrappers/test_passive_env_checker.py index fbc5c1cec..f62c09d63 100644 --- a/tests/wrappers/test_passive_env_checker.py +++ b/tests/wrappers/test_passive_env_checker.py @@ -68,8 +68,8 @@ def _step_failure(self, action): def test_api_failures(): env = GenericTestEnv( - reset_fn=_reset_failure, - step_fn=_step_failure, + reset_func=_reset_failure, + step_func=_step_failure, metadata={"render_modes": "error"}, ) env = PassiveEnvChecker(env)