mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-30 21:34:30 +00:00
Update experimental wrappers (#176)
This commit is contained in:
@@ -22,9 +22,9 @@ repos:
|
||||
hooks:
|
||||
- id: codespell
|
||||
args:
|
||||
- --ignore-words-list=nd,reacher,thist,ths, ure, referenc,wile
|
||||
- --ignore-words-list=nd,reacher,thist,ths,ure,referenc,wile
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
rev: 5.0.4
|
||||
hooks:
|
||||
- id: flake8
|
||||
args:
|
||||
|
@@ -27,8 +27,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
||||
|
||||
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
|
||||
|
||||
### Lambda Observation Wrappers
|
||||
|
||||
### Observation Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
@@ -44,61 +43,60 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
||||
- VectorLambdaObservation
|
||||
- No
|
||||
* - :class:`wrappers.FilterObservation`
|
||||
- :class:`experimental.wrappers.FilterObservation`
|
||||
- :class:`experimental.wrappers.FilterObservationV0`
|
||||
- VectorFilterObservation (*)
|
||||
- Yes
|
||||
* - :class:`wrappers.FlattenObservation`
|
||||
- `:class:`experimental.wrappers.FlattenObservation`
|
||||
- :class:`experimental.wrappers.FlattenObservationV0`
|
||||
- VectorFlattenObservation (*)
|
||||
- No
|
||||
* - :class:`wrappers.GrayScaleObservation`
|
||||
- `:class:`experimental.wrappers.GrayscaleObservation`
|
||||
- :class:`experimental.wrappers.GrayscaleObservationV0`
|
||||
- VectorGrayscaleObservation (*)
|
||||
- Yes
|
||||
* - :class:`wrappers.ResizeObservation`
|
||||
- :class:`experimental.wrappers.ResizeObservation`
|
||||
- :class:`experimental.wrappers.ResizeObservationV0`
|
||||
- VectorResizeObservation (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.ReshapeObservation`
|
||||
- :class:`experimental.wrappers.ReshapeObservationV0`
|
||||
- VectorReshapeObservation (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.RescaleObservation`
|
||||
- :class:`experimental.wrappers.RescaleObservationV0`
|
||||
- VectorRescaleObservation (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.DtypeObservation`
|
||||
- :class:`experimental.wrappers.DtypeObservationV0`
|
||||
- VectorDtypeObservation (*)
|
||||
- Yes
|
||||
* - :class:`wrappers.PixelObservationWrapper`
|
||||
- PixelObservation
|
||||
- VectorPixelObservation
|
||||
- No
|
||||
* - :class:`NormalizeObservation`
|
||||
* - :class:`wrappers.NormalizeObservation`
|
||||
- NormalizeObservation
|
||||
- VectorNormalizeObservation
|
||||
- No
|
||||
* - :class:`TimeAwareObservation`
|
||||
- TimeAwareObservation
|
||||
* - :class:`wrappers.TimeAwareObservation`
|
||||
- :class:`experimental.wrappers.TimeAwareObservationV0`
|
||||
- VectorTimeAwareObservation
|
||||
- No
|
||||
* - :class:`FrameStack`
|
||||
* - :class:`wrappers.FrameStack`
|
||||
- FrameStackObservation
|
||||
- VectorFrameStackObservation
|
||||
- No
|
||||
* - Not Implemented
|
||||
- DelayObservation
|
||||
- :class:`experimental.wrappers.DelayObservationV0`
|
||||
- VectorDelayObservation
|
||||
- No
|
||||
* - :class:`AtariPreprocessing`
|
||||
* - :class:`wrappers.AtariPreprocessing`
|
||||
- AtariPreprocessing
|
||||
- Not Implemented
|
||||
- No
|
||||
```
|
||||
|
||||
### Lambda Action Wrappers
|
||||
|
||||
### Action Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
@@ -114,25 +112,20 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
||||
- VectorLambdaAction
|
||||
- No
|
||||
* - :class:`wrappers.ClipAction`
|
||||
- ClipAction
|
||||
- :class:`experimental.wrappers.ClipActionV0`
|
||||
- VectorClipAction (*)
|
||||
- Yes
|
||||
* - :class:`wrappers.RescaleAction`
|
||||
- RescaleAction
|
||||
- :class:`experimental.wrappers.RescaleActionV0`
|
||||
- VectorRescaleAction (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- NanAction
|
||||
- VectorNanAction (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- StickyAction
|
||||
- :class:`experimental.wrappers.StickyActionV0`
|
||||
- VectorStickyAction
|
||||
- No
|
||||
```
|
||||
|
||||
### Lambda Reward Wrappers
|
||||
|
||||
### Reward Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
@@ -175,7 +168,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
||||
- VectorPassiveEnvChecker
|
||||
* - :class:`wrappers.OrderEnforcing`
|
||||
- OrderEnforcing
|
||||
- VectorOrderEnforcing (*)
|
||||
- VectorOrderEnforcing
|
||||
* - :class:`wrappers.EnvCompatibility`
|
||||
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
||||
- Not Implemented
|
||||
@@ -189,10 +182,10 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
||||
- HumanRendering
|
||||
- Not Implemented
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.JaxToNumpy`
|
||||
- :class:`experimental.wrappers.JaxToNumpyV0`
|
||||
- VectorJaxToNumpy (*)
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.JaxToTorch`
|
||||
- :class:`experimental.wrappers.JaxToTorchV0`
|
||||
- VectorJaxToTorch (*)
|
||||
```
|
||||
|
||||
|
@@ -12,9 +12,6 @@ title: Functional
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.initial
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.transition
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.observation
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.initial
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.observation
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.reward
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.terminal
|
||||
@@ -33,4 +30,8 @@ title: Functional
|
||||
|
||||
```{eval-rst}
|
||||
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
|
||||
|
||||
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.reset
|
||||
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.step
|
||||
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.render
|
||||
```
|
||||
|
@@ -1,6 +1,6 @@
|
||||
# Wrappers
|
||||
|
||||
## Lambda Observation Wrappers
|
||||
## Observation Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
|
||||
@@ -11,24 +11,6 @@
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
|
||||
```
|
||||
|
||||
## Lambda Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
|
||||
```
|
||||
|
||||
## Lambda Reward Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||
```
|
||||
|
||||
## Observation Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
||||
```
|
||||
@@ -36,11 +18,22 @@
|
||||
## Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
||||
```
|
||||
|
||||
# Reward Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||
```
|
||||
|
||||
## Common Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
|
||||
```
|
||||
|
@@ -27,7 +27,7 @@ __all__ = [
|
||||
"register",
|
||||
"registry",
|
||||
"pprint_registry",
|
||||
# root files
|
||||
# module folders
|
||||
"envs",
|
||||
"spaces",
|
||||
"utils",
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Root __init__ of the gym experimental wrappers."""
|
||||
|
||||
|
||||
from gymnasium.experimental import functional, wrappers
|
||||
from gymnasium.experimental.functional import FuncEnv
|
||||
|
||||
|
||||
@@ -8,6 +9,8 @@ __all__ = [
|
||||
# Functional
|
||||
"FuncEnv",
|
||||
"functional",
|
||||
# Wrapper
|
||||
# Wrappers
|
||||
"wrappers",
|
||||
# Vector
|
||||
# "vector",
|
||||
]
|
||||
|
@@ -10,31 +10,59 @@ from gymnasium.experimental.wrappers.lambda_action import (
|
||||
ClipActionV0,
|
||||
RescaleActionV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
||||
from gymnasium.experimental.wrappers.lambda_observations import (
|
||||
LambdaObservationV0,
|
||||
FilterObservationV0,
|
||||
FlattenObservationV0,
|
||||
GrayscaleObservationV0,
|
||||
ResizeObservationV0,
|
||||
ReshapeObservationV0,
|
||||
RescaleObservationV0,
|
||||
DtypeObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
|
||||
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
|
||||
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
|
||||
from gymnasium.experimental.wrappers.time_aware_observation import (
|
||||
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
||||
from gymnasium.experimental.wrappers.stateful_observation import (
|
||||
TimeAwareObservationV0,
|
||||
DelayObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0
|
||||
|
||||
__all__ = [
|
||||
"ArgType",
|
||||
# Lambda Action
|
||||
# --- Observation wrappers ---
|
||||
"LambdaObservationV0",
|
||||
"FilterObservationV0",
|
||||
"FlattenObservationV0",
|
||||
"GrayscaleObservationV0",
|
||||
"ResizeObservationV0",
|
||||
"ReshapeObservationV0",
|
||||
"RescaleObservationV0",
|
||||
"DtypeObservationV0",
|
||||
# "PixelObservationV0",
|
||||
# "NormalizeObservationV0",
|
||||
"TimeAwareObservationV0",
|
||||
# "FrameStackV0",
|
||||
"DelayObservationV0",
|
||||
# "AtariPreprocessingV0"
|
||||
# --- Action Wrappers ---
|
||||
"LambdaActionV0",
|
||||
"StickyActionV0",
|
||||
"ClipActionV0",
|
||||
"RescaleActionV0",
|
||||
# Lambda Observation
|
||||
"LambdaObservationV0",
|
||||
"DelayObservationV0",
|
||||
"TimeAwareObservationV0",
|
||||
# Lambda Reward
|
||||
# "NanAction",
|
||||
"StickyActionV0",
|
||||
# --- Reward wrappers ---
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
# Jax conversion wrappers
|
||||
# "RescaleRewardV0",
|
||||
# "NormalizeRewardV0",
|
||||
# --- Common ---
|
||||
# "AutoReset",
|
||||
# "PassiveEnvChecker",
|
||||
# "OrderEnforcing",
|
||||
# "RecordEpisodeStatistics",
|
||||
# "RenderCollection",
|
||||
# "HumanRendering",
|
||||
"JaxToNumpyV0",
|
||||
"JaxToTorchV0",
|
||||
]
|
||||
|
@@ -1,35 +0,0 @@
|
||||
"""Wrapper for delaying the returned observation."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import jumpy as jp
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
|
||||
class DelayObservationV0(gym.ObservationWrapper):
|
||||
"""Wrapper which adds a delay to the returned observation."""
|
||||
|
||||
def __init__(self, env: gym.Env, delay: int):
|
||||
"""Initialize the DelayObservation wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): the wrapped environment
|
||||
delay (int): number of timesteps for delaying the observation.
|
||||
Before reaching the `delay` number of timesteps,
|
||||
returned observation is an array of zeros with the
|
||||
same shape of the observation space.
|
||||
"""
|
||||
super().__init__(env)
|
||||
self.delay = delay
|
||||
self.observation_queue = deque()
|
||||
|
||||
def observation(self, observation: ObsType) -> ObsType:
|
||||
"""Return the delayed observation."""
|
||||
self.observation_queue.append(observation)
|
||||
|
||||
if len(self.observation_queue) > self.delay:
|
||||
return self.observation_queue.popleft()
|
||||
|
||||
return jp.zeros_like(observation)
|
@@ -1,13 +1,19 @@
|
||||
"""Lambda action wrapper which apply a function to the provided action."""
|
||||
from typing import Any, Callable, Union
|
||||
"""A collection of wrappers that all use the LambdaAction class.
|
||||
|
||||
* ``LambdaAction`` - Transforms the actions based on a function
|
||||
* ``ClipAction`` - Clips the action within a bounds
|
||||
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import jumpy as jp
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType
|
||||
from gymnasium.experimental.wrappers import ArgType
|
||||
from gymnasium.core import ActType, WrapperActType
|
||||
from gymnasium.spaces import Box, Space
|
||||
|
||||
|
||||
class LambdaActionV0(gym.ActionWrapper):
|
||||
@@ -16,19 +22,23 @@ class LambdaActionV0(gym.ActionWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
func: Callable[[ArgType], Any],
|
||||
func: Callable[[WrapperActType], ActType],
|
||||
action_space: Space | None,
|
||||
):
|
||||
"""Initialize LambdaAction.
|
||||
|
||||
Args:
|
||||
env (Env): The gymnasium environment
|
||||
func (Callable): function to apply to action
|
||||
env: The gymnasium environment
|
||||
func: Function to apply to ``step`` ``action``
|
||||
action_space: The updated action space of the wrapper given the function.
|
||||
"""
|
||||
super().__init__(env)
|
||||
if action_space is not None:
|
||||
self.action_space = action_space
|
||||
|
||||
self.func = func
|
||||
|
||||
def action(self, action: ActType) -> Any:
|
||||
def action(self, action: WrapperActType) -> ActType:
|
||||
"""Apply function to action."""
|
||||
return self.func(action)
|
||||
|
||||
@@ -53,14 +63,19 @@ class ClipActionV0(LambdaActionV0):
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
assert isinstance(env.action_space, spaces.Box)
|
||||
assert isinstance(env.action_space, Box)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
|
||||
Box(
|
||||
-np.inf,
|
||||
np.inf,
|
||||
shape=env.action_space.shape,
|
||||
dtype=env.action_space.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape)
|
||||
|
||||
|
||||
class RescaleActionV0(LambdaActionV0):
|
||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||
@@ -86,8 +101,8 @@ class RescaleActionV0(LambdaActionV0):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
min_action: Union[float, int, np.ndarray],
|
||||
max_action: Union[float, int, np.ndarray],
|
||||
min_action: float | int | np.ndarray,
|
||||
max_action: float | int | np.ndarray,
|
||||
):
|
||||
"""Initializes the :class:`RescaleAction` wrapper.
|
||||
|
||||
@@ -96,28 +111,44 @@ class RescaleActionV0(LambdaActionV0):
|
||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
assert isinstance(
|
||||
env.action_space, spaces.Box
|
||||
), f"expected Box action space, got {type(env.action_space)}"
|
||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
||||
|
||||
low = env.action_space.low
|
||||
high = env.action_space.high
|
||||
|
||||
self.min_action = np.full(
|
||||
env.action_space.shape, min_action, dtype=env.action_space.dtype
|
||||
assert isinstance(env.action_space, Box)
|
||||
assert not np.any(env.action_space.low == np.inf) and not np.any(
|
||||
env.action_space.high == np.inf
|
||||
)
|
||||
self.max_action = np.full(
|
||||
env.action_space.shape, max_action, dtype=env.action_space.dtype
|
||||
|
||||
if not isinstance(min_action, np.ndarray):
|
||||
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
|
||||
type(max_action), np.floating
|
||||
)
|
||||
min_action = np.full(env.action_space.shape, min_action)
|
||||
|
||||
assert min_action.shape == env.action_space.shape
|
||||
assert not np.any(min_action == np.inf)
|
||||
|
||||
if not isinstance(max_action, np.ndarray):
|
||||
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
|
||||
type(max_action), np.floating
|
||||
)
|
||||
max_action = np.full(env.action_space.shape, max_action)
|
||||
assert max_action.shape == env.action_space.shape
|
||||
assert not np.any(max_action == np.inf)
|
||||
|
||||
assert isinstance(env.action_space, Box)
|
||||
assert np.all(np.less_equal(min_action, max_action))
|
||||
|
||||
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||
gradient = (env.action_space.high - env.action_space.low) / (
|
||||
max_action - min_action
|
||||
)
|
||||
intercept = gradient * -min_action + env.action_space.low
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda action: jp.clip(
|
||||
low
|
||||
+ (high - low)
|
||||
* ((action - self.min_action) / (self.max_action - self.min_action)),
|
||||
low,
|
||||
high,
|
||||
lambda action: gradient * action + intercept,
|
||||
Box(
|
||||
low=min_action,
|
||||
high=max_action,
|
||||
shape=env.action_space.shape,
|
||||
dtype=env.action_space.dtype,
|
||||
),
|
||||
)
|
||||
|
@@ -1,17 +1,27 @@
|
||||
"""Lambda observation wrappers which apply a function to the observation."""
|
||||
"""A collection of observation wrappers using a lambda function.
|
||||
|
||||
* ``LambdaObservation`` - Transforms the observation with a function
|
||||
* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
|
||||
* ``FlattenObservation`` - Flattens the observations
|
||||
* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation
|
||||
* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation)
|
||||
* ``ReshapeObservation`` - Reshapes an array-based observation
|
||||
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
|
||||
* ``DtypeObservation`` - Convert a observation dtype
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Sequence
|
||||
from typing_extensions import Final
|
||||
|
||||
import jumpy as jp
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.spaces import utils
|
||||
from gymnasium.spaces import Box, utils
|
||||
|
||||
|
||||
class LambdaObservationV0(gym.ObservationWrapper):
|
||||
@@ -71,32 +81,82 @@ class FilterObservationV0(LambdaObservationV0):
|
||||
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {})
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, filter_keys: Sequence[str]):
|
||||
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
|
||||
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
||||
if not isinstance(env.observation_space, spaces.Dict):
|
||||
assert isinstance(filter_keys, Sequence)
|
||||
|
||||
# Filters for dictionary space
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
assert all(isinstance(key, str) for key in filter_keys)
|
||||
|
||||
if any(
|
||||
key not in env.observation_space.spaces.keys() for key in filter_keys
|
||||
):
|
||||
missing_keys = [
|
||||
key
|
||||
for key in filter_keys
|
||||
if key not in env.observation_space.spaces.keys()
|
||||
]
|
||||
raise ValueError(
|
||||
"All the `filter_keys` must be included in the observation space.\n"
|
||||
f"Filter keys: {filter_keys}\n"
|
||||
f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
|
||||
f"Missing keys: {missing_keys}"
|
||||
)
|
||||
|
||||
new_observation_space = spaces.Dict(
|
||||
{key: env.observation_space[key] for key in filter_keys}
|
||||
)
|
||||
if len(new_observation_space) == 0:
|
||||
raise ValueError(
|
||||
"The observation space is empty due to filtering all keys."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: {key: obs[key] for key in filter_keys},
|
||||
new_observation_space,
|
||||
)
|
||||
# Filter for tuple observation
|
||||
elif isinstance(env.observation_space, spaces.Tuple):
|
||||
assert all(isinstance(key, int) for key in filter_keys)
|
||||
assert len(set(filter_keys)) == len(
|
||||
filter_keys
|
||||
), f"Duplicate keys exist, filter_keys: {filter_keys}"
|
||||
|
||||
if any(
|
||||
0 < key and key >= len(env.observation_space) for key in filter_keys
|
||||
):
|
||||
missing_index = [
|
||||
key
|
||||
for key in filter_keys
|
||||
if 0 < key and key >= len(env.observation_space)
|
||||
]
|
||||
raise ValueError(
|
||||
"All the `filter_keys` must be included in the length of the observation space.\n"
|
||||
f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
|
||||
f"missing indexes: {missing_index}"
|
||||
)
|
||||
|
||||
new_observation_spaces = spaces.Tuple(
|
||||
env.observation_space[key] for key in filter_keys
|
||||
)
|
||||
if len(new_observation_spaces) == 0:
|
||||
raise ValueError(
|
||||
"The observation space is empty due to filtering all keys."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: tuple(obs[key] for key in filter_keys),
|
||||
new_observation_spaces,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"FilterObservation wrapper is only usable with dict observations, actual type: {type(env.observation_space)}"
|
||||
f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}"
|
||||
)
|
||||
|
||||
if any(key not in env.observation_space.keys() for key in filter_keys):
|
||||
missing_keys = [
|
||||
key for key in filter_keys if key not in env.observation_space.keys()
|
||||
]
|
||||
raise ValueError(
|
||||
"All the filter_keys must be included in the original observation space.\n"
|
||||
f"Filter keys: {filter_keys}\n"
|
||||
f"Observation keys: {list(env.observation_space.keys())}\n"
|
||||
f"Missing keys: {missing_keys}"
|
||||
)
|
||||
|
||||
new_observation_space = spaces.Dict(
|
||||
{key: env.observation_space[key] for key in filter_keys}
|
||||
)
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: {key: obs[key] for key in filter_keys},
|
||||
new_observation_space,
|
||||
)
|
||||
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||
|
||||
|
||||
class FlattenObservationV0(LambdaObservationV0):
|
||||
@@ -117,9 +177,10 @@ class FlattenObservationV0(LambdaObservationV0):
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
||||
flattened_space = utils.flatten_space(env.observation_space)
|
||||
super().__init__(
|
||||
env, lambda obs: utils.flatten(flattened_space, obs), flattened_space
|
||||
env,
|
||||
lambda obs: utils.flatten(env.observation_space, obs),
|
||||
utils.flatten_space(env.observation_space),
|
||||
)
|
||||
|
||||
|
||||
@@ -154,7 +215,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
||||
and env.observation_space.dtype == np.uint8
|
||||
)
|
||||
|
||||
self.keep_dim = keep_dim
|
||||
self.keep_dim: Final[bool] = keep_dim
|
||||
if keep_dim:
|
||||
new_observation_space = spaces.Box(
|
||||
low=0,
|
||||
@@ -167,7 +228,8 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
||||
lambda obs: jp.expand_dims(
|
||||
jp.sum(
|
||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||
)
|
||||
).astype(np.uint8),
|
||||
axis=-1,
|
||||
),
|
||||
new_observation_space,
|
||||
)
|
||||
@@ -179,7 +241,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
||||
env,
|
||||
lambda obs: jp.sum(
|
||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||
),
|
||||
).astype(np.uint8),
|
||||
new_observation_space,
|
||||
)
|
||||
|
||||
@@ -215,7 +277,7 @@ class ResizeObservationV0(LambdaObservationV0):
|
||||
"opencv is not install, run `pip install gymnasium[other]`"
|
||||
)
|
||||
|
||||
self.shape = tuple(shape)
|
||||
self.shape: Final[tuple[int, ...]] = tuple(shape)
|
||||
|
||||
new_observation_space = spaces.Box(
|
||||
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
||||
@@ -237,7 +299,7 @@ class ReshapeObservationV0(LambdaObservationV0):
|
||||
|
||||
assert isinstance(shape, tuple)
|
||||
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
|
||||
assert all(x > 0 for x in shape)
|
||||
assert all(x > 0 or x == -1 for x in shape)
|
||||
|
||||
new_observation_space = spaces.Box(
|
||||
low=np.reshape(np.ravel(env.observation_space.low), shape),
|
||||
@@ -245,9 +307,8 @@ class ReshapeObservationV0(LambdaObservationV0):
|
||||
shape=shape,
|
||||
dtype=env.observation_space.dtype,
|
||||
)
|
||||
super().__init__(
|
||||
env, lambda obs: jp.reshape(obs, self.shape), new_observation_space
|
||||
)
|
||||
self.shape = shape
|
||||
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
|
||||
|
||||
|
||||
class RescaleObservationV0(LambdaObservationV0):
|
||||
@@ -256,18 +317,23 @@ class RescaleObservationV0(LambdaObservationV0):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
min_obs: tuple[np.floating, np.integer, np.ndarray],
|
||||
max_obs: tuple[np.floating, np.integer, np.ndarray],
|
||||
min_obs: np.floating | np.integer | np.ndarray,
|
||||
max_obs: np.floating | np.integer | np.ndarray,
|
||||
):
|
||||
"""Constructor that requires the env observation spaces to be a :class:`Box`."""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert not np.any(env.observation_space.low == np.inf) and not np.any(
|
||||
env.observation_space.high == np.inf
|
||||
)
|
||||
|
||||
if not isinstance(min_obs, np.ndarray):
|
||||
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
|
||||
type(max_obs), np.floating
|
||||
)
|
||||
min_obs = np.full(env.observation_space.shape, min_obs)
|
||||
assert min_obs.shape == env.observation_space.shape
|
||||
assert (
|
||||
min_obs.shape == env.observation_space.shape
|
||||
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
|
||||
assert not np.any(min_obs == np.inf)
|
||||
|
||||
if not isinstance(max_obs, np.ndarray):
|
||||
@@ -278,52 +344,66 @@ class RescaleObservationV0(LambdaObservationV0):
|
||||
assert max_obs.shape == env.observation_space.shape
|
||||
assert not np.any(max_obs == np.inf)
|
||||
|
||||
env_low = env.observation_space.low
|
||||
env_high = env.observation_space.high
|
||||
self.min_obs = min_obs
|
||||
self.max_obs = max_obs
|
||||
|
||||
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||
gradient = (max_obs - min_obs) / (
|
||||
env.observation_space.high - env.observation_space.low
|
||||
)
|
||||
intercept = gradient * -env.observation_space.low + min_obs
|
||||
|
||||
new_observation_space = spaces.Box(low=min_obs, high=max_obs)
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: env_low
|
||||
+ (env_high - env_low) * ((obs - min_obs) / (max_obs - min_obs)),
|
||||
new_observation_space,
|
||||
lambda obs: gradient * obs + intercept,
|
||||
Box(
|
||||
low=min_obs,
|
||||
high=max_obs,
|
||||
shape=env.observation_space.shape,
|
||||
dtype=env.observation_space.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DtypeObservationV0(LambdaObservationV0):
|
||||
"""Observation wrapper for transforming the dtype of an observation."""
|
||||
|
||||
def __init__(self, env: gym.Env, dtype: npt.DTypeLike):
|
||||
def __init__(self, env: gym.Env, dtype: Any):
|
||||
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
|
||||
assert isinstance(
|
||||
env.observation_space,
|
||||
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
||||
)
|
||||
|
||||
dtype = np.dtype(dtype)
|
||||
self.dtype = dtype
|
||||
if isinstance(env.observation_space, spaces.Box):
|
||||
new_observation_space = spaces.Box(
|
||||
low=env.observation_space.low,
|
||||
high=env.observation_space.high,
|
||||
shape=env.observation_space.shape,
|
||||
dtype=dtype.__name__,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.Discrete):
|
||||
new_observation_space = spaces.Box(
|
||||
low=env.observation_space.start,
|
||||
high=env.observation_space.start + env.observation_space.n,
|
||||
shape=(),
|
||||
dtype=dtype.__name__,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.MultiDiscrete):
|
||||
new_observation_space = spaces.MultiDiscrete(
|
||||
env.observation_space.nvec, dtype=dtype.__name__
|
||||
env.observation_space.nvec, dtype=dtype
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.MultiBinary):
|
||||
new_observation_space = spaces.Box(
|
||||
low=0, high=1, shape=env.observation_space.shape, dtype=dtype.__name__
|
||||
low=0,
|
||||
high=1,
|
||||
shape=env.observation_space.shape,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
raise TypeError
|
||||
raise TypeError(
|
||||
"DtypeObservation is only compatible with value / array-based observations."
|
||||
)
|
||||
|
||||
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
||||
|
@@ -1,12 +1,17 @@
|
||||
"""Lambda reward wrappers which apply a function to the reward."""
|
||||
"""A collection of wrappers for modifying the reward.
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
* ``LambdaReward`` - Transforms the reward by a function
|
||||
* ``ClipReward`` - Clips the reward between a minimum and maximum value
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.error import InvalidBound
|
||||
from gymnasium.experimental.wrappers import ArgType
|
||||
|
||||
|
||||
class LambdaRewardV0(gym.RewardWrapper):
|
||||
@@ -26,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
func: Callable[[ArgType], Any],
|
||||
func: Callable[[SupportsFloat], SupportsFloat],
|
||||
):
|
||||
"""Initialize LambdaRewardV0 wrapper.
|
||||
|
||||
@@ -38,7 +43,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
||||
|
||||
self.func = func
|
||||
|
||||
def reward(self, reward: Union[float, int, np.ndarray]) -> Any:
|
||||
def reward(self, reward: SupportsFloat) -> SupportsFloat:
|
||||
"""Apply function to reward.
|
||||
|
||||
Args:
|
||||
@@ -64,8 +69,8 @@ class ClipRewardV0(LambdaRewardV0):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
min_reward: Optional[Union[float, np.ndarray]] = None,
|
||||
max_reward: Optional[Union[float, np.ndarray]] = None,
|
||||
min_reward: float | np.ndarray | None = None,
|
||||
max_reward: float | np.ndarray | None = None,
|
||||
):
|
||||
"""Initialize ClipRewardsV0 wrapper.
|
||||
|
||||
|
@@ -6,70 +6,90 @@ import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
except ImportError:
|
||||
# We handle the error internal to the relative functions
|
||||
jnp = None
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_jax(value: Any) -> Any:
|
||||
"""Converts a value to a Jax DeviceArray."""
|
||||
raise Exception(
|
||||
f"No conversion for Numpy to Jax registered for type: {type(value)}"
|
||||
)
|
||||
if jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
|
||||
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||
return jnp.array(value)
|
||||
if jnp is not None:
|
||||
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _number_ndarray_numpy_to_jax(
|
||||
value: np.ndarray | numbers.Number,
|
||||
) -> jnp.DeviceArray:
|
||||
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Iterable)
|
||||
def _iterable_numpy_to_jax(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(numpy_to_jax(v) for v in value)
|
||||
@numpy_to_jax.register(abc.Iterable)
|
||||
def _iterable_numpy_to_jax(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(numpy_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_numpy(value: Any) -> Any:
|
||||
"""Converts a value to a numpy array."""
|
||||
raise Exception(
|
||||
f"No conversion for Jax to Numpy registered for type: {type(value)}"
|
||||
)
|
||||
if jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@jax_to_numpy.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||
"""Converts a Jax DeviceArray to a numpy array."""
|
||||
return np.array(value)
|
||||
if jnp is not None:
|
||||
|
||||
@jax_to_numpy.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||
"""Converts a Jax DeviceArray to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
@jax_to_numpy.register(abc.Mapping)
|
||||
def _mapping_jax_to_numpy(
|
||||
value: Mapping[str, jnp.DeviceArray | Any]
|
||||
) -> Mapping[str, np.ndarray | Any]:
|
||||
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||
@jax_to_numpy.register(abc.Mapping)
|
||||
def _mapping_jax_to_numpy(
|
||||
value: Mapping[str, jnp.DeviceArray | Any]
|
||||
) -> Mapping[str, np.ndarray | Any]:
|
||||
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_numpy.register(abc.Iterable)
|
||||
def _iterable_jax_to_numpy(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
@jax_to_numpy.register(abc.Iterable)
|
||||
def _iterable_jax_to_numpy(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
class JaxToNumpyV0(Wrapper):
|
||||
@@ -88,6 +108,10 @@ class JaxToNumpyV0(Wrapper):
|
||||
Args:
|
||||
env: the environment to wrap
|
||||
"""
|
||||
if jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
super().__init__(env)
|
||||
|
||||
def step(
|
||||
|
56
gymnasium/experimental/wrappers/stateful_action.py
Normal file
56
gymnasium/experimental/wrappers/stateful_action.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""A collection of stateful action wrappers.
|
||||
|
||||
* StickyAction - There is a probability that the action is taken again
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, SupportsFloat
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import WrapperActType, WrapperObsType
|
||||
from gymnasium.error import InvalidProbability
|
||||
|
||||
|
||||
class StickyActionV0(gym.Wrapper):
|
||||
"""Wrapper which adds a probability of repeating the previous action.
|
||||
|
||||
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||
in Section 5.2 on page 12.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, repeat_action_probability: float):
|
||||
"""Initialize StickyAction wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): the wrapped environment
|
||||
repeat_action_probability (int | float): a probability of repeating the old action.
|
||||
"""
|
||||
if not 0 <= repeat_action_probability < 1:
|
||||
raise InvalidProbability(
|
||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||
)
|
||||
|
||||
super().__init__(env)
|
||||
self.repeat_action_probability = repeat_action_probability
|
||||
self.last_action: WrapperActType | None = None
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.last_action = None
|
||||
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Execute the action."""
|
||||
if (
|
||||
self.last_action is not None
|
||||
and self.np_random.uniform() < self.repeat_action_probability
|
||||
):
|
||||
action = self.last_action
|
||||
|
||||
self.last_action = action
|
||||
return action
|
200
gymnasium/experimental/wrappers/stateful_observation.py
Normal file
200
gymnasium/experimental/wrappers/stateful_observation.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""A collection of stateful observation wrappers.
|
||||
|
||||
* DelayObservation - A wrapper for delaying the returned observation
|
||||
* TimeAwareObservation - A wrapper for adding time aware observations to environment observation
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any, SupportsFloat
|
||||
from typing_extensions import Final
|
||||
|
||||
import jumpy as jp
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
import gymnasium.spaces as spaces
|
||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
|
||||
|
||||
|
||||
class DelayObservationV0(gym.ObservationWrapper):
|
||||
"""Wrapper which adds a delay to the returned observation."""
|
||||
|
||||
def __init__(self, env: gym.Env, delay: int):
|
||||
"""Initialize the DelayObservation wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): the wrapped environment
|
||||
delay (int): number of timesteps for delaying the observation.
|
||||
Before reaching the `delay` number of timesteps,
|
||||
returned observation is an array of zeros with the
|
||||
same shape of the observation space.
|
||||
"""
|
||||
assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete))
|
||||
assert 0 < delay
|
||||
|
||||
self.delay: Final[int] = delay
|
||||
self.observation_queue: Final[deque] = deque()
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment, clearing the observation queue."""
|
||||
self.observation_queue.clear()
|
||||
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def observation(self, observation: ObsType) -> ObsType:
|
||||
"""Return the delayed observation."""
|
||||
self.observation_queue.append(observation)
|
||||
|
||||
if len(self.observation_queue) > self.delay:
|
||||
return self.observation_queue.popleft()
|
||||
|
||||
return jp.zeros_like(observation)
|
||||
|
||||
|
||||
class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
"""Augment the observation with time information of the episode.
|
||||
|
||||
Time can be represented as a normalized value between [0,1]
|
||||
or by the number of timesteps remaining before truncation occurs.
|
||||
|
||||
For environments with ``Dict`` or ``Tuple`` observation spaces, by default,
|
||||
the time information is automatically added in the key `"time"` and
|
||||
as the final element in the tuple.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservationV0(env)
|
||||
>>> env.observation_space
|
||||
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
|
||||
>>> _ = env.reset()
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
OrderedDict([('obs',
|
||||
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
|
||||
... ('time', array([0.002]))])
|
||||
|
||||
Flatten observation space example:
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservationV0(env, flatten=True)
|
||||
>>> env.observation_space
|
||||
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
|
||||
>>> _ = env.reset()
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
flatten: bool = False,
|
||||
normalize_time: bool = True,
|
||||
*,
|
||||
dict_time_key: str = "time",
|
||||
):
|
||||
"""Initialize :class:`TimeAwareObservationV0`.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
flatten: Flatten the observation to a `Box` of a single dimension
|
||||
normalize_time: if `True` return time in the range [0,1]
|
||||
otherwise return time as remaining timesteps before truncation
|
||||
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
|
||||
"""
|
||||
super().__init__(env)
|
||||
self.flatten: Final[bool] = flatten
|
||||
self.normalize_time: Final[bool] = normalize_time
|
||||
|
||||
if hasattr(env, "_max_episode_steps"):
|
||||
self.max_timesteps = getattr(env, "_max_episode_steps")
|
||||
elif env.spec is not None and env.spec.max_episode_steps is not None:
|
||||
self.max_timesteps = env.spec.max_episode_steps
|
||||
else:
|
||||
raise ValueError(
|
||||
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
||||
)
|
||||
|
||||
self.timesteps: int = 0
|
||||
|
||||
# Find the normalized time space
|
||||
if self.normalize_time:
|
||||
self._time_preprocess_func = lambda time: time / self.max_timesteps
|
||||
time_space = Box(0.0, 1.0)
|
||||
else:
|
||||
self._time_preprocess_func = lambda time: self.max_timesteps - time
|
||||
time_space = Box(0, self.max_timesteps, dtype=np.int32)
|
||||
|
||||
# Find the observation space
|
||||
if isinstance(env.observation_space, Dict):
|
||||
assert dict_time_key not in env.observation_space.keys()
|
||||
observation_space = Dict(
|
||||
{dict_time_key: time_space}, **env.observation_space.spaces
|
||||
)
|
||||
self._append_data_func = lambda obs, time: {**obs, dict_time_key: time}
|
||||
elif isinstance(env.observation_space, Tuple):
|
||||
observation_space = Tuple(env.observation_space.spaces + (time_space,))
|
||||
self._append_data_func = lambda obs, time: obs + (time,)
|
||||
else:
|
||||
observation_space = Dict(obs=env.observation_space, time=time_space)
|
||||
self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
|
||||
|
||||
# If to flatten the observation space
|
||||
if self.flatten:
|
||||
self.observation_space = spaces.flatten_space(observation_space)
|
||||
self._obs_postprocess_func = lambda obs: spaces.flatten(
|
||||
observation_space, obs
|
||||
)
|
||||
else:
|
||||
self.observation_space = observation_space
|
||||
self._obs_postprocess_func = lambda obs: obs
|
||||
|
||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||
"""Adds to the observation with the current time information.
|
||||
|
||||
Args:
|
||||
observation: The observation to add the time step to
|
||||
|
||||
Returns:
|
||||
The observation with the time information appended to
|
||||
"""
|
||||
return self._obs_postprocess_func(
|
||||
self._append_data_func(
|
||||
observation, self._time_preprocess_func(self.timesteps)
|
||||
)
|
||||
)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment, incrementing the time step.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The environment's step using the action.
|
||||
"""
|
||||
self.timesteps += 1
|
||||
return super().step(action)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Reset the environment setting the time to zero.
|
||||
|
||||
Args:
|
||||
seed: The seed to reset the environment
|
||||
options: The options used to reset the environment
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self.timesteps = 0
|
||||
|
||||
return super().reset(seed=seed, options=options)
|
@@ -1,40 +0,0 @@
|
||||
"""Wrapper which adds a probability of repeating the previous executed action."""
|
||||
from typing import Union
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType
|
||||
from gymnasium.error import InvalidProbability
|
||||
|
||||
|
||||
class StickyActionV0(gym.ActionWrapper):
|
||||
"""Wrapper which adds a probability of repeating the previous action."""
|
||||
|
||||
def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]):
|
||||
"""Initialize StickyAction wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): the wrapped environment
|
||||
repeat_action_probability (int | float): a proability of repeating the old action.
|
||||
"""
|
||||
if not 0 <= repeat_action_probability < 1:
|
||||
raise InvalidProbability(
|
||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||
)
|
||||
super().__init__(env)
|
||||
self.repeat_action_probability = repeat_action_probability
|
||||
self.old_action = None
|
||||
|
||||
def action(self, action: ActType):
|
||||
"""Execute the action."""
|
||||
if (
|
||||
self.old_action is not None
|
||||
and self.np_random.uniform() < self.repeat_action_probability
|
||||
):
|
||||
action = self.old_action
|
||||
self.old_action = action
|
||||
return action
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment."""
|
||||
self.old_action = None
|
||||
return super().reset(**kwargs)
|
@@ -1,113 +0,0 @@
|
||||
"""Wrapper for adding time aware observations to environment observation."""
|
||||
from collections import OrderedDict
|
||||
|
||||
import gymnasium as gym
|
||||
import gymnasium.spaces as spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces import Box, Dict
|
||||
|
||||
|
||||
class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
"""Augment the observation with time information of the episode.
|
||||
|
||||
Time can be represented as a normalized value between [0,1]
|
||||
or by the number of timesteps remaining before truncation occurs.
|
||||
|
||||
Example:
|
||||
>>> import gym
|
||||
>>> from gym.wrappers import TimeAwareObservationV0
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservationV0(env)
|
||||
>>> env.observation_space
|
||||
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
|
||||
>>> _ = env.reset()
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
OrderedDict([('obs',
|
||||
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
|
||||
... ('time', array([0.002]))])
|
||||
|
||||
Flatten observation space example:
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservationV0(env, flatten=True)
|
||||
>>> env.observation_space
|
||||
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
|
||||
>>> _ = env.reset()
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, flatten=False, normalize_time=True):
|
||||
"""Initialize :class:`TimeAwareObservationV0`.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
flatten: Flatten the observation to a `Box` of a single dimension
|
||||
normalize_time: if `True` return time in the range [0,1]
|
||||
otherwise return time as remaining timesteps before truncation
|
||||
"""
|
||||
super().__init__(env)
|
||||
self.flatten = flatten
|
||||
self.normalize_time = normalize_time
|
||||
self.max_timesteps = getattr(env, "_max_episode_steps")
|
||||
|
||||
if self.normalize_time:
|
||||
self._get_time_observation = lambda: self.timesteps / self.max_timesteps
|
||||
time_space = Box(0, 1)
|
||||
else:
|
||||
self._get_time_observation = lambda: self.max_timesteps - self.timesteps
|
||||
time_space = Box(0, self.max_timesteps)
|
||||
|
||||
self.time_aware_observation_space = Dict(
|
||||
obs=env.observation_space, time=time_space
|
||||
)
|
||||
|
||||
if self.flatten:
|
||||
self.observation_space = spaces.flatten_space(
|
||||
self.time_aware_observation_space
|
||||
)
|
||||
self._observation_postprocess = lambda observation: spaces.flatten(
|
||||
self.time_aware_observation_space, observation
|
||||
)
|
||||
else:
|
||||
self.observation_space = self.time_aware_observation_space
|
||||
self._observation_postprocess = lambda observation: observation
|
||||
|
||||
def observation(self, observation: ObsType):
|
||||
"""Adds to the observation with the current time information.
|
||||
|
||||
Args:
|
||||
observation: The observation to add the time step to
|
||||
|
||||
Returns:
|
||||
The observation with the time information appended to
|
||||
"""
|
||||
time_observation = self._get_time_observation()
|
||||
observation = OrderedDict(obs=observation, time=time_observation)
|
||||
|
||||
return self._observation_postprocess(observation)
|
||||
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment, incrementing the time step.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The environment's step using the action.
|
||||
"""
|
||||
self.timesteps += 1
|
||||
observation, reward, terminated, truncated, info = super().step(action)
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment setting the time to zero.
|
||||
|
||||
Args:
|
||||
**kwargs: Kwargs to apply to env.reset()
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self.timesteps = 0
|
||||
return super().reset(**kwargs)
|
@@ -14,92 +14,121 @@ import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import dlpack as jax_dlpack
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
|
||||
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
from jax import dlpack as jax_dlpack
|
||||
except ImportError:
|
||||
jnp, jax_dlpack = None, None
|
||||
|
||||
try:
|
||||
import torch
|
||||
from torch.utils import dlpack as torch_dlpack
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled("torch is not installed, run `pip install torch`")
|
||||
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
torch, torch_dlpack, Device = None, None, None
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def torch_to_jax(value: Any) -> Any:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
raise Exception(
|
||||
f"No conversion for PyTorch to Jax registered for type: {type(value)}"
|
||||
)
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
|
||||
)
|
||||
elif jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@torch_to_jax.register(numbers.Number)
|
||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||
return jnp.array(value)
|
||||
if torch is not None and jnp is not None:
|
||||
|
||||
@torch_to_jax.register(numbers.Number)
|
||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||
"""Convert a python number (int, float, complex) to a jax array."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
@torch_to_jax.register(torch.Tensor)
|
||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
|
||||
return tensor
|
||||
@torch_to_jax.register(torch.Tensor)
|
||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
assert torch_dlpack is not None and jax_dlpack is not None
|
||||
tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||
value
|
||||
)
|
||||
tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor
|
||||
)
|
||||
return tensor
|
||||
|
||||
@torch_to_jax.register(abc.Mapping)
|
||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||
|
||||
@torch_to_jax.register(abc.Mapping)
|
||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@torch_to_jax.register(abc.Iterable)
|
||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(torch_to_jax(v) for v in value)
|
||||
@torch_to_jax.register(abc.Iterable)
|
||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(torch_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
raise Exception(
|
||||
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}"
|
||||
)
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`"
|
||||
)
|
||||
elif jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@jax_to_torch.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_torch(
|
||||
value: jnp.DeviceArray, device: Device | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
if torch is not None and jnp is not None:
|
||||
|
||||
@jax_to_torch.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_torch(
|
||||
value: jnp.DeviceArray, device: Device | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
assert jax_dlpack is not None and torch_dlpack is not None
|
||||
dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||
value
|
||||
)
|
||||
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
@jax_to_torch.register(abc.Mapping)
|
||||
def _jax_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||
@jax_to_torch.register(abc.Mapping)
|
||||
def _jax_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_torch.register(abc.Iterable)
|
||||
def _jax_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||
return type(value)(jax_to_torch(v, device) for v in value)
|
||||
@jax_to_torch.register(abc.Iterable)
|
||||
def _jax_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||
return type(value)(jax_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
class JaxToTorchV0(Wrapper):
|
||||
@@ -107,7 +136,8 @@ class JaxToTorchV0(Wrapper):
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
|
||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||
Note:
|
||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, device: Device | None = None):
|
||||
@@ -117,6 +147,15 @@ class JaxToTorchV0(Wrapper):
|
||||
env: The Jax-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed, run `pip install torch`"
|
||||
)
|
||||
elif jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
|
||||
super().__init__(env)
|
||||
self.device: Device | None = device
|
||||
|
||||
|
@@ -76,6 +76,7 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
|
||||
* ``None`` The length will be randomly drawn from a geometric distribution
|
||||
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
|
||||
* ``int`` for a fixed length sample
|
||||
|
||||
The second element of the mask tuple `sample` mask specifies a mask that is applied when
|
||||
sampling elements from the base space. The mask is applied for each feature space sample.
|
||||
|
||||
|
@@ -29,7 +29,7 @@ dependencies = [
|
||||
"jax-jumpy >=0.2.0",
|
||||
"cloudpickle >=1.2.0",
|
||||
"importlib-metadata >=4.8.0; python_version < '3.10'",
|
||||
"typing-extensions >=4.3.0; python_version == '3.7'",
|
||||
"typing-extensions >=4.3.0",
|
||||
"gymnasium-notices >=0.0.1",
|
||||
"shimmy >=0.1.0,<1.0",
|
||||
]
|
||||
|
@@ -19,7 +19,7 @@ from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
||||
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
|
||||
from tests.envs.utils import all_testing_env_specs
|
||||
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
||||
from tests.testing_env import GenericTestEnv, old_step_fn
|
||||
from tests.testing_env import GenericTestEnv, old_step_func
|
||||
from tests.wrappers.utils import has_wrapper
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ def test_make_disable_env_checker():
|
||||
def test_apply_api_compatibility():
|
||||
gym.register(
|
||||
"testing-old-env",
|
||||
lambda: GenericTestEnv(step_fn=old_step_fn),
|
||||
lambda: GenericTestEnv(step_func=old_step_func),
|
||||
apply_api_compatibility=True,
|
||||
max_episode_steps=3,
|
||||
)
|
||||
|
@@ -1,48 +0,0 @@
|
||||
"""Test suite for LambdaActionV0."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import ClipActionV0
|
||||
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env", "action_unclipped_env", "action_clipped_env"),
|
||||
(
|
||||
[
|
||||
# MountainCar action space: Box(-1.0, 1.0, (1,), float32)
|
||||
gym.make("MountainCarContinuous-v0"),
|
||||
np.array([1]),
|
||||
np.array([1.5]),
|
||||
],
|
||||
[
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
np.array([1, 1, 1, 1]),
|
||||
np.array([10, 10, 10, 10]),
|
||||
],
|
||||
[
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
np.array([0.5, 0.5, 1, 1]),
|
||||
np.array([0.5, 0.5, 10, 10]),
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env):
|
||||
"""Tests if actions out of bound are correctly clipped.
|
||||
|
||||
Tests whether out of bound actions for the wrapped
|
||||
environments are correctly clipped.
|
||||
"""
|
||||
env.reset(seed=SEED)
|
||||
obs, _, _, _, _ = env.step(action_unclipped_env)
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = ClipActionV0(env)
|
||||
wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env)
|
||||
|
||||
assert np.alltrue(obs == wrapped_obs)
|
@@ -1,37 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import DelayObservationV0
|
||||
|
||||
|
||||
SEED = 42
|
||||
|
||||
DELAY = 3
|
||||
NUM_STEPS = 5
|
||||
|
||||
|
||||
def test_delay_observation():
|
||||
env = gym.make("CartPole-v1")
|
||||
env.action_space.seed(SEED)
|
||||
env.reset(seed=SEED)
|
||||
|
||||
undelayed_observations = []
|
||||
for _ in range(NUM_STEPS):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
undelayed_observations.append(obs)
|
||||
|
||||
env.action_space.seed(SEED)
|
||||
env.reset(seed=SEED)
|
||||
env = DelayObservationV0(env, delay=DELAY)
|
||||
|
||||
delayed_observations = []
|
||||
for i in range(NUM_STEPS):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
if i < DELAY - 1:
|
||||
assert np.all(obs == 0)
|
||||
delayed_observations.append(obs)
|
||||
|
||||
assert np.alltrue(
|
||||
np.array(delayed_observations[DELAY:])
|
||||
== np.array(undelayed_observations[: DELAY - 1])
|
||||
)
|
@@ -1,60 +1,78 @@
|
||||
"""Test suite for LambdaActionV0."""
|
||||
"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.error import InvalidAction
|
||||
from gymnasium.experimental.wrappers import LambdaActionV0
|
||||
from gymnasium.experimental.wrappers import (
|
||||
ClipActionV0,
|
||||
LambdaActionV0,
|
||||
RescaleActionV0,
|
||||
)
|
||||
from gymnasium.spaces import Box
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
NUM_ENVS = 3
|
||||
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
|
||||
SEED = 42
|
||||
|
||||
|
||||
def generic_step_fn(self, action):
|
||||
def _record_action_step_func(self, action):
|
||||
return 0, 0, False, False, {"action": action}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env", "func", "action", "expected"),
|
||||
[
|
||||
(
|
||||
GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn),
|
||||
lambda action: action + 2,
|
||||
1,
|
||||
3,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_lambda_action_v0(env, func, action, expected):
|
||||
"""Tests lambda action.
|
||||
Tests if function is correctly applied to environment's action.
|
||||
"""
|
||||
wrapped_env = LambdaActionV0(env, func)
|
||||
_, _, _, _, info = wrapped_env.step(action)
|
||||
executed_action = info["action"]
|
||||
def test_lambda_action_wrapper():
|
||||
"""Tests LambdaAction through checking that the action taken is transformed by function."""
|
||||
env = GenericTestEnv(step_func=_record_action_step_func)
|
||||
wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
|
||||
|
||||
assert executed_action == expected
|
||||
sampled_action = wrapped_env.action_space.sample()
|
||||
assert sampled_action not in env.action_space
|
||||
|
||||
_, _, _, _, info = wrapped_env.step(sampled_action)
|
||||
assert info["action"] in env.action_space
|
||||
assert sampled_action - 2 == info["action"]
|
||||
|
||||
|
||||
def test_lambda_action_v0_within_vector():
|
||||
"""Tests lambda action in vectorized environments.
|
||||
Tests if function is correctly applied to environment's action
|
||||
in vectorized environment.
|
||||
"""
|
||||
env = gym.vector.make(
|
||||
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
|
||||
def test_clip_action_wrapper():
|
||||
"""Test that the action is correctly clipped to the base environment action space."""
|
||||
env = GenericTestEnv(
|
||||
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
|
||||
step_func=_record_action_step_func,
|
||||
)
|
||||
action = np.ones(NUM_ENVS, dtype=np.float64)
|
||||
wrapped_env = ClipActionV0(env)
|
||||
|
||||
wrapped_env = LambdaActionV0(env, lambda action: action.astype(int))
|
||||
wrapped_env.reset()
|
||||
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
|
||||
assert sampled_action not in env.action_space
|
||||
assert sampled_action in wrapped_env.action_space
|
||||
|
||||
wrapped_env.step(action)
|
||||
_, _, _, _, info = wrapped_env.step(sampled_action)
|
||||
assert np.all(info["action"] in env.action_space)
|
||||
assert np.all(info["action"] == np.array([0, 2, 3.5]))
|
||||
|
||||
# unwrapped env should raise exception because it does not
|
||||
# support float actions
|
||||
with pytest.raises(InvalidAction):
|
||||
env.step(action)
|
||||
|
||||
def test_rescale_action_wrapper():
|
||||
"""Test that the action is rescale within a min / max bound."""
|
||||
env = GenericTestEnv(
|
||||
step_func=_record_action_step_func,
|
||||
action_space=Box(np.array([0, 1]), np.array([1, 3])),
|
||||
)
|
||||
wrapped_env = RescaleActionV0(
|
||||
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
|
||||
)
|
||||
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
|
||||
|
||||
for sample_action, expected_action in (
|
||||
(
|
||||
np.array([0.0, 0.5], dtype=np.float32),
|
||||
np.array([0.5, 2.0], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([-5.0, 0.0], dtype=np.float32),
|
||||
np.array([0.0, 1.0], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([5.0, 1.0], dtype=np.float32),
|
||||
np.array([1.0, 3.0], dtype=np.float32),
|
||||
),
|
||||
):
|
||||
assert sample_action in wrapped_env.action_space
|
||||
|
||||
_, _, _, _, info = wrapped_env.step(sample_action)
|
||||
assert np.all(info["action"] == expected_action)
|
||||
|
@@ -1,59 +1,250 @@
|
||||
"""Test suite for LambdaObservationV0."""
|
||||
"""Test suite for lambda observation wrappers: """
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import LambdaObservationV0
|
||||
from gymnasium.spaces import Box
|
||||
from gymnasium.experimental.wrappers import (
|
||||
DtypeObservationV0,
|
||||
FilterObservationV0,
|
||||
FlattenObservationV0,
|
||||
GrayscaleObservationV0,
|
||||
LambdaObservationV0,
|
||||
RescaleObservationV0,
|
||||
ReshapeObservationV0,
|
||||
ResizeObservationV0,
|
||||
)
|
||||
from gymnasium.spaces import Box, Dict, Tuple
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
NUM_ENVS = 3
|
||||
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
|
||||
|
||||
SEED = 42
|
||||
DISCRETE_ACTION = 1
|
||||
|
||||
|
||||
def test_lambda_observation_v0():
|
||||
"""Tests lambda observation.
|
||||
def _record_random_obs_reset(self: gym.Env, seed=None, options=None):
|
||||
obs = self.observation_space.sample()
|
||||
return obs, {"obs": obs}
|
||||
|
||||
Tests if function is correctly applied to environment's observation.
|
||||
"""
|
||||
env = gym.make("CartPole-v1")
|
||||
env.reset(seed=SEED)
|
||||
obs, _, _, _, _ = env.step(DISCRETE_ACTION)
|
||||
|
||||
observation_shift = 1
|
||||
def _record_random_obs_step(self: gym.Env, action):
|
||||
obs = self.observation_space.sample()
|
||||
return obs, 0, False, False, {"obs": obs}
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = LambdaObservationV0(
|
||||
env, lambda observation: observation + observation_shift, None
|
||||
|
||||
def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}):
|
||||
return options["obs"], {"obs": options["obs"]}
|
||||
|
||||
|
||||
def _record_action_obs_step(self: gym.Env, action):
|
||||
return action, 0, False, False, {"obs": action}
|
||||
|
||||
|
||||
def _check_obs(
|
||||
env: gym.Env,
|
||||
wrapped_env: gym.Wrapper,
|
||||
transformed_obs,
|
||||
original_obs,
|
||||
strict: bool = True,
|
||||
):
|
||||
assert (
|
||||
transformed_obs in wrapped_env.observation_space
|
||||
), f"{transformed_obs}, {wrapped_env.observation_space}"
|
||||
assert (
|
||||
original_obs in env.observation_space
|
||||
), f"{original_obs}, {env.observation_space}"
|
||||
|
||||
if strict:
|
||||
assert (
|
||||
transformed_obs not in env.observation_space
|
||||
), f"{transformed_obs}, {env.observation_space}"
|
||||
assert (
|
||||
original_obs not in wrapped_env.observation_space
|
||||
), f"{original_obs}, {wrapped_env.observation_space}"
|
||||
|
||||
|
||||
def test_lambda_observation_wrapper():
|
||||
"""Tests lambda observation that the function is applied to both the reset and step observation."""
|
||||
env = GenericTestEnv(
|
||||
reset_func=_record_action_obs_reset, step_func=_record_action_obs_step
|
||||
)
|
||||
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION)
|
||||
wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3))
|
||||
|
||||
assert np.alltrue(wrapped_obs == obs + observation_shift)
|
||||
obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)})
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32))
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
|
||||
def test_lambda_observation_v0_within_vector():
|
||||
"""Tests lambda observation in vectorized environments.
|
||||
|
||||
Tests if function is correctly applied to environment's observation
|
||||
in vectorized environment.
|
||||
"""
|
||||
env = gym.vector.make(
|
||||
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
|
||||
)
|
||||
env.reset(seed=SEED)
|
||||
obs, _, _, _, _ = env.step(np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)]))
|
||||
|
||||
observation_shift = 1
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = LambdaObservationV0(
|
||||
env, lambda observation: observation + observation_shift, None
|
||||
)
|
||||
wrapped_obs, _, _, _, _ = wrapped_env.step(
|
||||
np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)])
|
||||
def test_filter_observation_wrapper():
|
||||
"""Tests ``FilterObservation`` that the right keys are filtered."""
|
||||
dict_env = GenericTestEnv(
|
||||
observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
|
||||
reset_func=_record_random_obs_reset,
|
||||
step_func=_record_random_obs_step,
|
||||
)
|
||||
|
||||
assert np.alltrue(wrapped_obs == obs + observation_shift)
|
||||
wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
|
||||
obs, info = wrapped_env.reset()
|
||||
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
||||
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
||||
_check_obs(dict_env, wrapped_env, obs, info["obs"])
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
||||
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
||||
_check_obs(dict_env, wrapped_env, obs, info["obs"])
|
||||
|
||||
# Test tuple environments
|
||||
tuple_env = GenericTestEnv(
|
||||
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
|
||||
reset_func=_record_random_obs_reset,
|
||||
step_func=_record_random_obs_step,
|
||||
)
|
||||
wrapped_env = FilterObservationV0(tuple_env, (2,))
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
assert len(obs) == 1 and len(info["obs"]) == 3
|
||||
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
assert len(obs) == 1 and len(info["obs"]) == 3
|
||||
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
||||
|
||||
|
||||
def test_flatten_observation_wrapper():
|
||||
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
|
||||
env = GenericTestEnv(
|
||||
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
|
||||
reset_func=_record_random_obs_reset,
|
||||
step_func=_record_random_obs_step,
|
||||
)
|
||||
print(env.observation_space)
|
||||
wrapped_env = FlattenObservationV0(env)
|
||||
print(wrapped_env.observation_space)
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
|
||||
def test_grayscale_observation_wrapper():
|
||||
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
|
||||
env = GenericTestEnv(
|
||||
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
|
||||
reset_func=_record_random_obs_reset,
|
||||
step_func=_record_random_obs_step,
|
||||
)
|
||||
wrapped_env = GrayscaleObservationV0(env)
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
assert obs.shape == (25, 25)
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
# Keep_dim
|
||||
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
assert obs.shape == (25, 25, 1)
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
|
||||
def test_resize_observation_wrapper():
|
||||
"""Test the ``ResizeObservation`` that the observation has changed size"""
|
||||
env = GenericTestEnv(
|
||||
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
|
||||
reset_func=_record_random_obs_reset,
|
||||
step_func=_record_random_obs_step,
|
||||
)
|
||||
wrapped_env = ResizeObservationV0(env, (25, 25))
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
|
||||
def test_reshape_observation_wrapper():
|
||||
"""Test the ``ReshapeObservation`` wrapper."""
|
||||
env = GenericTestEnv(
|
||||
observation_space=Box(0, 1, shape=(2, 3, 2)),
|
||||
reset_func=_record_random_obs_reset,
|
||||
step_func=_record_random_obs_step,
|
||||
)
|
||||
wrapped_env = ReshapeObservationV0(env, (6, 2))
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
assert obs.shape == (6, 2)
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||
assert obs.shape == (6, 2)
|
||||
|
||||
|
||||
def test_rescale_observation():
|
||||
"""Test the ``RescaleObservation`` wrapper"""
|
||||
env = GenericTestEnv(
|
||||
observation_space=Box(
|
||||
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
|
||||
),
|
||||
reset_func=_record_action_obs_reset,
|
||||
step_func=_record_action_obs_step,
|
||||
)
|
||||
wrapped_env = RescaleObservationV0(
|
||||
env,
|
||||
min_obs=np.array([-5, 0], dtype=np.float32),
|
||||
max_obs=np.array([5, 1], dtype=np.float32),
|
||||
)
|
||||
assert wrapped_env.observation_space == Box(
|
||||
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
|
||||
)
|
||||
|
||||
for sample_obs, expected_obs in (
|
||||
(
|
||||
np.array([0.5, 2.0], dtype=np.float32),
|
||||
np.array([0.0, 0.5], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([0.0, 1.0], dtype=np.float32),
|
||||
np.array([-5.0, 0.0], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([1.0, 3.0], dtype=np.float32),
|
||||
np.array([5.0, 1.0], dtype=np.float32),
|
||||
),
|
||||
):
|
||||
assert sample_obs in env.observation_space
|
||||
assert expected_obs in wrapped_env.observation_space
|
||||
|
||||
obs, info = wrapped_env.reset(options={"obs": sample_obs})
|
||||
assert np.all(obs == expected_obs)
|
||||
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(sample_obs)
|
||||
assert np.all(obs == expected_obs)
|
||||
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
||||
|
||||
|
||||
def test_dtype_observation():
|
||||
"""Test ``DtypeObservation`` that the"""
|
||||
env = GenericTestEnv(
|
||||
reset_func=_record_random_obs_reset, step_func=_record_random_obs_step
|
||||
)
|
||||
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
assert obs.dtype != info["obs"].dtype
|
||||
assert obs.dtype == np.uint8
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
assert obs.dtype != info["obs"].dtype
|
||||
assert obs.dtype == np.uint8
|
||||
|
@@ -55,7 +55,7 @@ def jax_step_func(self, action):
|
||||
|
||||
|
||||
def test_jax_to_numpy():
|
||||
jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
||||
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
|
||||
|
||||
# Check that the reset and step for jax environment are as expected
|
||||
obs, info = jax_env.reset()
|
||||
|
@@ -1,52 +0,0 @@
|
||||
"""Test suite for RescaleActionV0."""
|
||||
import jax
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import RescaleActionV0
|
||||
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env", "low", "high", "action", "scaled_action"),
|
||||
[
|
||||
(
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
-0.5,
|
||||
0.5,
|
||||
np.array([1, 1, 1, 1]),
|
||||
np.array([0.5, 0.5, 0.5, 0.5]),
|
||||
),
|
||||
(
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
-0.5,
|
||||
0.5,
|
||||
jax.numpy.array([1, 1, 1, 1]),
|
||||
jax.numpy.array([0.5, 0.5, 0.5, 0.5]),
|
||||
),
|
||||
(
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
np.array([-0.5, -0.5, -1, -1], dtype=np.float32),
|
||||
np.array([0.5, 0.5, 1, 1], dtype=np.float32),
|
||||
jax.numpy.array([1, 1, 1, 1]),
|
||||
jax.numpy.array([0.5, 0.5, 1, 1]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rescale_actions_v0_box(env, low, high, action, scaled_action):
|
||||
"""Test action rescaling."""
|
||||
env.reset(seed=SEED)
|
||||
obs, _, _, _, _ = env.step(action)
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = RescaleActionV0(env, low, high)
|
||||
|
||||
obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action)
|
||||
|
||||
assert np.alltrue(obs == obs_scaled)
|
@@ -17,7 +17,9 @@ def step_fn(self, action):
|
||||
|
||||
|
||||
def test_sticky_action():
|
||||
env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5)
|
||||
env = StickyActionV0(
|
||||
GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5
|
||||
)
|
||||
env.reset(seed=SEED)
|
||||
env.action_space.seed(SEED)
|
||||
|
||||
@@ -34,7 +36,7 @@ def test_sticky_action():
|
||||
previous_action = input_action
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5])
|
||||
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
||||
def test_sticky_action_raise(repeat_action_probability):
|
||||
with pytest.raises(InvalidProbability):
|
||||
StickyActionV0(
|
89
tests/experimental/wrappers/test_stateful_observation.py
Normal file
89
tests/experimental/wrappers/test_stateful_observation.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0
|
||||
from gymnasium.spaces import Box, Dict, Tuple
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
NUM_STEPS = 20
|
||||
SEED = 0
|
||||
|
||||
DELAY = 3
|
||||
|
||||
|
||||
def test_time_aware_observation_wrapper():
|
||||
"""Tests the time aware observation wrapper."""
|
||||
# Test the environment observation space with Dict, Tuple and other
|
||||
env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3)))
|
||||
wrapped_env = TimeAwareObservationV0(env)
|
||||
assert isinstance(wrapped_env.observation_space, Dict)
|
||||
reset_obs, _ = wrapped_env.reset()
|
||||
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||
assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}"
|
||||
|
||||
env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3))))
|
||||
wrapped_env = TimeAwareObservationV0(env)
|
||||
assert isinstance(wrapped_env.observation_space, Tuple)
|
||||
reset_obs, _ = wrapped_env.reset()
|
||||
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||
assert len(reset_obs) == 3 and len(step_obs) == 3
|
||||
|
||||
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||
wrapped_env = TimeAwareObservationV0(env)
|
||||
assert isinstance(wrapped_env.observation_space, Dict)
|
||||
reset_obs, _ = wrapped_env.reset()
|
||||
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||
assert isinstance(reset_obs, dict) and isinstance(step_obs, dict)
|
||||
assert "obs" in reset_obs and "obs" in step_obs
|
||||
assert "time" in reset_obs and "time" in step_obs
|
||||
|
||||
# Tests the flatten parameter
|
||||
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||
wrapped_env = TimeAwareObservationV0(env, flatten=True)
|
||||
assert isinstance(wrapped_env.observation_space, Box)
|
||||
reset_obs, _ = wrapped_env.reset()
|
||||
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||
assert reset_obs.shape == (2,) and step_obs.shape == (2,)
|
||||
|
||||
# Tests the normalize_time parameter
|
||||
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||
wrapped_env = TimeAwareObservationV0(env, normalize_time=False)
|
||||
reset_obs, _ = wrapped_env.reset()
|
||||
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||
assert reset_obs["time"] == 100 and step_obs["time"] == 99
|
||||
|
||||
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||
wrapped_env = TimeAwareObservationV0(env, normalize_time=True)
|
||||
reset_obs, _ = wrapped_env.reset()
|
||||
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||
assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01
|
||||
|
||||
|
||||
def test_delay_observation_wrapper():
|
||||
env = gym.make("CartPole-v1")
|
||||
env.action_space.seed(SEED)
|
||||
env.reset(seed=SEED)
|
||||
|
||||
undelayed_observations = []
|
||||
for _ in range(NUM_STEPS):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
undelayed_observations.append(obs)
|
||||
|
||||
env = DelayObservationV0(env, delay=DELAY)
|
||||
env.action_space.seed(SEED)
|
||||
env.reset(seed=SEED)
|
||||
|
||||
delayed_observations = []
|
||||
for i in range(NUM_STEPS):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
delayed_observations.append(obs)
|
||||
if i < DELAY - 1:
|
||||
assert np.all(obs == 0)
|
||||
|
||||
undelayed_observations = np.array(undelayed_observations)
|
||||
delayed_observations = np.array(delayed_observations)
|
||||
|
||||
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])
|
@@ -1,99 +0,0 @@
|
||||
"""Test suite for TimeAwareobservationV0."""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
||||
from gymnasium.spaces import Box, Dict
|
||||
|
||||
|
||||
NUM_STEPS = 20
|
||||
SEED = 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
[
|
||||
gym.make("CartPole-v1", disable_env_checker=True),
|
||||
gym.make("CarRacing-v2", disable_env_checker=True),
|
||||
],
|
||||
)
|
||||
def test_time_aware_observation_creation(env):
|
||||
"""Test TimeAwareObservationV0 wrapper creation.
|
||||
|
||||
This test checks if wrapped env with TimeAwareObservationV0
|
||||
is correctly created.
|
||||
"""
|
||||
wrapped_env = TimeAwareObservationV0(env)
|
||||
obs, _ = wrapped_env.reset()
|
||||
|
||||
assert isinstance(wrapped_env.observation_space, Dict)
|
||||
assert isinstance(obs, OrderedDict)
|
||||
assert np.all(obs["time"] == 0)
|
||||
assert env.observation_space == wrapped_env.observation_space["obs"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("normalize_time", [True, False])
|
||||
@pytest.mark.parametrize("flatten", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
[
|
||||
gym.make("CartPole-v1", disable_env_checker=True),
|
||||
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
|
||||
],
|
||||
)
|
||||
def test_time_aware_observation_step(env, flatten, normalize_time):
|
||||
"""Test TimeAwareObservationV0 step.
|
||||
|
||||
This test checks if wrapped env with TimeAwareObservationV0
|
||||
steps correctly.
|
||||
"""
|
||||
env.action_space.seed(SEED)
|
||||
max_timesteps = env._max_episode_steps
|
||||
|
||||
wrapped_env = TimeAwareObservationV0(
|
||||
env, flatten=flatten, normalize_time=normalize_time
|
||||
)
|
||||
wrapped_env.reset(seed=SEED)
|
||||
|
||||
for timestep in range(1, NUM_STEPS):
|
||||
action = env.action_space.sample()
|
||||
observation, _, terminated, _, _ = wrapped_env.step(action)
|
||||
|
||||
expected_time_obs = (
|
||||
timestep / max_timesteps if normalize_time else max_timesteps - timestep
|
||||
)
|
||||
|
||||
if flatten:
|
||||
assert np.allclose(observation[-1], expected_time_obs)
|
||||
else:
|
||||
assert np.allclose(observation["time"], expected_time_obs)
|
||||
|
||||
if terminated:
|
||||
break
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
[
|
||||
gym.make("CartPole-v1", disable_env_checker=True),
|
||||
gym.make("CarRacing-v2", disable_env_checker=True),
|
||||
],
|
||||
)
|
||||
def test_time_aware_observation_creation_flatten(env):
|
||||
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
|
||||
|
||||
This test checks if wrapped env with TimeAwareObservationV0
|
||||
is correctly created when the `flatten` parameter is set to `True`.
|
||||
When flattened, the observation space should be a 1 dimension `Box`
|
||||
with time appended to the end.
|
||||
"""
|
||||
wrapped_env = TimeAwareObservationV0(env, flatten=True)
|
||||
obs, _ = wrapped_env.reset()
|
||||
|
||||
assert isinstance(wrapped_env.observation_space, Box)
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]
|
@@ -9,6 +9,7 @@ from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
def torch_data_equivalence(data_1, data_2) -> bool:
|
||||
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
|
||||
if type(data_1) == type(data_2):
|
||||
if isinstance(data_1, dict):
|
||||
return data_1.keys() == data_2.keys() and all(
|
||||
@@ -56,14 +57,15 @@ def torch_data_equivalence(data_1, data_2) -> bool:
|
||||
)
|
||||
def test_roundtripping(value, expected_value):
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||
assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value)
|
||||
roundtripped_value = jax_to_torch(torch_to_jax(value))
|
||||
assert torch_data_equivalence(roundtripped_value, expected_value)
|
||||
|
||||
|
||||
def jax_reset_func(self, seed=None, options=None):
|
||||
def _jax_reset_func(self, seed=None, options=None):
|
||||
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
||||
|
||||
|
||||
def jax_step_func(self, action):
|
||||
def _jax_step_func(self, action):
|
||||
assert isinstance(action, jnp.DeviceArray), type(action)
|
||||
return (
|
||||
jnp.array([1, 2, 3]),
|
||||
@@ -75,7 +77,7 @@ def jax_step_func(self, action):
|
||||
|
||||
|
||||
def test_jax_to_torch():
|
||||
env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
||||
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
|
||||
|
||||
# Check that the reset and step for jax environment are as expected
|
||||
obs, info = env.reset()
|
||||
|
@@ -278,7 +278,7 @@ def test_wrapper_types():
|
||||
obs, _, _, _, _ = observation_env.step(0)
|
||||
assert obs == np.array([1])
|
||||
|
||||
env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {}))
|
||||
env = GenericTestEnv(step_func=lambda self, action: (action, 0, False, False, {}))
|
||||
action_env = ExampleActionWrapper(env)
|
||||
obs, _, _, _, _ = action_env.step(0)
|
||||
assert obs == np.array([1])
|
||||
|
@@ -8,7 +8,7 @@ from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
|
||||
def basic_reset_fn(
|
||||
def basic_reset_func(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
@@ -20,17 +20,17 @@ def basic_reset_fn(
|
||||
return self.observation_space.sample(), {"options": options}
|
||||
|
||||
|
||||
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
|
||||
return self.observation_space.sample(), 0, False, False, {}
|
||||
|
||||
|
||||
def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
|
||||
return self.observation_space.sample(), 0, False, {}
|
||||
|
||||
|
||||
def basic_render_fn(self):
|
||||
def basic_render_func(self):
|
||||
"""Basic render fn that does nothing."""
|
||||
pass
|
||||
|
||||
@@ -43,12 +43,14 @@ class GenericTestEnv(gym.Env):
|
||||
self,
|
||||
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||
reset_fn: callable = basic_reset_fn,
|
||||
step_fn: callable = new_step_fn,
|
||||
render_fn: callable = basic_render_fn,
|
||||
reset_func: callable = basic_reset_func,
|
||||
step_func: callable = new_step_func,
|
||||
render_func: callable = basic_render_func,
|
||||
metadata: Dict[str, Any] = {"render_modes": []},
|
||||
render_mode: Optional[str] = None,
|
||||
spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"),
|
||||
spec: EnvSpec = EnvSpec(
|
||||
"TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100
|
||||
),
|
||||
):
|
||||
self.metadata = metadata
|
||||
self.render_mode = render_mode
|
||||
@@ -59,12 +61,12 @@ class GenericTestEnv(gym.Env):
|
||||
if action_space is not None:
|
||||
self.action_space = action_space
|
||||
|
||||
if reset_fn is not None:
|
||||
self.reset = types.MethodType(reset_fn, self)
|
||||
if step_fn is not None:
|
||||
self.step = types.MethodType(step_fn, self)
|
||||
if render_fn is not None:
|
||||
self.render = types.MethodType(render_fn, self)
|
||||
if reset_func is not None:
|
||||
self.reset = types.MethodType(reset_func, self)
|
||||
if step_func is not None:
|
||||
self.step = types.MethodType(step_func, self)
|
||||
if render_func is not None:
|
||||
self.render = types.MethodType(render_func, self)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -112,10 +112,10 @@ def test_check_reset_seed(test, func: callable, message: str):
|
||||
with pytest.warns(
|
||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||
):
|
||||
check_reset_seed(GenericTestEnv(reset_fn=func))
|
||||
check_reset_seed(GenericTestEnv(reset_func=func))
|
||||
else:
|
||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||
check_reset_seed(GenericTestEnv(reset_fn=func))
|
||||
check_reset_seed(GenericTestEnv(reset_func=func))
|
||||
|
||||
|
||||
def _deprecated_return_info(
|
||||
@@ -179,7 +179,7 @@ def test_check_reset_return_type(test, func: callable, message: str):
|
||||
"""Tests the check `env.reset()` function has a correct return type."""
|
||||
|
||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||
check_reset_return_type(GenericTestEnv(reset_fn=func))
|
||||
check_reset_return_type(GenericTestEnv(reset_func=func))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -198,7 +198,7 @@ def test_check_reset_return_info_deprecation(test, func: callable, message: str)
|
||||
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
|
||||
|
||||
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
|
||||
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func))
|
||||
check_reset_return_info_deprecation(GenericTestEnv(reset_func=func))
|
||||
|
||||
|
||||
def test_check_seed_deprecation():
|
||||
@@ -236,7 +236,7 @@ def test_check_reset_options():
|
||||
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
|
||||
),
|
||||
):
|
||||
check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {})))
|
||||
check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {})))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@@ -303,11 +303,11 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
|
||||
with pytest.warns(
|
||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||
):
|
||||
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
||||
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
|
||||
else:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
||||
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
|
||||
assert len(caught_warnings) == 0
|
||||
|
||||
|
||||
@@ -383,11 +383,11 @@ def test_passive_env_step_checker(
|
||||
with pytest.warns(
|
||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||
):
|
||||
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
|
||||
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
|
||||
else:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
|
||||
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
|
||||
assert len(caught_warnings) == 0, caught_warnings
|
||||
|
||||
|
||||
@@ -416,7 +416,7 @@ def test_passive_env_step_checker(
|
||||
GenericTestEnv(
|
||||
metadata={"render_modes": ["Testing mode"], "render_fps": None},
|
||||
render_mode="Testing mode",
|
||||
render_fn=lambda self: 0,
|
||||
render_func=lambda self: 0,
|
||||
),
|
||||
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
|
||||
],
|
||||
|
@@ -21,7 +21,7 @@ IRRELEVANT_KEY = 1
|
||||
PlayableEnv = partial(
|
||||
GenericTestEnv,
|
||||
metadata={"render_modes": ["rgb_array"]},
|
||||
render_fn=lambda self: np.ones((10, 10, 3)),
|
||||
render_func=lambda self: np.ones((10, 10, 3)),
|
||||
)
|
||||
|
||||
|
||||
|
@@ -82,8 +82,8 @@ def test_final_obs_info(vectoriser):
|
||||
return GenericTestEnv(
|
||||
action_space=Discrete(4),
|
||||
observation_space=Discrete(4),
|
||||
reset_fn=reset_fn,
|
||||
step_fn=lambda self, action: (
|
||||
reset_func=reset_fn,
|
||||
step_func=lambda self, action: (
|
||||
action if action < 3 else 0,
|
||||
0,
|
||||
action >= 3,
|
||||
|
@@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from gymnasium.spaces import Box, Discrete
|
||||
from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility
|
||||
from tests.testing_env import GenericTestEnv, old_step_fn
|
||||
from tests.testing_env import GenericTestEnv, old_step_func
|
||||
|
||||
|
||||
class AleTesting:
|
||||
@@ -34,7 +34,7 @@ class AtariTestingEnv(GenericTestEnv):
|
||||
low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1
|
||||
),
|
||||
action_space=Discrete(3, seed=1),
|
||||
step_fn=old_step_fn,
|
||||
step_func=old_step_func,
|
||||
)
|
||||
self.ale = AleTesting()
|
||||
|
||||
|
@@ -68,8 +68,8 @@ def _step_failure(self, action):
|
||||
|
||||
def test_api_failures():
|
||||
env = GenericTestEnv(
|
||||
reset_fn=_reset_failure,
|
||||
step_fn=_step_failure,
|
||||
reset_func=_reset_failure,
|
||||
step_func=_step_failure,
|
||||
metadata={"render_modes": "error"},
|
||||
)
|
||||
env = PassiveEnvChecker(env)
|
||||
|
Reference in New Issue
Block a user