Update experimental wrappers (#176)

This commit is contained in:
Mark Towers
2022-12-05 19:14:56 +00:00
committed by GitHub
parent 1a381bcd0d
commit 848b7097bf
39 changed files with 1140 additions and 806 deletions

View File

@@ -22,9 +22,9 @@ repos:
hooks: hooks:
- id: codespell - id: codespell
args: args:
- --ignore-words-list=nd,reacher,thist,ths, ure, referenc,wile - --ignore-words-list=nd,reacher,thist,ths,ure,referenc,wile
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 6.0.0 rev: 5.0.4
hooks: hooks:
- id: flake8 - id: flake8
args: args:

View File

@@ -27,8 +27,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided. * In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
### Lambda Observation Wrappers ### Observation Wrappers
```{eval-rst} ```{eval-rst}
.. py:currentmodule:: gymnasium .. py:currentmodule:: gymnasium
@@ -44,61 +43,60 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- VectorLambdaObservation - VectorLambdaObservation
- No - No
* - :class:`wrappers.FilterObservation` * - :class:`wrappers.FilterObservation`
- :class:`experimental.wrappers.FilterObservation` - :class:`experimental.wrappers.FilterObservationV0`
- VectorFilterObservation (*) - VectorFilterObservation (*)
- Yes - Yes
* - :class:`wrappers.FlattenObservation` * - :class:`wrappers.FlattenObservation`
- `:class:`experimental.wrappers.FlattenObservation` - :class:`experimental.wrappers.FlattenObservationV0`
- VectorFlattenObservation (*) - VectorFlattenObservation (*)
- No - No
* - :class:`wrappers.GrayScaleObservation` * - :class:`wrappers.GrayScaleObservation`
- `:class:`experimental.wrappers.GrayscaleObservation` - :class:`experimental.wrappers.GrayscaleObservationV0`
- VectorGrayscaleObservation (*) - VectorGrayscaleObservation (*)
- Yes - Yes
* - :class:`wrappers.ResizeObservation` * - :class:`wrappers.ResizeObservation`
- :class:`experimental.wrappers.ResizeObservation` - :class:`experimental.wrappers.ResizeObservationV0`
- VectorResizeObservation (*) - VectorResizeObservation (*)
- Yes - Yes
* - Not Implemented * - Not Implemented
- :class:`experimental.wrappers.ReshapeObservation` - :class:`experimental.wrappers.ReshapeObservationV0`
- VectorReshapeObservation (*) - VectorReshapeObservation (*)
- Yes - Yes
* - Not Implemented * - Not Implemented
- :class:`experimental.wrappers.RescaleObservation` - :class:`experimental.wrappers.RescaleObservationV0`
- VectorRescaleObservation (*) - VectorRescaleObservation (*)
- Yes - Yes
* - Not Implemented * - Not Implemented
- :class:`experimental.wrappers.DtypeObservation` - :class:`experimental.wrappers.DtypeObservationV0`
- VectorDtypeObservation (*) - VectorDtypeObservation (*)
- Yes - Yes
* - :class:`wrappers.PixelObservationWrapper` * - :class:`wrappers.PixelObservationWrapper`
- PixelObservation - PixelObservation
- VectorPixelObservation - VectorPixelObservation
- No - No
* - :class:`NormalizeObservation` * - :class:`wrappers.NormalizeObservation`
- NormalizeObservation - NormalizeObservation
- VectorNormalizeObservation - VectorNormalizeObservation
- No - No
* - :class:`TimeAwareObservation` * - :class:`wrappers.TimeAwareObservation`
- TimeAwareObservation - :class:`experimental.wrappers.TimeAwareObservationV0`
- VectorTimeAwareObservation - VectorTimeAwareObservation
- No - No
* - :class:`FrameStack` * - :class:`wrappers.FrameStack`
- FrameStackObservation - FrameStackObservation
- VectorFrameStackObservation - VectorFrameStackObservation
- No - No
* - Not Implemented * - Not Implemented
- DelayObservation - :class:`experimental.wrappers.DelayObservationV0`
- VectorDelayObservation - VectorDelayObservation
- No - No
* - :class:`AtariPreprocessing` * - :class:`wrappers.AtariPreprocessing`
- AtariPreprocessing - AtariPreprocessing
- Not Implemented - Not Implemented
- No - No
``` ```
### Lambda Action Wrappers ### Action Wrappers
```{eval-rst} ```{eval-rst}
.. py:currentmodule:: gymnasium .. py:currentmodule:: gymnasium
@@ -114,25 +112,20 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- VectorLambdaAction - VectorLambdaAction
- No - No
* - :class:`wrappers.ClipAction` * - :class:`wrappers.ClipAction`
- ClipAction - :class:`experimental.wrappers.ClipActionV0`
- VectorClipAction (*) - VectorClipAction (*)
- Yes - Yes
* - :class:`wrappers.RescaleAction` * - :class:`wrappers.RescaleAction`
- RescaleAction - :class:`experimental.wrappers.RescaleActionV0`
- VectorRescaleAction (*) - VectorRescaleAction (*)
- Yes - Yes
* - Not Implemented * - Not Implemented
- NanAction - :class:`experimental.wrappers.StickyActionV0`
- VectorNanAction (*)
- Yes
* - Not Implemented
- StickyAction
- VectorStickyAction - VectorStickyAction
- No - No
``` ```
### Lambda Reward Wrappers ### Reward Wrappers
```{eval-rst} ```{eval-rst}
.. py:currentmodule:: gymnasium .. py:currentmodule:: gymnasium
@@ -175,7 +168,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- VectorPassiveEnvChecker - VectorPassiveEnvChecker
* - :class:`wrappers.OrderEnforcing` * - :class:`wrappers.OrderEnforcing`
- OrderEnforcing - OrderEnforcing
- VectorOrderEnforcing (*) - VectorOrderEnforcing
* - :class:`wrappers.EnvCompatibility` * - :class:`wrappers.EnvCompatibility`
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_ - Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
- Not Implemented - Not Implemented
@@ -189,10 +182,10 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- HumanRendering - HumanRendering
- Not Implemented - Not Implemented
* - Not Implemented * - Not Implemented
- :class:`experimental.wrappers.JaxToNumpy` - :class:`experimental.wrappers.JaxToNumpyV0`
- VectorJaxToNumpy (*) - VectorJaxToNumpy (*)
* - Not Implemented * - Not Implemented
- :class:`experimental.wrappers.JaxToTorch` - :class:`experimental.wrappers.JaxToTorchV0`
- VectorJaxToTorch (*) - VectorJaxToTorch (*)
``` ```

View File

@@ -12,9 +12,6 @@ title: Functional
.. autofunction:: gymnasium.experimental.FuncEnv.initial .. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.transition .. autofunction:: gymnasium.experimental.FuncEnv.transition
.. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.observation .. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.reward .. autofunction:: gymnasium.experimental.FuncEnv.reward
.. autofunction:: gymnasium.experimental.FuncEnv.terminal .. autofunction:: gymnasium.experimental.FuncEnv.terminal
@@ -33,4 +30,8 @@ title: Functional
```{eval-rst} ```{eval-rst}
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv ... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.reset
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.step
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.render
``` ```

View File

@@ -1,6 +1,6 @@
# Wrappers # Wrappers
## Lambda Observation Wrappers ## Observation Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0 .. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
@@ -11,24 +11,6 @@
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0 .. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0 .. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0 .. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
```
## Lambda Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
```
## Lambda Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
```
## Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0 .. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0 .. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
``` ```
@@ -36,11 +18,22 @@
## Action Wrappers ## Action Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0
.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0 .. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
``` ```
# Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
```
## Common Wrappers ## Common Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
``` ```

View File

@@ -27,7 +27,7 @@ __all__ = [
"register", "register",
"registry", "registry",
"pprint_registry", "pprint_registry",
# root files # module folders
"envs", "envs",
"spaces", "spaces",
"utils", "utils",

View File

@@ -1,2 +1,2 @@
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
from gymnasium.envs.phys2d.pendulum import PendulumFunctional from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv

View File

@@ -1,6 +1,7 @@
"""Root __init__ of the gym experimental wrappers.""" """Root __init__ of the gym experimental wrappers."""
from gymnasium.experimental import functional, wrappers
from gymnasium.experimental.functional import FuncEnv from gymnasium.experimental.functional import FuncEnv
@@ -8,6 +9,8 @@ __all__ = [
# Functional # Functional
"FuncEnv", "FuncEnv",
"functional", "functional",
# Wrapper # Wrappers
"wrappers", "wrappers",
# Vector
# "vector",
] ]

View File

@@ -10,31 +10,59 @@ from gymnasium.experimental.wrappers.lambda_action import (
ClipActionV0, ClipActionV0,
RescaleActionV0, RescaleActionV0,
) )
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0 from gymnasium.experimental.wrappers.lambda_observations import (
LambdaObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
ResizeObservationV0,
ReshapeObservationV0,
RescaleObservationV0,
DtypeObservationV0,
)
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0 from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0 from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0 from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
from gymnasium.experimental.wrappers.time_aware_observation import ( from gymnasium.experimental.wrappers.stateful_observation import (
TimeAwareObservationV0, TimeAwareObservationV0,
DelayObservationV0,
) )
from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0
__all__ = [ __all__ = [
"ArgType", # --- Observation wrappers ---
# Lambda Action "LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
# "PixelObservationV0",
# "NormalizeObservationV0",
"TimeAwareObservationV0",
# "FrameStackV0",
"DelayObservationV0",
# "AtariPreprocessingV0"
# --- Action Wrappers ---
"LambdaActionV0", "LambdaActionV0",
"StickyActionV0",
"ClipActionV0", "ClipActionV0",
"RescaleActionV0", "RescaleActionV0",
# Lambda Observation # "NanAction",
"LambdaObservationV0", "StickyActionV0",
"DelayObservationV0", # --- Reward wrappers ---
"TimeAwareObservationV0",
# Lambda Reward
"LambdaRewardV0", "LambdaRewardV0",
"ClipRewardV0", "ClipRewardV0",
# Jax conversion wrappers # "RescaleRewardV0",
# "NormalizeRewardV0",
# --- Common ---
# "AutoReset",
# "PassiveEnvChecker",
# "OrderEnforcing",
# "RecordEpisodeStatistics",
# "RenderCollection",
# "HumanRendering",
"JaxToNumpyV0", "JaxToNumpyV0",
"JaxToTorchV0", "JaxToTorchV0",
] ]

