Add wrappers to experimental (#201)

This commit is contained in:
Mark Towers
2022-12-10 22:04:14 +00:00
committed by GitHub
parent 93ee100987
commit f208f874a0
61 changed files with 1987 additions and 659 deletions

View File

@@ -51,7 +51,7 @@ repos:
rev: 6.1.1
hooks:
- id: pydocstyle
exclude: ^(gymnasium/envs/)|(tests/)|(docs/)
exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/)
args:
- --source
- --explain

View File

@@ -14,7 +14,9 @@ experimental/vector_wrappers
## Functional Environments
```{eval-rst}
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
```
## Wrappers
@@ -36,64 +38,32 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* - Old name
- New name
- Vector version
- Tree structure
* - :class:`wrappers.TransformObservation`
- :class:`experimental.wrappers.LambdaObservationV0`
- VectorLambdaObservation
- No
* - :class:`wrappers.FilterObservation`
- :class:`experimental.wrappers.FilterObservationV0`
- VectorFilterObservation (*)
- Yes
* - :class:`wrappers.FlattenObservation`
- :class:`experimental.wrappers.FlattenObservationV0`
- VectorFlattenObservation (*)
- No
* - :class:`wrappers.GrayScaleObservation`
- :class:`experimental.wrappers.GrayscaleObservationV0`
- VectorGrayscaleObservation (*)
- Yes
* - :class:`wrappers.ResizeObservation`
- :class:`experimental.wrappers.ResizeObservationV0`
- VectorResizeObservation (*)
- Yes
* - Not Implemented
* - ``supersuit.reshape_v0``
- :class:`experimental.wrappers.ReshapeObservationV0`
- VectorReshapeObservation (*)
- Yes
* - Not Implemented
- :class:`experimental.wrappers.RescaleObservationV0`
- VectorRescaleObservation (*)
- Yes
* - Not Implemented
* - ``supersuit.dtype_v0``
- :class:`experimental.wrappers.DtypeObservationV0`
- VectorDtypeObservation (*)
- Yes
* - :class:`wrappers.PixelObservationWrapper`
- PixelObservation
- VectorPixelObservation
- No
- :class:`experimental.wrappers.PixelObservationV0`
* - :class:`wrappers.NormalizeObservation`
- NormalizeObservation
- VectorNormalizeObservation
- No
- :class:`experimental.wrappers.NormalizeObservationV0`
* - :class:`wrappers.TimeAwareObservation`
- :class:`experimental.wrappers.TimeAwareObservationV0`
- VectorTimeAwareObservation
- No
* - :class:`wrappers.FrameStack`
- FrameStackObservation
- VectorFrameStackObservation
- No
* - Not Implemented
- :class:`experimental.wrappers.FrameStackObservationV0`
* - ``supersuit.delay_observations_v0``
- :class:`experimental.wrappers.DelayObservationV0`
- VectorDelayObservation
- No
* - :class:`wrappers.AtariPreprocessing`
- AtariPreprocessing
- Not Implemented
- No
```
### Action Wrappers
@@ -105,24 +75,14 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* - Old name
- New name
- Vector version
- Tree structure
* - Not Implemented
* - ``supersuit.action_lambda_v1``
- :class:`experimental.wrappers.LambdaActionV0`
- VectorLambdaAction
- No
* - :class:`wrappers.ClipAction`
- :class:`experimental.wrappers.ClipActionV0`
- VectorClipAction (*)
- Yes
* - :class:`wrappers.RescaleAction`
- :class:`experimental.wrappers.RescaleActionV0`
- VectorRescaleAction (*)
- Yes
* - Not Implemented
* - ``supersuit.sticky_actions_v0``
- :class:`experimental.wrappers.StickyActionV0`
- VectorStickyAction
- No
```
### Reward Wrappers
@@ -134,19 +94,12 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* - Old name
- New name
- Vector version
* - :class:`wrappers.TransformReward`
- :class:`experimental.wrappers.LambdaRewardV0`
- VectorLambdaReward
* - Not Implemented
* - ``supersuit.clip_reward_v0``
- :class:`experimental.wrappers.ClipRewardV0`
- VectorClipReward (*)
* - Not Implemented
- RescaleReward
- VectorRescaleReward (*)
* - :class:`wrappers.NormalizeReward`
- NormalizeReward
- VectorNormalizeReward
- :class:`experimental.wrappers.NormalizeRewardV0`
```
### Common Wrappers
@@ -159,37 +112,21 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* - Old name
- New name
- Vector version
* - :class:`wrappers.AutoResetWrapper`
- AutoReset
- VectorAutoReset
- :class:`experimental.wrappers.AutoresetV0`
* - :class:`wrappers.PassiveEnvChecker`
- PassiveEnvChecker
- VectorPassiveEnvChecker
- :class:`experimental.wrappers.PassiveEnvCheckerV0`
* - :class:`wrappers.OrderEnforcing`
- OrderEnforcing
- VectorOrderEnforcing
- :class:`experimental.wrappers.OrderEnforcingV0`
* - :class:`wrappers.EnvCompatibility`
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
- Not Implemented
* - :class:`wrappers.RecordEpisodeStatistics`
- RecordEpisodeStatistics
- VectorRecordEpisodeStatistics
* - :class:`wrappers.RenderCollection`
- RenderCollection
- VectorRenderCollection
* - :class:`wrappers.HumanRendering`
- HumanRendering
- Not Implemented
* - Not Implemented
- :class:`experimental.wrappers.JaxToNumpyV0`
- VectorJaxToNumpy (*)
* - Not Implemented
- :class:`experimental.wrappers.JaxToTorchV0`
- VectorJaxToTorch (*)
- :class:`experimental.wrappers.RecordEpisodeStatisticsV0`
* - :class:`wrappers.AtariPreprocessing`
- :class:`experimental.wrappers.AtariPreprocessingV0`
```
### Vector Only Wrappers
### Rendering Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
@@ -199,8 +136,22 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* - Old name
- New name
* - :class:`wrappers.VectorListInfo`
- VectorListInfo
* - :class:`wrapper.RecordVideo`
- :class:`experimental.wrappers.RecordVideoV0`
* - :class:`wrappers.HumanRendering`
- :class:`experimental.wrappers.HumanRenderingV0`
* - :class:`wrappers.RenderCollection`
- :class:`experimental.wrappers.RenderCollectionV0`
```
### Environment data conversion
```{eval-rst}
.. py:currentmodule:: gymnasium
* :class:`experimental.wrappers.JaxToNumpyV0`
* :class:`experimental.wrappers.JaxToTorchV0`
* :class:`experimental.wrappers.NumpyToTorchV0`
```
## Vector Environment

View File

@@ -11,8 +11,12 @@
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.PixelObservationV0
.. autoclass:: gymnasium.experimental.wrappers.NormalizeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
.. autoclass:: gymnasium.experimental.wrappers.FrameStackObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
.. autoclass:: gymnasium.experimental.wrappers.AtariPreprocessingV0
```
## Action Wrappers
@@ -24,16 +28,34 @@
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
```
# Reward Wrappers
## Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV0
```
## Common Wrappers
## Other Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.AutoresetV0
.. autoclass:: gymnasium.experimental.wrappers.PassiveEnvCheckerV0
.. autoclass:: gymnasium.experimental.wrappers.OrderEnforcingV0
.. autoclass:: gymnasium.experimental.wrappers.RecordEpisodeStatisticsV0
```
## Rendering Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.RecordVideoV0
.. autoclass:: gymnasium.experimental.wrappers.HumanRenderingV0
.. autoclass:: gymnasium.experimental.wrappers.RenderCollectionV0
```
## Environment data conversion
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
.. autoclass:: gymnasium.experimental.wrappers.NumpyToTorchV0
```

View File

@@ -356,7 +356,7 @@ class Wrapper(Env[WrapperObsType, WrapperActType]):
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.step(action)

View File

@@ -1,2 +1,3 @@
"""Module for 2d physics environments with functional and environment implementations."""
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv

View File

@@ -1,8 +1,7 @@
"""
Implementation of a Jax-accelerated cartpole environment.
"""
"""Implementation of a Jax-accelerated cartpole environment."""
from __future__ import annotations
from typing import Optional, Tuple, Union
from typing import Any, Tuple
import jax
import jax.numpy as jnp
@@ -74,7 +73,7 @@ class CartPoleFunctional(
)
def transition(
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
self, state: jnp.ndarray, action: int | jnp.ndarray, rng: None = None
) -> StateType:
"""Cartpole transition."""
x, x_dot, theta, theta_dot = state
@@ -106,6 +105,7 @@ class CartPoleFunctional(
return state
def terminal(self, state: jnp.ndarray) -> jnp.ndarray:
"""Checks if the state is terminal."""
x, _, theta, _ = state
terminated = (
@@ -120,6 +120,7 @@ class CartPoleFunctional(
def reward(
self, state: StateType, action: ActType, next_state: StateType
) -> jnp.ndarray:
"""Computes the reward for the state transition using the action."""
x, _, theta, _ = state
terminated = (
@@ -136,8 +137,8 @@ class CartPoleFunctional(
self,
state: StateType,
render_state: RenderStateType,
) -> Tuple[RenderStateType, np.ndarray]:
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an image of the state using the render state."""
try:
import pygame
from pygame import gfxdraw
@@ -210,6 +211,7 @@ class CartPoleFunctional(
def render_init(
self, screen_width: int = 600, screen_height: int = 400
) -> RenderStateType:
"""Initialises the render state for a screen width and height."""
try:
import pygame
except ImportError as e:
@@ -224,6 +226,7 @@ class CartPoleFunctional(
return screen, clock
def render_close(self, render_state: RenderStateType) -> None:
"""Closes the render state."""
try:
import pygame
except ImportError as e:
@@ -235,20 +238,24 @@ class CartPoleFunctional(
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = CartPoleFunctional(**kwargs)
env.transform(jax.jit)
action_space = env.action_space
observation_space = env.observation_space
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
super().__init__(
env,
observation_space=observation_space,
action_space=action_space,
metadata=metadata,
metadata=self.metadata,
render_mode=render_mode,
)

View File

@@ -1,8 +1,8 @@
"""
Implementation of a Jax-accelerated pendulum environment.
"""
"""Implementation of a Jax-accelerated pendulum environment."""
from __future__ import annotations
from os import path
from typing import Optional, Tuple, Union
from typing import Any, Optional, Tuple
import jax
import jax.numpy as jnp
@@ -22,7 +22,7 @@ RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]]
class PendulumFunctional(
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
):
"""Pendulum but in jax and functional."""
"""Pendulum but in jax and functional structure."""
max_speed = 8
max_torque = 2.0
@@ -44,7 +44,7 @@ class PendulumFunctional(
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
def transition(
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
self, state: jnp.ndarray, action: int | jnp.ndarray, rng: None = None
) -> jnp.ndarray:
"""Pendulum transition."""
th, thdot = state # th := theta
@@ -65,10 +65,12 @@ class PendulumFunctional(
return new_state
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
"""Generates an observation based on the state."""
theta, thetadot = state
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
"""Generates the reward based on the state, action and next state."""
th, thdot = state # th := theta
u = action
@@ -80,13 +82,15 @@ class PendulumFunctional(
return -costs
def terminal(self, state: StateType) -> bool:
"""Determines if the state is a terminal state."""
return False
def render_image(
self,
state: StateType,
render_state: Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]], # type: ignore # noqa: F821
) -> Tuple[RenderStateType, np.ndarray]:
render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an RGB image."""
try:
import pygame
from pygame import gfxdraw
@@ -159,6 +163,7 @@ class PendulumFunctional(
def render_init(
self, screen_width: int = 600, screen_height: int = 400
) -> RenderStateType:
"""Initialises the render state."""
try:
import pygame
except ImportError as e:
@@ -172,7 +177,8 @@ class PendulumFunctional(
return screen, clock, None
def render_close(self, render_state: RenderStateType) -> None:
def render_close(self, render_state: RenderStateType):
"""Closes the render state."""
try:
import pygame
except ImportError as e:
@@ -184,21 +190,24 @@ class PendulumFunctional(
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = PendulumFunctional(**kwargs)
env.transform(jax.jit)
action_space = env.action_space
observation_space = env.observation_space
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
super().__init__(
env,
observation_space=observation_space,
action_space=action_space,
metadata=metadata,
metadata=self.metadata,
render_mode=render_mode,
)

View File

@@ -1,3 +1,6 @@
"""Functions for registering environments within gymnasium using public functions ``make``, ``register`` and ``spec``."""
from __future__ import annotations
import contextlib
import copy
import difflib
@@ -8,18 +11,7 @@ import sys
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
SupportsFloat,
Tuple,
Union,
overload,
)
from typing import Any, Callable, Iterable, Sequence, SupportsFloat, overload
import numpy as np
@@ -53,7 +45,7 @@ ENV_ID_RE = re.compile(
def load(name: str) -> callable:
"""Loads an environment with name and returns an environment creation function
"""Loads an environment with name and returns an environment creation function.
Args:
name: The environment name
@@ -67,7 +59,7 @@ def load(name: str) -> callable:
return fn
def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
def parse_env_id(id: str) -> tuple[str | None, str, int | None]:
"""Parse environment ID string format.
This format is true today, but it's *not* an official spec.
@@ -98,7 +90,7 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
return namespace, name, version
def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
def get_env_id(ns: str | None, name: str, version: int | None) -> str:
"""Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
Args:
@@ -109,7 +101,6 @@ def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
Returns:
The environment id
"""
full_name = name
if version is not None:
full_name += f"-v{version}"
@@ -134,14 +125,14 @@ class EnvSpec:
"""
id: str
entry_point: Union[Callable, str]
entry_point: Callable | str
# Environment attributes
reward_threshold: Optional[float] = field(default=None)
reward_threshold: float | None = field(default=None)
nondeterministic: bool = field(default=False)
# Wrappers
max_episode_steps: Optional[int] = field(default=None)
max_episode_steps: int | None = field(default=None)
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
@@ -151,20 +142,22 @@ class EnvSpec:
kwargs: dict = field(default_factory=dict)
# post-init attributes
namespace: Optional[str] = field(init=False)
namespace: str | None = field(init=False)
name: str = field(init=False)
version: Optional[int] = field(init=False)
version: int | None = field(init=False)
def __post_init__(self):
"""Calls after the spec is created to extract the namespace, name and version from the id."""
# Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(self.id)
def make(self, **kwargs) -> Env:
def make(self, **kwargs: Any) -> Env:
"""Calls ``make`` using the environment spec and any keyword arguments."""
# For compatibility purposes
return make(self, **kwargs)
def _check_namespace_exists(ns: Optional[str]):
def _check_namespace_exists(ns: str | None):
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
if ns is None:
return
@@ -186,7 +179,7 @@ def _check_namespace_exists(ns: Optional[str]):
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
def _check_name_exists(ns: Optional[str], name: str):
def _check_name_exists(ns: str | None, name: str):
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
_check_namespace_exists(ns)
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
@@ -203,8 +196,9 @@ def _check_name_exists(ns: Optional[str], name: str):
)
def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
def _check_version_exists(ns: str | None, name: str, version: int | None):
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
This is a complete test whether an environment identifier is valid, and will provide the best available hints.
Args:
@@ -258,8 +252,9 @@ def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
)
def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
version: List[int] = [
def find_highest_version(ns: str | None, name: str) -> int | None:
"""Finds the highest registered version of the environment in the registry."""
version: list[int] = [
spec_.version
for spec_ in registry.values()
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
@@ -268,6 +263,11 @@ def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
"""Load modules (plugins) using the gymnasium entry points == to `entry_points`.
Args:
entry_point: The string for the entry point.
"""
# Load third-party environments
for plugin in metadata.entry_points(group=entry_point):
# Python 3.8 doesn't support plugin.module, plugin.attr
@@ -323,37 +323,37 @@ def make(id: EnvSpec, **kwargs) -> Env: ...
# Classic control
# ----------------------------------------
@overload
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Box2d
# ----------------------------------------
@overload
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
# Toy Text
# ----------------------------------------
@overload
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Mujoco
@@ -376,8 +376,8 @@ def make(id: Literal[
# Global registry of environments. Meant to be accessed through `register` and `make`
registry: Dict[str, EnvSpec] = {}
current_namespace: Optional[str] = None
registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
def _check_spec_register(spec: EnvSpec):
@@ -445,6 +445,7 @@ def _check_metadata(metadata_: dict):
@contextlib.contextmanager
def namespace(ns: str):
"""Context manager for modifying the current namespace."""
global current_namespace
old_namespace = current_namespace
current_namespace = ns
@@ -454,10 +455,10 @@ def namespace(ns: str):
def register(
id: str,
entry_point: Union[Callable, str],
reward_threshold: Optional[float] = None,
entry_point: Callable | str,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: Optional[int] = None,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
@@ -521,11 +522,11 @@ def register(
def make(
id: Union[str, EnvSpec],
max_episode_steps: Optional[int] = None,
id: str | EnvSpec,
max_episode_steps: int | None = None,
autoreset: bool = False,
apply_api_compatibility: Optional[bool] = None,
disable_env_checker: Optional[bool] = None,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None,
**kwargs,
) -> Env:
"""Create an environment according to the given ID.
@@ -706,9 +707,9 @@ def spec(env_id: str) -> EnvSpec:
def pprint_registry(
_registry: dict = registry,
num_cols: int = 3,
exclude_namespaces: Optional[List[str]] = None,
exclude_namespaces: list[str] | None = None,
disable_print: bool = False,
) -> Optional[str]:
) -> str | None:
"""Pretty print the environments in the registry.
Args:
@@ -718,7 +719,6 @@ def pprint_registry(
disable_print: Whether to return a string of all the namespaces and environment IDs
instead of printing it to console.
"""
# Defaultdict to store environment names according to namespace.
namespace_envs = defaultdict(lambda: [])
max_justify = float("-inf")

View File

@@ -19,14 +19,34 @@ from gymnasium.experimental.wrappers.lambda_observations import (
ReshapeObservationV0,
RescaleObservationV0,
DtypeObservationV0,
PixelObservationV0,
NormalizeObservationV0,
)
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
from gymnasium.experimental.wrappers.lambda_reward import (
ClipRewardV0,
LambdaRewardV0,
NormalizeRewardV0,
)
from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0
from gymnasium.experimental.wrappers.jax_to_torch import JaxToTorchV0
from gymnasium.experimental.wrappers.numpy_to_torch import NumpyToTorchV0
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
from gymnasium.experimental.wrappers.stateful_observation import (
TimeAwareObservationV0,
DelayObservationV0,
FrameStackObservationV0,
)
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
from gymnasium.experimental.wrappers.common import (
PassiveEnvCheckerV0,
OrderEnforcingV0,
AutoresetV0,
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.rendering import (
RenderCollectionV0,
RecordVideoV0,
HumanRenderingV0,
)
__all__ = [
@@ -39,12 +59,12 @@ __all__ = [
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
# "PixelObservationV0",
# "NormalizeObservationV0",
"PixelObservationV0",
"NormalizeObservationV0",
"TimeAwareObservationV0",
# "FrameStackV0",
"FrameStackObservationV0",
"DelayObservationV0",
# "AtariPreprocessingV0"
"AtariPreprocessingV0",
# --- Action Wrappers ---
"LambdaActionV0",
"ClipActionV0",
@@ -54,15 +74,18 @@ __all__ = [
# --- Reward wrappers ---
"LambdaRewardV0",
"ClipRewardV0",
# "RescaleRewardV0",
# "NormalizeRewardV0",
"NormalizeRewardV0",
# --- Common ---
# "AutoReset",
# "PassiveEnvChecker",
# "OrderEnforcing",
# "RecordEpisodeStatistics",
# "RenderCollection",
# "HumanRendering",
"AutoresetV0",
"PassiveEnvCheckerV0",
"OrderEnforcingV0",
"RecordEpisodeStatisticsV0",
# --- Rendering ---
"RenderCollectionV0",
"RecordVideoV0",
"HumanRenderingV0",
# --- Data Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]

View File

@@ -0,0 +1,193 @@
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
try:
import cv2
except ImportError:
cv2 = None
class AtariPreprocessingV0(gym.Wrapper):
"""Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018),
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
Specifically, the following preprocess stages applies to the atari environment:
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
- Frame skipping: The number of frames skipped between steps, 4 by default
- Max-pooling: Pools over the most recent two observations from the frame skips
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
Turned off by default. Not recommended by Machado et al. (2018).
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale.
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled.
"""
def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
):
"""Wrapper for Atari 2600 preprocessing.
Args:
env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
grayscale observations to make them 3-dimensional.
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
optimization benefits of FrameStack Wrapper.
Raises:
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
"""
super().__init__(env)
if cv2 is None:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1:
if (
env.spec is not None
and "NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one "
"frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
if grayscale_obs:
self.obs_buffer = [
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
]
else:
self.obs_buffer = [
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
]
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
@property
def ale(self):
"""Make ale as a class property to avoid serialization error."""
return self.env.unwrapped.ale
def step(self, action):
"""Applies the preprocessing for an :meth:`env.step`."""
total_reward, terminated, truncated, info = 0.0, False, False, {}
for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
self.game_over = terminated
if self.terminal_on_life_loss:
new_lives = self.ale.lives()
terminated = terminated or new_lives < self.lives
self.game_over = terminated
self.lives = new_lives
if terminated or truncated:
break
if t == self.frame_skip - 2:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[1])
else:
self.ale.getScreenRGB(self.obs_buffer[1])
elif t == self.frame_skip - 1:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), total_reward, terminated, truncated, info
def reset(self, **kwargs):
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(**kwargs)
noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(**kwargs)
self.lives = self.ale.lives()
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
return self._get_obs(), reset_info
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
assert cv2 is not None
obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
interpolation=cv2.INTER_AREA,
)
if self.scale_obs:
obs = np.asarray(obs, dtype=np.float32) / 255.0
else:
obs = np.asarray(obs, dtype=np.uint8)
if self.grayscale_obs and self.grayscale_newaxis:
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
return obs

View File

@@ -0,0 +1,281 @@
"""A collection of common wrappers.
* ``AutoresetV0`` - Auto-resets the environment
* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data
* ``OrderEnforcingV0`` - Enforces the order of function calls to environments
* ``RecordEpisodeStatisticsV0`` - Records the episode statistics
"""
from __future__ import annotations
import time
from collections import deque
from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
class AutoresetV0(gym.Wrapper):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
def __init__(self, env: gym.Env):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
"""
super().__init__(env)
self._episode_ended: bool = False
self._reset_options: dict[str, Any] | None = None
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
Args:
action: The action to take
Returns:
The autoreset environment :meth:`step`
"""
if self._episode_ended:
obs, info = super().reset(options=self._reset_options)
self._episode_ended = True
return obs, 0, False, False, info
else:
obs, reward, terminated, truncated, info = super().step(action)
self._episode_ended = terminated or truncated
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment, saving the options used."""
self._episode_ended = False
self._reset_options = options
return super().reset(seed=seed, options=self._reset_options)
class PassiveEnvCheckerV0(gym.Wrapper):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env: Env[ObsType, ActType]):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
super().__init__(env)
assert hasattr(
env, "action_space"
), "The environment must specify an action space. https://gymnasium.farama.org/content/environment_creation/"
check_action_space(env.action_space)
assert hasattr(
env, "observation_space"
), "The environment must specify an observation space. https://gymnasium.farama.org/content/environment_creation/"
check_observation_space(env.observation_space)
self._checked_reset: bool = False
self._checked_step: bool = False
self._checked_render: bool = False
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self._checked_step is False:
self._checked_step = True
return env_step_passive_checker(self.env, action)
else:
return self.env.step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self._checked_reset is False:
self._checked_reset = True
return env_reset_passive_checker(self.env, seed=seed, options=options)
else:
return self.env.reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
if self._checked_render is False:
self._checked_render = True
return env_render_passive_checker(self.env)
else:
return self.env.render()
class OrderEnforcingV0(gym.Wrapper):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
>>> from gymnasium.envs.classic_control import CartPoleEnv
>>> env = CartPoleEnv()
>>> env = OrderEnforcingV0(env)
>>> env.step(0)
ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render()
ResetNeeded: Cannot call env.render() before calling env.reset()
>>> env.reset()
>>> env.render()
>>> env.step(0)
"""
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args:
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
super().__init__(env)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment with `kwargs`."""
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment with `kwargs`."""
self._has_reset = True
return super().reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment with `kwargs`."""
if not self._disable_render_order_enforcing and not self._has_reset:
raise ResetNeeded(
"Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
)
return super().render()
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset
class RecordEpisodeStatisticsV0(gym.Wrapper):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... ...
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... ...
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes
episode_length_buffer: The lengths of the last ``deque_size``-many episodes
"""
def __init__(
self,
env: Env[ObsType, ActType],
buffer_length: int | None = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key for the episode statistics
"""
super().__init__(env)
self._stats_key = stats_key
self.episode_count = 0
self.episode_start_time: float = -1
self.episode_reward: float = -1
self.episode_length: int = -1
self.episode_time_length_buffer = deque(maxlen=buffer_length)
self.episode_reward_buffer = deque(maxlen=buffer_length)
self.episode_length_buffer = deque(maxlen=buffer_length)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, recording the episode statistics."""
obs, reward, terminated, truncated, info = super().step(action)
self.episode_reward += reward
self.episode_length += 1
if terminated or truncated:
assert self._stats_key not in info
episode_time_length = np.round(
time.perf_counter() - self.episode_start_time, 6
)
info[self._stats_key] = {
"r": self.episode_reward,
"l": self.episode_length,
"t": episode_time_length,
}
self.episode_time_length_buffer.append(episode_time_length)
self.episode_reward_buffer.append(self.episode_reward)
self.episode_length_buffer.append(self.episode_length)
self.episode_count += 1
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_time = time.perf_counter()
self.episode_reward = 0
self.episode_length = 0
return obs, info

View File

@@ -17,7 +17,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
try:

View File

@@ -1,13 +1,15 @@
"""A collection of observation wrappers using a lambda function.
* ``LambdaObservation`` - Transforms the observation with a function
* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
* ``FlattenObservation`` - Flattens the observations
* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation
* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation)
* ``ReshapeObservation`` - Reshapes an array-based observation
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservation`` - Convert a observation dtype
* ``LambdaObservationV0`` - Transforms the observation with a function
* ``FilterObservationV0`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
* ``FlattenObservationV0`` - Flattens the observations
* ``GrayscaleObservationV0`` - Converts a RGB observation to a grayscale observation
* ``ResizeObservationV0`` - Resizes an array-based observation (normally a RGB observation)
* ``ReshapeObservationV0`` - Reshapes an array-based observation
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservationV0`` - Convert an observation to a dtype
* ``PixelObservationV0`` - Allows the observation to the rendered frame
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
"""
from __future__ import annotations
@@ -18,10 +20,11 @@ import jumpy as jp
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ObsType
from gymnasium import Env, spaces
from gymnasium.core import ActType, ObservationWrapper, ObsType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box, utils
from gymnasium.experimental.wrappers.utils import RunningMeanStd
from gymnasium.spaces import Box, Dict, utils
class LambdaObservationV0(gym.ObservationWrapper):
@@ -407,3 +410,83 @@ class DtypeObservationV0(LambdaObservationV0):
)
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
class PixelObservationV0(LambdaObservationV0):
"""Augment observations by pixel values.
Observations of this wrapper will be dictionaries of images.
You can also choose to add the observation of the base environment to this dictionary.
In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary
of rendered images will be updated with the base environment's observation. If, however, the observation
space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box`
space) will be added to the dictionary under the key "state".
"""
def __init__(
self,
env: Env[ObsType, ActType],
pixels_only: bool = True,
pixels_key: str = "pixels",
obs_key: str = "state",
):
"""Initializes a new pixel Wrapper.
Args:
env: The environment to wrap.
pixels_only (bool): If `True` (default), the original observation returned
by the wrapped environment will be discarded, and a dictionary
observation will only include pixels. If `False`, the
observation dictionary will contain both the original
observations and the pixel observations.
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
obs_key: Optional custom string specifying the obs key. Defaults to "state"
"""
assert env.render_mode is not None and env.render_mode != "human"
env.reset()
pixels = env.render()
assert pixels is not None and isinstance(pixels, np.ndarray)
pixel_space = Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
if pixels_only:
obs_space = pixel_space
super().__init__(env, lambda _: self.render(), obs_space)
elif isinstance(env.observation_space, Dict):
assert pixels_key not in env.observation_space.spaces.keys()
obs_space = Dict({pixels_key: pixel_space, **env.observation_space.spaces})
super().__init__(
env, lambda obs: {pixels_key: self.render(), **obs_space}, obs_space
)
else:
obs_space = Dict({obs_key: env.observation_space, pixels_key: pixel_space})
super().__init__(
env, lambda obs: {obs_key: obs, pixels_key: self.render()}, obs_space
)
class NormalizeObservationV0(ObservationWrapper):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
super().__init__(env)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
def observation(self, observation: ObsType) -> WrapperObsType:
"""Normalises the observation using the running mean and variance of the observations."""
self.obs_rms.update(observation)
return (observation - self.obs_rms.mean) / np.sqrt(
self.obs_rms.var + self.epsilon
)

View File

@@ -6,12 +6,14 @@
from __future__ import annotations
from typing import Callable, SupportsFloat
from typing import Any, Callable, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers.utils import RunningMeanStd
class LambdaRewardV0(gym.RewardWrapper):
@@ -89,3 +91,48 @@ class ClipRewardV0(LambdaRewardV0):
)
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
class NormalizeRewardV0(gym.Wrapper):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently.
"""
def __init__(
self,
env: gym.Env,
gamma: float = 0.99,
epsilon: float = 1e-8,
):
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
Args:
env (env): The environment to apply the wrapper
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
"""
super().__init__(env)
self.rewards_running_means = RunningMeanStd(shape=())
self.discounted_reward: float = 0.0
self.gamma = gamma
self.epsilon = epsilon
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, normalizing the reward returned."""
obs, reward, terminated, truncated, info = super().step(action)
self.discounted_reward = self.discounted_reward * self.gamma * (
1 - terminated
) + float(reward)
return obs, self.normalize(float(reward)), terminated, truncated, info
def normalize(self, reward):
"""Normalizes the rewards with the running mean rewards and their variance."""
self.rewards_running_means.update(self.discounted_reward)
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)

View File

@@ -0,0 +1,158 @@
"""Helper functions and wrapper class for converting between PyTorch and NumPy."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import numpy as np
from gymnasium import Env, Wrapper
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
try:
import torch
Device = Union[str, torch.device]
except ImportError:
torch, Device = None, None
@functools.singledispatch
def torch_to_numpy(value: Any) -> Any:
"""Converts a PyTorch Tensor into a NumPy Array."""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
)
else:
raise Exception(
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
)
if torch is not None:
@torch_to_numpy.register(numbers.Number)
@torch_to_numpy.register(torch.Tensor)
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
return np.array(value)
@torch_to_numpy.register(abc.Mapping)
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
@torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_numpy(v) for v in value)
@functools.singledispatch
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `numpy_to_torch`, run `pip install torch`"
)
else:
raise Exception(
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
)
if torch is not None:
@numpy_to_torch.register(np.ndarray)
def _numpy_to_torch(
value: np.ndarray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
assert torch is not None
tensor = torch.tensor(value)
if device:
return tensor.to(device=device)
return tensor
@numpy_to_torch.register(abc.Mapping)
def _numpy_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
@numpy_to_torch.register(abc.Iterable)
def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(numpy_to_torch(v, device) for v in value)
class NumpyToTorchV0(Wrapper):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed, run `pip install torch`"
)
super().__init__(env)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as a PyTorch Tensor
Returns:
The next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
numpy_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
numpy_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_numpy(options)
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@@ -0,0 +1,219 @@
"""A collections of rendering-based wrappers.
* ``RenderCollectionV0`` - Collects rendered frames into a list
* ``RecordVideoV0`` - Records a video of the environments
* ``HumanRenderingV0`` - Provides human rendering of environments with ``"rgb_array"``
"""
from __future__ import annotations
from copy import deepcopy
from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
class RenderCollectionV0(gym.Wrapper):
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
def __init__(
self,
env: gym.Env[ObsType, ActType],
pop_frames: bool = True,
reset_clean: bool = True,
):
"""Initialize a :class:`RenderCollection` instance.
Args:
env: The environment that is being wrapped
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
"""
super().__init__(env)
assert env.render_mode is not None
assert not env.render_mode.endswith("_list")
self.frame_list: list[RenderFrame] = []
self.pop_frames = pop_frames
self.reset_clean = reset_clean
self.metadata = deepcopy(self.env.metadata)
if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]:
self.metadata["render_modes"].append(f"{self.env.render_mode}_list")
@property
def render_mode(self):
"""Returns the collection render_mode name."""
return f"{self.env.render_mode}_list"
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Perform a step in the base environment and collect a frame."""
output = super().step(action)
self.frame_list.append(super().render())
return output
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
output = super().reset(seed=seed, options=options)
if self.reset_clean:
self.frame_list = []
self.frame_list.append(super().render())
return output
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the collection of frames and, if pop_frames = True, clears it."""
frames = self.frame_list
if self.pop_frames:
self.frame_list = []
return frames
class RecordVideoV0(gym.Wrapper):
"""Record a video of an environment."""
pass
class HumanRenderingV0(gym.Wrapper):
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce
RGB images but haven't implemented any code to render the images to the screen.
If you want to use this wrapper with your environments, remember to specify ``"render_fps"``
in the metadata of your environment.
The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
Example:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> wrapped = HumanRenderingV0(env)
>>> wrapped.reset() # This will start rendering to the screen
The wrapper can also be applied directly when the environment is instantiated, simply by passing
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
Example:
>>> env = gym.make("NoNativeRendering-v2", render_mode="human") # NoNativeRendering-v0 doesn't implement human-rendering natively
>>> env.reset() # This will start rendering to the screen
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
will always return an empty list:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
>>> wrapped = HumanRenderingV0(env)
>>> wrapped.reset()
>>> env.render()
[] # env.render() will always return an empty list!
"""
def __init__(self, env):
"""Initialize a :class:`HumanRendering` instance.
Args:
env: The environment that is being wrapped
"""
super().__init__(env)
assert env.render_mode in [
"rgb_array",
"rgb_array_list",
], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'"
assert (
"render_fps" in env.metadata
), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper"
self.screen_size = None
self.window = None
self.clock = None
if "human" not in self.metadata["render_modes"]:
self.metadata = deepcopy(self.env.metadata)
self.metadata["render_modes"].append("human")
@property
def render_mode(self):
"""Always returns ``'human'``."""
return "human"
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Perform a step in the base environment and render a frame to the screen."""
result = super().step(action)
self._render_frame()
return result
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the base environment and render a frame to the screen."""
result = super().reset(seed=seed, options=options)
self._render_frame()
return result
def render(self):
"""This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
return None
def _render_frame(self):
"""Fetch the last frame from the base environment and render it to the screen."""
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[box2d]`"
)
if self.env.render_mode == "rgb_array_list":
last_rgb_array = self.env.render()
assert isinstance(last_rgb_array, list)
last_rgb_array = last_rgb_array[-1]
elif self.env.render_mode == "rgb_array":
last_rgb_array = self.env.render()
else:
raise Exception(
f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}"
)
assert isinstance(last_rgb_array, np.ndarray)
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
if self.screen_size is None:
self.screen_size = rgb_array.shape[:2]
assert (
self.screen_size == rgb_array.shape[:2]
), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}"
if self.window is None:
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode(self.screen_size)
if self.clock is None:
self.clock = pygame.time.Clock()
surf = pygame.surfarray.make_surface(rgb_array)
self.window.blit(surf, (0, 0))
pygame.event.pump()
self.clock.tick(self.metadata["render_fps"])
pygame.display.flip()
def close(self):
"""Close the rendering window."""
if self.window is not None:
import pygame
pygame.display.quit()
pygame.quit()
super().close()

