Update experimental wrappers (#176)

This commit is contained in:
Mark Towers
2022-12-05 19:14:56 +00:00
committed by GitHub
parent 1a381bcd0d
commit 848b7097bf
39 changed files with 1140 additions and 806 deletions

View File

@@ -22,9 +22,9 @@ repos:
hooks:
- id: codespell
args:
- --ignore-words-list=nd,reacher,thist,ths, ure, referenc,wile
- --ignore-words-list=nd,reacher,thist,ths,ure,referenc,wile
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 5.0.4
hooks:
- id: flake8
args:

View File

@@ -27,8 +27,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
### Lambda Observation Wrappers
### Observation Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
@@ -44,61 +43,60 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- VectorLambdaObservation
- No
* - :class:`wrappers.FilterObservation`
- :class:`experimental.wrappers.FilterObservation`
- :class:`experimental.wrappers.FilterObservationV0`
- VectorFilterObservation (*)
- Yes
* - :class:`wrappers.FlattenObservation`
- `:class:`experimental.wrappers.FlattenObservation`
- :class:`experimental.wrappers.FlattenObservationV0`
- VectorFlattenObservation (*)
- No
* - :class:`wrappers.GrayScaleObservation`
- `:class:`experimental.wrappers.GrayscaleObservation`
- :class:`experimental.wrappers.GrayscaleObservationV0`
- VectorGrayscaleObservation (*)
- Yes
* - :class:`wrappers.ResizeObservation`
- :class:`experimental.wrappers.ResizeObservation`
- :class:`experimental.wrappers.ResizeObservationV0`
- VectorResizeObservation (*)
- Yes
* - Not Implemented
- :class:`experimental.wrappers.ReshapeObservation`
- :class:`experimental.wrappers.ReshapeObservationV0`
- VectorReshapeObservation (*)
- Yes
* - Not Implemented
- :class:`experimental.wrappers.RescaleObservation`
- :class:`experimental.wrappers.RescaleObservationV0`
- VectorRescaleObservation (*)
- Yes
* - Not Implemented
- :class:`experimental.wrappers.DtypeObservation`
- :class:`experimental.wrappers.DtypeObservationV0`
- VectorDtypeObservation (*)
- Yes
* - :class:`wrappers.PixelObservationWrapper`
- PixelObservation
- VectorPixelObservation
- No
* - :class:`NormalizeObservation`
* - :class:`wrappers.NormalizeObservation`
- NormalizeObservation
- VectorNormalizeObservation
- No
* - :class:`TimeAwareObservation`
- TimeAwareObservation
* - :class:`wrappers.TimeAwareObservation`
- :class:`experimental.wrappers.TimeAwareObservationV0`
- VectorTimeAwareObservation
- No
* - :class:`FrameStack`
* - :class:`wrappers.FrameStack`
- FrameStackObservation
- VectorFrameStackObservation
- No
* - Not Implemented
- DelayObservation
- :class:`experimental.wrappers.DelayObservationV0`
- VectorDelayObservation
- No
* - :class:`AtariPreprocessing`
* - :class:`wrappers.AtariPreprocessing`
- AtariPreprocessing
- Not Implemented
- No
```
### Lambda Action Wrappers
### Action Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
@@ -114,25 +112,20 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- VectorLambdaAction
- No
* - :class:`wrappers.ClipAction`
- ClipAction
- :class:`experimental.wrappers.ClipActionV0`
- VectorClipAction (*)
- Yes
* - :class:`wrappers.RescaleAction`
- RescaleAction
- :class:`experimental.wrappers.RescaleActionV0`
- VectorRescaleAction (*)
- Yes
* - Not Implemented
- NanAction
- VectorNanAction (*)
- Yes
* - Not Implemented
- StickyAction
- :class:`experimental.wrappers.StickyActionV0`
- VectorStickyAction
- No
```
### Lambda Reward Wrappers
### Reward Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
@@ -175,7 +168,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- VectorPassiveEnvChecker
* - :class:`wrappers.OrderEnforcing`
- OrderEnforcing
- VectorOrderEnforcing (*)
- VectorOrderEnforcing
* - :class:`wrappers.EnvCompatibility`
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
- Not Implemented
@@ -189,10 +182,10 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
- HumanRendering
- Not Implemented
* - Not Implemented
- :class:`experimental.wrappers.JaxToNumpy`
- :class:`experimental.wrappers.JaxToNumpyV0`
- VectorJaxToNumpy (*)
* - Not Implemented
- :class:`experimental.wrappers.JaxToTorch`
- :class:`experimental.wrappers.JaxToTorchV0`
- VectorJaxToTorch (*)
```

View File

@@ -12,9 +12,6 @@ title: Functional
.. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.transition
.. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.reward
.. autofunction:: gymnasium.experimental.FuncEnv.terminal
@@ -33,4 +30,8 @@ title: Functional
```{eval-rst}
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.reset
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.step
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.render
```

View File

@@ -1,6 +1,6 @@
# Wrappers
## Lambda Observation Wrappers
## Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
@@ -11,24 +11,6 @@
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
```
## Lambda Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
```
## Lambda Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
```
## Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
```
@@ -36,11 +18,22 @@
## Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0
.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
```
# Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
```
## Common Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
```

View File

@@ -27,7 +27,7 @@ __all__ = [
"register",
"registry",
"pprint_registry",
# root files
# module folders
"envs",
"spaces",
"utils",

View File

@@ -1,2 +1,2 @@
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv

View File

@@ -1,6 +1,7 @@
"""Root __init__ of the gym experimental wrappers."""
from gymnasium.experimental import functional, wrappers
from gymnasium.experimental.functional import FuncEnv
@@ -8,6 +9,8 @@ __all__ = [
# Functional
"FuncEnv",
"functional",
# Wrapper
# Wrappers
"wrappers",
# Vector
# "vector",
]

View File

@@ -10,31 +10,59 @@ from gymnasium.experimental.wrappers.lambda_action import (
ClipActionV0,
RescaleActionV0,
)
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
from gymnasium.experimental.wrappers.lambda_observations import (
LambdaObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
ResizeObservationV0,
ReshapeObservationV0,
RescaleObservationV0,
DtypeObservationV0,
)
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
from gymnasium.experimental.wrappers.time_aware_observation import (
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
from gymnasium.experimental.wrappers.stateful_observation import (
TimeAwareObservationV0,
DelayObservationV0,
)
from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0
__all__ = [
"ArgType",
# Lambda Action
# --- Observation wrappers ---
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
# "PixelObservationV0",
# "NormalizeObservationV0",
"TimeAwareObservationV0",
# "FrameStackV0",
"DelayObservationV0",
# "AtariPreprocessingV0"
# --- Action Wrappers ---
"LambdaActionV0",
"StickyActionV0",
"ClipActionV0",
"RescaleActionV0",
# Lambda Observation
"LambdaObservationV0",
"DelayObservationV0",
"TimeAwareObservationV0",
# Lambda Reward
# "NanAction",
"StickyActionV0",
# --- Reward wrappers ---
"LambdaRewardV0",
"ClipRewardV0",
# Jax conversion wrappers
# "RescaleRewardV0",
# "NormalizeRewardV0",
# --- Common ---
# "AutoReset",
# "PassiveEnvChecker",
# "OrderEnforcing",
# "RecordEpisodeStatistics",
# "RenderCollection",
# "HumanRendering",
"JaxToNumpyV0",
"JaxToTorchV0",
]

View File

@@ -1,35 +0,0 @@
"""Wrapper for delaying the returned observation."""
from collections import deque
import jumpy as jp
import gymnasium as gym
from gymnasium.core import ObsType
class DelayObservationV0(gym.ObservationWrapper):
"""Wrapper which adds a delay to the returned observation."""
def __init__(self, env: gym.Env, delay: int):
"""Initialize the DelayObservation wrapper.
Args:
env (Env): the wrapped environment
delay (int): number of timesteps for delaying the observation.
Before reaching the `delay` number of timesteps,
returned observation is an array of zeros with the
same shape of the observation space.
"""
super().__init__(env)
self.delay = delay
self.observation_queue = deque()
def observation(self, observation: ObsType) -> ObsType:
"""Return the delayed observation."""
self.observation_queue.append(observation)
if len(self.observation_queue) > self.delay:
return self.observation_queue.popleft()
return jp.zeros_like(observation)

View File

@@ -1,13 +1,19 @@
"""Lambda action wrapper which apply a function to the provided action."""
from typing import Any, Callable, Union
"""A collection of wrappers that all use the LambdaAction class.
* ``LambdaAction`` - Transforms the actions based on a function
* ``ClipAction`` - Clips the action within a bounds
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
"""
from __future__ import annotations
from typing import Callable
import jumpy as jp
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ActType
from gymnasium.experimental.wrappers import ArgType
from gymnasium.core import ActType, WrapperActType
from gymnasium.spaces import Box, Space
class LambdaActionV0(gym.ActionWrapper):
@@ -16,19 +22,23 @@ class LambdaActionV0(gym.ActionWrapper):
def __init__(
self,
env: gym.Env,
func: Callable[[ArgType], Any],
func: Callable[[WrapperActType], ActType],
action_space: Space | None,
):
"""Initialize LambdaAction.
Args:
env (Env): The gymnasium environment
func (Callable): function to apply to action
env: The gymnasium environment
func: Function to apply to ``step`` ``action``
action_space: The updated action space of the wrapper given the function.
"""
super().__init__(env)
if action_space is not None:
self.action_space = action_space
self.func = func
def action(self, action: ActType) -> Any:
def action(self, action: WrapperActType) -> ActType:
"""Apply function to action."""
return self.func(action)
@@ -53,14 +63,19 @@ class ClipActionV0(LambdaActionV0):
Args:
env: The environment to apply the wrapper
"""
assert isinstance(env.action_space, spaces.Box)
assert isinstance(env.action_space, Box)
super().__init__(
env,
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
Box(
-np.inf,
np.inf,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
)
self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape)
class RescaleActionV0(LambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
@@ -86,8 +101,8 @@ class RescaleActionV0(LambdaActionV0):
def __init__(
self,
env: gym.Env,
min_action: Union[float, int, np.ndarray],
max_action: Union[float, int, np.ndarray],
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Initializes the :class:`RescaleAction` wrapper.
@@ -96,28 +111,44 @@ class RescaleActionV0(LambdaActionV0):
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
assert isinstance(
env.action_space, spaces.Box
), f"expected Box action space, got {type(env.action_space)}"
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
low = env.action_space.low
high = env.action_space.high
self.min_action = np.full(
env.action_space.shape, min_action, dtype=env.action_space.dtype
assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf
)
self.max_action = np.full(
env.action_space.shape, max_action, dtype=env.action_space.dtype
if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)
assert min_action.shape == env.action_space.shape
assert not np.any(min_action == np.inf)
if not isinstance(max_action, np.ndarray):
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
max_action = np.full(env.action_space.shape, max_action)
assert max_action.shape == env.action_space.shape
assert not np.any(max_action == np.inf)
assert isinstance(env.action_space, Box)
assert np.all(np.less_equal(min_action, max_action))
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (env.action_space.high - env.action_space.low) / (
max_action - min_action
)
intercept = gradient * -min_action + env.action_space.low
super().__init__(
env,
lambda action: jp.clip(
low
+ (high - low)
* ((action - self.min_action) / (self.max_action - self.min_action)),
low,
high,
lambda action: gradient * action + intercept,
Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
)

View File

@@ -1,17 +1,27 @@
"""Lambda observation wrappers which apply a function to the observation."""
"""A collection of observation wrappers using a lambda function.
* ``LambdaObservation`` - Transforms the observation with a function
* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
* ``FlattenObservation`` - Flattens the observations
* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation
* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation)
* ``ReshapeObservation`` - Reshapes an array-based observation
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservation`` - Convert a observation dtype
"""
from __future__ import annotations
from typing import Any, Callable, Sequence
from typing_extensions import Final
import jumpy as jp
import numpy as np
import numpy.typing as npt
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import utils
from gymnasium.spaces import Box, utils
class LambdaObservationV0(gym.ObservationWrapper):
@@ -71,32 +81,82 @@ class FilterObservationV0(LambdaObservationV0):
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {})
"""
def __init__(self, env: gym.Env, filter_keys: Sequence[str]):
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
if not isinstance(env.observation_space, spaces.Dict):
assert isinstance(filter_keys, Sequence)
# Filters for dictionary space
if isinstance(env.observation_space, spaces.Dict):
assert all(isinstance(key, str) for key in filter_keys)
if any(
key not in env.observation_space.spaces.keys() for key in filter_keys
):
missing_keys = [
key
for key in filter_keys
if key not in env.observation_space.spaces.keys()
]
raise ValueError(
"All the `filter_keys` must be included in the observation space.\n"
f"Filter keys: {filter_keys}\n"
f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
f"Missing keys: {missing_keys}"
)
new_observation_space = spaces.Dict(
{key: env.observation_space[key] for key in filter_keys}
)
if len(new_observation_space) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
super().__init__(
env,
lambda obs: {key: obs[key] for key in filter_keys},
new_observation_space,
)
# Filter for tuple observation
elif isinstance(env.observation_space, spaces.Tuple):
assert all(isinstance(key, int) for key in filter_keys)
assert len(set(filter_keys)) == len(
filter_keys
), f"Duplicate keys exist, filter_keys: {filter_keys}"
if any(
0 < key and key >= len(env.observation_space) for key in filter_keys
):
missing_index = [
key
for key in filter_keys
if 0 < key and key >= len(env.observation_space)
]
raise ValueError(
"All the `filter_keys` must be included in the length of the observation space.\n"
f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
f"missing indexes: {missing_index}"
)
new_observation_spaces = spaces.Tuple(
env.observation_space[key] for key in filter_keys
)
if len(new_observation_spaces) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
super().__init__(
env,
lambda obs: tuple(obs[key] for key in filter_keys),
new_observation_spaces,
)
else:
raise ValueError(
f"FilterObservation wrapper is only usable with dict observations, actual type: {type(env.observation_space)}"
f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}"
)
if any(key not in env.observation_space.keys() for key in filter_keys):
missing_keys = [
key for key in filter_keys if key not in env.observation_space.keys()
]
raise ValueError(
"All the filter_keys must be included in the original observation space.\n"
f"Filter keys: {filter_keys}\n"
f"Observation keys: {list(env.observation_space.keys())}\n"
f"Missing keys: {missing_keys}"
)
new_observation_space = spaces.Dict(
{key: env.observation_space[key] for key in filter_keys}
)
super().__init__(
env,
lambda obs: {key: obs[key] for key in filter_keys},
new_observation_space,
)
self.filter_keys: Final[Sequence[str | int]] = filter_keys
class FlattenObservationV0(LambdaObservationV0):
@@ -117,9 +177,10 @@ class FlattenObservationV0(LambdaObservationV0):
def __init__(self, env: gym.Env):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
flattened_space = utils.flatten_space(env.observation_space)
super().__init__(
env, lambda obs: utils.flatten(flattened_space, obs), flattened_space
env,
lambda obs: utils.flatten(env.observation_space, obs),
utils.flatten_space(env.observation_space),
)
@@ -154,7 +215,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
and env.observation_space.dtype == np.uint8
)
self.keep_dim = keep_dim
self.keep_dim: Final[bool] = keep_dim
if keep_dim:
new_observation_space = spaces.Box(
low=0,
@@ -167,7 +228,8 @@ class GrayscaleObservationV0(LambdaObservationV0):
lambda obs: jp.expand_dims(
jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
)
).astype(np.uint8),
axis=-1,
),
new_observation_space,
)
@@ -179,7 +241,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
env,
lambda obs: jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
),
).astype(np.uint8),
new_observation_space,
)
@@ -215,7 +277,7 @@ class ResizeObservationV0(LambdaObservationV0):
"opencv is not install, run `pip install gymnasium[other]`"
)
self.shape = tuple(shape)
self.shape: Final[tuple[int, ...]] = tuple(shape)
new_observation_space = spaces.Box(
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
@@ -237,7 +299,7 @@ class ReshapeObservationV0(LambdaObservationV0):
assert isinstance(shape, tuple)
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
assert all(x > 0 for x in shape)
assert all(x > 0 or x == -1 for x in shape)
new_observation_space = spaces.Box(
low=np.reshape(np.ravel(env.observation_space.low), shape),
@@ -245,9 +307,8 @@ class ReshapeObservationV0(LambdaObservationV0):
shape=shape,
dtype=env.observation_space.dtype,
)
super().__init__(
env, lambda obs: jp.reshape(obs, self.shape), new_observation_space
)
self.shape = shape
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
class RescaleObservationV0(LambdaObservationV0):
@@ -256,18 +317,23 @@ class RescaleObservationV0(LambdaObservationV0):
def __init__(
self,
env: gym.Env,
min_obs: tuple[np.floating, np.integer, np.ndarray],
max_obs: tuple[np.floating, np.integer, np.ndarray],
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`."""
assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
)
if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
min_obs = np.full(env.observation_space.shape, min_obs)
assert min_obs.shape == env.observation_space.shape
assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
assert not np.any(min_obs == np.inf)
if not isinstance(max_obs, np.ndarray):
@@ -278,52 +344,66 @@ class RescaleObservationV0(LambdaObservationV0):
assert max_obs.shape == env.observation_space.shape
assert not np.any(max_obs == np.inf)
env_low = env.observation_space.low
env_high = env.observation_space.high
self.min_obs = min_obs
self.max_obs = max_obs
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (max_obs - min_obs) / (
env.observation_space.high - env.observation_space.low
)
intercept = gradient * -env.observation_space.low + min_obs
new_observation_space = spaces.Box(low=min_obs, high=max_obs)
super().__init__(
env,
lambda obs: env_low
+ (env_high - env_low) * ((obs - min_obs) / (max_obs - min_obs)),
new_observation_space,
lambda obs: gradient * obs + intercept,
Box(
low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
),
)
class DtypeObservationV0(LambdaObservationV0):
"""Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: gym.Env, dtype: npt.DTypeLike):
def __init__(self, env: gym.Env, dtype: Any):
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
assert isinstance(
env.observation_space,
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
)
dtype = np.dtype(dtype)
self.dtype = dtype
if isinstance(env.observation_space, spaces.Box):
new_observation_space = spaces.Box(
low=env.observation_space.low,
high=env.observation_space.high,
shape=env.observation_space.shape,
dtype=dtype.__name__,
dtype=self.dtype,
)
elif isinstance(env.observation_space, spaces.Discrete):
new_observation_space = spaces.Box(
low=env.observation_space.start,
high=env.observation_space.start + env.observation_space.n,
shape=(),
dtype=dtype.__name__,
dtype=self.dtype,
)
elif isinstance(env.observation_space, spaces.MultiDiscrete):
new_observation_space = spaces.MultiDiscrete(
env.observation_space.nvec, dtype=dtype.__name__
env.observation_space.nvec, dtype=dtype
)
elif isinstance(env.observation_space, spaces.MultiBinary):
new_observation_space = spaces.Box(
low=0, high=1, shape=env.observation_space.shape, dtype=dtype.__name__
low=0,
high=1,
shape=env.observation_space.shape,
dtype=self.dtype,
)
else:
raise TypeError
raise TypeError(
"DtypeObservation is only compatible with value / array-based observations."
)
super().__init__(env, lambda obs: dtype(obs), new_observation_space)

View File

@@ -1,12 +1,17 @@
"""Lambda reward wrappers which apply a function to the reward."""
"""A collection of wrappers for modifying the reward.
from typing import Any, Callable, Optional, Union
* ``LambdaReward`` - Transforms the reward by a function
* ``ClipReward`` - Clips the reward between a minimum and maximum value
"""
from __future__ import annotations
from typing import Callable, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers import ArgType
class LambdaRewardV0(gym.RewardWrapper):
@@ -26,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper):
def __init__(
self,
env: gym.Env,
func: Callable[[ArgType], Any],
func: Callable[[SupportsFloat], SupportsFloat],
):
"""Initialize LambdaRewardV0 wrapper.
@@ -38,7 +43,7 @@ class LambdaRewardV0(gym.RewardWrapper):
self.func = func
def reward(self, reward: Union[float, int, np.ndarray]) -> Any:
def reward(self, reward: SupportsFloat) -> SupportsFloat:
"""Apply function to reward.
Args:
@@ -64,8 +69,8 @@ class ClipRewardV0(LambdaRewardV0):
def __init__(
self,
env: gym.Env,
min_reward: Optional[Union[float, np.ndarray]] = None,
max_reward: Optional[Union[float, np.ndarray]] = None,
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Initialize ClipRewardsV0 wrapper.

View File

@@ -6,70 +6,90 @@ import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat
import jax.numpy as jnp
import numpy as np
from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
try:
import jax.numpy as jnp
except ImportError:
# We handle the error internal to the relative functions
jnp = None
@functools.singledispatch
def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax DeviceArray."""
raise Exception(
f"No conversion for Numpy to Jax registered for type: {type(value)}"
)
if jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)
@numpy_to_jax.register(numbers.Number)
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
return jnp.array(value)
if jnp is not None:
@numpy_to_jax.register(numbers.Number)
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(
value: np.ndarray | numbers.Number,
) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
assert jnp is not None
return jnp.array(value)
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
return type(value)(numpy_to_jax(v) for v in value)
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
return type(value)(numpy_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array."""
raise Exception(
f"No conversion for Jax to Numpy registered for type: {type(value)}"
)
if jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)
@jax_to_numpy.register(jnp.DeviceArray)
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
"""Converts a Jax DeviceArray to a numpy array."""
return np.array(value)
if jnp is not None:
@jax_to_numpy.register(jnp.DeviceArray)
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
"""Converts a Jax DeviceArray to a numpy array."""
return np.array(value)
@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jnp.DeviceArray | Any]
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jnp.DeviceArray | Any]
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
return type(value)(jax_to_numpy(v) for v in value)
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(Wrapper):
@@ -88,6 +108,10 @@ class JaxToNumpyV0(Wrapper):
Args:
env: the environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
"Jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env)
def step(

View File

@@ -0,0 +1,56 @@
"""A collection of stateful action wrappers.
* StickyAction - There is a probability that the action is taken again
"""
from __future__ import annotations
from typing import Any, SupportsFloat
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.error import InvalidProbability
class StickyActionV0(gym.Wrapper):
"""Wrapper which adds a probability of repeating the previous action.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
in Section 5.2 on page 12.
"""
def __init__(self, env: gym.Env, repeat_action_probability: float):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a probability of repeating the old action.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)
super().__init__(env)
self.repeat_action_probability = repeat_action_probability
self.last_action: WrapperActType | None = None
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment."""
self.last_action = None
return super().reset(seed=seed, options=options)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Execute the action."""
if (
self.last_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
action = self.last_action
self.last_action = action
return action

View File

@@ -0,0 +1,200 @@
"""A collection of stateful observation wrappers.
* DelayObservation - A wrapper for delaying the returned observation
* TimeAwareObservation - A wrapper for adding time aware observations to environment observation
"""
from __future__ import annotations
from collections import deque
from typing import Any, SupportsFloat
from typing_extensions import Final
import jumpy as jp
import numpy as np
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
class DelayObservationV0(gym.ObservationWrapper):
"""Wrapper which adds a delay to the returned observation."""
def __init__(self, env: gym.Env, delay: int):
"""Initialize the DelayObservation wrapper.
Args:
env (Env): the wrapped environment
delay (int): number of timesteps for delaying the observation.
Before reaching the `delay` number of timesteps,
returned observation is an array of zeros with the
same shape of the observation space.
"""
assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete))
assert 0 < delay
self.delay: Final[int] = delay
self.observation_queue: Final[deque] = deque()
super().__init__(env)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment, clearing the observation queue."""
self.observation_queue.clear()
return super().reset(seed=seed, options=options)
def observation(self, observation: ObsType) -> ObsType:
"""Return the delayed observation."""
self.observation_queue.append(observation)
if len(self.observation_queue) > self.delay:
return self.observation_queue.popleft()
return jp.zeros_like(observation)
class TimeAwareObservationV0(gym.ObservationWrapper):
"""Augment the observation with time information of the episode.
Time can be represented as a normalized value between [0,1]
or by the number of timesteps remaining before truncation occurs.
For environments with ``Dict`` or ``Tuple`` observation spaces, by default,
the time information is automatically added in the key `"time"` and
as the final element in the tuple.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env)
>>> env.observation_space
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
OrderedDict([('obs',
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
... ('time', array([0.002]))])
Flatten observation space example:
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
"""
def __init__(
self,
env: gym.Env,
flatten: bool = False,
normalize_time: bool = True,
*,
dict_time_key: str = "time",
):
"""Initialize :class:`TimeAwareObservationV0`.
Args:
env: The environment to apply the wrapper
flatten: Flatten the observation to a `Box` of a single dimension
normalize_time: if `True` return time in the range [0,1]
otherwise return time as remaining timesteps before truncation
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
"""
super().__init__(env)
self.flatten: Final[bool] = flatten
self.normalize_time: Final[bool] = normalize_time
if hasattr(env, "_max_episode_steps"):
self.max_timesteps = getattr(env, "_max_episode_steps")
elif env.spec is not None and env.spec.max_episode_steps is not None:
self.max_timesteps = env.spec.max_episode_steps
else:
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)
self.timesteps: int = 0
# Find the normalized time space
if self.normalize_time:
self._time_preprocess_func = lambda time: time / self.max_timesteps
time_space = Box(0.0, 1.0)
else:
self._time_preprocess_func = lambda time: self.max_timesteps - time
time_space = Box(0, self.max_timesteps, dtype=np.int32)
# Find the observation space
if isinstance(env.observation_space, Dict):
assert dict_time_key not in env.observation_space.keys()
observation_space = Dict(
{dict_time_key: time_space}, **env.observation_space.spaces
)
self._append_data_func = lambda obs, time: {**obs, dict_time_key: time}
elif isinstance(env.observation_space, Tuple):
observation_space = Tuple(env.observation_space.spaces + (time_space,))
self._append_data_func = lambda obs, time: obs + (time,)
else:
observation_space = Dict(obs=env.observation_space, time=time_space)
self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
# If to flatten the observation space
if self.flatten:
self.observation_space = spaces.flatten_space(observation_space)
self._obs_postprocess_func = lambda obs: spaces.flatten(
observation_space, obs
)
else:
self.observation_space = observation_space
self._obs_postprocess_func = lambda obs: obs
def observation(self, observation: ObsType) -> WrapperObsType:
"""Adds to the observation with the current time information.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time information appended to
"""
return self._obs_postprocess_func(
self._append_data_func(
observation, self._time_preprocess_func(self.timesteps)
)
)
def step(
self, action: ActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.timesteps += 1
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment setting the time to zero.
Args:
seed: The seed to reset the environment
options: The options used to reset the environment
Returns:
The reset environment
"""
self.timesteps = 0
return super().reset(seed=seed, options=options)

View File

@@ -1,40 +0,0 @@
"""Wrapper which adds a probability of repeating the previous executed action."""
from typing import Union
import gymnasium as gym
from gymnasium.core import ActType
from gymnasium.error import InvalidProbability
class StickyActionV0(gym.ActionWrapper):
"""Wrapper which adds a probability of repeating the previous action."""
def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a proability of repeating the old action.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)
super().__init__(env)
self.repeat_action_probability = repeat_action_probability
self.old_action = None
def action(self, action: ActType):
"""Execute the action."""
if (
self.old_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
action = self.old_action
self.old_action = action
return action
def reset(self, **kwargs):
"""Reset the environment."""
self.old_action = None
return super().reset(**kwargs)

View File

@@ -1,113 +0,0 @@
"""Wrapper for adding time aware observations to environment observation."""
from collections import OrderedDict
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium.core import ActType, ObsType
from gymnasium.spaces import Box, Dict
class TimeAwareObservationV0(gym.ObservationWrapper):
"""Augment the observation with time information of the episode.
Time can be represented as a normalized value between [0,1]
or by the number of timesteps remaining before truncation occurs.
Example:
>>> import gym
>>> from gym.wrappers import TimeAwareObservationV0
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env)
>>> env.observation_space
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
OrderedDict([('obs',
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
... ('time', array([0.002]))])
Flatten observation space example:
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
"""
def __init__(self, env: gym.Env, flatten=False, normalize_time=True):
"""Initialize :class:`TimeAwareObservationV0`.
Args:
env: The environment to apply the wrapper
flatten: Flatten the observation to a `Box` of a single dimension
normalize_time: if `True` return time in the range [0,1]
otherwise return time as remaining timesteps before truncation
"""
super().__init__(env)
self.flatten = flatten
self.normalize_time = normalize_time
self.max_timesteps = getattr(env, "_max_episode_steps")
if self.normalize_time:
self._get_time_observation = lambda: self.timesteps / self.max_timesteps
time_space = Box(0, 1)
else:
self._get_time_observation = lambda: self.max_timesteps - self.timesteps
time_space = Box(0, self.max_timesteps)
self.time_aware_observation_space = Dict(
obs=env.observation_space, time=time_space
)
if self.flatten:
self.observation_space = spaces.flatten_space(
self.time_aware_observation_space
)
self._observation_postprocess = lambda observation: spaces.flatten(
self.time_aware_observation_space, observation
)
else:
self.observation_space = self.time_aware_observation_space
self._observation_postprocess = lambda observation: observation
def observation(self, observation: ObsType):
"""Adds to the observation with the current time information.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time information appended to
"""
time_observation = self._get_time_observation()
observation = OrderedDict(obs=observation, time=time_observation)
return self._observation_postprocess(observation)
def step(self, action: ActType):
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.timesteps += 1
observation, reward, terminated, truncated, info = super().step(action)
return observation, reward, terminated, truncated, info
def reset(self, **kwargs):
"""Reset the environment setting the time to zero.
Args:
**kwargs: Kwargs to apply to env.reset()
Returns:
The reset environment
"""
self.timesteps = 0
return super().reset(**kwargs)

View File

@@ -14,92 +14,121 @@ import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
try:
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
except ImportError:
jnp, jax_dlpack = None, None
try:
import torch
from torch.utils import dlpack as torch_dlpack
Device = Union[str, torch.device]
except ImportError:
raise DependencyNotInstalled("torch is not installed, run `pip install torch`")
Device = Union[str, torch.device]
torch, torch_dlpack, Device = None, None, None
@functools.singledispatch
def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
raise Exception(
f"No conversion for PyTorch to Jax registered for type: {type(value)}"
)
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
)
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
return jnp.array(value)
if torch is not None and jnp is not None:
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a jax array."""
assert jnp is not None
return jnp.array(value)
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
return tensor
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
assert torch_dlpack is not None and jax_dlpack is not None
tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
value
)
tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage]
tensor
)
return tensor
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_jax(v) for v in value)
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
raise Exception(
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}"
)
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@jax_to_torch.register(jnp.DeviceArray)
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
if torch is not None and jnp is not None:
@jax_to_torch.register(jnp.DeviceArray)
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
assert jax_dlpack is not None and torch_dlpack is not None
dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
value
)
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value)
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(Wrapper):
@@ -107,7 +136,8 @@ class JaxToTorchV0(Wrapper):
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: Env, device: Device | None = None):
@@ -117,6 +147,15 @@ class JaxToTorchV0(Wrapper):
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env)
self.device: Device | None = device

View File

@@ -76,6 +76,7 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
* ``None`` The length will be randomly drawn from a geometric distribution
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
* ``int`` for a fixed length sample
The second element of the mask tuple `sample` mask specifies a mask that is applied when
sampling elements from the base space. The mask is applied for each feature space sample.

View File

@@ -29,7 +29,7 @@ dependencies = [
"jax-jumpy >=0.2.0",
"cloudpickle >=1.2.0",
"importlib-metadata >=4.8.0; python_version < '3.10'",
"typing-extensions >=4.3.0; python_version == '3.7'",
"typing-extensions >=4.3.0",
"gymnasium-notices >=0.0.1",
"shimmy >=0.1.0,<1.0",
]

View File

@@ -19,7 +19,7 @@ from gymnasium.wrappers.env_checker import PassiveEnvChecker
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.utils import all_testing_env_specs
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
from tests.testing_env import GenericTestEnv, old_step_fn
from tests.testing_env import GenericTestEnv, old_step_func
from tests.wrappers.utils import has_wrapper
@@ -155,7 +155,7 @@ def test_make_disable_env_checker():
def test_apply_api_compatibility():
gym.register(
"testing-old-env",
lambda: GenericTestEnv(step_fn=old_step_fn),
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True,
max_episode_steps=3,
)

View File

@@ -1,48 +0,0 @@
"""Test suite for LambdaActionV0."""
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import ClipActionV0
SEED = 42
@pytest.mark.parametrize(
("env", "action_unclipped_env", "action_clipped_env"),
(
[
# MountainCar action space: Box(-1.0, 1.0, (1,), float32)
gym.make("MountainCarContinuous-v0"),
np.array([1]),
np.array([1.5]),
],
[
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([1, 1, 1, 1]),
np.array([10, 10, 10, 10]),
],
[
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([0.5, 0.5, 1, 1]),
np.array([0.5, 0.5, 10, 10]),
],
),
)
def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env):
"""Tests if actions out of bound are correctly clipped.
Tests whether out of bound actions for the wrapped
environments are correctly clipped.
"""
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(action_unclipped_env)
env.reset(seed=SEED)
wrapped_env = ClipActionV0(env)
wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env)
assert np.alltrue(obs == wrapped_obs)

View File

@@ -1,37 +0,0 @@
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0
SEED = 42
DELAY = 3
NUM_STEPS = 5
def test_delay_observation():
env = gym.make("CartPole-v1")
env.action_space.seed(SEED)
env.reset(seed=SEED)
undelayed_observations = []
for _ in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_observations.append(obs)
env.action_space.seed(SEED)
env.reset(seed=SEED)
env = DelayObservationV0(env, delay=DELAY)
delayed_observations = []
for i in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
if i < DELAY - 1:
assert np.all(obs == 0)
delayed_observations.append(obs)
assert np.alltrue(
np.array(delayed_observations[DELAY:])
== np.array(undelayed_observations[: DELAY - 1])
)

View File

@@ -1,60 +1,78 @@
"""Test suite for LambdaActionV0."""
"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction."""
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.error import InvalidAction
from gymnasium.experimental.wrappers import LambdaActionV0
from gymnasium.experimental.wrappers import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
)
from gymnasium.spaces import Box
from tests.testing_env import GenericTestEnv
NUM_ENVS = 3
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
SEED = 42
def generic_step_fn(self, action):
def _record_action_step_func(self, action):
return 0, 0, False, False, {"action": action}
@pytest.mark.parametrize(
("env", "func", "action", "expected"),
[
(
GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn),
lambda action: action + 2,
1,
3,
),
],
)
def test_lambda_action_v0(env, func, action, expected):
"""Tests lambda action.
Tests if function is correctly applied to environment's action.
"""
wrapped_env = LambdaActionV0(env, func)
_, _, _, _, info = wrapped_env.step(action)
executed_action = info["action"]
def test_lambda_action_wrapper():
"""Tests LambdaAction through checking that the action taken is transformed by function."""
env = GenericTestEnv(step_func=_record_action_step_func)
wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
assert executed_action == expected
sampled_action = wrapped_env.action_space.sample()
assert sampled_action not in env.action_space
_, _, _, _, info = wrapped_env.step(sampled_action)
assert info["action"] in env.action_space
assert sampled_action - 2 == info["action"]
def test_lambda_action_v0_within_vector():
"""Tests lambda action in vectorized environments.
Tests if function is correctly applied to environment's action
in vectorized environment.
"""
env = gym.vector.make(
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
def test_clip_action_wrapper():
"""Test that the action is correctly clipped to the base environment action space."""
env = GenericTestEnv(
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
step_func=_record_action_step_func,
)
action = np.ones(NUM_ENVS, dtype=np.float64)
wrapped_env = ClipActionV0(env)
wrapped_env = LambdaActionV0(env, lambda action: action.astype(int))
wrapped_env.reset()
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
assert sampled_action not in env.action_space
assert sampled_action in wrapped_env.action_space
wrapped_env.step(action)
_, _, _, _, info = wrapped_env.step(sampled_action)
assert np.all(info["action"] in env.action_space)
assert np.all(info["action"] == np.array([0, 2, 3.5]))
# unwrapped env should raise exception because it does not
# support float actions
with pytest.raises(InvalidAction):
env.step(action)
def test_rescale_action_wrapper():
"""Test that the action is rescale within a min / max bound."""
env = GenericTestEnv(
step_func=_record_action_step_func,
action_space=Box(np.array([0, 1]), np.array([1, 3])),
)
wrapped_env = RescaleActionV0(
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
)
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
for sample_action, expected_action in (
(
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0], dtype=np.float32),
),
(
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0], dtype=np.float32),
),
(
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0], dtype=np.float32),
),
):
assert sample_action in wrapped_env.action_space
_, _, _, _, info = wrapped_env.step(sample_action)
assert np.all(info["action"] == expected_action)

View File

@@ -1,59 +1,250 @@
"""Test suite for LambdaObservationV0."""
"""Test suite for lambda observation wrappers: """
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import LambdaObservationV0
from gymnasium.spaces import Box
from gymnasium.experimental.wrappers import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
)
from gymnasium.spaces import Box, Dict, Tuple
from tests.testing_env import GenericTestEnv
NUM_ENVS = 3
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
SEED = 42
DISCRETE_ACTION = 1
def test_lambda_observation_v0():
"""Tests lambda observation.
def _record_random_obs_reset(self: gym.Env, seed=None, options=None):
obs = self.observation_space.sample()
return obs, {"obs": obs}
Tests if function is correctly applied to environment's observation.
"""
env = gym.make("CartPole-v1")
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(DISCRETE_ACTION)
observation_shift = 1
def _record_random_obs_step(self: gym.Env, action):
obs = self.observation_space.sample()
return obs, 0, False, False, {"obs": obs}
env.reset(seed=SEED)
wrapped_env = LambdaObservationV0(
env, lambda observation: observation + observation_shift, None
def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}):
return options["obs"], {"obs": options["obs"]}
def _record_action_obs_step(self: gym.Env, action):
return action, 0, False, False, {"obs": action}
def _check_obs(
env: gym.Env,
wrapped_env: gym.Wrapper,
transformed_obs,
original_obs,
strict: bool = True,
):
assert (
transformed_obs in wrapped_env.observation_space
), f"{transformed_obs}, {wrapped_env.observation_space}"
assert (
original_obs in env.observation_space
), f"{original_obs}, {env.observation_space}"
if strict:
assert (
transformed_obs not in env.observation_space
), f"{transformed_obs}, {env.observation_space}"
assert (
original_obs not in wrapped_env.observation_space
), f"{original_obs}, {wrapped_env.observation_space}"
def test_lambda_observation_wrapper():
"""Tests lambda observation that the function is applied to both the reset and step observation."""
env = GenericTestEnv(
reset_func=_record_action_obs_reset, step_func=_record_action_obs_step
)
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION)
wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3))
assert np.alltrue(wrapped_obs == obs + observation_shift)
obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)})
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32))
_check_obs(env, wrapped_env, obs, info["obs"])
def test_lambda_observation_v0_within_vector():
"""Tests lambda observation in vectorized environments.
Tests if function is correctly applied to environment's observation
in vectorized environment.
"""
env = gym.vector.make(
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
)
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)]))
observation_shift = 1
env.reset(seed=SEED)
wrapped_env = LambdaObservationV0(
env, lambda observation: observation + observation_shift, None
)
wrapped_obs, _, _, _, _ = wrapped_env.step(
np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)])
def test_filter_observation_wrapper():
"""Tests ``FilterObservation`` that the right keys are filtered."""
dict_env = GenericTestEnv(
observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
assert np.alltrue(wrapped_obs == obs + observation_shift)
wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
obs, info = wrapped_env.reset()
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
_check_obs(dict_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
_check_obs(dict_env, wrapped_env, obs, info["obs"])
# Test tuple environments
tuple_env = GenericTestEnv(
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = FilterObservationV0(tuple_env, (2,))
obs, info = wrapped_env.reset()
assert len(obs) == 1 and len(info["obs"]) == 3
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert len(obs) == 1 and len(info["obs"]) == 3
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
def test_flatten_observation_wrapper():
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
env = GenericTestEnv(
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
print(env.observation_space)
wrapped_env = FlattenObservationV0(env)
print(wrapped_env.observation_space)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_grayscale_observation_wrapper():
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = GrayscaleObservationV0(env)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
# Keep_dim
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25, 1)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_resize_observation_wrapper():
"""Test the ``ResizeObservation`` that the observation has changed size"""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = ResizeObservationV0(env, (25, 25))
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_reshape_observation_wrapper():
"""Test the ``ReshapeObservation`` wrapper."""
env = GenericTestEnv(
observation_space=Box(0, 1, shape=(2, 3, 2)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = ReshapeObservationV0(env, (6, 2))
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)
def test_rescale_observation():
"""Test the ``RescaleObservation`` wrapper"""
env = GenericTestEnv(
observation_space=Box(
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
),
reset_func=_record_action_obs_reset,
step_func=_record_action_obs_step,
)
wrapped_env = RescaleObservationV0(
env,
min_obs=np.array([-5, 0], dtype=np.float32),
max_obs=np.array([5, 1], dtype=np.float32),
)
assert wrapped_env.observation_space == Box(
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
)
for sample_obs, expected_obs in (
(
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5], dtype=np.float32),
),
(
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0], dtype=np.float32),
),
(
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0], dtype=np.float32),
),
):
assert sample_obs in env.observation_space
assert expected_obs in wrapped_env.observation_space
obs, info = wrapped_env.reset(options={"obs": sample_obs})
assert np.all(obs == expected_obs)
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
obs, _, _, _, info = wrapped_env.step(sample_obs)
assert np.all(obs == expected_obs)
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
def test_dtype_observation():
"""Test ``DtypeObservation`` that the"""
env = GenericTestEnv(
reset_func=_record_random_obs_reset, step_func=_record_random_obs_step
)
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
obs, info = wrapped_env.reset()
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8
obs, _, _, _, info = wrapped_env.step(None)
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8

View File

@@ -55,7 +55,7 @@ def jax_step_func(self, action):
def test_jax_to_numpy():
jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
# Check that the reset and step for jax environment are as expected
obs, info = jax_env.reset()

View File

@@ -1,52 +0,0 @@
"""Test suite for RescaleActionV0."""
import jax
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import RescaleActionV0
SEED = 42
@pytest.mark.parametrize(
("env", "low", "high", "action", "scaled_action"),
[
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
-0.5,
0.5,
np.array([1, 1, 1, 1]),
np.array([0.5, 0.5, 0.5, 0.5]),
),
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
-0.5,
0.5,
jax.numpy.array([1, 1, 1, 1]),
jax.numpy.array([0.5, 0.5, 0.5, 0.5]),
),
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([-0.5, -0.5, -1, -1], dtype=np.float32),
np.array([0.5, 0.5, 1, 1], dtype=np.float32),
jax.numpy.array([1, 1, 1, 1]),
jax.numpy.array([0.5, 0.5, 1, 1]),
),
],
)
def test_rescale_actions_v0_box(env, low, high, action, scaled_action):
"""Test action rescaling."""
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(action)
env.reset(seed=SEED)
wrapped_env = RescaleActionV0(env, low, high)
obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action)
assert np.alltrue(obs == obs_scaled)

View File

@@ -17,7 +17,9 @@ def step_fn(self, action):
def test_sticky_action():
env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5)
env = StickyActionV0(
GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5
)
env.reset(seed=SEED)
env.action_space.seed(SEED)
@@ -34,7 +36,7 @@ def test_sticky_action():
previous_action = input_action
@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5])
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability):
with pytest.raises(InvalidProbability):
StickyActionV0(

View File

@@ -0,0 +1,89 @@
"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation."""
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0
from gymnasium.spaces import Box, Dict, Tuple
from tests.testing_env import GenericTestEnv
NUM_STEPS = 20
SEED = 0
DELAY = 3
def test_time_aware_observation_wrapper():
"""Tests the time aware observation wrapper."""
# Test the environment observation space with Dict, Tuple and other
env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3)))
wrapped_env = TimeAwareObservationV0(env)
assert isinstance(wrapped_env.observation_space, Dict)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}"
env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3))))
wrapped_env = TimeAwareObservationV0(env)
assert isinstance(wrapped_env.observation_space, Tuple)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert len(reset_obs) == 3 and len(step_obs) == 3
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env)
assert isinstance(wrapped_env.observation_space, Dict)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert isinstance(reset_obs, dict) and isinstance(step_obs, dict)
assert "obs" in reset_obs and "obs" in step_obs
assert "time" in reset_obs and "time" in step_obs
# Tests the flatten parameter
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env, flatten=True)
assert isinstance(wrapped_env.observation_space, Box)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert reset_obs.shape == (2,) and step_obs.shape == (2,)
# Tests the normalize_time parameter
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env, normalize_time=False)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert reset_obs["time"] == 100 and step_obs["time"] == 99
env = GenericTestEnv(observation_space=Box(0, 1))
wrapped_env = TimeAwareObservationV0(env, normalize_time=True)
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01
def test_delay_observation_wrapper():
env = gym.make("CartPole-v1")
env.action_space.seed(SEED)
env.reset(seed=SEED)
undelayed_observations = []
for _ in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_observations.append(obs)
env = DelayObservationV0(env, delay=DELAY)
env.action_space.seed(SEED)
env.reset(seed=SEED)
delayed_observations = []
for i in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
delayed_observations.append(obs)
if i < DELAY - 1:
assert np.all(obs == 0)
undelayed_observations = np.array(undelayed_observations)
delayed_observations = np.array(delayed_observations)
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])

View File

@@ -1,99 +0,0 @@
"""Test suite for TimeAwareobservationV0."""
from collections import OrderedDict
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import TimeAwareObservationV0
from gymnasium.spaces import Box, Dict
NUM_STEPS = 20
SEED = 0
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation(env):
"""Test TimeAwareObservationV0 wrapper creation.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created.
"""
wrapped_env = TimeAwareObservationV0(env)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Dict)
assert isinstance(obs, OrderedDict)
assert np.all(obs["time"] == 0)
assert env.observation_space == wrapped_env.observation_space["obs"]
@pytest.mark.parametrize("normalize_time", [True, False])
@pytest.mark.parametrize("flatten", [False, True])
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
],
)
def test_time_aware_observation_step(env, flatten, normalize_time):
"""Test TimeAwareObservationV0 step.
This test checks if wrapped env with TimeAwareObservationV0
steps correctly.
"""
env.action_space.seed(SEED)
max_timesteps = env._max_episode_steps
wrapped_env = TimeAwareObservationV0(
env, flatten=flatten, normalize_time=normalize_time
)
wrapped_env.reset(seed=SEED)
for timestep in range(1, NUM_STEPS):
action = env.action_space.sample()
observation, _, terminated, _, _ = wrapped_env.step(action)
expected_time_obs = (
timestep / max_timesteps if normalize_time else max_timesteps - timestep
)
if flatten:
assert np.allclose(observation[-1], expected_time_obs)
else:
assert np.allclose(observation["time"], expected_time_obs)
if terminated:
break
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation_flatten(env):
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created when the `flatten` parameter is set to `True`.
When flattened, the observation space should be a 1 dimension `Box`
with time appended to the end.
"""
wrapped_env = TimeAwareObservationV0(env, flatten=True)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Box)
assert isinstance(obs, np.ndarray)
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]

View File

@@ -9,6 +9,7 @@ from tests.testing_env import GenericTestEnv
def torch_data_equivalence(data_1, data_2) -> bool:
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
if type(data_1) == type(data_2):
if isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all(
@@ -56,14 +57,15 @@ def torch_data_equivalence(data_1, data_2) -> bool:
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value)
roundtripped_value = jax_to_torch(torch_to_jax(value))
assert torch_data_equivalence(roundtripped_value, expected_value)
def jax_reset_func(self, seed=None, options=None):
def _jax_reset_func(self, seed=None, options=None):
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def jax_step_func(self, action):
def _jax_step_func(self, action):
assert isinstance(action, jnp.DeviceArray), type(action)
return (
jnp.array([1, 2, 3]),
@@ -75,7 +77,7 @@ def jax_step_func(self, action):
def test_jax_to_torch():
env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
# Check that the reset and step for jax environment are as expected
obs, info = env.reset()

View File

@@ -278,7 +278,7 @@ def test_wrapper_types():
obs, _, _, _, _ = observation_env.step(0)
assert obs == np.array([1])
env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {}))
env = GenericTestEnv(step_func=lambda self, action: (action, 0, False, False, {}))
action_env = ExampleActionWrapper(env)
obs, _, _, _, _ = action_env.step(0)
assert obs == np.array([1])

View File

@@ -8,7 +8,7 @@ from gymnasium.core import ActType, ObsType
from gymnasium.envs.registration import EnvSpec
def basic_reset_fn(
def basic_reset_func(
self,
*,
seed: Optional[int] = None,
@@ -20,17 +20,17 @@ def basic_reset_fn(
return self.observation_space.sample(), {"options": options}
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, False, {}
def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, {}
def basic_render_fn(self):
def basic_render_func(self):
"""Basic render fn that does nothing."""
pass
@@ -43,12 +43,14 @@ class GenericTestEnv(gym.Env):
self,
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
reset_fn: callable = basic_reset_fn,
step_fn: callable = new_step_fn,
render_fn: callable = basic_render_fn,
reset_func: callable = basic_reset_func,
step_func: callable = new_step_func,
render_func: callable = basic_render_func,
metadata: Dict[str, Any] = {"render_modes": []},
render_mode: Optional[str] = None,
spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"),
spec: EnvSpec = EnvSpec(
"TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100
),
):
self.metadata = metadata
self.render_mode = render_mode
@@ -59,12 +61,12 @@ class GenericTestEnv(gym.Env):
if action_space is not None:
self.action_space = action_space
if reset_fn is not None:
self.reset = types.MethodType(reset_fn, self)
if step_fn is not None:
self.step = types.MethodType(step_fn, self)
if render_fn is not None:
self.render = types.MethodType(render_fn, self)
if reset_func is not None:
self.reset = types.MethodType(reset_func, self)
if step_func is not None:
self.step = types.MethodType(step_func, self)
if render_func is not None:
self.render = types.MethodType(render_func, self)
def reset(
self,

View File

@@ -112,10 +112,10 @@ def test_check_reset_seed(test, func: callable, message: str):
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_reset_seed(GenericTestEnv(reset_fn=func))
check_reset_seed(GenericTestEnv(reset_func=func))
else:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_seed(GenericTestEnv(reset_fn=func))
check_reset_seed(GenericTestEnv(reset_func=func))
def _deprecated_return_info(
@@ -179,7 +179,7 @@ def test_check_reset_return_type(test, func: callable, message: str):
"""Tests the check `env.reset()` function has a correct return type."""
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_return_type(GenericTestEnv(reset_fn=func))
check_reset_return_type(GenericTestEnv(reset_func=func))
@pytest.mark.parametrize(
@@ -198,7 +198,7 @@ def test_check_reset_return_info_deprecation(test, func: callable, message: str)
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func))
check_reset_return_info_deprecation(GenericTestEnv(reset_func=func))
def test_check_seed_deprecation():
@@ -236,7 +236,7 @@ def test_check_reset_options():
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
),
):
check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {})))
check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {})))
@pytest.mark.parametrize(

View File

@@ -303,11 +303,11 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
assert len(caught_warnings) == 0
@@ -383,11 +383,11 @@ def test_passive_env_step_checker(
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
assert len(caught_warnings) == 0, caught_warnings
@@ -416,7 +416,7 @@ def test_passive_env_step_checker(
GenericTestEnv(
metadata={"render_modes": ["Testing mode"], "render_fps": None},
render_mode="Testing mode",
render_fn=lambda self: 0,
render_func=lambda self: 0,
),
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
],

View File

@@ -21,7 +21,7 @@ IRRELEVANT_KEY = 1
PlayableEnv = partial(
GenericTestEnv,
metadata={"render_modes": ["rgb_array"]},
render_fn=lambda self: np.ones((10, 10, 3)),
render_func=lambda self: np.ones((10, 10, 3)),
)

View File

@@ -82,8 +82,8 @@ def test_final_obs_info(vectoriser):
return GenericTestEnv(
action_space=Discrete(4),
observation_space=Discrete(4),
reset_fn=reset_fn,
step_fn=lambda self, action: (
reset_func=reset_fn,
step_func=lambda self, action: (
action if action < 3 else 0,
0,
action >= 3,

View File

@@ -3,7 +3,7 @@ import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility
from tests.testing_env import GenericTestEnv, old_step_fn
from tests.testing_env import GenericTestEnv, old_step_func
class AleTesting:
@@ -34,7 +34,7 @@ class AtariTestingEnv(GenericTestEnv):
low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1
),
action_space=Discrete(3, seed=1),
step_fn=old_step_fn,
step_func=old_step_func,
)
self.ale = AleTesting()

View File

@@ -68,8 +68,8 @@ def _step_failure(self, action):
def test_api_failures():
env = GenericTestEnv(
reset_fn=_reset_failure,
step_fn=_step_failure,
reset_func=_reset_failure,
step_func=_step_failure,
metadata={"render_modes": "error"},
)
env = PassiveEnvChecker(env)