View File

@@ -1,35 +0,0 @@
"""Wrapper for delaying the returned observation."""
from collections import deque
import jumpy as jp
import gymnasium as gym
from gymnasium.core import ObsType
class DelayObservationV0(gym.ObservationWrapper):
"""Wrapper which adds a delay to the returned observation."""
def __init__(self, env: gym.Env, delay: int):
"""Initialize the DelayObservation wrapper.
Args:
env (Env): the wrapped environment
delay (int): number of timesteps for delaying the observation.
Before reaching the `delay` number of timesteps,
returned observation is an array of zeros with the
same shape of the observation space.
"""
super().__init__(env)
self.delay = delay
self.observation_queue = deque()
def observation(self, observation: ObsType) -> ObsType:
"""Return the delayed observation."""
self.observation_queue.append(observation)
if len(self.observation_queue) > self.delay:
return self.observation_queue.popleft()
return jp.zeros_like(observation)

View File

@@ -1,13 +1,19 @@
"""Lambda action wrapper which apply a function to the provided action.""" """A collection of wrappers that all use the LambdaAction class.
from typing import Any, Callable, Union
* ``LambdaAction`` - Transforms the actions based on a function
* ``ClipAction`` - Clips the action within a bounds
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
"""
from __future__ import annotations
from typing import Callable
import jumpy as jp import jumpy as jp
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium import spaces from gymnasium.core import ActType, WrapperActType
from gymnasium.core import ActType from gymnasium.spaces import Box, Space
from gymnasium.experimental.wrappers import ArgType
class LambdaActionV0(gym.ActionWrapper): class LambdaActionV0(gym.ActionWrapper):
@@ -16,19 +22,23 @@ class LambdaActionV0(gym.ActionWrapper):
def __init__( def __init__(
self, self,
env: gym.Env, env: gym.Env,
func: Callable[[ArgType], Any], func: Callable[[WrapperActType], ActType],
action_space: Space | None,
): ):
"""Initialize LambdaAction. """Initialize LambdaAction.
Args: Args:
env (Env): The gymnasium environment env: The gymnasium environment
func (Callable): function to apply to action func: Function to apply to ``step`` ``action``
action_space: The updated action space of the wrapper given the function.
""" """
super().__init__(env) super().__init__(env)
if action_space is not None:
self.action_space = action_space
self.func = func self.func = func
def action(self, action: ActType) -> Any: def action(self, action: WrapperActType) -> ActType:
"""Apply function to action.""" """Apply function to action."""
return self.func(action) return self.func(action)
@@ -53,14 +63,19 @@ class ClipActionV0(LambdaActionV0):
Args: Args:
env: The environment to apply the wrapper env: The environment to apply the wrapper
""" """
assert isinstance(env.action_space, spaces.Box) assert isinstance(env.action_space, Box)
super().__init__( super().__init__(
env, env,
lambda action: jp.clip(action, env.action_space.low, env.action_space.high), lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
Box(
-np.inf,
np.inf,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
) )
self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape)
class RescaleActionV0(LambdaActionV0): class RescaleActionV0(LambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. """Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
@@ -86,8 +101,8 @@ class RescaleActionV0(LambdaActionV0):
def __init__( def __init__(
self, self,
env: gym.Env, env: gym.Env,
min_action: Union[float, int, np.ndarray], min_action: float | int | np.ndarray,
max_action: Union[float, int, np.ndarray], max_action: float | int | np.ndarray,
): ):
"""Initializes the :class:`RescaleAction` wrapper. """Initializes the :class:`RescaleAction` wrapper.
@@ -96,28 +111,44 @@ class RescaleActionV0(LambdaActionV0):
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
""" """
assert isinstance( assert isinstance(env.action_space, Box)
env.action_space, spaces.Box assert not np.any(env.action_space.low == np.inf) and not np.any(
), f"expected Box action space, got {type(env.action_space)}" env.action_space.high == np.inf
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
low = env.action_space.low
high = env.action_space.high
self.min_action = np.full(
env.action_space.shape, min_action, dtype=env.action_space.dtype
) )
self.max_action = np.full(
env.action_space.shape, max_action, dtype=env.action_space.dtype if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)
assert min_action.shape == env.action_space.shape
assert not np.any(min_action == np.inf)
if not isinstance(max_action, np.ndarray):
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
max_action = np.full(env.action_space.shape, max_action)
assert max_action.shape == env.action_space.shape
assert not np.any(max_action == np.inf)
assert isinstance(env.action_space, Box)
assert np.all(np.less_equal(min_action, max_action))
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (env.action_space.high - env.action_space.low) / (
max_action - min_action
) )
intercept = gradient * -min_action + env.action_space.low
super().__init__( super().__init__(
env, env,
lambda action: jp.clip( lambda action: gradient * action + intercept,
low Box(
+ (high - low) low=min_action,
* ((action - self.min_action) / (self.max_action - self.min_action)), high=max_action,
low, shape=env.action_space.shape,
high, dtype=env.action_space.dtype,
), ),
) )

View File

@@ -1,17 +1,27 @@
"""Lambda observation wrappers which apply a function to the observation.""" """A collection of observation wrappers using a lambda function.
* ``LambdaObservation`` - Transforms the observation with a function
* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
* ``FlattenObservation`` - Flattens the observations
* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation
* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation)
* ``ReshapeObservation`` - Reshapes an array-based observation
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservation`` - Convert a observation dtype
"""
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Sequence from typing import Any, Callable, Sequence
from typing_extensions import Final
import jumpy as jp import jumpy as jp
import numpy as np import numpy as np
import numpy.typing as npt
import gymnasium as gym import gymnasium as gym
from gymnasium import spaces from gymnasium import spaces
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import utils from gymnasium.spaces import Box, utils
class LambdaObservationV0(gym.ObservationWrapper): class LambdaObservationV0(gym.ObservationWrapper):
@@ -71,32 +81,82 @@ class FilterObservationV0(LambdaObservationV0):
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {}) ({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {})
""" """
def __init__(self, env: gym.Env, filter_keys: Sequence[str]): def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys.""" """Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
if not isinstance(env.observation_space, spaces.Dict): assert isinstance(filter_keys, Sequence)
# Filters for dictionary space
if isinstance(env.observation_space, spaces.Dict):
assert all(isinstance(key, str) for key in filter_keys)
if any(
key not in env.observation_space.spaces.keys() for key in filter_keys
):
missing_keys = [
key
for key in filter_keys
if key not in env.observation_space.spaces.keys()
]
raise ValueError(
"All the `filter_keys` must be included in the observation space.\n"
f"Filter keys: {filter_keys}\n"
f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
f"Missing keys: {missing_keys}"
)
new_observation_space = spaces.Dict(
{key: env.observation_space[key] for key in filter_keys}
)
if len(new_observation_space) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
super().__init__(
env,
lambda obs: {key: obs[key] for key in filter_keys},
new_observation_space,
)
# Filter for tuple observation
elif isinstance(env.observation_space, spaces.Tuple):
assert all(isinstance(key, int) for key in filter_keys)
assert len(set(filter_keys)) == len(
filter_keys
), f"Duplicate keys exist, filter_keys: {filter_keys}"
if any(
0 < key and key >= len(env.observation_space) for key in filter_keys
):
missing_index = [
key
for key in filter_keys
if 0 < key and key >= len(env.observation_space)
]
raise ValueError(
"All the `filter_keys` must be included in the length of the observation space.\n"
f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
f"missing indexes: {missing_index}"
)
new_observation_spaces = spaces.Tuple(
env.observation_space[key] for key in filter_keys
)
if len(new_observation_spaces) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
super().__init__(
env,
lambda obs: tuple(obs[key] for key in filter_keys),
new_observation_spaces,
)
else:
raise ValueError( raise ValueError(
f"FilterObservation wrapper is only usable with dict observations, actual type: {type(env.observation_space)}" f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}"
) )
if any(key not in env.observation_space.keys() for key in filter_keys): self.filter_keys: Final[Sequence[str | int]] = filter_keys
missing_keys = [
key for key in filter_keys if key not in env.observation_space.keys()
]
raise ValueError(
"All the filter_keys must be included in the original observation space.\n"
f"Filter keys: {filter_keys}\n"
f"Observation keys: {list(env.observation_space.keys())}\n"
f"Missing keys: {missing_keys}"
)
new_observation_space = spaces.Dict(
{key: env.observation_space[key] for key in filter_keys}
)
super().__init__(
env,
lambda obs: {key: obs[key] for key in filter_keys},
new_observation_space,
)
class FlattenObservationV0(LambdaObservationV0): class FlattenObservationV0(LambdaObservationV0):
@@ -117,9 +177,10 @@ class FlattenObservationV0(LambdaObservationV0):
def __init__(self, env: gym.Env): def __init__(self, env: gym.Env):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.""" """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
flattened_space = utils.flatten_space(env.observation_space)
super().__init__( super().__init__(
env, lambda obs: utils.flatten(flattened_space, obs), flattened_space env,
lambda obs: utils.flatten(env.observation_space, obs),
utils.flatten_space(env.observation_space),
) )
@@ -154,7 +215,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
and env.observation_space.dtype == np.uint8 and env.observation_space.dtype == np.uint8
) )
self.keep_dim = keep_dim self.keep_dim: Final[bool] = keep_dim
if keep_dim: if keep_dim:
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=0, low=0,
@@ -167,7 +228,8 @@ class GrayscaleObservationV0(LambdaObservationV0):
lambda obs: jp.expand_dims( lambda obs: jp.expand_dims(
jp.sum( jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1 jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
) ).astype(np.uint8),
axis=-1,
), ),
new_observation_space, new_observation_space,
) )
@@ -179,7 +241,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
env, env,
lambda obs: jp.sum( lambda obs: jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1 jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
), ).astype(np.uint8),
new_observation_space, new_observation_space,
) )
@@ -215,7 +277,7 @@ class ResizeObservationV0(LambdaObservationV0):
"opencv is not install, run `pip install gymnasium[other]`" "opencv is not install, run `pip install gymnasium[other]`"
) )
self.shape = tuple(shape) self.shape: Final[tuple[int, ...]] = tuple(shape)
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=0, high=255, shape=self.shape + env.observation_space.shape[2:] low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
@@ -237,7 +299,7 @@ class ReshapeObservationV0(LambdaObservationV0):
assert isinstance(shape, tuple) assert isinstance(shape, tuple)
assert all(np.issubdtype(type(elem), np.integer) for elem in shape) assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
assert all(x > 0 for x in shape) assert all(x > 0 or x == -1 for x in shape)
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=np.reshape(np.ravel(env.observation_space.low), shape), low=np.reshape(np.ravel(env.observation_space.low), shape),
@@ -245,9 +307,8 @@ class ReshapeObservationV0(LambdaObservationV0):
shape=shape, shape=shape,
dtype=env.observation_space.dtype, dtype=env.observation_space.dtype,
) )
super().__init__( self.shape = shape
env, lambda obs: jp.reshape(obs, self.shape), new_observation_space super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
)
class RescaleObservationV0(LambdaObservationV0): class RescaleObservationV0(LambdaObservationV0):
@@ -256,18 +317,23 @@ class RescaleObservationV0(LambdaObservationV0):
def __init__( def __init__(
self, self,
env: gym.Env, env: gym.Env,
min_obs: tuple[np.floating, np.integer, np.ndarray], min_obs: np.floating | np.integer | np.ndarray,
max_obs: tuple[np.floating, np.integer, np.ndarray], max_obs: np.floating | np.integer | np.ndarray,
): ):
"""Constructor that requires the env observation spaces to be a :class:`Box`.""" """Constructor that requires the env observation spaces to be a :class:`Box`."""
assert isinstance(env.observation_space, spaces.Box) assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
)
if not isinstance(min_obs, np.ndarray): if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype( assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating type(max_obs), np.floating
) )
min_obs = np.full(env.observation_space.shape, min_obs) min_obs = np.full(env.observation_space.shape, min_obs)
assert min_obs.shape == env.observation_space.shape assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
assert not np.any(min_obs == np.inf) assert not np.any(min_obs == np.inf)
if not isinstance(max_obs, np.ndarray): if not isinstance(max_obs, np.ndarray):
@@ -278,52 +344,66 @@ class RescaleObservationV0(LambdaObservationV0):
assert max_obs.shape == env.observation_space.shape assert max_obs.shape == env.observation_space.shape
assert not np.any(max_obs == np.inf) assert not np.any(max_obs == np.inf)
env_low = env.observation_space.low self.min_obs = min_obs
env_high = env.observation_space.high self.max_obs = max_obs
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (max_obs - min_obs) / (
env.observation_space.high - env.observation_space.low
)
intercept = gradient * -env.observation_space.low + min_obs
new_observation_space = spaces.Box(low=min_obs, high=max_obs)
super().__init__( super().__init__(
env, env,
lambda obs: env_low lambda obs: gradient * obs + intercept,
+ (env_high - env_low) * ((obs - min_obs) / (max_obs - min_obs)), Box(
new_observation_space, low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
),
) )
class DtypeObservationV0(LambdaObservationV0): class DtypeObservationV0(LambdaObservationV0):
"""Observation wrapper for transforming the dtype of an observation.""" """Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: gym.Env, dtype: npt.DTypeLike): def __init__(self, env: gym.Env, dtype: Any):
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces.""" """Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
assert isinstance( assert isinstance(
env.observation_space, env.observation_space,
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary), (spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
) )
dtype = np.dtype(dtype) self.dtype = dtype
if isinstance(env.observation_space, spaces.Box): if isinstance(env.observation_space, spaces.Box):
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=env.observation_space.low, low=env.observation_space.low,
high=env.observation_space.high, high=env.observation_space.high,
shape=env.observation_space.shape, shape=env.observation_space.shape,
dtype=dtype.__name__, dtype=self.dtype,
) )
elif isinstance(env.observation_space, spaces.Discrete): elif isinstance(env.observation_space, spaces.Discrete):
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=env.observation_space.start, low=env.observation_space.start,
high=env.observation_space.start + env.observation_space.n, high=env.observation_space.start + env.observation_space.n,
shape=(), shape=(),
dtype=dtype.__name__, dtype=self.dtype,
) )
elif isinstance(env.observation_space, spaces.MultiDiscrete): elif isinstance(env.observation_space, spaces.MultiDiscrete):
new_observation_space = spaces.MultiDiscrete( new_observation_space = spaces.MultiDiscrete(
env.observation_space.nvec, dtype=dtype.__name__ env.observation_space.nvec, dtype=dtype
) )
elif isinstance(env.observation_space, spaces.MultiBinary): elif isinstance(env.observation_space, spaces.MultiBinary):
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=0, high=1, shape=env.observation_space.shape, dtype=dtype.__name__ low=0,
high=1,
shape=env.observation_space.shape,
dtype=self.dtype,
) )
else: else:
raise TypeError raise TypeError(
"DtypeObservation is only compatible with value / array-based observations."
)
super().__init__(env, lambda obs: dtype(obs), new_observation_space) super().__init__(env, lambda obs: dtype(obs), new_observation_space)