View File

@@ -1,17 +1,14 @@
"""A collection of stateful action wrappers.
* StickyAction - There is a probability that the action is taken again
"""
"""``StickyAction`` wrapper - There is a probability that the action is taken again."""
from __future__ import annotations
from typing import Any, SupportsFloat
from typing import Any
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.core import ActionWrapper, ActType, WrapperActType, WrapperObsType
from gymnasium.error import InvalidProbability
class StickyActionV0(gym.Wrapper):
class StickyActionV0(ActionWrapper):
"""Wrapper which adds a probability of repeating the previous action.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
@@ -42,9 +39,7 @@ class StickyActionV0(gym.Wrapper):
return super().reset(seed=seed, options=options)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
def action(self, action: WrapperActType) -> ActType:
"""Execute the action."""
if (
self.last_action is not None

View File

@@ -1,7 +1,9 @@
"""A collection of stateful observation wrappers.
* DelayObservation - A wrapper for delaying the returned observation
* TimeAwareObservation - A wrapper for adding time aware observations to environment observation
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
* ``FrameStackObservationV0`` - Frame stack the observations
* ``AtariPreprocessingV0`` - Preprocessing wrapper for atari environments
"""
from __future__ import annotations
@@ -14,8 +16,10 @@ import numpy as np
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium import Env
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
class DelayObservationV0(gym.ObservationWrapper):
@@ -31,7 +35,9 @@ class DelayObservationV0(gym.ObservationWrapper):
returned observation is an array of zeros with the
same shape of the observation space.
"""
assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete))
assert isinstance(
env.observation_space, (Box, MultiBinary, MultiDiscrete)
), type(env.observation_space)
assert 0 < delay
self.delay: Final[int] = delay
@@ -134,9 +140,9 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
if isinstance(env.observation_space, Dict):
assert dict_time_key not in env.observation_space.keys()
observation_space = Dict(
{dict_time_key: time_space}, **env.observation_space.spaces
{dict_time_key: time_space, **env.observation_space.spaces}
)
self._append_data_func = lambda obs, time: {**obs, dict_time_key: time}
self._append_data_func = lambda obs, time: {dict_time_key: time, **obs}
elif isinstance(env.observation_space, Tuple):
observation_space = Tuple(env.observation_space.spaces + (time_space,))
self._append_data_func = lambda obs, time: obs + (time,)
@@ -198,3 +204,101 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
self.timesteps = 0
return super().reset(seed=seed, options=options)
class FrameStackObservationV0(gym.Wrapper):
"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
the most recent 4 observations. For environment 'Pendulum-v1', the original observation
is an array with shape [3], so if we stack 4 observations, the processed observation
has shape [4, 3].
Note:
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
Example:
>>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1')
>>> env = FrameStack(env, 4)
>>> env.observation_space
Box(4, 96, 96, 3)
>>> obs = env.reset()
>>> obs.shape
(4, 96, 96, 3)
"""
def __init__(self, env: Env[ObsType, ActType], stack_size: int):
"""Observation wrapper that stacks the observations in a rolling manner.
Args:
env: The environment to apply the wrapper
stack_size: The number of frames to stack
"""
assert np.issubdtype(type(stack_size), np.integer)
assert stack_size > 0
super().__init__(env)
self.observation_space = batch_space(env.observation_space, n=stack_size)
self.stack_size = stack_size
self.stacked_obs_array = create_empty_array(env.observation_space, n=stack_size)
self.stacked_obs = self._init_stacked_obs()
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, appending the observation to the frame buffer.
Args:
action: The action to step through the environment with
Returns:
Stacked observations, reward, terminated, truncated, and info from the environment
"""
obs, reward, terminated, truncated, info = super().step(action)
self.stacked_obs.rotate(1)
self.stacked_obs[0] = obs
return (
concatenate(
self.observation_space, self.stacked_obs, self.stacked_obs_array
),
reward,
terminated,
truncated,
info,
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment, returning the stacked observation and info.
Args:
seed: The environment seed
options: The reset options
Returns:
The stacked observations and info
"""
obs, info = super().reset(seed=seed, options=options)
self.stacked_obs = self._init_stacked_obs()
self.stacked_obs[0] = obs
return (
concatenate(
self.observation_space, self.stacked_obs, self.stacked_obs_array
),
info,
)
def _init_stacked_obs(self) -> deque:
return deque(
iterate(
self.observation_space,
create_empty_array(self.env.observation_space, n=self.stack_size),
)
)

View File

@@ -0,0 +1,43 @@
"""Utility functions for the wrappers."""
import numpy as np
class RunningMeanStd:
"""Tracks the mean, variance and count of values."""
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
"""Tracks the mean, variance and count of values."""
self.mean = np.zeros(shape, "float64")
self.var = np.ones(shape, "float64")
self.count = epsilon
def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count

View File

@@ -93,7 +93,7 @@ class NormalizeObservation(gym.Wrapper):
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
class NormalizeReward(gym.Wrapper):
class NormalizeReward(gym.core.Wrapper):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
@@ -129,10 +129,8 @@ class NormalizeReward(gym.Wrapper):
obs, rews, terminateds, truncateds, infos = self.env.step(action)
if not self.is_vector_env:
rews = np.array([rews])
self.returns = self.returns * self.gamma + rews
self.returns = self.returns * self.gamma * (1 - terminateds) + rews
rews = self.normalize(rews)
dones = np.logical_or(terminateds, truncateds)
self.returns[dones] = 0.0
if not self.is_vector_env:
rews = rews[0]
return obs, rews, terminateds, truncateds, infos

View File

@@ -0,0 +1 @@
"""Testing for Gymnasium."""

View File

@@ -0,0 +1 @@
"""Testing suite for ``gymnasium.experimental``."""

View File

@@ -0,0 +1 @@
"""Module for functional environment API."""

View File

@@ -1,7 +1,8 @@
"""Test the functional jax environment."""
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np
import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
@@ -9,7 +10,8 @@ from gymnasium.envs.phys2d.pendulum import PendulumFunctional
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_normal(env_class):
def test_without_transform(env_class):
"""Tests the environment without transforming the environment."""
env = env_class()
rng = jrng.PRNGKey(0)
@@ -42,6 +44,7 @@ def test_normal(env_class):
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_jit(env_class):
"""Tests jitting the functional instance functions."""
env = env_class()
rng = jrng.PRNGKey(0)
@@ -75,6 +78,7 @@ def test_jit(env_class):
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_vmap(env_class):
"""Tests vmap of functional instance functions with transform."""
env = env_class()
num_envs = 10
rng = jrng.split(jrng.PRNGKey(0), num_envs)
@@ -98,7 +102,7 @@ def test_vmap(env_class):
assert reward.shape == (num_envs,)
assert reward.dtype == jnp.float32
assert terminal.shape == (num_envs,)
assert terminal.dtype == np.bool
assert terminal.dtype == bool
assert isinstance(obs, jnp.ndarray)
assert obs.dtype == jnp.float32

View File

@@ -1,4 +1,7 @@
from typing import Any, Dict, Optional
"""Tests the functional api."""
from __future__ import annotations
from typing import Any
import numpy as np
@@ -6,29 +9,41 @@ from gymnasium.experimental import FuncEnv
class GenericTestFuncEnv(FuncEnv):
def __init__(self, options: Optional[Dict[str, Any]] = None):
"""Generic testing functional environment."""
def __init__(self, options: dict[str, Any] | None = None):
"""Constructor that allows generic options to be set on the environment."""
super().__init__(options)
def initial(self, rng: Any) -> np.ndarray:
"""Testing initial function."""
return np.array([0, 0], dtype=np.float32)
def observation(self, state: np.ndarray) -> np.ndarray:
"""Testing observation function."""
return state
def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
"""Testing transition function."""
return state + np.array([0, action], dtype=np.float32)
def reward(self, state: np.ndarray, action: int, next_state: np.ndarray) -> float:
"""Testing reward function."""
return 1.0 if next_state[1] > 0 else 0.0
def terminal(self, state: np.ndarray) -> bool:
"""Testing terminal function."""
return state[1] > 0
def test_api():
def test_functional_api():
"""Tests the core functional api specification using a generic testing environment."""
env = GenericTestFuncEnv()
state = env.initial(None)
obs = env.observation(state)
assert state.shape == (2,)
assert state.dtype == np.float32
assert obs.shape == (2,)

View File

@@ -0,0 +1 @@
"""Experimental wrapper module."""

View File

@@ -0,0 +1 @@
"""Test suite for HumanRenderingV0."""

View File

@@ -0,0 +1 @@
"""Test suite for AtariPreprocessingV0."""

View File

@@ -0,0 +1 @@
"""Test suite for AutoresetV0."""

View File

@@ -0,0 +1,25 @@
"""Test suite for ClipActionV0."""
import numpy as np
from gymnasium.experimental.wrappers import ClipActionV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import record_action_step
from tests.testing_env import GenericTestEnv
def test_clip_action_wrapper():
"""Test that the action is correctly clipped to the base environment action space."""
env = GenericTestEnv(
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
step_func=record_action_step,
)
wrapped_env = ClipActionV0(env)
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
assert sampled_action not in env.action_space
assert sampled_action in wrapped_env.action_space
_, _, _, _, info = wrapped_env.step(sampled_action)
assert np.all(info["action"] in env.action_space)
assert np.all(info["action"] == np.array([0, 2, 3.5]))

View File

@@ -0,0 +1,67 @@
"""Test suite for ClipRewardV0."""
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers import ClipRewardV0
from tests.envs.test_envs import SEED
from tests.experimental.wrappers.test_lambda_rewards import (
DISCRETE_ACTION,
ENV_ID,
NUM_ENVS,
)
@pytest.mark.parametrize(
("lower_bound", "upper_bound", "expected_reward"),
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
)
def test_clip_reward(lower_bound, upper_bound, expected_reward):
"""Test reward clipping.
Test if reward is correctly clipped accordingly to the input args.
"""
env = gym.make(ENV_ID)
env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
assert rew == expected_reward
@pytest.mark.parametrize(
("lower_bound", "upper_bound", "expected_reward"),
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
)
def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward):
"""Test reward clipping in vectorized environment.
Test if reward is correctly clipped accordingly to the input args in a vectorized environment.
"""
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(actions)
assert np.alltrue(rew == expected_reward)
@pytest.mark.parametrize(
("lower_bound", "upper_bound"),
[(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))],
)
def test_clip_reward_incorrect_params(lower_bound, upper_bound):
"""Test reward clipping with incorrect params.
Test whether passing wrong params to clip_rewards correctly raise an exception.
clip_rewards should raise an exception if, both low and upper bound of reward are `None`
or if upper bound is lower than lower bound.
"""
env = gym.make(ENV_ID)
with pytest.raises(InvalidBound):
ClipRewardV0(env, lower_bound, upper_bound)

View File

@@ -0,0 +1,34 @@
"""Test suite for DelayObservationV0."""
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0
from tests.experimental.wrappers.utils import DELAY, NUM_STEPS, SEED
def test_delay_observation_wrapper():
"""Tests the delay observation wrapper."""
env = gym.make("CartPole-v1")
env.action_space.seed(SEED)
env.reset(seed=SEED)
undelayed_observations = []
for _ in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_observations.append(obs)
env = DelayObservationV0(env, delay=DELAY)
env.action_space.seed(SEED)
env.reset(seed=SEED)
delayed_observations = []
for i in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
delayed_observations.append(obs)
if i < DELAY - 1:
assert np.all(obs == 0)
undelayed_observations = np.array(undelayed_observations)
delayed_observations = np.array(delayed_observations)
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])

View File

@@ -0,0 +1,25 @@
"""Test suite for DtypeObservationV0."""
import numpy as np
from gymnasium.experimental.wrappers import DtypeObservationV0
from tests.experimental.wrappers.utils import (
record_random_obs_reset,
record_random_obs_step,
)
from tests.testing_env import GenericTestEnv
def test_dtype_observation():
"""Test ``DtypeObservation`` that the dtype is corrected modified."""
env = GenericTestEnv(
reset_func=record_random_obs_reset, step_func=record_random_obs_step
)
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
obs, info = wrapped_env.reset()
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8
obs, _, _, _, info = wrapped_env.step(None)
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8

View File

@@ -0,0 +1,45 @@
"""Test suite for FilterObservationV0."""
from gymnasium.experimental.wrappers import FilterObservationV0
from gymnasium.spaces import Box, Dict, Tuple
from tests.experimental.wrappers.utils import (
check_obs,
record_random_obs_reset,
record_random_obs_step,
)
from tests.testing_env import GenericTestEnv
def test_filter_observation_wrapper():
"""Tests ``FilterObservation`` that the right keys are filtered."""
dict_env = GenericTestEnv(
observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
obs, info = wrapped_env.reset()
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
check_obs(dict_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
check_obs(dict_env, wrapped_env, obs, info["obs"])
# Test tuple environments
tuple_env = GenericTestEnv(
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
wrapped_env = FilterObservationV0(tuple_env, (2,))
obs, info = wrapped_env.reset()
assert len(obs) == 1 and len(info["obs"]) == 3
check_obs(tuple_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert len(obs) == 1 and len(info["obs"]) == 3
check_obs(tuple_env, wrapped_env, obs, info["obs"])

View File

@@ -0,0 +1,27 @@
"""Test suite for FlattenObservationV0."""
from gymnasium.experimental.wrappers import FlattenObservationV0
from gymnasium.spaces import Box, Dict
from tests.experimental.wrappers.utils import (
check_obs,
record_random_obs_reset,
record_random_obs_step,
)
from tests.testing_env import GenericTestEnv
def test_flatten_observation_wrapper():
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
env = GenericTestEnv(
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
print(env.observation_space)
wrapped_env = FlattenObservationV0(env)
print(wrapped_env.observation_space)
obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
check_obs(env, wrapped_env, obs, info["obs"])

View File

@@ -0,0 +1 @@
"""Test suite for FrameStackObservationV0."""

View File

@@ -0,0 +1,38 @@
"""Test suite for GrayscaleObservationV0."""
import numpy as np
from gymnasium.experimental.wrappers import GrayscaleObservationV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import (
check_obs,
record_random_obs_reset,
record_random_obs_step,
)
from tests.testing_env import GenericTestEnv
def test_grayscale_observation_wrapper():
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
wrapped_env = GrayscaleObservationV0(env)
obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25)
obs, _, _, _, info = wrapped_env.step(None)
check_obs(env, wrapped_env, obs, info["obs"])
# Keep_dim
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25, 1)
obs, _, _, _, info = wrapped_env.step(None)
check_obs(env, wrapped_env, obs, info["obs"])

View File

@@ -1,9 +1,11 @@
"""Test suite for JaxToNumpyV0."""
import jax.numpy as jnp
import numpy as np
import pytest
from gymnasium.experimental.wrappers import JaxToNumpyV0
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy, numpy_to_jax
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
from gymnasium.utils.env_checker import data_equivalence
from tests.testing_env import GenericTestEnv
@@ -40,10 +42,12 @@ def test_roundtripping(value, expected_value):
def jax_reset_func(self, seed=None, options=None):
"""A jax-based reset function."""
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def jax_step_func(self, action):
"""A jax-based step function."""
assert isinstance(action, jnp.DeviceArray), type(action)
return (
jnp.array([1, 2, 3]),
@@ -54,7 +58,8 @@ def jax_step_func(self, action):
)
def test_jax_to_numpy():
def test_jax_to_numpy_wrapper():
"""Tests the ``JaxToNumpyV0`` wrapper."""
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
# Check that the reset and step for jax environment are as expected

View File

@@ -1,10 +1,12 @@
"""Test suite for TorchToJaxV0."""
import jax.numpy as jnp
import numpy as np
import pytest
import torch
from gymnasium.experimental.wrappers import JaxToTorchV0
from gymnasium.experimental.wrappers.torch_to_jax import jax_to_torch, torch_to_jax
from gymnasium.experimental.wrappers.jax_to_torch import jax_to_torch, torch_to_jax
from tests.testing_env import GenericTestEnv
@@ -76,7 +78,8 @@ def _jax_step_func(self, action):
)
def test_jax_to_torch():
def test_jax_to_torch_wrapper():
"""Tests the `JaxToTorchV0` wrapper."""
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
# Check that the reset and step for jax environment are as expected

View File

@@ -1,25 +1,14 @@
"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction."""
import numpy as np
"""Test suite for LambdaActionV0."""
from gymnasium.experimental.wrappers import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
)
from gymnasium.experimental.wrappers import LambdaActionV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import record_action_step
from tests.testing_env import GenericTestEnv
SEED = 42
def _record_action_step_func(self, action):
return 0, 0, False, False, {"action": action}
def test_lambda_action_wrapper():
"""Tests LambdaAction through checking that the action taken is transformed by function."""
env = GenericTestEnv(step_func=_record_action_step_func)
env = GenericTestEnv(step_func=record_action_step)
wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
sampled_action = wrapped_env.action_space.sample()
@@ -28,51 +17,3 @@ def test_lambda_action_wrapper():
_, _, _, _, info = wrapped_env.step(sampled_action)
assert info["action"] in env.action_space
assert sampled_action - 2 == info["action"]
def test_clip_action_wrapper():
"""Test that the action is correctly clipped to the base environment action space."""
env = GenericTestEnv(
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
step_func=_record_action_step_func,
)
wrapped_env = ClipActionV0(env)
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
assert sampled_action not in env.action_space
assert sampled_action in wrapped_env.action_space
_, _, _, _, info = wrapped_env.step(sampled_action)
assert np.all(info["action"] in env.action_space)
assert np.all(info["action"] == np.array([0, 2, 3.5]))
def test_rescale_action_wrapper():
"""Test that the action is rescale within a min / max bound."""
env = GenericTestEnv(
step_func=_record_action_step_func,
action_space=Box(np.array([0, 1]), np.array([1, 3])),
)
wrapped_env = RescaleActionV0(
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
)
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
for sample_action, expected_action in (
(
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0], dtype=np.float32),
),
(
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0], dtype=np.float32),
),
(
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0], dtype=np.float32),
),
):
assert sample_action in wrapped_env.action_space
_, _, _, _, info = wrapped_env.step(sample_action)
assert np.all(info["action"] == expected_action)