View File

@@ -1,12 +1,17 @@
"""Lambda reward wrappers which apply a function to the reward.""" """A collection of wrappers for modifying the reward.
from typing import Any, Callable, Optional, Union * ``LambdaReward`` - Transforms the reward by a function
* ``ClipReward`` - Clips the reward between a minimum and maximum value
"""
from __future__ import annotations
from typing import Callable, SupportsFloat
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.error import InvalidBound from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers import ArgType
class LambdaRewardV0(gym.RewardWrapper): class LambdaRewardV0(gym.RewardWrapper):
@@ -26,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper):
def __init__( def __init__(
self, self,
env: gym.Env, env: gym.Env,
func: Callable[[ArgType], Any], func: Callable[[SupportsFloat], SupportsFloat],
): ):
"""Initialize LambdaRewardV0 wrapper. """Initialize LambdaRewardV0 wrapper.
@@ -38,7 +43,7 @@ class LambdaRewardV0(gym.RewardWrapper):
self.func = func self.func = func
def reward(self, reward: Union[float, int, np.ndarray]) -> Any: def reward(self, reward: SupportsFloat) -> SupportsFloat:
"""Apply function to reward. """Apply function to reward.
Args: Args:
@@ -64,8 +69,8 @@ class ClipRewardV0(LambdaRewardV0):
def __init__( def __init__(
self, self,
env: gym.Env, env: gym.Env,
min_reward: Optional[Union[float, np.ndarray]] = None, min_reward: float | np.ndarray | None = None,
max_reward: Optional[Union[float, np.ndarray]] = None, max_reward: float | np.ndarray | None = None,
): ):
"""Initialize ClipRewardsV0 wrapper. """Initialize ClipRewardsV0 wrapper.