View File

@@ -1,250 +1,26 @@
"""Test suite for lambda observation wrappers: """
"""Test suite for lambda observation wrappers."""
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
from gymnasium.experimental.wrappers import LambdaObservationV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import (
check_obs,
record_action_as_obs_step,
record_obs_reset,
)
from gymnasium.spaces import Box, Dict, Tuple
from tests.testing_env import GenericTestEnv
SEED = 42
def _record_random_obs_reset(self: gym.Env, seed=None, options=None):
obs = self.observation_space.sample()
return obs, {"obs": obs}
def _record_random_obs_step(self: gym.Env, action):
obs = self.observation_space.sample()
return obs, 0, False, False, {"obs": obs}
def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}):
return options["obs"], {"obs": options["obs"]}
def _record_action_obs_step(self: gym.Env, action):
return action, 0, False, False, {"obs": action}
def _check_obs(
env: gym.Env,
wrapped_env: gym.Wrapper,
transformed_obs,
original_obs,
strict: bool = True,
):
assert (
transformed_obs in wrapped_env.observation_space
), f"{transformed_obs}, {wrapped_env.observation_space}"
assert (
original_obs in env.observation_space
), f"{original_obs}, {env.observation_space}"
if strict:
assert (
transformed_obs not in env.observation_space
), f"{transformed_obs}, {env.observation_space}"
assert (
original_obs not in wrapped_env.observation_space
), f"{original_obs}, {wrapped_env.observation_space}"
def test_lambda_observation_wrapper():
"""Tests lambda observation that the function is applied to both the reset and step observation."""
env = GenericTestEnv(
reset_func=_record_action_obs_reset, step_func=_record_action_obs_step
reset_func=record_obs_reset, step_func=record_action_as_obs_step
)
wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3))
wrapped_env = LambdaObservationV0(env, lambda _obs: _obs + 2, Box(2, 3))
obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)})
_check_obs(env, wrapped_env, obs, info["obs"])
check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32))
_check_obs(env, wrapped_env, obs, info["obs"])
def test_filter_observation_wrapper():
"""Tests ``FilterObservation`` that the right keys are filtered."""
dict_env = GenericTestEnv(
observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
obs, info = wrapped_env.reset()
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
_check_obs(dict_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert list(obs.keys()) == ["arm_1", "arm_3"]
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
_check_obs(dict_env, wrapped_env, obs, info["obs"])
# Test tuple environments
tuple_env = GenericTestEnv(
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = FilterObservationV0(tuple_env, (2,))
obs, info = wrapped_env.reset()
assert len(obs) == 1 and len(info["obs"]) == 3
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
assert len(obs) == 1 and len(info["obs"]) == 3
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
def test_flatten_observation_wrapper():
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
env = GenericTestEnv(
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
print(env.observation_space)
wrapped_env = FlattenObservationV0(env)
print(wrapped_env.observation_space)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_grayscale_observation_wrapper():
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = GrayscaleObservationV0(env)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
# Keep_dim
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (25, 25, 1)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_resize_observation_wrapper():
"""Test the ``ResizeObservation`` that the observation has changed size"""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = ResizeObservationV0(env, (25, 25))
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
def test_reshape_observation_wrapper():
"""Test the ``ReshapeObservation`` wrapper."""
env = GenericTestEnv(
observation_space=Box(0, 1, shape=(2, 3, 2)),
reset_func=_record_random_obs_reset,
step_func=_record_random_obs_step,
)
wrapped_env = ReshapeObservationV0(env, (6, 2))
obs, info = wrapped_env.reset()
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)
obs, _, _, _, info = wrapped_env.step(None)
_check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)
def test_rescale_observation():
"""Test the ``RescaleObservation`` wrapper"""
env = GenericTestEnv(
observation_space=Box(
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
),
reset_func=_record_action_obs_reset,
step_func=_record_action_obs_step,
)
wrapped_env = RescaleObservationV0(
env,
min_obs=np.array([-5, 0], dtype=np.float32),
max_obs=np.array([5, 1], dtype=np.float32),
)
assert wrapped_env.observation_space == Box(
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
)
for sample_obs, expected_obs in (
(
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5], dtype=np.float32),
),
(
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0], dtype=np.float32),
),
(
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0], dtype=np.float32),
),
):
assert sample_obs in env.observation_space
assert expected_obs in wrapped_env.observation_space
obs, info = wrapped_env.reset(options={"obs": sample_obs})
assert np.all(obs == expected_obs)
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
obs, _, _, _, info = wrapped_env.step(sample_obs)
assert np.all(obs == expected_obs)
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
def test_dtype_observation():
"""Test ``DtypeObservation`` that the"""
env = GenericTestEnv(
reset_func=_record_random_obs_reset, step_func=_record_random_obs_step
)
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
obs, info = wrapped_env.reset()
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8
obs, _, _, _, info = wrapped_env.step(None)
assert obs.dtype != info["obs"].dtype
assert obs.dtype == np.uint8
check_obs(env, wrapped_env, obs, info["obs"])

View File

@@ -4,14 +4,8 @@ import numpy as np
import pytest
import gymnasium as gym
from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers import ClipRewardV0, LambdaRewardV0
ENV_ID = "CartPole-v1"
DISCRETE_ACTION = 0
NUM_ENVS = 3
SEED = 0
from gymnasium.experimental.wrappers import LambdaRewardV0
from tests.experimental.wrappers.utils import DISCRETE_ACTION, ENV_ID, NUM_ENVS, SEED
@pytest.mark.parametrize(
@@ -54,57 +48,3 @@ def test_lambda_reward_within_vector(reward_fn, expected_reward):
_, rew, _, _, _ = env.step(actions)
assert np.alltrue(rew == expected_reward)
@pytest.mark.parametrize(
("lower_bound", "upper_bound", "expected_reward"),
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
)
def test_clip_reward(lower_bound, upper_bound, expected_reward):
"""Test reward clipping.
Test if reward is correctly clipped
accordingly to the input args.
"""
env = gym.make(ENV_ID)
env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
assert rew == expected_reward
@pytest.mark.parametrize(
("lower_bound", "upper_bound", "expected_reward"),
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
)
def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward):
"""Test reward clipping in vectorized environment.
Test if reward is correctly clipped
accordingly to the input args in a vectorized environment.
"""
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(actions)
assert np.alltrue(rew == expected_reward)
@pytest.mark.parametrize(
("lower_bound", "upper_bound"),
[(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))],
)
def test_clip_reward_incorrect_params(lower_bound, upper_bound):
"""Test reward clipping with incorrect params.
Test whether passing wrong params to clip_rewards
correctly raise an exception.
clip_rewards should raise an exception if, both low and upper
bound of reward are `None` or if upper bound is lower than lower bound.
"""
env = gym.make(ENV_ID)
with pytest.raises(InvalidBound):
env = ClipRewardV0(env, lower_bound, upper_bound)