View File

@@ -6,70 +6,90 @@ import numbers
from collections import abc from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat from typing import Any, Iterable, Mapping, SupportsFloat
import jax.numpy as jnp
import numpy as np import numpy as np
from gymnasium import Env, Wrapper from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
try:
import jax.numpy as jnp
except ImportError:
# We handle the error internal to the relative functions
jnp = None
@functools.singledispatch @functools.singledispatch
def numpy_to_jax(value: Any) -> Any: def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax DeviceArray.""" """Converts a value to a Jax DeviceArray."""
raise Exception( if jnp is None:
f"No conversion for Numpy to Jax registered for type: {type(value)}" raise DependencyNotInstalled(
) "Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)
@numpy_to_jax.register(numbers.Number) if jnp is not None:
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
return jnp.array(value)
@numpy_to_jax.register(numbers.Number)
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(
value: np.ndarray | numbers.Number,
) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
assert jnp is not None
return jnp.array(value)
@numpy_to_jax.register(abc.Mapping) @numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays.""" """Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Iterable)
@numpy_to_jax.register(abc.Iterable) def _iterable_numpy_to_jax(
def _iterable_numpy_to_jax( value: Iterable[np.ndarray | Any],
value: Iterable[np.ndarray | Any], ) -> Iterable[jnp.DeviceArray | Any]:
) -> Iterable[jnp.DeviceArray | Any]: """Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays.""" return type(value)(numpy_to_jax(v) for v in value)
return type(value)(numpy_to_jax(v) for v in value)
@functools.singledispatch @functools.singledispatch
def jax_to_numpy(value: Any) -> Any: def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array.""" """Converts a value to a numpy array."""
raise Exception( if jnp is None:
f"No conversion for Jax to Numpy registered for type: {type(value)}" raise DependencyNotInstalled(
) "Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)
@jax_to_numpy.register(jnp.DeviceArray) if jnp is not None:
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
"""Converts a Jax DeviceArray to a numpy array."""
return np.array(value)
@jax_to_numpy.register(jnp.DeviceArray)
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
"""Converts a Jax DeviceArray to a numpy array."""
return np.array(value)
@jax_to_numpy.register(abc.Mapping) @jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy( def _mapping_jax_to_numpy(
value: Mapping[str, jnp.DeviceArray | Any] value: Mapping[str, jnp.DeviceArray | Any]
) -> Mapping[str, np.ndarray | Any]: ) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays.""" """Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Iterable)
@jax_to_numpy.register(abc.Iterable) def _iterable_jax_to_numpy(
def _iterable_jax_to_numpy( value: Iterable[np.ndarray | Any],
value: Iterable[np.ndarray | Any], ) -> Iterable[jnp.DeviceArray | Any]:
) -> Iterable[jnp.DeviceArray | Any]: """Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays.""" return type(value)(jax_to_numpy(v) for v in value)
return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(Wrapper): class JaxToNumpyV0(Wrapper):
@@ -88,6 +108,10 @@ class JaxToNumpyV0(Wrapper):
Args: Args:
env: the environment to wrap env: the environment to wrap
""" """
if jnp is None:
raise DependencyNotInstalled(
"Jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env) super().__init__(env)
def step( def step(

View File

@@ -0,0 +1,56 @@
"""A collection of stateful action wrappers.
* StickyAction - There is a probability that the action is taken again
"""
from __future__ import annotations
from typing import Any, SupportsFloat
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.error import InvalidProbability
class StickyActionV0(gym.Wrapper):
"""Wrapper which adds a probability of repeating the previous action.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
in Section 5.2 on page 12.
"""
def __init__(self, env: gym.Env, repeat_action_probability: float):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a probability of repeating the old action.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)
super().__init__(env)
self.repeat_action_probability = repeat_action_probability
self.last_action: WrapperActType | None = None
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment."""
self.last_action = None
return super().reset(seed=seed, options=options)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Execute the action."""
if (
self.last_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
action = self.last_action
self.last_action = action
return action

View File

@@ -0,0 +1,200 @@
"""A collection of stateful observation wrappers.
* DelayObservation - A wrapper for delaying the returned observation
* TimeAwareObservation - A wrapper for adding time aware observations to environment observation
"""
from __future__ import annotations
from collections import deque
from typing import Any, SupportsFloat
from typing_extensions import Final
import jumpy as jp
import numpy as np
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
class DelayObservationV0(gym.ObservationWrapper):
"""Wrapper which adds a delay to the returned observation."""
def __init__(self, env: gym.Env, delay: int):
"""Initialize the DelayObservation wrapper.
Args:
env (Env): the wrapped environment
delay (int): number of timesteps for delaying the observation.
Before reaching the `delay` number of timesteps,
returned observation is an array of zeros with the
same shape of the observation space.
"""
assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete))
assert 0 < delay
self.delay: Final[int] = delay
self.observation_queue: Final[deque] = deque()
super().__init__(env)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment, clearing the observation queue."""
self.observation_queue.clear()
return super().reset(seed=seed, options=options)
def observation(self, observation: ObsType) -> ObsType:
"""Return the delayed observation."""
self.observation_queue.append(observation)
if len(self.observation_queue) > self.delay:
return self.observation_queue.popleft()
return jp.zeros_like(observation)
class TimeAwareObservationV0(gym.ObservationWrapper):
"""Augment the observation with time information of the episode.
Time can be represented as a normalized value between [0,1]
or by the number of timesteps remaining before truncation occurs.
For environments with ``Dict`` or ``Tuple`` observation spaces, by default,
the time information is automatically added in the key `"time"` and
as the final element in the tuple.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env)
>>> env.observation_space
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
OrderedDict([('obs',
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
... ('time', array([0.002]))])
Flatten observation space example:
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
"""
def __init__(
self,
env: gym.Env,
flatten: bool = False,
normalize_time: bool = True,
*,
dict_time_key: str = "time",
):
"""Initialize :class:`TimeAwareObservationV0`.
Args:
env: The environment to apply the wrapper
flatten: Flatten the observation to a `Box` of a single dimension
normalize_time: if `True` return time in the range [0,1]
otherwise return time as remaining timesteps before truncation
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
"""
super().__init__(env)
self.flatten: Final[bool] = flatten
self.normalize_time: Final[bool] = normalize_time
if hasattr(env, "_max_episode_steps"):
self.max_timesteps = getattr(env, "_max_episode_steps")
elif env.spec is not None and env.spec.max_episode_steps is not None:
self.max_timesteps = env.spec.max_episode_steps
else:
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)
self.timesteps: int = 0
# Find the normalized time space
if self.normalize_time:
self._time_preprocess_func = lambda time: time / self.max_timesteps
time_space = Box(0.0, 1.0)
else:
self._time_preprocess_func = lambda time: self.max_timesteps - time
time_space = Box(0, self.max_timesteps, dtype=np.int32)
# Find the observation space
if isinstance(env.observation_space, Dict):
assert dict_time_key not in env.observation_space.keys()
observation_space = Dict(
{dict_time_key: time_space}, **env.observation_space.spaces
)
self._append_data_func = lambda obs, time: {**obs, dict_time_key: time}
elif isinstance(env.observation_space, Tuple):
observation_space = Tuple(env.observation_space.spaces + (time_space,))
self._append_data_func = lambda obs, time: obs + (time,)
else:
observation_space = Dict(obs=env.observation_space, time=time_space)
self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
# If to flatten the observation space
if self.flatten:
self.observation_space = spaces.flatten_space(observation_space)
self._obs_postprocess_func = lambda obs: spaces.flatten(
observation_space, obs
)
else:
self.observation_space = observation_space
self._obs_postprocess_func = lambda obs: obs
def observation(self, observation: ObsType) -> WrapperObsType:
"""Adds to the observation with the current time information.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time information appended to
"""
return self._obs_postprocess_func(
self._append_data_func(
observation, self._time_preprocess_func(self.timesteps)
)
)
def step(
self, action: ActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.timesteps += 1
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment setting the time to zero.
Args:
seed: The seed to reset the environment
options: The options used to reset the environment
Returns:
The reset environment
"""
self.timesteps = 0
return super().reset(seed=seed, options=options)

View File

@@ -1,40 +0,0 @@
"""Wrapper which adds a probability of repeating the previous executed action."""
from typing import Union
import gymnasium as gym
from gymnasium.core import ActType
from gymnasium.error import InvalidProbability
class StickyActionV0(gym.ActionWrapper):
"""Wrapper which adds a probability of repeating the previous action."""
def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a proability of repeating the old action.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)
super().__init__(env)
self.repeat_action_probability = repeat_action_probability
self.old_action = None
def action(self, action: ActType):
"""Execute the action."""
if (
self.old_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
action = self.old_action
self.old_action = action
return action
def reset(self, **kwargs):
"""Reset the environment."""
self.old_action = None
return super().reset(**kwargs)

View File

@@ -1,113 +0,0 @@
"""Wrapper for adding time aware observations to environment observation."""
from collections import OrderedDict
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium.core import ActType, ObsType
from gymnasium.spaces import Box, Dict
class TimeAwareObservationV0(gym.ObservationWrapper):
"""Augment the observation with time information of the episode.
Time can be represented as a normalized value between [0,1]
or by the number of timesteps remaining before truncation occurs.
Example:
>>> import gym
>>> from gym.wrappers import TimeAwareObservationV0
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env)
>>> env.observation_space
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
OrderedDict([('obs',
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
... ('time', array([0.002]))])
Flatten observation space example:
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
"""
def __init__(self, env: gym.Env, flatten=False, normalize_time=True):
"""Initialize :class:`TimeAwareObservationV0`.
Args:
env: The environment to apply the wrapper
flatten: Flatten the observation to a `Box` of a single dimension
normalize_time: if `True` return time in the range [0,1]
otherwise return time as remaining timesteps before truncation
"""
super().__init__(env)
self.flatten = flatten
self.normalize_time = normalize_time
self.max_timesteps = getattr(env, "_max_episode_steps")
if self.normalize_time:
self._get_time_observation = lambda: self.timesteps / self.max_timesteps
time_space = Box(0, 1)
else:
self._get_time_observation = lambda: self.max_timesteps - self.timesteps
time_space = Box(0, self.max_timesteps)
self.time_aware_observation_space = Dict(
obs=env.observation_space, time=time_space
)
if self.flatten:
self.observation_space = spaces.flatten_space(
self.time_aware_observation_space
)
self._observation_postprocess = lambda observation: spaces.flatten(
self.time_aware_observation_space, observation
)
else:
self.observation_space = self.time_aware_observation_space
self._observation_postprocess = lambda observation: observation
def observation(self, observation: ObsType):
"""Adds to the observation with the current time information.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time information appended to
"""
time_observation = self._get_time_observation()
observation = OrderedDict(obs=observation, time=time_observation)
return self._observation_postprocess(observation)
def step(self, action: ActType):
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.timesteps += 1
observation, reward, terminated, truncated, info = super().step(action)
return observation, reward, terminated, truncated, info
def reset(self, **kwargs):
"""Reset the environment setting the time to zero.
Args:
**kwargs: Kwargs to apply to env.reset()
Returns:
The reset environment
"""
self.timesteps = 0
return super().reset(**kwargs)

View File

@@ -14,92 +14,121 @@ import numbers
from collections import abc from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union from typing import Any, Iterable, Mapping, SupportsFloat, Union
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
from gymnasium import Env, Wrapper from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
try:
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
except ImportError:
jnp, jax_dlpack = None, None
try: try:
import torch import torch
from torch.utils import dlpack as torch_dlpack from torch.utils import dlpack as torch_dlpack
Device = Union[str, torch.device]
except ImportError: except ImportError:
raise DependencyNotInstalled("torch is not installed, run `pip install torch`") torch, torch_dlpack, Device = None, None, None
Device = Union[str, torch.device]
@functools.singledispatch @functools.singledispatch
def torch_to_jax(value: Any) -> Any: def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax DeviceArray.""" """Converts a PyTorch Tensor into a Jax DeviceArray."""
raise Exception( if torch is None:
f"No conversion for PyTorch to Jax registered for type: {type(value)}" raise DependencyNotInstalled(
) "Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
)
@torch_to_jax.register(numbers.Number) if torch is not None and jnp is not None:
def _number_torch_to_jax(value: numbers.Number) -> Any:
return jnp.array(value)
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a jax array."""
assert jnp is not None
return jnp.array(value)
@torch_to_jax.register(torch.Tensor) @torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray: def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
"""Converts a PyTorch Tensor into a Jax DeviceArray.""" """Converts a PyTorch Tensor into a Jax DeviceArray."""
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage] assert torch_dlpack is not None and jax_dlpack is not None
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage] tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
return tensor value
)
tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage]
tensor
)
return tensor
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Mapping) @torch_to_jax.register(abc.Iterable)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) return type(value)(torch_to_jax(v) for v in value)
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_jax(v) for v in value)
@functools.singledispatch @functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any: def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor.""" """Converts a Jax DeviceArray into a PyTorch Tensor."""
raise Exception( if torch is None:
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}" raise DependencyNotInstalled(
) "Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@jax_to_torch.register(jnp.DeviceArray) if torch is not None and jnp is not None:
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(jnp.DeviceArray)
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
assert jax_dlpack is not None and torch_dlpack is not None
dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
value
)
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping) @jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch( def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()}) return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
@jax_to_torch.register(abc.Iterable) def _jax_iterable_to_torch(
def _jax_iterable_to_torch( value: Iterable[Any], device: Device | None = None
value: Iterable[Any], device: Device | None = None ) -> Iterable[Any]:
) -> Iterable[Any]: """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" return type(value)(jax_to_torch(v, device) for v in value)
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(Wrapper): class JaxToTorchV0(Wrapper):
@@ -107,7 +136,8 @@ class JaxToTorchV0(Wrapper):
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
""" """
def __init__(self, env: Env, device: Device | None = None): def __init__(self, env: Env, device: Device | None = None):
@@ -117,6 +147,15 @@ class JaxToTorchV0(Wrapper):
env: The Jax-based environment to wrap env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to device: The device the torch Tensors should be moved to
""" """
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env) super().__init__(env)
self.device: Device | None = device self.device: Device | None = device

View File

@@ -76,6 +76,7 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
* ``None`` The length will be randomly drawn from a geometric distribution * ``None`` The length will be randomly drawn from a geometric distribution
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array. * ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
* ``int`` for a fixed length sample * ``int`` for a fixed length sample
The second element of the mask tuple `sample` mask specifies a mask that is applied when The second element of the mask tuple `sample` mask specifies a mask that is applied when
sampling elements from the base space. The mask is applied for each feature space sample. sampling elements from the base space. The mask is applied for each feature space sample.

View File

@@ -29,7 +29,7 @@ dependencies = [
"jax-jumpy >=0.2.0", "jax-jumpy >=0.2.0",
"cloudpickle >=1.2.0", "cloudpickle >=1.2.0",
"importlib-metadata >=4.8.0; python_version < '3.10'", "importlib-metadata >=4.8.0; python_version < '3.10'",
"typing-extensions >=4.3.0; python_version == '3.7'", "typing-extensions >=4.3.0",
"gymnasium-notices >=0.0.1", "gymnasium-notices >=0.0.1",
"shimmy >=0.1.0,<1.0", "shimmy >=0.1.0,<1.0",
] ]

View File

@@ -19,7 +19,7 @@ from gymnasium.wrappers.env_checker import PassiveEnvChecker
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.utils import all_testing_env_specs from tests.envs.utils import all_testing_env_specs
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
from tests.testing_env import GenericTestEnv, old_step_fn from tests.testing_env import GenericTestEnv, old_step_func
from tests.wrappers.utils import has_wrapper from tests.wrappers.utils import has_wrapper
@@ -155,7 +155,7 @@ def test_make_disable_env_checker():
def test_apply_api_compatibility(): def test_apply_api_compatibility():
gym.register( gym.register(
"testing-old-env", "testing-old-env",
lambda: GenericTestEnv(step_fn=old_step_fn), lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True, apply_api_compatibility=True,
max_episode_steps=3, max_episode_steps=3,
) )