View File

@@ -0,0 +1 @@
"""Test suite for NormalizeObservationV0."""

View File

@@ -0,0 +1 @@
"""Test suite for NormalizeRewardV0."""

View File

@@ -0,0 +1 @@
"""Test suite for NumpyToTorchV0."""

View File

@@ -0,0 +1 @@
"""Test suite for OrderEnforcingV0."""

View File

@@ -0,0 +1 @@
"""Test suite for PassiveEnvCheckerV0."""

View File

@@ -0,0 +1 @@
"""Test suite for PixelObservationV0."""

View File

@@ -0,0 +1 @@
"""Test suite for RecordEpisodeStatisticsV0."""

View File

@@ -0,0 +1 @@
"""Test suite for RecordVideoV0."""

View File

@@ -0,0 +1 @@
"""Test suite for RenderCollectionV0."""

View File

@@ -0,0 +1,38 @@
"""Test suite for RescaleActionV0."""
import numpy as np
from gymnasium.experimental.wrappers import RescaleActionV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import record_action_step
from tests.testing_env import GenericTestEnv
def test_rescale_action_wrapper():
"""Test that the action is rescale within a min / max bound."""
env = GenericTestEnv(
step_func=record_action_step,
action_space=Box(np.array([0, 1]), np.array([1, 3])),
)
wrapped_env = RescaleActionV0(
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
)
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
for sample_action, expected_action in (
(
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0], dtype=np.float32),
),
(
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0], dtype=np.float32),
),
(
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0], dtype=np.float32),
),
):
assert sample_action in wrapped_env.action_space
_, _, _, _, info = wrapped_env.step(sample_action)
assert np.all(info["action"] == expected_action)