View File

@@ -1,48 +0,0 @@
"""Test suite for LambdaActionV0."""
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import ClipActionV0
SEED = 42
@pytest.mark.parametrize(
("env", "action_unclipped_env", "action_clipped_env"),
(
[
# MountainCar action space: Box(-1.0, 1.0, (1,), float32)
gym.make("MountainCarContinuous-v0"),
np.array([1]),
np.array([1.5]),
],
[
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([1, 1, 1, 1]),
np.array([10, 10, 10, 10]),
],
[
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([0.5, 0.5, 1, 1]),
np.array([0.5, 0.5, 10, 10]),
],
),
)
def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env):
"""Tests if actions out of bound are correctly clipped.
Tests whether out of bound actions for the wrapped
environments are correctly clipped.
"""
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(action_unclipped_env)
env.reset(seed=SEED)
wrapped_env = ClipActionV0(env)
wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env)
assert np.alltrue(obs == wrapped_obs)

View File

@@ -1,37 +0,0 @@
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0
SEED = 42
DELAY = 3
NUM_STEPS = 5
def test_delay_observation():
env = gym.make("CartPole-v1")
env.action_space.seed(SEED)
env.reset(seed=SEED)
undelayed_observations = []
for _ in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_observations.append(obs)
env.action_space.seed(SEED)
env.reset(seed=SEED)
env = DelayObservationV0(env, delay=DELAY)
delayed_observations = []
for i in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
if i < DELAY - 1:
assert np.all(obs == 0)
delayed_observations.append(obs)
assert np.alltrue(
np.array(delayed_observations[DELAY:])
== np.array(undelayed_observations[: DELAY - 1])
)

View File

@@ -1,60 +1,78 @@
"""Test suite for LambdaActionV0.""" """Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction."""
import numpy as np import numpy as np
import pytest
import gymnasium as gym from gymnasium.experimental.wrappers import (
from gymnasium.error import InvalidAction ClipActionV0,
from gymnasium.experimental.wrappers import LambdaActionV0 LambdaActionV0,
RescaleActionV0,
)
from gymnasium.spaces import Box from gymnasium.spaces import Box
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
NUM_ENVS = 3 SEED = 42
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
def generic_step_fn(self, action): def _record_action_step_func(self, action):
return 0, 0, False, False, {"action": action} return 0, 0, False, False, {"action": action}
@pytest.mark.parametrize( def test_lambda_action_wrapper():
("env", "func", "action", "expected"), """Tests LambdaAction through checking that the action taken is transformed by function."""
[ env = GenericTestEnv(step_func=_record_action_step_func)
( wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn),
lambda action: action + 2,
1,
3,
),
],
)
def test_lambda_action_v0(env, func, action, expected):
"""Tests lambda action.
Tests if function is correctly applied to environment's action.
"""
wrapped_env = LambdaActionV0(env, func)
_, _, _, _, info = wrapped_env.step(action)
executed_action = info["action"]
assert executed_action == expected sampled_action = wrapped_env.action_space.sample()
assert sampled_action not in env.action_space
_, _, _, _, info = wrapped_env.step(sampled_action)
assert info["action"] in env.action_space
assert sampled_action - 2 == info["action"]
def test_lambda_action_v0_within_vector(): def test_clip_action_wrapper():
"""Tests lambda action in vectorized environments. """Test that the action is correctly clipped to the base environment action space."""
Tests if function is correctly applied to environment's action env = GenericTestEnv(
in vectorized environment. action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
""" step_func=_record_action_step_func,
env = gym.vector.make(
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
) )
action = np.ones(NUM_ENVS, dtype=np.float64) wrapped_env = ClipActionV0(env)
wrapped_env = LambdaActionV0(env, lambda action: action.astype(int)) sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
wrapped_env.reset() assert sampled_action not in env.action_space
assert sampled_action in wrapped_env.action_space
wrapped_env.step(action) _, _, _, _, info = wrapped_env.step(sampled_action)
assert np.all(info["action"] in env.action_space)
assert np.all(info["action"] == np.array([0, 2, 3.5]))
# unwrapped env should raise exception because it does not
# support float actions def test_rescale_action_wrapper():
with pytest.raises(InvalidAction): """Test that the action is rescale within a min / max bound."""
env.step(action) env = GenericTestEnv(
step_func=_record_action_step_func,
action_space=Box(np.array([0, 1]), np.array([1, 3])),
)
wrapped_env = RescaleActionV0(
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
)
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
for sample_action, expected_action in (
(
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0], dtype=np.float32),
),
(
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0], dtype=np.float32),
),
(
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0], dtype=np.float32),
),
):
assert sample_action in wrapped_env.action_space
_, _, _, _, info = wrapped_env.step(sample_action)
assert np.all(info["action"] == expected_action)

View File

@@ -1,59 +1,250 @@
"""Test suite for LambdaObservationV0.""" """Test suite for lambda observation wrappers: """
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.experimental.wrappers import LambdaObservationV0 from gymnasium.experimental.wrappers import (
from gymnasium.spaces import Box DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
)
from gymnasium.spaces import Box, Dict, Tuple
from tests.testing_env import GenericTestEnv
NUM_ENVS = 3
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
SEED = 42 SEED = 42
DISCRETE_ACTION = 1
def test_lambda_observation_v0(): def _record_random_obs_reset(self: gym.Env, seed=None, options=None):
"""Tests lambda observation. obs = self.observation_space.sample()
return obs, {"obs": obs}
Tests if function is correctly applied to environment's observation.
"""
env = gym.make("CartPole-v1")
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(DISCRETE_ACTION)
observation_shift = 1 def _record_random_obs_step(self: gym.Env, action):
obs = self.observation_space.sample()
return obs, 0, False, False, {"obs": obs}
env.reset(seed=SEED)
wrapped_env = LambdaObservationV0( def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}):
env, lambda observation: observation + observation_shift, None return options["obs"], {"obs": options["obs"]}
def _record_action_obs_step(self: gym.Env, action):
return action, 0, False, False, {"obs": action}
def _check_obs(
env: gym.Env,
wrapped_env: gym.Wrapper,
transformed_obs,
original_obs,
strict: bool = True,
):
assert (
transformed_obs in wrapped_env.observation_space
), f"{transformed_obs}, {wrapped_env.observation_space}"
assert (
original_obs in env.observation_space
), f"{original_obs}, {env.observation_space}"
if strict:
assert (
transformed_obs not in env.observation_space
), f"{transformed_obs}, {env.observation_space}"
assert (
original_obs not in wrapped_env.observation_space
), f"{original_obs}, {wrapped_env.observation_space}"
def test_lambda_observation_wrapper():
"""Tests lambda observation that the function is applied to both the reset and step observation."""
env = GenericTestEnv(
reset_func=_record_action_obs_reset, step_func=_record_action_obs_step
) )
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION) wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3))
assert np.alltrue(wrapped_obs == obs + observation_shift) obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)})
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32))
_check_obs(env, wrapped_env, obs, info["obs"])
def test_lambda_observation_v0_within_vector(): def test_filter_observation_wrapper():
"""Tests lambda observation in vectorized environments. """Tests ``FilterObservation`` that the right keys are filtered."""
dict_env = GenericTestEnv(
Tests if function is correctly applied to environment's observation observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
in vectorized environment. reset_func=_record_random_obs_reset,
""" step_func=_record_random_obs_step,
env = gym.vector.make(
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
)
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)]))
observation_shift = 1
env.reset(seed=SEED)
wrapped_env = LambdaObservationV0(
env, lambda observation: observation + observation_shift, None
)
wrapped_obs, _, _, _, _ = wrapped_env.step(
np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)])
) )
assert np.alltrue(wrapped_obs == obs + observation_shift) wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
obs, info = wrapped_env.reset()
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
_check_obs(dict_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
_check_obs(dict_env, wrapped_env, obs, info["obs"])
# Test tuple environments
tuple_env = GenericTestEnv(
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = FilterObservationV0(tuple_env, (2,))
obs, info = wrapped_env.reset()
assert len(obs) == 1 and len(info["obs"]) == 3
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert len(obs) == 1 and len(info["obs"]) == 3
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
def test_flatten_observation_wrapper():
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
env = GenericTestEnv(
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
print(env.observation_space)
wrapped_env = FlattenObservationV0(env)
print(wrapped_env.observation_space)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_grayscale_observation_wrapper():
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = GrayscaleObservationV0(env)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
# Keep_dim
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25, 1)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_resize_observation_wrapper():
"""Test the ``ResizeObservation`` that the observation has changed size"""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = ResizeObservationV0(env, (25, 25))
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_reshape_observation_wrapper():
"""Test the ``ReshapeObservation`` wrapper."""
env = GenericTestEnv(
observation_space=Box(0, 1, shape=(2, 3, 2)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = ReshapeObservationV0(env, (6, 2))
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)
def test_rescale_observation():
"""Test the ``RescaleObservation`` wrapper"""
env = GenericTestEnv(
observation_space=Box(
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
),
reset_func=_record_action_obs_reset,
step_func=_record_action_obs_step,
)
wrapped_env = RescaleObservationV0(
env,
min_obs=np.array([-5, 0], dtype=np.float32),
max_obs=np.array([5, 1], dtype=np.float32),
)
assert wrapped_env.observation_space == Box(
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
)
for sample_obs, expected_obs in (
(
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5], dtype=np.float32),
),
(
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0], dtype=np.float32),
),
(
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0], dtype=np.float32),
),
):
assert sample_obs in env.observation_space
assert expected_obs in wrapped_env.observation_space
obs, info = wrapped_env.reset(options={"obs": sample_obs})
assert np.all(obs == expected_obs)
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
obs, _, _, _, info = wrapped_env.step(sample_obs)
assert np.all(obs == expected_obs)
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
def test_dtype_observation():
"""Test ``DtypeObservation`` that the"""
env = GenericTestEnv(
reset_func=_record_random_obs_reset, step_func=_record_random_obs_step
)
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
obs, info = wrapped_env.reset()
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8
obs, _, _, _, info = wrapped_env.step(None)
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8

View File

@@ -55,7 +55,7 @@ def jax_step_func(self, action):
def test_jax_to_numpy(): def test_jax_to_numpy():
jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func) jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
# Check that the reset and step for jax environment are as expected # Check that the reset and step for jax environment are as expected
obs, info = jax_env.reset() obs, info = jax_env.reset()

View File

@@ -1,52 +0,0 @@
"""Test suite for RescaleActionV0."""
import jax
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import RescaleActionV0
SEED = 42
@pytest.mark.parametrize(
("env", "low", "high", "action", "scaled_action"),
[
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
-0.5,
0.5,
np.array([1, 1, 1, 1]),
np.array([0.5, 0.5, 0.5, 0.5]),
),
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
-0.5,
0.5,
jax.numpy.array([1, 1, 1, 1]),
jax.numpy.array([0.5, 0.5, 0.5, 0.5]),
),
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([-0.5, -0.5, -1, -1], dtype=np.float32),
np.array([0.5, 0.5, 1, 1], dtype=np.float32),
jax.numpy.array([1, 1, 1, 1]),
jax.numpy.array([0.5, 0.5, 1, 1]),
),
],
)
def test_rescale_actions_v0_box(env, low, high, action, scaled_action):
"""Test action rescaling."""
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(action)
env.reset(seed=SEED)
wrapped_env = RescaleActionV0(env, low, high)
obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action)
assert np.alltrue(obs == obs_scaled)

View File

@@ -17,7 +17,9 @@ def step_fn(self, action):
def test_sticky_action(): def test_sticky_action():
env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5) env = StickyActionV0(
GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5
)
env.reset(seed=SEED) env.reset(seed=SEED)
env.action_space.seed(SEED) env.action_space.seed(SEED)
@@ -34,7 +36,7 @@ def test_sticky_action():
previous_action = input_action previous_action = input_action
@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5]) @pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability): def test_sticky_action_raise(repeat_action_probability):
with pytest.raises(InvalidProbability): with pytest.raises(InvalidProbability):
StickyActionV0( StickyActionV0(

View File

@@ -0,0 +1,89 @@
"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation."""
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0
from gymnasium.spaces import Box, Dict, Tuple
from tests.testing_env import GenericTestEnv
NUM_STEPS = 20
SEED = 0
DELAY = 3
def test_time_aware_observation_wrapper():
"""Tests the time aware observation wrapper."""
# Test the environment observation space with Dict, Tuple and other
env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3)))
wrapped_env = TimeAwareObservationV0(env)
assert isinstance(wrapped_env.observation_space, Dict)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}"
env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3))))
wrapped_env = TimeAwareObservationV0(env)
assert isinstance(wrapped_env.observation_space, Tuple)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert len(reset_obs) == 3 and len(step_obs) == 3
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env)
assert isinstance(wrapped_env.observation_space, Dict)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert isinstance(reset_obs, dict) and isinstance(step_obs, dict)
assert "obs" in reset_obs and "obs" in step_obs
assert "time" in reset_obs and "time" in step_obs
# Tests the flatten parameter
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env, flatten=True)
assert isinstance(wrapped_env.observation_space, Box)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert reset_obs.shape == (2,) and step_obs.shape == (2,)
# Tests the normalize_time parameter
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env, normalize_time=False)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert reset_obs["time"] == 100 and step_obs["time"] == 99
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env, normalize_time=True)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01
def test_delay_observation_wrapper():
env = gym.make("CartPole-v1")
env.action_space.seed(SEED)
env.reset(seed=SEED)
undelayed_observations = []
for _ in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_observations.append(obs)
env = DelayObservationV0(env, delay=DELAY)
env.action_space.seed(SEED)
env.reset(seed=SEED)
delayed_observations = []
for i in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
delayed_observations.append(obs)
if i < DELAY - 1:
assert np.all(obs == 0)
undelayed_observations = np.array(undelayed_observations)
delayed_observations = np.array(delayed_observations)
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])

View File