View File

@@ -0,0 +1,55 @@
"""Test suite for RescaleObservationV0."""
import numpy as np
from gymnasium.experimental.wrappers import RescaleObservationV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import (
check_obs,
record_action_as_obs_step,
record_obs_reset,
)
from tests.testing_env import GenericTestEnv
def test_rescale_observation():
"""Test the ``RescaleObservation`` wrapper."""
env = GenericTestEnv(
observation_space=Box(
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
),
reset_func=record_obs_reset,
step_func=record_action_as_obs_step,
)
wrapped_env = RescaleObservationV0(
env,
min_obs=np.array([-5, 0], dtype=np.float32),
max_obs=np.array([5, 1], dtype=np.float32),
)
assert wrapped_env.observation_space == Box(
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
)
for sample_obs, expected_obs in (
(
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5], dtype=np.float32),
),
(
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0], dtype=np.float32),
),
(
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0], dtype=np.float32),
),
):
assert sample_obs in env.observation_space
assert expected_obs in wrapped_env.observation_space
obs, info = wrapped_env.reset(options={"obs": sample_obs})
assert np.all(obs == expected_obs)
check_obs(env, wrapped_env, obs, info["obs"], strict=False)
obs, _, _, _, info = wrapped_env.step(sample_obs)
assert np.all(obs == expected_obs)
check_obs(env, wrapped_env, obs, info["obs"], strict=False)

View File

@@ -0,0 +1,27 @@
"""Test suite for ReshapeObservationv0."""
from gymnasium.experimental.wrappers import ReshapeObservationV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import (
check_obs,
record_random_obs_reset,
record_random_obs_step,
)
from tests.testing_env import GenericTestEnv
def test_reshape_observation_wrapper():
"""Test the ``ReshapeObservation`` wrapper."""
env = GenericTestEnv(
observation_space=Box(0, 1, shape=(2, 3, 2)),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
wrapped_env = ReshapeObservationV0(env, (6, 2))
obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)
obs, _, _, _, info = wrapped_env.step(None)
check_obs(env, wrapped_env, obs, info["obs"])
assert obs.shape == (6, 2)

View File

@@ -0,0 +1,27 @@
"""Test suite for ResizeObservationV0."""
import numpy as np
from gymnasium.experimental.wrappers import ResizeObservationV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import (
check_obs,
record_random_obs_reset,
record_random_obs_step,
)
from tests.testing_env import GenericTestEnv
def test_resize_observation_wrapper():
"""Test the ``ResizeObservation`` that the observation has changed size."""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
wrapped_env = ResizeObservationV0(env, (25, 25))
obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
check_obs(env, wrapped_env, obs, info["obs"])

View File

@@ -1,43 +1,34 @@
"""Test suite for StickyActionV0."""
import numpy as np
import pytest
from gymnasium.error import InvalidProbability
from gymnasium.experimental.wrappers import StickyActionV0
from tests.experimental.wrappers.utils import NUM_STEPS, record_action_as_obs_step
from tests.testing_env import GenericTestEnv
SEED = 42
DELAY = 3
NUM_STEPS = 10
def step_fn(self, action):
return action
def test_sticky_action():
"""Tests the sticky action wrapper."""
env = StickyActionV0(
GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5
GenericTestEnv(step_func=record_action_as_obs_step),
repeat_action_probability=0.5,
)
env.reset(seed=SEED)
env.action_space.seed(SEED)
previous_action = None
for _ in range(NUM_STEPS):
input_action = env.action_space.sample()
executed_action = env.step(input_action)
executed_action, _, _, _, _ = env.step(input_action)
if executed_action != input_action:
assert executed_action == previous_action
else:
assert executed_action == input_action
previous_action = input_action
assert np.all(executed_action == input_action) or np.all(
executed_action == previous_action
)
previous_action = executed_action
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability):
"""Tests the stick action wrapper with probabilities that should raise an error."""
with pytest.raises(InvalidProbability):
StickyActionV0(
GenericTestEnv(), repeat_action_probability=repeat_action_probability

View File

@@ -1,19 +1,10 @@
"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation."""
"""Test suite for TimeAwareObservationV0."""
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0
from gymnasium.experimental.wrappers import TimeAwareObservationV0
from gymnasium.spaces import Box, Dict, Tuple
from tests.testing_env import GenericTestEnv
NUM_STEPS = 20
SEED = 0
DELAY = 3
def test_time_aware_observation_wrapper():
"""Tests the time aware observation wrapper."""
# Test the environment observation space with Dict, Tuple and other
@@ -60,30 +51,3 @@ def test_time_aware_observation_wrapper():
reset_obs, _ = wrapped_env.reset()
step_obs, _, _, _, _ = wrapped_env.step(None)
assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01
def test_delay_observation_wrapper():
env = gym.make("CartPole-v1")
env.action_space.seed(SEED)
env.reset(seed=SEED)
undelayed_observations = []
for _ in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_observations.append(obs)
env = DelayObservationV0(env, delay=DELAY)
env.action_space.seed(SEED)
env.reset(seed=SEED)
delayed_observations = []
for i in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
delayed_observations.append(obs)
if i < DELAY - 1:
assert np.all(obs == 0)
undelayed_observations = np.array(undelayed_observations)
delayed_observations = np.array(delayed_observations)
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])