@@ -1,99 +0,0 @@
"""Test suite for TimeAwareobservationV0."""
from collections import OrderedDict
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import TimeAwareObservationV0
from gymnasium.spaces import Box, Dict
NUM_STEPS = 20
SEED = 0
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation(env):
"""Test TimeAwareObservationV0 wrapper creation.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created.
"""
wrapped_env = TimeAwareObservationV0(env)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Dict)
assert isinstance(obs, OrderedDict)
assert np.all(obs["time"] == 0)
assert env.observation_space == wrapped_env.observation_space["obs"]
@pytest.mark.parametrize("normalize_time", [True, False])
@pytest.mark.parametrize("flatten", [False, True])
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
],
)
def test_time_aware_observation_step(env, flatten, normalize_time):
"""Test TimeAwareObservationV0 step.
This test checks if wrapped env with TimeAwareObservationV0
steps correctly.
"""
env.action_space.seed(SEED)
max_timesteps = env._max_episode_steps
wrapped_env = TimeAwareObservationV0(
env, flatten=flatten, normalize_time=normalize_time
)
wrapped_env.reset(seed=SEED)
for timestep in range(1, NUM_STEPS):
action = env.action_space.sample()
observation, _, terminated, _, _ = wrapped_env.step(action)
expected_time_obs = (
timestep / max_timesteps if normalize_time else max_timesteps - timestep
)
if flatten:
assert np.allclose(observation[-1], expected_time_obs)
else:
assert np.allclose(observation["time"], expected_time_obs)
if terminated:
break
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation_flatten(env):
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created when the `flatten` parameter is set to `True`.
When flattened, the observation space should be a 1 dimension `Box`
with time appended to the end.
"""
wrapped_env = TimeAwareObservationV0(env, flatten=True)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Box)
assert isinstance(obs, np.ndarray)
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]

View File

@@ -9,6 +9,7 @@ from tests.testing_env import GenericTestEnv
def torch_data_equivalence(data_1, data_2) -> bool: def torch_data_equivalence(data_1, data_2) -> bool:
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
if type(data_1) == type(data_2): if type(data_1) == type(data_2):
if isinstance(data_1, dict): if isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all( return data_1.keys() == data_2.keys() and all(
@@ -56,14 +57,15 @@ def torch_data_equivalence(data_1, data_2) -> bool:
) )
def test_roundtripping(value, expected_value): def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value) roundtripped_value = jax_to_torch(torch_to_jax(value))
assert torch_data_equivalence(roundtripped_value, expected_value)
def jax_reset_func(self, seed=None, options=None): def _jax_reset_func(self, seed=None, options=None):
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])} return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def jax_step_func(self, action): def _jax_step_func(self, action):
assert isinstance(action, jnp.DeviceArray), type(action) assert isinstance(action, jnp.DeviceArray), type(action)
return ( return (
jnp.array([1, 2, 3]), jnp.array([1, 2, 3]),
@@ -75,7 +77,7 @@ def jax_step_func(self, action):
def test_jax_to_torch(): def test_jax_to_torch():
env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func) env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
# Check that the reset and step for jax environment are as expected # Check that the reset and step for jax environment are as expected
obs, info = env.reset() obs, info = env.reset()

View File

@@ -278,7 +278,7 @@ def test_wrapper_types():
obs, _, _, _, _ = observation_env.step(0) obs, _, _, _, _ = observation_env.step(0)
assert obs == np.array([1]) assert obs == np.array([1])
env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {})) env = GenericTestEnv(step_func=lambda self, action: (action, 0, False, False, {}))
action_env = ExampleActionWrapper(env) action_env = ExampleActionWrapper(env)
obs, _, _, _, _ = action_env.step(0) obs, _, _, _, _ = action_env.step(0)
assert obs == np.array([1]) assert obs == np.array([1])

View File

@@ -8,7 +8,7 @@ from gymnasium.core import ActType, ObsType
from gymnasium.envs.registration import EnvSpec from gymnasium.envs.registration import EnvSpec
def basic_reset_fn( def basic_reset_func(
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
@@ -20,17 +20,17 @@ def basic_reset_fn(
return self.observation_space.sample(), {"options": options} return self.observation_space.sample(), {"options": options}
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space.""" """A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, False, {} return self.observation_space.sample(), 0, False, False, {}
def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space.""" """A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, {} return self.observation_space.sample(), 0, False, {}
def basic_render_fn(self): def basic_render_func(self):
"""Basic render fn that does nothing.""" """Basic render fn that does nothing."""
pass pass
@@ -43,12 +43,14 @@ class GenericTestEnv(gym.Env):
self, self,
action_space: spaces.Space = spaces.Box(0, 1, (1,)), action_space: spaces.Space = spaces.Box(0, 1, (1,)),
observation_space: spaces.Space = spaces.Box(0, 1, (1,)), observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
reset_fn: callable = basic_reset_fn, reset_func: callable = basic_reset_func,
step_fn: callable = new_step_fn, step_func: callable = new_step_func,
render_fn: callable = basic_render_fn, render_func: callable = basic_render_func,
metadata: Dict[str, Any] = {"render_modes": []}, metadata: Dict[str, Any] = {"render_modes": []},
render_mode: Optional[str] = None, render_mode: Optional[str] = None,
spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"), spec: EnvSpec = EnvSpec(
"TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100
),
): ):
self.metadata = metadata self.metadata = metadata
self.render_mode = render_mode self.render_mode = render_mode
@@ -59,12 +61,12 @@ class GenericTestEnv(gym.Env):
if action_space is not None: if action_space is not None:
self.action_space = action_space self.action_space = action_space
if reset_fn is not None: if reset_func is not None:
self.reset = types.MethodType(reset_fn, self) self.reset = types.MethodType(reset_func, self)
if step_fn is not None: if step_func is not None:
self.step = types.MethodType(step_fn, self) self.step = types.MethodType(step_func, self)
if render_fn is not None: if render_func is not None:
self.render = types.MethodType(render_fn, self) self.render = types.MethodType(render_func, self)
def reset( def reset(
self, self,

View File

@@ -112,10 +112,10 @@ def test_check_reset_seed(test, func: callable, message: str):
with pytest.warns( with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
): ):
check_reset_seed(GenericTestEnv(reset_fn=func)) check_reset_seed(GenericTestEnv(reset_func=func))
else: else:
with pytest.raises(test, match=f"^{re.escape(message)}$"): with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_seed(GenericTestEnv(reset_fn=func)) check_reset_seed(GenericTestEnv(reset_func=func))
def _deprecated_return_info( def _deprecated_return_info(
@@ -179,7 +179,7 @@ def test_check_reset_return_type(test, func: callable, message: str):
"""Tests the check `env.reset()` function has a correct return type.""" """Tests the check `env.reset()` function has a correct return type."""
with pytest.raises(test, match=f"^{re.escape(message)}$"): with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_return_type(GenericTestEnv(reset_fn=func)) check_reset_return_type(GenericTestEnv(reset_func=func))
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -198,7 +198,7 @@ def test_check_reset_return_info_deprecation(test, func: callable, message: str)
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`.""" """Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"): with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func)) check_reset_return_info_deprecation(GenericTestEnv(reset_func=func))
def test_check_seed_deprecation(): def test_check_seed_deprecation():
@@ -236,7 +236,7 @@ def test_check_reset_options():
"The `reset` method does not provide an `options` or `**kwargs` keyword argument" "The `reset` method does not provide an `options` or `**kwargs` keyword argument"
), ),
): ):
check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {}))) check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {})))
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -303,11 +303,11 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
with pytest.warns( with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
): ):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs) env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
else: else:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"): with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs) env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
assert len(caught_warnings) == 0 assert len(caught_warnings) == 0
@@ -383,11 +383,11 @@ def test_passive_env_step_checker(
with pytest.warns( with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
): ):
env_step_passive_checker(GenericTestEnv(step_fn=func), 0) env_step_passive_checker(GenericTestEnv(step_func=func), 0)
else: else:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"): with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_step_passive_checker(GenericTestEnv(step_fn=func), 0) env_step_passive_checker(GenericTestEnv(step_func=func), 0)
assert len(caught_warnings) == 0, caught_warnings assert len(caught_warnings) == 0, caught_warnings
@@ -416,7 +416,7 @@ def test_passive_env_step_checker(
GenericTestEnv( GenericTestEnv(
metadata={"render_modes": ["Testing mode"], "render_fps": None}, metadata={"render_modes": ["Testing mode"], "render_fps": None},
render_mode="Testing mode", render_mode="Testing mode",
render_fn=lambda self: 0, render_func=lambda self: 0,
), ),
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.", "No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
], ],

View File

@@ -21,7 +21,7 @@ IRRELEVANT_KEY = 1
PlayableEnv = partial( PlayableEnv = partial(
GenericTestEnv, GenericTestEnv,
metadata={"render_modes": ["rgb_array"]}, metadata={"render_modes": ["rgb_array"]},
render_fn=lambda self: np.ones((10, 10, 3)), render_func=lambda self: np.ones((10, 10, 3)),
) )

View File

@@ -82,8 +82,8 @@ def test_final_obs_info(vectoriser):
return GenericTestEnv( return GenericTestEnv(
action_space=Discrete(4), action_space=Discrete(4),
observation_space=Discrete(4), observation_space=Discrete(4),
reset_fn=reset_fn, reset_func=reset_fn,
step_fn=lambda self, action: ( step_func=lambda self, action: (
action if action < 3 else 0, action if action < 3 else 0,
0, 0,
action >= 3, action >= 3,

View File

@@ -3,7 +3,7 @@ import pytest
from gymnasium.spaces import Box, Discrete from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility
from tests.testing_env import GenericTestEnv, old_step_fn from tests.testing_env import GenericTestEnv, old_step_func
class AleTesting: class AleTesting:
@@ -34,7 +34,7 @@ class AtariTestingEnv(GenericTestEnv):
low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1 low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1
), ),
action_space=Discrete(3, seed=1), action_space=Discrete(3, seed=1),
step_fn=old_step_fn, step_func=old_step_func,
) )
self.ale = AleTesting() self.ale = AleTesting()

View File

@@ -68,8 +68,8 @@ def _step_failure(self, action):
def test_api_failures(): def test_api_failures():
env = GenericTestEnv( env = GenericTestEnv(
reset_fn=_reset_failure, reset_func=_reset_failure,
step_fn=_step_failure, step_func=_step_failure,
metadata={"render_modes": "error"}, metadata={"render_modes": "error"},
) )
env = PassiveEnvChecker(env) env = PassiveEnvChecker(env)