View File

@@ -0,0 +1,69 @@
"""Utility functions for testing the experimental wrappers."""
import gymnasium as gym
SEED = 42
ENV_ID = "CartPole-v1"
DISCRETE_ACTION = 0
NUM_ENVS = 3
NUM_STEPS = 20
DELAY = 3
def record_obs_reset(self: gym.Env, seed=None, options: dict = None):
"""Records and uses an observation passed through options."""
return options["obs"], {"obs": options["obs"]}
def record_random_obs_reset(self: gym.Env, seed=None, options=None):
"""Records random observation generated by the environment."""
obs = self.observation_space.sample()
return obs, {"obs": obs}
def record_action_step(self: gym.Env, action):
"""Records the actions passed to the environment."""
return 0, 0, False, False, {"action": action}
def record_random_obs_step(self: gym.Env, action):
"""Records the observation generated by the environment."""
obs = self.observation_space.sample()
return obs, 0, False, False, {"obs": obs}
def record_action_as_obs_step(self: gym.Env, action):
"""Uses the action as the observation."""
return action, 0, False, False, {"obs": action}
def check_obs(
env: gym.Env,
wrapped_env: gym.Wrapper,
transformed_obs,
original_obs,
strict: bool = True,
):
"""Checks that the original and transformed observations using the environment and wrapped environment.
Args:
env: The base environment
wrapped_env: The wrapped environment
transformed_obs: The transformed observation by the wrapped environment
original_obs: The original observation by the base environment.
strict: If to check that the observations aren't contained in the other environment.
"""
assert (
transformed_obs in wrapped_env.observation_space
), f"{transformed_obs}, {wrapped_env.observation_space}"
assert (
original_obs in env.observation_space
), f"{original_obs}, {env.observation_space}"
if strict:
assert (
transformed_obs not in env.observation_space
), f"{transformed_obs}, {env.observation_space}"
assert (
original_obs not in wrapped_env.observation_space
), f"{original_obs}, {wrapped_env.observation_space}"

View File

@@ -23,36 +23,44 @@ from tests.testing_env import GenericTestEnv
class ArgumentEnv(Env):
"""Testing environment that records the number of times the environment is created."""
observation_space = spaces.Box(low=0, high=1, shape=(1,))
action_space = spaces.Box(low=0, high=1, shape=(1,))
calls = 0
def __init__(self, arg):
def __init__(self, arg: Any):
"""Constructor."""
self.calls += 1
self.arg = arg
class UnittestEnv(Env):
"""Example testing environment."""
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
"""Resets the environment."""
super().reset(seed=seed)
return self.observation_space.sample(), {"info": "dummy"}
def step(self, action):
"""Steps through the environment."""
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})
return observation, 0.0, False, {}
class UnknownSpacesEnv(Env):
"""This environment defines its observation & action spaces only
after the first call to reset. Although this pattern is sometimes
necessary when implementing a new environment (e.g. if it depends
"""This environment defines its observation & action spaces only after the first call to reset.
Although this pattern is sometimes necessary when implementing a new environment (e.g. if it depends
on external resources), it is not encouraged.
"""
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
"""Resets the environment."""
super().reset(seed=seed)
self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
@@ -61,25 +69,27 @@ class UnknownSpacesEnv(Env):
return self.observation_space.sample(), {} # Dummy observation with info
def step(self, action):
"""Steps through the environment."""
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})
return observation, 0.0, False, {}
class OldStyleEnv(Env):
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)"""
def __init__(self):
pass
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)."""
def reset(self):
"""Resets the environment."""
super().reset()
return 0
def step(self, action):
"""Steps through the environment."""
return 0, 0, False, {}
class NewPropertyWrapper(Wrapper):
"""Wrapper that tests setting a property."""
def __init__(
self,
env,
@@ -88,6 +98,15 @@ class NewPropertyWrapper(Wrapper):
reward_range=None,
metadata=None,
):
"""New property wrapper.
Args:
env: The environment to wrap
observation_space: The observation space
action_space: The action space
reward_range: The reward range
metadata: The environment metadata
"""
super().__init__(env)
if observation_space is not None:
# Only set the observation space if not None to test property forwarding
@@ -101,6 +120,7 @@ class NewPropertyWrapper(Wrapper):
def test_env_instantiation():
"""Tests the environment instantiation using ArgumentEnv."""
# This looks like a pretty trivial, but given our usage of
# __new__, it's worth having.
env = ArgumentEnv("arg")
@@ -129,6 +149,7 @@ properties = [
@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv])
@pytest.mark.parametrize("props", properties)
def test_wrapper_property_forwarding(class_, props):
"""Tests wrapper property forwarding."""
env = class_()
env = NewPropertyWrapper(env, **props)
@@ -147,6 +168,7 @@ def test_wrapper_property_forwarding(class_, props):
def test_compatibility_with_old_style_env():
"""Test compatibility with old style environment."""
env = OldStyleEnv()
env = OrderEnforcing(env)
env = TimeLimit(env)
@@ -158,13 +180,17 @@ def test_compatibility_with_old_style_env():
class ExampleEnv(Env):
"""Example testing environment."""
def __init__(self):
"""Constructor for example environment."""
self.observation_space = Box(0, 1)
self.action_space = Box(0, 1)
def step(
self, action: ActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
"""Steps through the environment."""
return 0, 0, False, False, {}
def reset(
@@ -173,10 +199,12 @@ class ExampleEnv(Env):
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
"""Resets the environment."""
return 0, {}
def test_gymnasium_env():
"""Tests a gymnasium environment."""
env = ExampleEnv()
assert env.metadata == {"render_modes": []}
@@ -187,7 +215,10 @@ def test_gymnasium_env():
class ExampleWrapper(Wrapper):
"""An example testing wrapper."""
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor that sets the reward."""
super().__init__(env)
self.new_reward = 3
@@ -195,11 +226,13 @@ class ExampleWrapper(Wrapper):
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[WrapperObsType, Dict[str, Any]]:
"""Resets the environment ."""
return super().reset(seed=seed, options=options)
def step(
self, action: WrapperActType
) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]:
"""Steps through the environment."""
obs, reward, termination, truncation, info = self.env.step(action)
return obs, self.new_reward, termination, truncation, info
@@ -209,6 +242,7 @@ class ExampleWrapper(Wrapper):
def test_gymnasium_wrapper():
"""Tests the gymnasium wrapper works as expected."""
env = ExampleEnv()
wrapper_env = ExampleWrapper(env)
@@ -250,21 +284,31 @@ def test_gymnasium_wrapper():
class ExampleRewardWrapper(RewardWrapper):
"""Example reward wrapper for testing."""
def reward(self, reward: SupportsFloat) -> SupportsFloat:
"""Reward function."""
return 1
class ExampleObservationWrapper(ObservationWrapper):
"""Example observation wrapper for testing."""
def observation(self, observation: ObsType) -> ObsType:
"""Observation function."""
return np.array([1])
class ExampleActionWrapper(ActionWrapper):
"""Example action wrapper for testing."""
def action(self, action: ActType) -> ActType:
"""Action function."""
return np.array([1])
def test_wrapper_types():
"""Tests the observation, action and reward wrapper examples."""
env = GenericTestEnv()
reward_env = ExampleRewardWrapper(env)

View File

@@ -1,6 +1,8 @@
"""Provides a generic testing environment for use in tests with custom reset, step and render functions."""
from __future__ import annotations
import types
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any
import gymnasium as gym
from gymnasium import spaces
@@ -11,21 +13,21 @@ from gymnasium.envs.registration import EnvSpec
def basic_reset_func(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
seed: int | None = None,
options: dict | None = None,
) -> ObsType | tuple[ObsType, dict]:
"""A basic reset function that will pass the environment check using random actions from the observation space."""
super(GenericTestEnv, self).reset(seed=seed)
self.observation_space.seed(seed)
return self.observation_space.sample(), {"options": options}
def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
def new_step_func(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, False, {}
def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
def old_step_func(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, {}
@@ -46,12 +48,24 @@ class GenericTestEnv(gym.Env):
reset_func: callable = basic_reset_func,
step_func: callable = new_step_func,
render_func: callable = basic_render_func,
metadata: Dict[str, Any] = {"render_modes": []},
render_mode: Optional[str] = None,
metadata: dict[str, Any] = {"render_modes": []},
render_mode: str | None = None,
spec: EnvSpec = EnvSpec(
"TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100
),
):
"""Generic testing environment constructor.
Args:
action_space: The environment action space
observation_space: The environment observation space
reset_func: The environment reset function
step_func: The environment step function
render_func: The environment render function
metadata: The environment metadata
render_mode: The render mode of the environment
spec: The environment spec
"""
self.metadata = metadata
self.render_mode = render_mode
self.spec = spec
@@ -71,14 +85,17 @@ class GenericTestEnv(gym.Env):
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
seed: int | None = None,
options: dict | None = None,
) -> ObsType | tuple[ObsType, dict]:
"""Resets the environment."""
# If you need a default working reset function, use `basic_reset_fn` above
raise NotImplementedError("TestingEnv reset_fn is not set.")
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict[str, Any]]:
"""Steps through the environment."""
raise NotImplementedError("TestingEnv step_fn is not set.")
def render(self):
"""Renders the environment."""
raise NotImplementedError("testingEnv render_fn is not set.")