mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Add wrappers to experimental (#201)
This commit is contained in:
@@ -51,7 +51,7 @@ repos:
|
|||||||
rev: 6.1.1
|
rev: 6.1.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pydocstyle
|
- id: pydocstyle
|
||||||
exclude: ^(gymnasium/envs/)|(tests/)|(docs/)
|
exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/)
|
||||||
args:
|
args:
|
||||||
- --source
|
- --source
|
||||||
- --explain
|
- --explain
|
||||||
|
@@ -14,7 +14,9 @@ experimental/vector_wrappers
|
|||||||
|
|
||||||
## Functional Environments
|
## Functional Environments
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
|
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
|
||||||
|
```
|
||||||
|
|
||||||
## Wrappers
|
## Wrappers
|
||||||
|
|
||||||
@@ -36,64 +38,32 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
|
|
||||||
* - Old name
|
* - Old name
|
||||||
- New name
|
- New name
|
||||||
- Vector version
|
|
||||||
- Tree structure
|
|
||||||
* - :class:`wrappers.TransformObservation`
|
* - :class:`wrappers.TransformObservation`
|
||||||
- :class:`experimental.wrappers.LambdaObservationV0`
|
- :class:`experimental.wrappers.LambdaObservationV0`
|
||||||
- VectorLambdaObservation
|
|
||||||
- No
|
|
||||||
* - :class:`wrappers.FilterObservation`
|
* - :class:`wrappers.FilterObservation`
|
||||||
- :class:`experimental.wrappers.FilterObservationV0`
|
- :class:`experimental.wrappers.FilterObservationV0`
|
||||||
- VectorFilterObservation (*)
|
|
||||||
- Yes
|
|
||||||
* - :class:`wrappers.FlattenObservation`
|
* - :class:`wrappers.FlattenObservation`
|
||||||
- :class:`experimental.wrappers.FlattenObservationV0`
|
- :class:`experimental.wrappers.FlattenObservationV0`
|
||||||
- VectorFlattenObservation (*)
|
|
||||||
- No
|
|
||||||
* - :class:`wrappers.GrayScaleObservation`
|
* - :class:`wrappers.GrayScaleObservation`
|
||||||
- :class:`experimental.wrappers.GrayscaleObservationV0`
|
- :class:`experimental.wrappers.GrayscaleObservationV0`
|
||||||
- VectorGrayscaleObservation (*)
|
|
||||||
- Yes
|
|
||||||
* - :class:`wrappers.ResizeObservation`
|
* - :class:`wrappers.ResizeObservation`
|
||||||
- :class:`experimental.wrappers.ResizeObservationV0`
|
- :class:`experimental.wrappers.ResizeObservationV0`
|
||||||
- VectorResizeObservation (*)
|
* - ``supersuit.reshape_v0``
|
||||||
- Yes
|
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.ReshapeObservationV0`
|
- :class:`experimental.wrappers.ReshapeObservationV0`
|
||||||
- VectorReshapeObservation (*)
|
|
||||||
- Yes
|
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- :class:`experimental.wrappers.RescaleObservationV0`
|
- :class:`experimental.wrappers.RescaleObservationV0`
|
||||||
- VectorRescaleObservation (*)
|
* - ``supersuit.dtype_v0``
|
||||||
- Yes
|
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.DtypeObservationV0`
|
- :class:`experimental.wrappers.DtypeObservationV0`
|
||||||
- VectorDtypeObservation (*)
|
|
||||||
- Yes
|
|
||||||
* - :class:`wrappers.PixelObservationWrapper`
|
* - :class:`wrappers.PixelObservationWrapper`
|
||||||
- PixelObservation
|
- :class:`experimental.wrappers.PixelObservationV0`
|
||||||
- VectorPixelObservation
|
|
||||||
- No
|
|
||||||
* - :class:`wrappers.NormalizeObservation`
|
* - :class:`wrappers.NormalizeObservation`
|
||||||
- NormalizeObservation
|
- :class:`experimental.wrappers.NormalizeObservationV0`
|
||||||
- VectorNormalizeObservation
|
|
||||||
- No
|
|
||||||
* - :class:`wrappers.TimeAwareObservation`
|
* - :class:`wrappers.TimeAwareObservation`
|
||||||
- :class:`experimental.wrappers.TimeAwareObservationV0`
|
- :class:`experimental.wrappers.TimeAwareObservationV0`
|
||||||
- VectorTimeAwareObservation
|
|
||||||
- No
|
|
||||||
* - :class:`wrappers.FrameStack`
|
* - :class:`wrappers.FrameStack`
|
||||||
- FrameStackObservation
|
- :class:`experimental.wrappers.FrameStackObservationV0`
|
||||||
- VectorFrameStackObservation
|
* - ``supersuit.delay_observations_v0``
|
||||||
- No
|
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.DelayObservationV0`
|
- :class:`experimental.wrappers.DelayObservationV0`
|
||||||
- VectorDelayObservation
|
|
||||||
- No
|
|
||||||
* - :class:`wrappers.AtariPreprocessing`
|
|
||||||
- AtariPreprocessing
|
|
||||||
- Not Implemented
|
|
||||||
- No
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Action Wrappers
|
### Action Wrappers
|
||||||
@@ -105,24 +75,14 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
|
|
||||||
* - Old name
|
* - Old name
|
||||||
- New name
|
- New name
|
||||||
- Vector version
|
* - ``supersuit.action_lambda_v1``
|
||||||
- Tree structure
|
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.LambdaActionV0`
|
- :class:`experimental.wrappers.LambdaActionV0`
|
||||||
- VectorLambdaAction
|
|
||||||
- No
|
|
||||||
* - :class:`wrappers.ClipAction`
|
* - :class:`wrappers.ClipAction`
|
||||||
- :class:`experimental.wrappers.ClipActionV0`
|
- :class:`experimental.wrappers.ClipActionV0`
|
||||||
- VectorClipAction (*)
|
|
||||||
- Yes
|
|
||||||
* - :class:`wrappers.RescaleAction`
|
* - :class:`wrappers.RescaleAction`
|
||||||
- :class:`experimental.wrappers.RescaleActionV0`
|
- :class:`experimental.wrappers.RescaleActionV0`
|
||||||
- VectorRescaleAction (*)
|
* - ``supersuit.sticky_actions_v0``
|
||||||
- Yes
|
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.StickyActionV0`
|
- :class:`experimental.wrappers.StickyActionV0`
|
||||||
- VectorStickyAction
|
|
||||||
- No
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Reward Wrappers
|
### Reward Wrappers
|
||||||
@@ -134,19 +94,12 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
|
|
||||||
* - Old name
|
* - Old name
|
||||||
- New name
|
- New name
|
||||||
- Vector version
|
|
||||||
* - :class:`wrappers.TransformReward`
|
* - :class:`wrappers.TransformReward`
|
||||||
- :class:`experimental.wrappers.LambdaRewardV0`
|
- :class:`experimental.wrappers.LambdaRewardV0`
|
||||||
- VectorLambdaReward
|
* - ``supersuit.clip_reward_v0``
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.ClipRewardV0`
|
- :class:`experimental.wrappers.ClipRewardV0`
|
||||||
- VectorClipReward (*)
|
|
||||||
* - Not Implemented
|
|
||||||
- RescaleReward
|
|
||||||
- VectorRescaleReward (*)
|
|
||||||
* - :class:`wrappers.NormalizeReward`
|
* - :class:`wrappers.NormalizeReward`
|
||||||
- NormalizeReward
|
- :class:`experimental.wrappers.NormalizeRewardV0`
|
||||||
- VectorNormalizeReward
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Common Wrappers
|
### Common Wrappers
|
||||||
@@ -159,37 +112,21 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
|
|
||||||
* - Old name
|
* - Old name
|
||||||
- New name
|
- New name
|
||||||
- Vector version
|
|
||||||
* - :class:`wrappers.AutoResetWrapper`
|
* - :class:`wrappers.AutoResetWrapper`
|
||||||
- AutoReset
|
- :class:`experimental.wrappers.AutoresetV0`
|
||||||
- VectorAutoReset
|
|
||||||
* - :class:`wrappers.PassiveEnvChecker`
|
* - :class:`wrappers.PassiveEnvChecker`
|
||||||
- PassiveEnvChecker
|
- :class:`experimental.wrappers.PassiveEnvCheckerV0`
|
||||||
- VectorPassiveEnvChecker
|
|
||||||
* - :class:`wrappers.OrderEnforcing`
|
* - :class:`wrappers.OrderEnforcing`
|
||||||
- OrderEnforcing
|
- :class:`experimental.wrappers.OrderEnforcingV0`
|
||||||
- VectorOrderEnforcing
|
|
||||||
* - :class:`wrappers.EnvCompatibility`
|
* - :class:`wrappers.EnvCompatibility`
|
||||||
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
||||||
- Not Implemented
|
|
||||||
* - :class:`wrappers.RecordEpisodeStatistics`
|
* - :class:`wrappers.RecordEpisodeStatistics`
|
||||||
- RecordEpisodeStatistics
|
- :class:`experimental.wrappers.RecordEpisodeStatisticsV0`
|
||||||
- VectorRecordEpisodeStatistics
|
* - :class:`wrappers.AtariPreprocessing`
|
||||||
* - :class:`wrappers.RenderCollection`
|
- :class:`experimental.wrappers.AtariPreprocessingV0`
|
||||||
- RenderCollection
|
|
||||||
- VectorRenderCollection
|
|
||||||
* - :class:`wrappers.HumanRendering`
|
|
||||||
- HumanRendering
|
|
||||||
- Not Implemented
|
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.JaxToNumpyV0`
|
|
||||||
- VectorJaxToNumpy (*)
|
|
||||||
* - Not Implemented
|
|
||||||
- :class:`experimental.wrappers.JaxToTorchV0`
|
|
||||||
- VectorJaxToTorch (*)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Vector Only Wrappers
|
### Rendering Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. py:currentmodule:: gymnasium
|
.. py:currentmodule:: gymnasium
|
||||||
@@ -199,8 +136,22 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
|
|
||||||
* - Old name
|
* - Old name
|
||||||
- New name
|
- New name
|
||||||
* - :class:`wrappers.VectorListInfo`
|
* - :class:`wrapper.RecordVideo`
|
||||||
- VectorListInfo
|
- :class:`experimental.wrappers.RecordVideoV0`
|
||||||
|
* - :class:`wrappers.HumanRendering`
|
||||||
|
- :class:`experimental.wrappers.HumanRenderingV0`
|
||||||
|
* - :class:`wrappers.RenderCollection`
|
||||||
|
- :class:`experimental.wrappers.RenderCollectionV0`
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment data conversion
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. py:currentmodule:: gymnasium
|
||||||
|
|
||||||
|
* :class:`experimental.wrappers.JaxToNumpyV0`
|
||||||
|
* :class:`experimental.wrappers.JaxToTorchV0`
|
||||||
|
* :class:`experimental.wrappers.NumpyToTorchV0`
|
||||||
```
|
```
|
||||||
|
|
||||||
## Vector Environment
|
## Vector Environment
|
||||||
|
@@ -11,8 +11,12 @@
|
|||||||
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.PixelObservationV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.NormalizeObservationV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.FrameStackObservationV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.AtariPreprocessingV0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Action Wrappers
|
## Action Wrappers
|
||||||
@@ -24,16 +28,34 @@
|
|||||||
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
||||||
```
|
```
|
||||||
|
|
||||||
# Reward Wrappers
|
## Reward Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Common Wrappers
|
## Other Wrappers
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.AutoresetV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.PassiveEnvCheckerV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.OrderEnforcingV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.RecordEpisodeStatisticsV0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Rendering Wrappers
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.RecordVideoV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.HumanRenderingV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.RenderCollectionV0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment data conversion
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
|
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
|
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.NumpyToTorchV0
|
||||||
```
|
```
|
||||||
|
@@ -356,7 +356,7 @@ class Wrapper(Env[WrapperObsType, WrapperActType]):
|
|||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: WrapperActType
|
self, action: WrapperActType
|
||||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
|
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
|
@@ -1,2 +1,3 @@
|
|||||||
|
"""Module for 2d physics environments with functional and environment implementations."""
|
||||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
|
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
|
||||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv
|
from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv
|
||||||
|
@@ -1,8 +1,7 @@
|
|||||||
"""
|
"""Implementation of a Jax-accelerated cartpole environment."""
|
||||||
Implementation of a Jax-accelerated cartpole environment.
|
from __future__ import annotations
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Any, Tuple
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -74,7 +73,7 @@ class CartPoleFunctional(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def transition(
|
def transition(
|
||||||
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
|
self, state: jnp.ndarray, action: int | jnp.ndarray, rng: None = None
|
||||||
) -> StateType:
|
) -> StateType:
|
||||||
"""Cartpole transition."""
|
"""Cartpole transition."""
|
||||||
x, x_dot, theta, theta_dot = state
|
x, x_dot, theta, theta_dot = state
|
||||||
@@ -106,6 +105,7 @@ class CartPoleFunctional(
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
def terminal(self, state: jnp.ndarray) -> jnp.ndarray:
|
def terminal(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""Checks if the state is terminal."""
|
||||||
x, _, theta, _ = state
|
x, _, theta, _ = state
|
||||||
|
|
||||||
terminated = (
|
terminated = (
|
||||||
@@ -120,6 +120,7 @@ class CartPoleFunctional(
|
|||||||
def reward(
|
def reward(
|
||||||
self, state: StateType, action: ActType, next_state: StateType
|
self, state: StateType, action: ActType, next_state: StateType
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
|
"""Computes the reward for the state transition using the action."""
|
||||||
x, _, theta, _ = state
|
x, _, theta, _ = state
|
||||||
|
|
||||||
terminated = (
|
terminated = (
|
||||||
@@ -136,8 +137,8 @@ class CartPoleFunctional(
|
|||||||
self,
|
self,
|
||||||
state: StateType,
|
state: StateType,
|
||||||
render_state: RenderStateType,
|
render_state: RenderStateType,
|
||||||
) -> Tuple[RenderStateType, np.ndarray]:
|
) -> tuple[RenderStateType, np.ndarray]:
|
||||||
|
"""Renders an image of the state using the render state."""
|
||||||
try:
|
try:
|
||||||
import pygame
|
import pygame
|
||||||
from pygame import gfxdraw
|
from pygame import gfxdraw
|
||||||
@@ -210,6 +211,7 @@ class CartPoleFunctional(
|
|||||||
def render_init(
|
def render_init(
|
||||||
self, screen_width: int = 600, screen_height: int = 400
|
self, screen_width: int = 600, screen_height: int = 400
|
||||||
) -> RenderStateType:
|
) -> RenderStateType:
|
||||||
|
"""Initialises the render state for a screen width and height."""
|
||||||
try:
|
try:
|
||||||
import pygame
|
import pygame
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -224,6 +226,7 @@ class CartPoleFunctional(
|
|||||||
return screen, clock
|
return screen, clock
|
||||||
|
|
||||||
def render_close(self, render_state: RenderStateType) -> None:
|
def render_close(self, render_state: RenderStateType) -> None:
|
||||||
|
"""Closes the render state."""
|
||||||
try:
|
try:
|
||||||
import pygame
|
import pygame
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -235,20 +238,24 @@ class CartPoleFunctional(
|
|||||||
|
|
||||||
|
|
||||||
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||||
|
"""Jax-based implementation of the CartPole environment."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||||
|
|
||||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
||||||
|
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
|
||||||
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||||
|
|
||||||
env = CartPoleFunctional(**kwargs)
|
env = CartPoleFunctional(**kwargs)
|
||||||
env.transform(jax.jit)
|
env.transform(jax.jit)
|
||||||
|
|
||||||
action_space = env.action_space
|
action_space = env.action_space
|
||||||
observation_space = env.observation_space
|
observation_space = env.observation_space
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
env,
|
env,
|
||||||
observation_space=observation_space,
|
observation_space=observation_space,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
metadata=metadata,
|
metadata=self.metadata,
|
||||||
render_mode=render_mode,
|
render_mode=render_mode,
|
||||||
)
|
)
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
"""
|
"""Implementation of a Jax-accelerated pendulum environment."""
|
||||||
Implementation of a Jax-accelerated pendulum environment.
|
from __future__ import annotations
|
||||||
"""
|
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -22,7 +22,7 @@ RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]]
|
|||||||
class PendulumFunctional(
|
class PendulumFunctional(
|
||||||
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
|
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
|
||||||
):
|
):
|
||||||
"""Pendulum but in jax and functional."""
|
"""Pendulum but in jax and functional structure."""
|
||||||
|
|
||||||
max_speed = 8
|
max_speed = 8
|
||||||
max_torque = 2.0
|
max_torque = 2.0
|
||||||
@@ -44,7 +44,7 @@ class PendulumFunctional(
|
|||||||
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
|
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
|
||||||
|
|
||||||
def transition(
|
def transition(
|
||||||
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
|
self, state: jnp.ndarray, action: int | jnp.ndarray, rng: None = None
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
"""Pendulum transition."""
|
"""Pendulum transition."""
|
||||||
th, thdot = state # th := theta
|
th, thdot = state # th := theta
|
||||||
@@ -65,10 +65,12 @@ class PendulumFunctional(
|
|||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
|
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""Generates an observation based on the state."""
|
||||||
theta, thetadot = state
|
theta, thetadot = state
|
||||||
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
|
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
|
||||||
|
|
||||||
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
|
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
|
||||||
|
"""Generates the reward based on the state, action and next state."""
|
||||||
th, thdot = state # th := theta
|
th, thdot = state # th := theta
|
||||||
u = action
|
u = action
|
||||||
|
|
||||||
@@ -80,13 +82,15 @@ class PendulumFunctional(
|
|||||||
return -costs
|
return -costs
|
||||||
|
|
||||||
def terminal(self, state: StateType) -> bool:
|
def terminal(self, state: StateType) -> bool:
|
||||||
|
"""Determines if the state is a terminal state."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def render_image(
|
def render_image(
|
||||||
self,
|
self,
|
||||||
state: StateType,
|
state: StateType,
|
||||||
render_state: Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]], # type: ignore # noqa: F821
|
render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821
|
||||||
) -> Tuple[RenderStateType, np.ndarray]:
|
) -> tuple[RenderStateType, np.ndarray]:
|
||||||
|
"""Renders an RGB image."""
|
||||||
try:
|
try:
|
||||||
import pygame
|
import pygame
|
||||||
from pygame import gfxdraw
|
from pygame import gfxdraw
|
||||||
@@ -159,6 +163,7 @@ class PendulumFunctional(
|
|||||||
def render_init(
|
def render_init(
|
||||||
self, screen_width: int = 600, screen_height: int = 400
|
self, screen_width: int = 600, screen_height: int = 400
|
||||||
) -> RenderStateType:
|
) -> RenderStateType:
|
||||||
|
"""Initialises the render state."""
|
||||||
try:
|
try:
|
||||||
import pygame
|
import pygame
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -172,7 +177,8 @@ class PendulumFunctional(
|
|||||||
|
|
||||||
return screen, clock, None
|
return screen, clock, None
|
||||||
|
|
||||||
def render_close(self, render_state: RenderStateType) -> None:
|
def render_close(self, render_state: RenderStateType):
|
||||||
|
"""Closes the render state."""
|
||||||
try:
|
try:
|
||||||
import pygame
|
import pygame
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -184,21 +190,24 @@ class PendulumFunctional(
|
|||||||
|
|
||||||
|
|
||||||
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||||
|
"""Jax-based pendulum environment using the functional version as base."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
||||||
|
|
||||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
||||||
|
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
|
||||||
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||||
|
|
||||||
env = PendulumFunctional(**kwargs)
|
env = PendulumFunctional(**kwargs)
|
||||||
env.transform(jax.jit)
|
env.transform(jax.jit)
|
||||||
|
|
||||||
action_space = env.action_space
|
action_space = env.action_space
|
||||||
observation_space = env.observation_space
|
observation_space = env.observation_space
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
env,
|
env,
|
||||||
observation_space=observation_space,
|
observation_space=observation_space,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
metadata=metadata,
|
metadata=self.metadata,
|
||||||
render_mode=render_mode,
|
render_mode=render_mode,
|
||||||
)
|
)
|
||||||
|
@@ -1,3 +1,6 @@
|
|||||||
|
"""Functions for registering environments within gymnasium using public functions ``make``, ``register`` and ``spec``."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import difflib
|
import difflib
|
||||||
@@ -8,18 +11,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import Any, Callable, Iterable, Sequence, SupportsFloat, overload
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
SupportsFloat,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
overload,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -53,7 +45,7 @@ ENV_ID_RE = re.compile(
|
|||||||
|
|
||||||
|
|
||||||
def load(name: str) -> callable:
|
def load(name: str) -> callable:
|
||||||
"""Loads an environment with name and returns an environment creation function
|
"""Loads an environment with name and returns an environment creation function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The environment name
|
name: The environment name
|
||||||
@@ -67,7 +59,7 @@ def load(name: str) -> callable:
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
def parse_env_id(id: str) -> tuple[str | None, str, int | None]:
|
||||||
"""Parse environment ID string format.
|
"""Parse environment ID string format.
|
||||||
|
|
||||||
This format is true today, but it's *not* an official spec.
|
This format is true today, but it's *not* an official spec.
|
||||||
@@ -98,7 +90,7 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
|||||||
return namespace, name, version
|
return namespace, name, version
|
||||||
|
|
||||||
|
|
||||||
def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
|
def get_env_id(ns: str | None, name: str, version: int | None) -> str:
|
||||||
"""Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
|
"""Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -109,7 +101,6 @@ def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
The environment id
|
The environment id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
full_name = name
|
full_name = name
|
||||||
if version is not None:
|
if version is not None:
|
||||||
full_name += f"-v{version}"
|
full_name += f"-v{version}"
|
||||||
@@ -134,14 +125,14 @@ class EnvSpec:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
entry_point: Union[Callable, str]
|
entry_point: Callable | str
|
||||||
|
|
||||||
# Environment attributes
|
# Environment attributes
|
||||||
reward_threshold: Optional[float] = field(default=None)
|
reward_threshold: float | None = field(default=None)
|
||||||
nondeterministic: bool = field(default=False)
|
nondeterministic: bool = field(default=False)
|
||||||
|
|
||||||
# Wrappers
|
# Wrappers
|
||||||
max_episode_steps: Optional[int] = field(default=None)
|
max_episode_steps: int | None = field(default=None)
|
||||||
order_enforce: bool = field(default=True)
|
order_enforce: bool = field(default=True)
|
||||||
autoreset: bool = field(default=False)
|
autoreset: bool = field(default=False)
|
||||||
disable_env_checker: bool = field(default=False)
|
disable_env_checker: bool = field(default=False)
|
||||||
@@ -151,20 +142,22 @@ class EnvSpec:
|
|||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
# post-init attributes
|
# post-init attributes
|
||||||
namespace: Optional[str] = field(init=False)
|
namespace: str | None = field(init=False)
|
||||||
name: str = field(init=False)
|
name: str = field(init=False)
|
||||||
version: Optional[int] = field(init=False)
|
version: int | None = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
"""Calls after the spec is created to extract the namespace, name and version from the id."""
|
||||||
# Initialize namespace, name, version
|
# Initialize namespace, name, version
|
||||||
self.namespace, self.name, self.version = parse_env_id(self.id)
|
self.namespace, self.name, self.version = parse_env_id(self.id)
|
||||||
|
|
||||||
def make(self, **kwargs) -> Env:
|
def make(self, **kwargs: Any) -> Env:
|
||||||
|
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
||||||
# For compatibility purposes
|
# For compatibility purposes
|
||||||
return make(self, **kwargs)
|
return make(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _check_namespace_exists(ns: Optional[str]):
|
def _check_namespace_exists(ns: str | None):
|
||||||
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
|
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
|
||||||
if ns is None:
|
if ns is None:
|
||||||
return
|
return
|
||||||
@@ -186,7 +179,7 @@ def _check_namespace_exists(ns: Optional[str]):
|
|||||||
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
|
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
|
||||||
|
|
||||||
|
|
||||||
def _check_name_exists(ns: Optional[str], name: str):
|
def _check_name_exists(ns: str | None, name: str):
|
||||||
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
|
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
|
||||||
_check_namespace_exists(ns)
|
_check_namespace_exists(ns)
|
||||||
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
|
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
|
||||||
@@ -203,8 +196,9 @@ def _check_name_exists(ns: Optional[str], name: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
|
def _check_version_exists(ns: str | None, name: str, version: int | None):
|
||||||
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
|
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
|
||||||
|
|
||||||
This is a complete test whether an environment identifier is valid, and will provide the best available hints.
|
This is a complete test whether an environment identifier is valid, and will provide the best available hints.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -258,8 +252,9 @@ def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
|
def find_highest_version(ns: str | None, name: str) -> int | None:
|
||||||
version: List[int] = [
|
"""Finds the highest registered version of the environment in the registry."""
|
||||||
|
version: list[int] = [
|
||||||
spec_.version
|
spec_.version
|
||||||
for spec_ in registry.values()
|
for spec_ in registry.values()
|
||||||
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
||||||
@@ -268,6 +263,11 @@ def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
|
|||||||
|
|
||||||
|
|
||||||
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
||||||
|
"""Load modules (plugins) using the gymnasium entry points == to `entry_points`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_point: The string for the entry point.
|
||||||
|
"""
|
||||||
# Load third-party environments
|
# Load third-party environments
|
||||||
for plugin in metadata.entry_points(group=entry_point):
|
for plugin in metadata.entry_points(group=entry_point):
|
||||||
# Python 3.8 doesn't support plugin.module, plugin.attr
|
# Python 3.8 doesn't support plugin.module, plugin.attr
|
||||||
@@ -323,37 +323,37 @@ def make(id: EnvSpec, **kwargs) -> Env: ...
|
|||||||
# Classic control
|
# Classic control
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
|
|
||||||
|
|
||||||
# Box2d
|
# Box2d
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||||
|
|
||||||
|
|
||||||
# Toy Text
|
# Toy Text
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
|
|
||||||
|
|
||||||
# Mujoco
|
# Mujoco
|
||||||
@@ -376,8 +376,8 @@ def make(id: Literal[
|
|||||||
|
|
||||||
|
|
||||||
# Global registry of environments. Meant to be accessed through `register` and `make`
|
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||||
registry: Dict[str, EnvSpec] = {}
|
registry: dict[str, EnvSpec] = {}
|
||||||
current_namespace: Optional[str] = None
|
current_namespace: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def _check_spec_register(spec: EnvSpec):
|
def _check_spec_register(spec: EnvSpec):
|
||||||
@@ -445,6 +445,7 @@ def _check_metadata(metadata_: dict):
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def namespace(ns: str):
|
def namespace(ns: str):
|
||||||
|
"""Context manager for modifying the current namespace."""
|
||||||
global current_namespace
|
global current_namespace
|
||||||
old_namespace = current_namespace
|
old_namespace = current_namespace
|
||||||
current_namespace = ns
|
current_namespace = ns
|
||||||
@@ -454,10 +455,10 @@ def namespace(ns: str):
|
|||||||
|
|
||||||
def register(
|
def register(
|
||||||
id: str,
|
id: str,
|
||||||
entry_point: Union[Callable, str],
|
entry_point: Callable | str,
|
||||||
reward_threshold: Optional[float] = None,
|
reward_threshold: float | None = None,
|
||||||
nondeterministic: bool = False,
|
nondeterministic: bool = False,
|
||||||
max_episode_steps: Optional[int] = None,
|
max_episode_steps: int | None = None,
|
||||||
order_enforce: bool = True,
|
order_enforce: bool = True,
|
||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
disable_env_checker: bool = False,
|
disable_env_checker: bool = False,
|
||||||
@@ -521,11 +522,11 @@ def register(
|
|||||||
|
|
||||||
|
|
||||||
def make(
|
def make(
|
||||||
id: Union[str, EnvSpec],
|
id: str | EnvSpec,
|
||||||
max_episode_steps: Optional[int] = None,
|
max_episode_steps: int | None = None,
|
||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
apply_api_compatibility: Optional[bool] = None,
|
apply_api_compatibility: bool | None = None,
|
||||||
disable_env_checker: Optional[bool] = None,
|
disable_env_checker: bool | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Env:
|
) -> Env:
|
||||||
"""Create an environment according to the given ID.
|
"""Create an environment according to the given ID.
|
||||||
@@ -706,9 +707,9 @@ def spec(env_id: str) -> EnvSpec:
|
|||||||
def pprint_registry(
|
def pprint_registry(
|
||||||
_registry: dict = registry,
|
_registry: dict = registry,
|
||||||
num_cols: int = 3,
|
num_cols: int = 3,
|
||||||
exclude_namespaces: Optional[List[str]] = None,
|
exclude_namespaces: list[str] | None = None,
|
||||||
disable_print: bool = False,
|
disable_print: bool = False,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
"""Pretty print the environments in the registry.
|
"""Pretty print the environments in the registry.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -718,7 +719,6 @@ def pprint_registry(
|
|||||||
disable_print: Whether to return a string of all the namespaces and environment IDs
|
disable_print: Whether to return a string of all the namespaces and environment IDs
|
||||||
instead of printing it to console.
|
instead of printing it to console.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Defaultdict to store environment names according to namespace.
|
# Defaultdict to store environment names according to namespace.
|
||||||
namespace_envs = defaultdict(lambda: [])
|
namespace_envs = defaultdict(lambda: [])
|
||||||
max_justify = float("-inf")
|
max_justify = float("-inf")
|
||||||
|
@@ -19,14 +19,34 @@ from gymnasium.experimental.wrappers.lambda_observations import (
|
|||||||
ReshapeObservationV0,
|
ReshapeObservationV0,
|
||||||
RescaleObservationV0,
|
RescaleObservationV0,
|
||||||
DtypeObservationV0,
|
DtypeObservationV0,
|
||||||
|
PixelObservationV0,
|
||||||
|
NormalizeObservationV0,
|
||||||
)
|
)
|
||||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
from gymnasium.experimental.wrappers.lambda_reward import (
|
||||||
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
|
ClipRewardV0,
|
||||||
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
|
LambdaRewardV0,
|
||||||
|
NormalizeRewardV0,
|
||||||
|
)
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_torch import JaxToTorchV0
|
||||||
|
from gymnasium.experimental.wrappers.numpy_to_torch import NumpyToTorchV0
|
||||||
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
||||||
from gymnasium.experimental.wrappers.stateful_observation import (
|
from gymnasium.experimental.wrappers.stateful_observation import (
|
||||||
TimeAwareObservationV0,
|
TimeAwareObservationV0,
|
||||||
DelayObservationV0,
|
DelayObservationV0,
|
||||||
|
FrameStackObservationV0,
|
||||||
|
)
|
||||||
|
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
|
||||||
|
from gymnasium.experimental.wrappers.common import (
|
||||||
|
PassiveEnvCheckerV0,
|
||||||
|
OrderEnforcingV0,
|
||||||
|
AutoresetV0,
|
||||||
|
RecordEpisodeStatisticsV0,
|
||||||
|
)
|
||||||
|
from gymnasium.experimental.wrappers.rendering import (
|
||||||
|
RenderCollectionV0,
|
||||||
|
RecordVideoV0,
|
||||||
|
HumanRenderingV0,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -39,12 +59,12 @@ __all__ = [
|
|||||||
"ReshapeObservationV0",
|
"ReshapeObservationV0",
|
||||||
"RescaleObservationV0",
|
"RescaleObservationV0",
|
||||||
"DtypeObservationV0",
|
"DtypeObservationV0",
|
||||||
# "PixelObservationV0",
|
"PixelObservationV0",
|
||||||
# "NormalizeObservationV0",
|
"NormalizeObservationV0",
|
||||||
"TimeAwareObservationV0",
|
"TimeAwareObservationV0",
|
||||||
# "FrameStackV0",
|
"FrameStackObservationV0",
|
||||||
"DelayObservationV0",
|
"DelayObservationV0",
|
||||||
# "AtariPreprocessingV0"
|
"AtariPreprocessingV0",
|
||||||
# --- Action Wrappers ---
|
# --- Action Wrappers ---
|
||||||
"LambdaActionV0",
|
"LambdaActionV0",
|
||||||
"ClipActionV0",
|
"ClipActionV0",
|
||||||
@@ -54,15 +74,18 @@ __all__ = [
|
|||||||
# --- Reward wrappers ---
|
# --- Reward wrappers ---
|
||||||
"LambdaRewardV0",
|
"LambdaRewardV0",
|
||||||
"ClipRewardV0",
|
"ClipRewardV0",
|
||||||
# "RescaleRewardV0",
|
"NormalizeRewardV0",
|
||||||
# "NormalizeRewardV0",
|
|
||||||
# --- Common ---
|
# --- Common ---
|
||||||
# "AutoReset",
|
"AutoresetV0",
|
||||||
# "PassiveEnvChecker",
|
"PassiveEnvCheckerV0",
|
||||||
# "OrderEnforcing",
|
"OrderEnforcingV0",
|
||||||
# "RecordEpisodeStatistics",
|
"RecordEpisodeStatisticsV0",
|
||||||
# "RenderCollection",
|
# --- Rendering ---
|
||||||
# "HumanRendering",
|
"RenderCollectionV0",
|
||||||
|
"RecordVideoV0",
|
||||||
|
"HumanRenderingV0",
|
||||||
|
# --- Data Conversion ---
|
||||||
"JaxToNumpyV0",
|
"JaxToNumpyV0",
|
||||||
"JaxToTorchV0",
|
"JaxToTorchV0",
|
||||||
|
"NumpyToTorchV0",
|
||||||
]
|
]
|
||||||
|
193
gymnasium/experimental/wrappers/atari_preprocessing.py
Normal file
193
gymnasium/experimental/wrappers/atari_preprocessing.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
except ImportError:
|
||||||
|
cv2 = None
|
||||||
|
|
||||||
|
|
||||||
|
class AtariPreprocessingV0(gym.Wrapper):
|
||||||
|
"""Atari 2600 preprocessing wrapper.
|
||||||
|
|
||||||
|
This class follows the guidelines in Machado et al. (2018),
|
||||||
|
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
|
||||||
|
|
||||||
|
Specifically, the following preprocess stages applies to the atari environment:
|
||||||
|
|
||||||
|
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
|
||||||
|
- Frame skipping: The number of frames skipped between steps, 4 by default
|
||||||
|
- Max-pooling: Pools over the most recent two observations from the frame skips
|
||||||
|
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
|
||||||
|
Turned off by default. Not recommended by Machado et al. (2018).
|
||||||
|
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default
|
||||||
|
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale.
|
||||||
|
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
noop_max: int = 30,
|
||||||
|
frame_skip: int = 4,
|
||||||
|
screen_size: int = 84,
|
||||||
|
terminal_on_life_loss: bool = False,
|
||||||
|
grayscale_obs: bool = True,
|
||||||
|
grayscale_newaxis: bool = False,
|
||||||
|
scale_obs: bool = False,
|
||||||
|
):
|
||||||
|
"""Wrapper for Atari 2600 preprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): The environment to apply the preprocessing
|
||||||
|
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
|
||||||
|
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
|
||||||
|
screen_size (int): resize Atari frame
|
||||||
|
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
|
||||||
|
life is lost.
|
||||||
|
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
|
||||||
|
is returned.
|
||||||
|
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
|
||||||
|
grayscale observations to make them 3-dimensional.
|
||||||
|
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
|
||||||
|
optimization benefits of FrameStack Wrapper.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyNotInstalled: opencv-python package not installed
|
||||||
|
ValueError: Disable frame-skipping in the original env
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
if cv2 is None:
|
||||||
|
raise gym.error.DependencyNotInstalled(
|
||||||
|
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||||
|
)
|
||||||
|
assert frame_skip > 0
|
||||||
|
assert screen_size > 0
|
||||||
|
assert noop_max >= 0
|
||||||
|
if frame_skip > 1:
|
||||||
|
if (
|
||||||
|
env.spec is not None
|
||||||
|
and "NoFrameskip" not in env.spec.id
|
||||||
|
and getattr(env.unwrapped, "_frameskip", None) != 1
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Disable frame-skipping in the original env. Otherwise, more than one "
|
||||||
|
"frame-skip will happen as through this wrapper"
|
||||||
|
)
|
||||||
|
self.noop_max = noop_max
|
||||||
|
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||||
|
|
||||||
|
self.frame_skip = frame_skip
|
||||||
|
self.screen_size = screen_size
|
||||||
|
self.terminal_on_life_loss = terminal_on_life_loss
|
||||||
|
self.grayscale_obs = grayscale_obs
|
||||||
|
self.grayscale_newaxis = grayscale_newaxis
|
||||||
|
self.scale_obs = scale_obs
|
||||||
|
|
||||||
|
# buffer of most recent two observations for max pooling
|
||||||
|
assert isinstance(env.observation_space, Box)
|
||||||
|
if grayscale_obs:
|
||||||
|
self.obs_buffer = [
|
||||||
|
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
||||||
|
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.obs_buffer = [
|
||||||
|
np.empty(env.observation_space.shape, dtype=np.uint8),
|
||||||
|
np.empty(env.observation_space.shape, dtype=np.uint8),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.lives = 0
|
||||||
|
self.game_over = False
|
||||||
|
|
||||||
|
_low, _high, _obs_dtype = (
|
||||||
|
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
|
||||||
|
)
|
||||||
|
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
|
||||||
|
if grayscale_obs and not grayscale_newaxis:
|
||||||
|
_shape = _shape[:-1] # Remove channel axis
|
||||||
|
self.observation_space = Box(
|
||||||
|
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ale(self):
|
||||||
|
"""Make ale as a class property to avoid serialization error."""
|
||||||
|
return self.env.unwrapped.ale
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""Applies the preprocessing for an :meth:`env.step`."""
|
||||||
|
total_reward, terminated, truncated, info = 0.0, False, False, {}
|
||||||
|
|
||||||
|
for t in range(self.frame_skip):
|
||||||
|
_, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
total_reward += reward
|
||||||
|
self.game_over = terminated
|
||||||
|
|
||||||
|
if self.terminal_on_life_loss:
|
||||||
|
new_lives = self.ale.lives()
|
||||||
|
terminated = terminated or new_lives < self.lives
|
||||||
|
self.game_over = terminated
|
||||||
|
self.lives = new_lives
|
||||||
|
|
||||||
|
if terminated or truncated:
|
||||||
|
break
|
||||||
|
if t == self.frame_skip - 2:
|
||||||
|
if self.grayscale_obs:
|
||||||
|
self.ale.getScreenGrayscale(self.obs_buffer[1])
|
||||||
|
else:
|
||||||
|
self.ale.getScreenRGB(self.obs_buffer[1])
|
||||||
|
elif t == self.frame_skip - 1:
|
||||||
|
if self.grayscale_obs:
|
||||||
|
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||||
|
else:
|
||||||
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
|
return self._get_obs(), total_reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
"""Resets the environment using preprocessing."""
|
||||||
|
# NoopReset
|
||||||
|
_, reset_info = self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
noops = (
|
||||||
|
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||||
|
if self.noop_max > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for _ in range(noops):
|
||||||
|
_, _, terminated, truncated, step_info = self.env.step(0)
|
||||||
|
reset_info.update(step_info)
|
||||||
|
if terminated or truncated:
|
||||||
|
_, reset_info = self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
self.lives = self.ale.lives()
|
||||||
|
if self.grayscale_obs:
|
||||||
|
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||||
|
else:
|
||||||
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
|
self.obs_buffer[1].fill(0)
|
||||||
|
|
||||||
|
return self._get_obs(), reset_info
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
if self.frame_skip > 1: # more efficient in-place pooling
|
||||||
|
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
|
||||||
|
assert cv2 is not None
|
||||||
|
obs = cv2.resize(
|
||||||
|
self.obs_buffer[0],
|
||||||
|
(self.screen_size, self.screen_size),
|
||||||
|
interpolation=cv2.INTER_AREA,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.scale_obs:
|
||||||
|
obs = np.asarray(obs, dtype=np.float32) / 255.0
|
||||||
|
else:
|
||||||
|
obs = np.asarray(obs, dtype=np.uint8)
|
||||||
|
|
||||||
|
if self.grayscale_obs and self.grayscale_newaxis:
|
||||||
|
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
|
||||||
|
return obs
|
281
gymnasium/experimental/wrappers/common.py
Normal file
281
gymnasium/experimental/wrappers/common.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""A collection of common wrappers.
|
||||||
|
|
||||||
|
* ``AutoresetV0`` - Auto-resets the environment
|
||||||
|
* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data
|
||||||
|
* ``OrderEnforcingV0`` - Enforces the order of function calls to environments
|
||||||
|
* ``RecordEpisodeStatisticsV0`` - Records the episode statistics
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, SupportsFloat
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import Env
|
||||||
|
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import ResetNeeded
|
||||||
|
from gymnasium.utils.passive_env_checker import (
|
||||||
|
check_action_space,
|
||||||
|
check_observation_space,
|
||||||
|
env_render_passive_checker,
|
||||||
|
env_reset_passive_checker,
|
||||||
|
env_step_passive_checker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoresetV0(gym.Wrapper):
|
||||||
|
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env):
|
||||||
|
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (gym.Env): The environment to apply the wrapper
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self._episode_ended: bool = False
|
||||||
|
self._reset_options: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to take
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The autoreset environment :meth:`step`
|
||||||
|
"""
|
||||||
|
if self._episode_ended:
|
||||||
|
obs, info = super().reset(options=self._reset_options)
|
||||||
|
self._episode_ended = True
|
||||||
|
return obs, 0, False, False, info
|
||||||
|
else:
|
||||||
|
obs, reward, terminated, truncated, info = super().step(action)
|
||||||
|
self._episode_ended = terminated or truncated
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment, saving the options used."""
|
||||||
|
self._episode_ended = False
|
||||||
|
self._reset_options = options
|
||||||
|
return super().reset(seed=seed, options=self._reset_options)
|
||||||
|
|
||||||
|
|
||||||
|
class PassiveEnvCheckerV0(gym.Wrapper):
|
||||||
|
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||||
|
|
||||||
|
def __init__(self, env: Env[ObsType, ActType]):
|
||||||
|
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
assert hasattr(
|
||||||
|
env, "action_space"
|
||||||
|
), "The environment must specify an action space. https://gymnasium.farama.org/content/environment_creation/"
|
||||||
|
check_action_space(env.action_space)
|
||||||
|
assert hasattr(
|
||||||
|
env, "observation_space"
|
||||||
|
), "The environment must specify an observation space. https://gymnasium.farama.org/content/environment_creation/"
|
||||||
|
check_observation_space(env.observation_space)
|
||||||
|
|
||||||
|
self._checked_reset: bool = False
|
||||||
|
self._checked_step: bool = False
|
||||||
|
self._checked_render: bool = False
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
|
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
||||||
|
if self._checked_step is False:
|
||||||
|
self._checked_step = True
|
||||||
|
return env_step_passive_checker(self.env, action)
|
||||||
|
else:
|
||||||
|
return self.env.step(action)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
||||||
|
if self._checked_reset is False:
|
||||||
|
self._checked_reset = True
|
||||||
|
return env_reset_passive_checker(self.env, seed=seed, options=options)
|
||||||
|
else:
|
||||||
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||||
|
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
|
||||||
|
if self._checked_render is False:
|
||||||
|
self._checked_render = True
|
||||||
|
return env_render_passive_checker(self.env)
|
||||||
|
else:
|
||||||
|
return self.env.render()
|
||||||
|
|
||||||
|
|
||||||
|
class OrderEnforcingV0(gym.Wrapper):
|
||||||
|
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
|
>>> env = CartPoleEnv()
|
||||||
|
>>> env = OrderEnforcingV0(env)
|
||||||
|
>>> env.step(0)
|
||||||
|
ResetNeeded: Cannot call env.step() before calling env.reset()
|
||||||
|
>>> env.render()
|
||||||
|
ResetNeeded: Cannot call env.render() before calling env.reset()
|
||||||
|
>>> env.reset()
|
||||||
|
>>> env.render()
|
||||||
|
>>> env.step(0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
|
||||||
|
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
disable_render_order_enforcing: If to disable render order enforcing
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self._has_reset: bool = False
|
||||||
|
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Steps through the environment with `kwargs`."""
|
||||||
|
if not self._has_reset:
|
||||||
|
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
|
||||||
|
return super().step(action)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment with `kwargs`."""
|
||||||
|
self._has_reset = True
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||||
|
"""Renders the environment with `kwargs`."""
|
||||||
|
if not self._disable_render_order_enforcing and not self._has_reset:
|
||||||
|
raise ResetNeeded(
|
||||||
|
"Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
|
||||||
|
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
|
||||||
|
)
|
||||||
|
return super().render()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_reset(self):
|
||||||
|
"""Returns if the environment has been reset before."""
|
||||||
|
return self._has_reset
|
||||||
|
|
||||||
|
|
||||||
|
class RecordEpisodeStatisticsV0(gym.Wrapper):
|
||||||
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
|
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||||
|
using the key ``episode``. If using a vectorized environment also the key
|
||||||
|
``_episode`` is used which indicates whether the env at the respective index has
|
||||||
|
the episode statistics.
|
||||||
|
|
||||||
|
After the completion of an episode, ``info`` will look like this::
|
||||||
|
|
||||||
|
>>> info = {
|
||||||
|
... ...
|
||||||
|
... "episode": {
|
||||||
|
... "r": "<cumulative reward>",
|
||||||
|
... "l": "<episode length>",
|
||||||
|
... "t": "<elapsed time since beginning of episode>"
|
||||||
|
... },
|
||||||
|
... }
|
||||||
|
|
||||||
|
For a vectorized environments the output will be in the form of::
|
||||||
|
|
||||||
|
>>> infos = {
|
||||||
|
... ...
|
||||||
|
... "episode": {
|
||||||
|
... "r": "<array of cumulative reward>",
|
||||||
|
... "l": "<array of episode length>",
|
||||||
|
... "t": "<array of elapsed time since beginning of episode>"
|
||||||
|
... },
|
||||||
|
... "_episode": "<boolean array of length num-envs>"
|
||||||
|
... }
|
||||||
|
|
||||||
|
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
|
||||||
|
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes
|
||||||
|
episode_length_buffer: The lengths of the last ``deque_size``-many episodes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: Env[ObsType, ActType],
|
||||||
|
buffer_length: int | None = 100,
|
||||||
|
stats_key: str = "episode",
|
||||||
|
):
|
||||||
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): The environment to apply the wrapper
|
||||||
|
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||||
|
stats_key: The info key for the episode statistics
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
self._stats_key = stats_key
|
||||||
|
|
||||||
|
self.episode_count = 0
|
||||||
|
self.episode_start_time: float = -1
|
||||||
|
self.episode_reward: float = -1
|
||||||
|
self.episode_length: int = -1
|
||||||
|
|
||||||
|
self.episode_time_length_buffer = deque(maxlen=buffer_length)
|
||||||
|
self.episode_reward_buffer = deque(maxlen=buffer_length)
|
||||||
|
self.episode_length_buffer = deque(maxlen=buffer_length)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
|
"""Steps through the environment, recording the episode statistics."""
|
||||||
|
obs, reward, terminated, truncated, info = super().step(action)
|
||||||
|
|
||||||
|
self.episode_reward += reward
|
||||||
|
self.episode_length += 1
|
||||||
|
|
||||||
|
if terminated or truncated:
|
||||||
|
assert self._stats_key not in info
|
||||||
|
|
||||||
|
episode_time_length = np.round(
|
||||||
|
time.perf_counter() - self.episode_start_time, 6
|
||||||
|
)
|
||||||
|
info[self._stats_key] = {
|
||||||
|
"r": self.episode_reward,
|
||||||
|
"l": self.episode_length,
|
||||||
|
"t": episode_time_length,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.episode_time_length_buffer.append(episode_time_length)
|
||||||
|
self.episode_reward_buffer.append(self.episode_reward)
|
||||||
|
self.episode_length_buffer.append(self.episode_length)
|
||||||
|
|
||||||
|
self.episode_count += 1
|
||||||
|
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
|
||||||
|
obs, info = super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
self.episode_start_time = time.perf_counter()
|
||||||
|
self.episode_reward = 0
|
||||||
|
self.episode_length = 0
|
||||||
|
|
||||||
|
return obs, info
|
@@ -17,7 +17,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
|||||||
from gymnasium import Env, Wrapper
|
from gymnasium import Env, Wrapper
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
|
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
@@ -1,13 +1,15 @@
|
|||||||
"""A collection of observation wrappers using a lambda function.
|
"""A collection of observation wrappers using a lambda function.
|
||||||
|
|
||||||
* ``LambdaObservation`` - Transforms the observation with a function
|
* ``LambdaObservationV0`` - Transforms the observation with a function
|
||||||
* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
|
* ``FilterObservationV0`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
|
||||||
* ``FlattenObservation`` - Flattens the observations
|
* ``FlattenObservationV0`` - Flattens the observations
|
||||||
* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation
|
* ``GrayscaleObservationV0`` - Converts a RGB observation to a grayscale observation
|
||||||
* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation)
|
* ``ResizeObservationV0`` - Resizes an array-based observation (normally a RGB observation)
|
||||||
* ``ReshapeObservation`` - Reshapes an array-based observation
|
* ``ReshapeObservationV0`` - Reshapes an array-based observation
|
||||||
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
|
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
|
||||||
* ``DtypeObservation`` - Convert a observation dtype
|
* ``DtypeObservationV0`` - Convert an observation to a dtype
|
||||||
|
* ``PixelObservationV0`` - Allows the observation to the rendered frame
|
||||||
|
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -18,10 +20,11 @@ import jumpy as jp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
from gymnasium import Env, spaces
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ActType, ObservationWrapper, ObsType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.spaces import Box, utils
|
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||||
|
from gymnasium.spaces import Box, Dict, utils
|
||||||
|
|
||||||
|
|
||||||
class LambdaObservationV0(gym.ObservationWrapper):
|
class LambdaObservationV0(gym.ObservationWrapper):
|
||||||
@@ -407,3 +410,83 @@ class DtypeObservationV0(LambdaObservationV0):
|
|||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelObservationV0(LambdaObservationV0):
|
||||||
|
"""Augment observations by pixel values.
|
||||||
|
|
||||||
|
Observations of this wrapper will be dictionaries of images.
|
||||||
|
You can also choose to add the observation of the base environment to this dictionary.
|
||||||
|
In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary
|
||||||
|
of rendered images will be updated with the base environment's observation. If, however, the observation
|
||||||
|
space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box`
|
||||||
|
space) will be added to the dictionary under the key "state".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: Env[ObsType, ActType],
|
||||||
|
pixels_only: bool = True,
|
||||||
|
pixels_key: str = "pixels",
|
||||||
|
obs_key: str = "state",
|
||||||
|
):
|
||||||
|
"""Initializes a new pixel Wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap.
|
||||||
|
pixels_only (bool): If `True` (default), the original observation returned
|
||||||
|
by the wrapped environment will be discarded, and a dictionary
|
||||||
|
observation will only include pixels. If `False`, the
|
||||||
|
observation dictionary will contain both the original
|
||||||
|
observations and the pixel observations.
|
||||||
|
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
||||||
|
obs_key: Optional custom string specifying the obs key. Defaults to "state"
|
||||||
|
"""
|
||||||
|
assert env.render_mode is not None and env.render_mode != "human"
|
||||||
|
env.reset()
|
||||||
|
pixels = env.render()
|
||||||
|
assert pixels is not None and isinstance(pixels, np.ndarray)
|
||||||
|
pixel_space = Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
obs_space = pixel_space
|
||||||
|
super().__init__(env, lambda _: self.render(), obs_space)
|
||||||
|
elif isinstance(env.observation_space, Dict):
|
||||||
|
assert pixels_key not in env.observation_space.spaces.keys()
|
||||||
|
|
||||||
|
obs_space = Dict({pixels_key: pixel_space, **env.observation_space.spaces})
|
||||||
|
super().__init__(
|
||||||
|
env, lambda obs: {pixels_key: self.render(), **obs_space}, obs_space
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
obs_space = Dict({obs_key: env.observation_space, pixels_key: pixel_space})
|
||||||
|
super().__init__(
|
||||||
|
env, lambda obs: {obs_key: obs, pixels_key: self.render()}, obs_space
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeObservationV0(ObservationWrapper):
|
||||||
|
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
|
||||||
|
newly instantiated or the policy was changed recently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
|
||||||
|
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): The environment to apply the wrapper
|
||||||
|
epsilon: A stability parameter that is used when scaling the observations.
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||||
|
"""Normalises the observation using the running mean and variance of the observations."""
|
||||||
|
self.obs_rms.update(observation)
|
||||||
|
return (observation - self.obs_rms.mean) / np.sqrt(
|
||||||
|
self.obs_rms.var + self.epsilon
|
||||||
|
)
|
||||||
|
@@ -6,12 +6,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Callable, SupportsFloat
|
from typing import Any, Callable, SupportsFloat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import InvalidBound
|
from gymnasium.error import InvalidBound
|
||||||
|
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||||
|
|
||||||
|
|
||||||
class LambdaRewardV0(gym.RewardWrapper):
|
class LambdaRewardV0(gym.RewardWrapper):
|
||||||
@@ -89,3 +91,48 @@ class ClipRewardV0(LambdaRewardV0):
|
|||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
|
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeRewardV0(gym.Wrapper):
|
||||||
|
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
|
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
|
||||||
|
instantiated or the policy was changed recently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epsilon: float = 1e-8,
|
||||||
|
):
|
||||||
|
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (env): The environment to apply the wrapper
|
||||||
|
epsilon (float): A stability parameter
|
||||||
|
gamma (float): The discount factor that is used in the exponential moving average.
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self.rewards_running_means = RunningMeanStd(shape=())
|
||||||
|
self.discounted_reward: float = 0.0
|
||||||
|
self.gamma = gamma
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
|
"""Steps through the environment, normalizing the reward returned."""
|
||||||
|
obs, reward, terminated, truncated, info = super().step(action)
|
||||||
|
self.discounted_reward = self.discounted_reward * self.gamma * (
|
||||||
|
1 - terminated
|
||||||
|
) + float(reward)
|
||||||
|
return obs, self.normalize(float(reward)), terminated, truncated, info
|
||||||
|
|
||||||
|
def normalize(self, reward):
|
||||||
|
"""Normalizes the rewards with the running mean rewards and their variance."""
|
||||||
|
self.rewards_running_means.update(self.discounted_reward)
|
||||||
|
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)
|
||||||
|
158
gymnasium/experimental/wrappers/numpy_to_torch.py
Normal file
158
gymnasium/experimental/wrappers/numpy_to_torch.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Helper functions and wrapper class for converting between PyTorch and NumPy."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import numbers
|
||||||
|
from collections import abc
|
||||||
|
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium import Env, Wrapper
|
||||||
|
from gymnasium.core import WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
Device = Union[str, torch.device]
|
||||||
|
except ImportError:
|
||||||
|
torch, Device = None, None
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def torch_to_numpy(value: Any) -> Any:
|
||||||
|
"""Converts a PyTorch Tensor into a NumPy Array."""
|
||||||
|
if torch is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if torch is not None:
|
||||||
|
|
||||||
|
@torch_to_numpy.register(numbers.Number)
|
||||||
|
@torch_to_numpy.register(torch.Tensor)
|
||||||
|
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
|
||||||
|
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
|
||||||
|
return np.array(value)
|
||||||
|
|
||||||
|
@torch_to_numpy.register(abc.Mapping)
|
||||||
|
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||||
|
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
@torch_to_numpy.register(abc.Iterable)
|
||||||
|
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||||
|
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||||
|
return type(value)(torch_to_numpy(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||||
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
|
if torch is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Torch is not installed therefore cannot call `numpy_to_torch`, run `pip install torch`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if torch is not None:
|
||||||
|
|
||||||
|
@numpy_to_torch.register(np.ndarray)
|
||||||
|
def _numpy_to_torch(
|
||||||
|
value: np.ndarray, device: Device | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
|
assert torch is not None
|
||||||
|
tensor = torch.tensor(value)
|
||||||
|
if device:
|
||||||
|
return tensor.to(device=device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
@numpy_to_torch.register(abc.Mapping)
|
||||||
|
def _numpy_mapping_to_torch(
|
||||||
|
value: Mapping[str, Any], device: Device | None = None
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||||
|
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
||||||
|
|
||||||
|
@numpy_to_torch.register(abc.Iterable)
|
||||||
|
def _numpy_iterable_to_torch(
|
||||||
|
value: Iterable[Any], device: Device | None = None
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||||
|
return type(value)(numpy_to_torch(v, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyToTorchV0(Wrapper):
|
||||||
|
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: Env, device: Device | None = None):
|
||||||
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Jax-based environment to wrap
|
||||||
|
device: The device the torch Tensors should be moved to
|
||||||
|
"""
|
||||||
|
if torch is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Torch is not installed, run `pip install torch`"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(env)
|
||||||
|
self.device: Device | None = device
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Performs the given action within the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to perform as a PyTorch Tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next observation, reward, termination, truncation, and extra info
|
||||||
|
"""
|
||||||
|
jax_action = torch_to_numpy(action)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
|
||||||
|
return (
|
||||||
|
numpy_to_torch(obs, self.device),
|
||||||
|
float(reward),
|
||||||
|
bool(terminated),
|
||||||
|
bool(truncated),
|
||||||
|
numpy_to_torch(info, self.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning PyTorch-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTorch-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = torch_to_numpy(options)
|
||||||
|
|
||||||
|
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
219
gymnasium/experimental/wrappers/rendering.py
Normal file
219
gymnasium/experimental/wrappers/rendering.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""A collections of rendering-based wrappers.
|
||||||
|
|
||||||
|
* ``RenderCollectionV0`` - Collects rendered frames into a list
|
||||||
|
* ``RecordVideoV0`` - Records a video of the environments
|
||||||
|
* ``HumanRenderingV0`` - Provides human rendering of environments with ``"rgb_array"``
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, SupportsFloat
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
|
|
||||||
|
class RenderCollectionV0(gym.Wrapper):
|
||||||
|
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env[ObsType, ActType],
|
||||||
|
pop_frames: bool = True,
|
||||||
|
reset_clean: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize a :class:`RenderCollection` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment that is being wrapped
|
||||||
|
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
|
||||||
|
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
assert env.render_mode is not None
|
||||||
|
assert not env.render_mode.endswith("_list")
|
||||||
|
|
||||||
|
self.frame_list: list[RenderFrame] = []
|
||||||
|
self.pop_frames = pop_frames
|
||||||
|
self.reset_clean = reset_clean
|
||||||
|
|
||||||
|
self.metadata = deepcopy(self.env.metadata)
|
||||||
|
if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]:
|
||||||
|
self.metadata["render_modes"].append(f"{self.env.render_mode}_list")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def render_mode(self):
|
||||||
|
"""Returns the collection render_mode name."""
|
||||||
|
return f"{self.env.render_mode}_list"
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Perform a step in the base environment and collect a frame."""
|
||||||
|
output = super().step(action)
|
||||||
|
self.frame_list.append(super().render())
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
|
||||||
|
output = super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
if self.reset_clean:
|
||||||
|
self.frame_list = []
|
||||||
|
self.frame_list.append(super().render())
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||||
|
"""Returns the collection of frames and, if pop_frames = True, clears it."""
|
||||||
|
frames = self.frame_list
|
||||||
|
if self.pop_frames:
|
||||||
|
self.frame_list = []
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class RecordVideoV0(gym.Wrapper):
|
||||||
|
"""Record a video of an environment."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HumanRenderingV0(gym.Wrapper):
|
||||||
|
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
||||||
|
|
||||||
|
This wrapper is particularly useful when you have implemented an environment that can produce
|
||||||
|
RGB images but haven't implemented any code to render the images to the screen.
|
||||||
|
If you want to use this wrapper with your environments, remember to specify ``"render_fps"``
|
||||||
|
in the metadata of your environment.
|
||||||
|
|
||||||
|
The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
|
||||||
|
>>> wrapped = HumanRenderingV0(env)
|
||||||
|
>>> wrapped.reset() # This will start rendering to the screen
|
||||||
|
|
||||||
|
The wrapper can also be applied directly when the environment is instantiated, simply by passing
|
||||||
|
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
|
||||||
|
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> env = gym.make("NoNativeRendering-v2", render_mode="human") # NoNativeRendering-v0 doesn't implement human-rendering natively
|
||||||
|
>>> env.reset() # This will start rendering to the screen
|
||||||
|
|
||||||
|
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
|
||||||
|
will always return an empty list:
|
||||||
|
|
||||||
|
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
|
||||||
|
>>> wrapped = HumanRenderingV0(env)
|
||||||
|
>>> wrapped.reset()
|
||||||
|
>>> env.render()
|
||||||
|
[] # env.render() will always return an empty list!
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env):
|
||||||
|
"""Initialize a :class:`HumanRendering` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment that is being wrapped
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
assert env.render_mode in [
|
||||||
|
"rgb_array",
|
||||||
|
"rgb_array_list",
|
||||||
|
], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'"
|
||||||
|
assert (
|
||||||
|
"render_fps" in env.metadata
|
||||||
|
), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper"
|
||||||
|
|
||||||
|
self.screen_size = None
|
||||||
|
self.window = None
|
||||||
|
self.clock = None
|
||||||
|
|
||||||
|
if "human" not in self.metadata["render_modes"]:
|
||||||
|
self.metadata = deepcopy(self.env.metadata)
|
||||||
|
self.metadata["render_modes"].append("human")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def render_mode(self):
|
||||||
|
"""Always returns ``'human'``."""
|
||||||
|
return "human"
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Perform a step in the base environment and render a frame to the screen."""
|
||||||
|
result = super().step(action)
|
||||||
|
self._render_frame()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Reset the base environment and render a frame to the screen."""
|
||||||
|
result = super().reset(seed=seed, options=options)
|
||||||
|
self._render_frame()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
"""This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _render_frame(self):
|
||||||
|
"""Fetch the last frame from the base environment and render it to the screen."""
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gymnasium[box2d]`"
|
||||||
|
)
|
||||||
|
if self.env.render_mode == "rgb_array_list":
|
||||||
|
last_rgb_array = self.env.render()
|
||||||
|
assert isinstance(last_rgb_array, list)
|
||||||
|
last_rgb_array = last_rgb_array[-1]
|
||||||
|
elif self.env.render_mode == "rgb_array":
|
||||||
|
last_rgb_array = self.env.render()
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}"
|
||||||
|
)
|
||||||
|
assert isinstance(last_rgb_array, np.ndarray)
|
||||||
|
|
||||||
|
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
|
||||||
|
|
||||||
|
if self.screen_size is None:
|
||||||
|
self.screen_size = rgb_array.shape[:2]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.screen_size == rgb_array.shape[:2]
|
||||||
|
), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}"
|
||||||
|
|
||||||
|
if self.window is None:
|
||||||
|
pygame.init()
|
||||||
|
pygame.display.init()
|
||||||
|
self.window = pygame.display.set_mode(self.screen_size)
|
||||||
|
|
||||||
|
if self.clock is None:
|
||||||
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
surf = pygame.surfarray.make_surface(rgb_array)
|
||||||
|
self.window.blit(surf, (0, 0))
|
||||||
|
pygame.event.pump()
|
||||||
|
self.clock.tick(self.metadata["render_fps"])
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the rendering window."""
|
||||||
|
if self.window is not None:
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
pygame.display.quit()
|
||||||
|
pygame.quit()
|
||||||
|
super().close()
|
@@ -1,17 +1,14 @@
|
|||||||
"""A collection of stateful action wrappers.
|
"""``StickyAction`` wrapper - There is a probability that the action is taken again."""
|
||||||
|
|
||||||
* StickyAction - There is a probability that the action is taken again
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, SupportsFloat
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import WrapperActType, WrapperObsType
|
from gymnasium.core import ActionWrapper, ActType, WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import InvalidProbability
|
from gymnasium.error import InvalidProbability
|
||||||
|
|
||||||
|
|
||||||
class StickyActionV0(gym.Wrapper):
|
class StickyActionV0(ActionWrapper):
|
||||||
"""Wrapper which adds a probability of repeating the previous action.
|
"""Wrapper which adds a probability of repeating the previous action.
|
||||||
|
|
||||||
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||||
@@ -42,9 +39,7 @@ class StickyActionV0(gym.Wrapper):
|
|||||||
|
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
def step(
|
def action(self, action: WrapperActType) -> ActType:
|
||||||
self, action: WrapperActType
|
|
||||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
|
||||||
"""Execute the action."""
|
"""Execute the action."""
|
||||||
if (
|
if (
|
||||||
self.last_action is not None
|
self.last_action is not None
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
"""A collection of stateful observation wrappers.
|
"""A collection of stateful observation wrappers.
|
||||||
|
|
||||||
* DelayObservation - A wrapper for delaying the returned observation
|
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
|
||||||
* TimeAwareObservation - A wrapper for adding time aware observations to environment observation
|
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
|
||||||
|
* ``FrameStackObservationV0`` - Frame stack the observations
|
||||||
|
* ``AtariPreprocessingV0`` - Preprocessing wrapper for atari environments
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -14,8 +16,10 @@ import numpy as np
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import gymnasium.spaces as spaces
|
import gymnasium.spaces as spaces
|
||||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
from gymnasium import Env
|
||||||
|
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
|
||||||
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
|
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
|
||||||
|
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||||
|
|
||||||
|
|
||||||
class DelayObservationV0(gym.ObservationWrapper):
|
class DelayObservationV0(gym.ObservationWrapper):
|
||||||
@@ -31,7 +35,9 @@ class DelayObservationV0(gym.ObservationWrapper):
|
|||||||
returned observation is an array of zeros with the
|
returned observation is an array of zeros with the
|
||||||
same shape of the observation space.
|
same shape of the observation space.
|
||||||
"""
|
"""
|
||||||
assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete))
|
assert isinstance(
|
||||||
|
env.observation_space, (Box, MultiBinary, MultiDiscrete)
|
||||||
|
), type(env.observation_space)
|
||||||
assert 0 < delay
|
assert 0 < delay
|
||||||
|
|
||||||
self.delay: Final[int] = delay
|
self.delay: Final[int] = delay
|
||||||
@@ -134,9 +140,9 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
|
|||||||
if isinstance(env.observation_space, Dict):
|
if isinstance(env.observation_space, Dict):
|
||||||
assert dict_time_key not in env.observation_space.keys()
|
assert dict_time_key not in env.observation_space.keys()
|
||||||
observation_space = Dict(
|
observation_space = Dict(
|
||||||
{dict_time_key: time_space}, **env.observation_space.spaces
|
{dict_time_key: time_space, **env.observation_space.spaces}
|
||||||
)
|
)
|
||||||
self._append_data_func = lambda obs, time: {**obs, dict_time_key: time}
|
self._append_data_func = lambda obs, time: {dict_time_key: time, **obs}
|
||||||
elif isinstance(env.observation_space, Tuple):
|
elif isinstance(env.observation_space, Tuple):
|
||||||
observation_space = Tuple(env.observation_space.spaces + (time_space,))
|
observation_space = Tuple(env.observation_space.spaces + (time_space,))
|
||||||
self._append_data_func = lambda obs, time: obs + (time,)
|
self._append_data_func = lambda obs, time: obs + (time,)
|
||||||
@@ -198,3 +204,101 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
|
|||||||
self.timesteps = 0
|
self.timesteps = 0
|
||||||
|
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameStackObservationV0(gym.Wrapper):
|
||||||
|
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||||
|
|
||||||
|
For example, if the number of stacks is 4, then the returned observation contains
|
||||||
|
the most recent 4 observations. For environment 'Pendulum-v1', the original observation
|
||||||
|
is an array with shape [3], so if we stack 4 observations, the processed observation
|
||||||
|
has shape [4, 3].
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
|
||||||
|
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import gymnasium as gym
|
||||||
|
>>> env = gym.make('CarRacing-v1')
|
||||||
|
>>> env = FrameStack(env, 4)
|
||||||
|
>>> env.observation_space
|
||||||
|
Box(4, 96, 96, 3)
|
||||||
|
>>> obs = env.reset()
|
||||||
|
>>> obs.shape
|
||||||
|
(4, 96, 96, 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: Env[ObsType, ActType], stack_size: int):
|
||||||
|
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to apply the wrapper
|
||||||
|
stack_size: The number of frames to stack
|
||||||
|
"""
|
||||||
|
assert np.issubdtype(type(stack_size), np.integer)
|
||||||
|
assert stack_size > 0
|
||||||
|
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
self.observation_space = batch_space(env.observation_space, n=stack_size)
|
||||||
|
self.stack_size = stack_size
|
||||||
|
|
||||||
|
self.stacked_obs_array = create_empty_array(env.observation_space, n=stack_size)
|
||||||
|
self.stacked_obs = self._init_stacked_obs()
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
|
"""Steps through the environment, appending the observation to the frame buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to step through the environment with
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stacked observations, reward, terminated, truncated, and info from the environment
|
||||||
|
"""
|
||||||
|
obs, reward, terminated, truncated, info = super().step(action)
|
||||||
|
self.stacked_obs.rotate(1)
|
||||||
|
self.stacked_obs[0] = obs
|
||||||
|
|
||||||
|
return (
|
||||||
|
concatenate(
|
||||||
|
self.observation_space, self.stacked_obs, self.stacked_obs_array
|
||||||
|
),
|
||||||
|
reward,
|
||||||
|
terminated,
|
||||||
|
truncated,
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Reset the environment, returning the stacked observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The environment seed
|
||||||
|
options: The reset options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The stacked observations and info
|
||||||
|
"""
|
||||||
|
obs, info = super().reset(seed=seed, options=options)
|
||||||
|
self.stacked_obs = self._init_stacked_obs()
|
||||||
|
self.stacked_obs[0] = obs
|
||||||
|
|
||||||
|
return (
|
||||||
|
concatenate(
|
||||||
|
self.observation_space, self.stacked_obs, self.stacked_obs_array
|
||||||
|
),
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_stacked_obs(self) -> deque:
|
||||||
|
return deque(
|
||||||
|
iterate(
|
||||||
|
self.observation_space,
|
||||||
|
create_empty_array(self.env.observation_space, n=self.stack_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
43
gymnasium/experimental/wrappers/utils.py
Normal file
43
gymnasium/experimental/wrappers/utils.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Utility functions for the wrappers."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class RunningMeanStd:
|
||||||
|
"""Tracks the mean, variance and count of values."""
|
||||||
|
|
||||||
|
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||||
|
def __init__(self, epsilon=1e-4, shape=()):
|
||||||
|
"""Tracks the mean, variance and count of values."""
|
||||||
|
self.mean = np.zeros(shape, "float64")
|
||||||
|
self.var = np.ones(shape, "float64")
|
||||||
|
self.count = epsilon
|
||||||
|
|
||||||
|
def update(self, x):
|
||||||
|
"""Updates the mean, var and count from a batch of samples."""
|
||||||
|
batch_mean = np.mean(x, axis=0)
|
||||||
|
batch_var = np.var(x, axis=0)
|
||||||
|
batch_count = x.shape[0]
|
||||||
|
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||||
|
|
||||||
|
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
||||||
|
"""Updates from batch mean, variance and count moments."""
|
||||||
|
self.mean, self.var, self.count = update_mean_var_count_from_moments(
|
||||||
|
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_mean_var_count_from_moments(
|
||||||
|
mean, var, count, batch_mean, batch_var, batch_count
|
||||||
|
):
|
||||||
|
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
|
||||||
|
delta = batch_mean - mean
|
||||||
|
tot_count = count + batch_count
|
||||||
|
|
||||||
|
new_mean = mean + delta * batch_count / tot_count
|
||||||
|
m_a = var * count
|
||||||
|
m_b = batch_var * batch_count
|
||||||
|
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
|
||||||
|
new_var = M2 / tot_count
|
||||||
|
new_count = tot_count
|
||||||
|
|
||||||
|
return new_mean, new_var, new_count
|
@@ -93,7 +93,7 @@ class NormalizeObservation(gym.Wrapper):
|
|||||||
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
|
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeReward(gym.Wrapper):
|
class NormalizeReward(gym.core.Wrapper):
|
||||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||||
@@ -129,10 +129,8 @@ class NormalizeReward(gym.Wrapper):
|
|||||||
obs, rews, terminateds, truncateds, infos = self.env.step(action)
|
obs, rews, terminateds, truncateds, infos = self.env.step(action)
|
||||||
if not self.is_vector_env:
|
if not self.is_vector_env:
|
||||||
rews = np.array([rews])
|
rews = np.array([rews])
|
||||||
self.returns = self.returns * self.gamma + rews
|
self.returns = self.returns * self.gamma * (1 - terminateds) + rews
|
||||||
rews = self.normalize(rews)
|
rews = self.normalize(rews)
|
||||||
dones = np.logical_or(terminateds, truncateds)
|
|
||||||
self.returns[dones] = 0.0
|
|
||||||
if not self.is_vector_env:
|
if not self.is_vector_env:
|
||||||
rews = rews[0]
|
rews = rews[0]
|
||||||
return obs, rews, terminateds, truncateds, infos
|
return obs, rews, terminateds, truncateds, infos
|
||||||
|
@@ -0,0 +1 @@
|
|||||||
|
"""Testing for Gymnasium."""
|
||||||
|
@@ -0,0 +1 @@
|
|||||||
|
"""Testing suite for ``gymnasium.experimental``."""
|
||||||
|
@@ -0,0 +1 @@
|
|||||||
|
"""Module for functional environment API."""
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
|
"""Test the functional jax environment."""
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax.random as jrng
|
import jax.random as jrng
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
||||||
@@ -9,7 +10,8 @@ from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||||
def test_normal(env_class):
|
def test_without_transform(env_class):
|
||||||
|
"""Tests the environment without transforming the environment."""
|
||||||
env = env_class()
|
env = env_class()
|
||||||
rng = jrng.PRNGKey(0)
|
rng = jrng.PRNGKey(0)
|
||||||
|
|
||||||
@@ -42,6 +44,7 @@ def test_normal(env_class):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||||
def test_jit(env_class):
|
def test_jit(env_class):
|
||||||
|
"""Tests jitting the functional instance functions."""
|
||||||
env = env_class()
|
env = env_class()
|
||||||
rng = jrng.PRNGKey(0)
|
rng = jrng.PRNGKey(0)
|
||||||
|
|
||||||
@@ -75,6 +78,7 @@ def test_jit(env_class):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||||
def test_vmap(env_class):
|
def test_vmap(env_class):
|
||||||
|
"""Tests vmap of functional instance functions with transform."""
|
||||||
env = env_class()
|
env = env_class()
|
||||||
num_envs = 10
|
num_envs = 10
|
||||||
rng = jrng.split(jrng.PRNGKey(0), num_envs)
|
rng = jrng.split(jrng.PRNGKey(0), num_envs)
|
||||||
@@ -98,7 +102,7 @@ def test_vmap(env_class):
|
|||||||
assert reward.shape == (num_envs,)
|
assert reward.shape == (num_envs,)
|
||||||
assert reward.dtype == jnp.float32
|
assert reward.dtype == jnp.float32
|
||||||
assert terminal.shape == (num_envs,)
|
assert terminal.shape == (num_envs,)
|
||||||
assert terminal.dtype == np.bool
|
assert terminal.dtype == bool
|
||||||
assert isinstance(obs, jnp.ndarray)
|
assert isinstance(obs, jnp.ndarray)
|
||||||
assert obs.dtype == jnp.float32
|
assert obs.dtype == jnp.float32
|
||||||
|
|
@@ -1,4 +1,7 @@
|
|||||||
from typing import Any, Dict, Optional
|
"""Tests the functional api."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -6,29 +9,41 @@ from gymnasium.experimental import FuncEnv
|
|||||||
|
|
||||||
|
|
||||||
class GenericTestFuncEnv(FuncEnv):
|
class GenericTestFuncEnv(FuncEnv):
|
||||||
def __init__(self, options: Optional[Dict[str, Any]] = None):
|
"""Generic testing functional environment."""
|
||||||
|
|
||||||
|
def __init__(self, options: dict[str, Any] | None = None):
|
||||||
|
"""Constructor that allows generic options to be set on the environment."""
|
||||||
super().__init__(options)
|
super().__init__(options)
|
||||||
|
|
||||||
def initial(self, rng: Any) -> np.ndarray:
|
def initial(self, rng: Any) -> np.ndarray:
|
||||||
|
"""Testing initial function."""
|
||||||
return np.array([0, 0], dtype=np.float32)
|
return np.array([0, 0], dtype=np.float32)
|
||||||
|
|
||||||
def observation(self, state: np.ndarray) -> np.ndarray:
|
def observation(self, state: np.ndarray) -> np.ndarray:
|
||||||
|
"""Testing observation function."""
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
|
def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
|
||||||
|
"""Testing transition function."""
|
||||||
return state + np.array([0, action], dtype=np.float32)
|
return state + np.array([0, action], dtype=np.float32)
|
||||||
|
|
||||||
def reward(self, state: np.ndarray, action: int, next_state: np.ndarray) -> float:
|
def reward(self, state: np.ndarray, action: int, next_state: np.ndarray) -> float:
|
||||||
|
"""Testing reward function."""
|
||||||
return 1.0 if next_state[1] > 0 else 0.0
|
return 1.0 if next_state[1] > 0 else 0.0
|
||||||
|
|
||||||
def terminal(self, state: np.ndarray) -> bool:
|
def terminal(self, state: np.ndarray) -> bool:
|
||||||
|
"""Testing terminal function."""
|
||||||
return state[1] > 0
|
return state[1] > 0
|
||||||
|
|
||||||
|
|
||||||
def test_api():
|
def test_functional_api():
|
||||||
|
"""Tests the core functional api specification using a generic testing environment."""
|
||||||
env = GenericTestFuncEnv()
|
env = GenericTestFuncEnv()
|
||||||
|
|
||||||
state = env.initial(None)
|
state = env.initial(None)
|
||||||
|
|
||||||
obs = env.observation(state)
|
obs = env.observation(state)
|
||||||
|
|
||||||
assert state.shape == (2,)
|
assert state.shape == (2,)
|
||||||
assert state.dtype == np.float32
|
assert state.dtype == np.float32
|
||||||
assert obs.shape == (2,)
|
assert obs.shape == (2,)
|
@@ -0,0 +1 @@
|
|||||||
|
"""Experimental wrapper module."""
|
||||||
|
1
tests/experimental/wrappers/human_rendering.py
Normal file
1
tests/experimental/wrappers/human_rendering.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for HumanRenderingV0."""
|
1
tests/experimental/wrappers/test_atari_preprocessing.py
Normal file
1
tests/experimental/wrappers/test_atari_preprocessing.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for AtariPreprocessingV0."""
|
1
tests/experimental/wrappers/test_autoreset.py
Normal file
1
tests/experimental/wrappers/test_autoreset.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for AutoresetV0."""
|
25
tests/experimental/wrappers/test_clip_action.py
Normal file
25
tests/experimental/wrappers/test_clip_action.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Test suite for ClipActionV0."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import ClipActionV0
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from tests.experimental.wrappers.utils import record_action_step
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_action_wrapper():
|
||||||
|
"""Test that the action is correctly clipped to the base environment action space."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
|
||||||
|
step_func=record_action_step,
|
||||||
|
)
|
||||||
|
wrapped_env = ClipActionV0(env)
|
||||||
|
|
||||||
|
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
|
||||||
|
assert sampled_action not in env.action_space
|
||||||
|
assert sampled_action in wrapped_env.action_space
|
||||||
|
|
||||||
|
_, _, _, _, info = wrapped_env.step(sampled_action)
|
||||||
|
assert np.all(info["action"] in env.action_space)
|
||||||
|
assert np.all(info["action"] == np.array([0, 2, 3.5]))
|
67
tests/experimental/wrappers/test_clip_reward.py
Normal file
67
tests/experimental/wrappers/test_clip_reward.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Test suite for ClipRewardV0."""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.error import InvalidBound
|
||||||
|
from gymnasium.experimental.wrappers import ClipRewardV0
|
||||||
|
from tests.envs.test_envs import SEED
|
||||||
|
from tests.experimental.wrappers.test_lambda_rewards import (
|
||||||
|
DISCRETE_ACTION,
|
||||||
|
ENV_ID,
|
||||||
|
NUM_ENVS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("lower_bound", "upper_bound", "expected_reward"),
|
||||||
|
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
|
||||||
|
)
|
||||||
|
def test_clip_reward(lower_bound, upper_bound, expected_reward):
|
||||||
|
"""Test reward clipping.
|
||||||
|
|
||||||
|
Test if reward is correctly clipped accordingly to the input args.
|
||||||
|
"""
|
||||||
|
env = gym.make(ENV_ID)
|
||||||
|
env = ClipRewardV0(env, lower_bound, upper_bound)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
|
||||||
|
|
||||||
|
assert rew == expected_reward
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("lower_bound", "upper_bound", "expected_reward"),
|
||||||
|
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
|
||||||
|
)
|
||||||
|
def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward):
|
||||||
|
"""Test reward clipping in vectorized environment.
|
||||||
|
|
||||||
|
Test if reward is correctly clipped accordingly to the input args in a vectorized environment.
|
||||||
|
"""
|
||||||
|
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
|
||||||
|
|
||||||
|
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
|
||||||
|
env = ClipRewardV0(env, lower_bound, upper_bound)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
|
_, rew, _, _, _ = env.step(actions)
|
||||||
|
|
||||||
|
assert np.alltrue(rew == expected_reward)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("lower_bound", "upper_bound"),
|
||||||
|
[(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))],
|
||||||
|
)
|
||||||
|
def test_clip_reward_incorrect_params(lower_bound, upper_bound):
|
||||||
|
"""Test reward clipping with incorrect params.
|
||||||
|
|
||||||
|
Test whether passing wrong params to clip_rewards correctly raise an exception.
|
||||||
|
clip_rewards should raise an exception if, both low and upper bound of reward are `None`
|
||||||
|
or if upper bound is lower than lower bound.
|
||||||
|
"""
|
||||||
|
env = gym.make(ENV_ID)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidBound):
|
||||||
|
ClipRewardV0(env, lower_bound, upper_bound)
|
34
tests/experimental/wrappers/test_delay_observation.py
Normal file
34
tests/experimental/wrappers/test_delay_observation.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Test suite for DelayObservationV0."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.experimental.wrappers import DelayObservationV0
|
||||||
|
from tests.experimental.wrappers.utils import DELAY, NUM_STEPS, SEED
|
||||||
|
|
||||||
|
|
||||||
|
def test_delay_observation_wrapper():
|
||||||
|
"""Tests the delay observation wrapper."""
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
|
undelayed_observations = []
|
||||||
|
for _ in range(NUM_STEPS):
|
||||||
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
undelayed_observations.append(obs)
|
||||||
|
|
||||||
|
env = DelayObservationV0(env, delay=DELAY)
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
|
delayed_observations = []
|
||||||
|
for i in range(NUM_STEPS):
|
||||||
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
delayed_observations.append(obs)
|
||||||
|
if i < DELAY - 1:
|
||||||
|
assert np.all(obs == 0)
|
||||||
|
|
||||||
|
undelayed_observations = np.array(undelayed_observations)
|
||||||
|
delayed_observations = np.array(delayed_observations)
|
||||||
|
|
||||||
|
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])
|
25
tests/experimental/wrappers/test_dtype_observation.py
Normal file
25
tests/experimental/wrappers/test_dtype_observation.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Test suite for DtypeObservationV0."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import DtypeObservationV0
|
||||||
|
from tests.experimental.wrappers.utils import (
|
||||||
|
record_random_obs_reset,
|
||||||
|
record_random_obs_step,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_dtype_observation():
|
||||||
|
"""Test ``DtypeObservation`` that the dtype is corrected modified."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
reset_func=record_random_obs_reset, step_func=record_random_obs_step
|
||||||
|
)
|
||||||
|
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert obs.dtype != info["obs"].dtype
|
||||||
|
assert obs.dtype == np.uint8
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
assert obs.dtype != info["obs"].dtype
|
||||||
|
assert obs.dtype == np.uint8
|
45
tests/experimental/wrappers/test_filter_observation.py
Normal file
45
tests/experimental/wrappers/test_filter_observation.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Test suite for FilterObservationV0."""
|
||||||
|
from gymnasium.experimental.wrappers import FilterObservationV0
|
||||||
|
from gymnasium.spaces import Box, Dict, Tuple
|
||||||
|
from tests.experimental.wrappers.utils import (
|
||||||
|
check_obs,
|
||||||
|
record_random_obs_reset,
|
||||||
|
record_random_obs_step,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_observation_wrapper():
|
||||||
|
"""Tests ``FilterObservation`` that the right keys are filtered."""
|
||||||
|
dict_env = GenericTestEnv(
|
||||||
|
observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
|
||||||
|
reset_func=record_random_obs_reset,
|
||||||
|
step_func=record_random_obs_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
||||||
|
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
||||||
|
check_obs(dict_env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
||||||
|
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
||||||
|
check_obs(dict_env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
# Test tuple environments
|
||||||
|
tuple_env = GenericTestEnv(
|
||||||
|
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
|
||||||
|
reset_func=record_random_obs_reset,
|
||||||
|
step_func=record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = FilterObservationV0(tuple_env, (2,))
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert len(obs) == 1 and len(info["obs"]) == 3
|
||||||
|
check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
assert len(obs) == 1 and len(info["obs"]) == 3
|
||||||
|
check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
27
tests/experimental/wrappers/test_flatten_observation.py
Normal file
27
tests/experimental/wrappers/test_flatten_observation.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Test suite for FlattenObservationV0."""
|
||||||
|
from gymnasium.experimental.wrappers import FlattenObservationV0
|
||||||
|
from gymnasium.spaces import Box, Dict
|
||||||
|
from tests.experimental.wrappers.utils import (
|
||||||
|
check_obs,
|
||||||
|
record_random_obs_reset,
|
||||||
|
record_random_obs_step,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_flatten_observation_wrapper():
|
||||||
|
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
|
||||||
|
reset_func=record_random_obs_reset,
|
||||||
|
step_func=record_random_obs_step,
|
||||||
|
)
|
||||||
|
print(env.observation_space)
|
||||||
|
wrapped_env = FlattenObservationV0(env)
|
||||||
|
print(wrapped_env.observation_space)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for FrameStackObservationV0."""
|
38
tests/experimental/wrappers/test_grayscale_observation.py
Normal file
38
tests/experimental/wrappers/test_grayscale_observation.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Test suite for GrayscaleObservationV0."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import GrayscaleObservationV0
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from tests.experimental.wrappers.utils import (
|
||||||
|
check_obs,
|
||||||
|
record_random_obs_reset,
|
||||||
|
record_random_obs_step,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_grayscale_observation_wrapper():
|
||||||
|
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
|
||||||
|
reset_func=record_random_obs_reset,
|
||||||
|
step_func=record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = GrayscaleObservationV0(env)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (25, 25)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
# Keep_dim
|
||||||
|
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (25, 25, 1)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
@@ -1,9 +1,11 @@
|
|||||||
|
"""Test suite for JaxToNumpyV0."""
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers import JaxToNumpyV0
|
from gymnasium.experimental.wrappers import JaxToNumpyV0
|
||||||
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy, numpy_to_jax
|
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
||||||
from gymnasium.utils.env_checker import data_equivalence
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
@@ -40,10 +42,12 @@ def test_roundtripping(value, expected_value):
|
|||||||
|
|
||||||
|
|
||||||
def jax_reset_func(self, seed=None, options=None):
|
def jax_reset_func(self, seed=None, options=None):
|
||||||
|
"""A jax-based reset function."""
|
||||||
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
||||||
|
|
||||||
|
|
||||||
def jax_step_func(self, action):
|
def jax_step_func(self, action):
|
||||||
|
"""A jax-based step function."""
|
||||||
assert isinstance(action, jnp.DeviceArray), type(action)
|
assert isinstance(action, jnp.DeviceArray), type(action)
|
||||||
return (
|
return (
|
||||||
jnp.array([1, 2, 3]),
|
jnp.array([1, 2, 3]),
|
||||||
@@ -54,7 +58,8 @@ def jax_step_func(self, action):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_jax_to_numpy():
|
def test_jax_to_numpy_wrapper():
|
||||||
|
"""Tests the ``JaxToNumpyV0`` wrapper."""
|
||||||
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
|
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
|
||||||
|
|
||||||
# Check that the reset and step for jax environment are as expected
|
# Check that the reset and step for jax environment are as expected
|
@@ -1,10 +1,12 @@
|
|||||||
|
"""Test suite for TorchToJaxV0."""
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers import JaxToTorchV0
|
from gymnasium.experimental.wrappers import JaxToTorchV0
|
||||||
from gymnasium.experimental.wrappers.torch_to_jax import jax_to_torch, torch_to_jax
|
from gymnasium.experimental.wrappers.jax_to_torch import jax_to_torch, torch_to_jax
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
@@ -76,7 +78,8 @@ def _jax_step_func(self, action):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_jax_to_torch():
|
def test_jax_to_torch_wrapper():
|
||||||
|
"""Tests the `JaxToTorchV0` wrapper."""
|
||||||
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
|
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
|
||||||
|
|
||||||
# Check that the reset and step for jax environment are as expected
|
# Check that the reset and step for jax environment are as expected
|
@@ -1,25 +1,14 @@
|
|||||||
"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction."""
|
"""Test suite for LambdaActionV0."""
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers import (
|
from gymnasium.experimental.wrappers import LambdaActionV0
|
||||||
ClipActionV0,
|
|
||||||
LambdaActionV0,
|
|
||||||
RescaleActionV0,
|
|
||||||
)
|
|
||||||
from gymnasium.spaces import Box
|
from gymnasium.spaces import Box
|
||||||
|
from tests.experimental.wrappers.utils import record_action_step
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
SEED = 42
|
|
||||||
|
|
||||||
|
|
||||||
def _record_action_step_func(self, action):
|
|
||||||
return 0, 0, False, False, {"action": action}
|
|
||||||
|
|
||||||
|
|
||||||
def test_lambda_action_wrapper():
|
def test_lambda_action_wrapper():
|
||||||
"""Tests LambdaAction through checking that the action taken is transformed by function."""
|
"""Tests LambdaAction through checking that the action taken is transformed by function."""
|
||||||
env = GenericTestEnv(step_func=_record_action_step_func)
|
env = GenericTestEnv(step_func=record_action_step)
|
||||||
wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
|
wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
|
||||||
|
|
||||||
sampled_action = wrapped_env.action_space.sample()
|
sampled_action = wrapped_env.action_space.sample()
|
||||||
@@ -28,51 +17,3 @@ def test_lambda_action_wrapper():
|
|||||||
_, _, _, _, info = wrapped_env.step(sampled_action)
|
_, _, _, _, info = wrapped_env.step(sampled_action)
|
||||||
assert info["action"] in env.action_space
|
assert info["action"] in env.action_space
|
||||||
assert sampled_action - 2 == info["action"]
|
assert sampled_action - 2 == info["action"]
|
||||||
|
|
||||||
|
|
||||||
def test_clip_action_wrapper():
|
|
||||||
"""Test that the action is correctly clipped to the base environment action space."""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
|
|
||||||
step_func=_record_action_step_func,
|
|
||||||
)
|
|
||||||
wrapped_env = ClipActionV0(env)
|
|
||||||
|
|
||||||
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
|
|
||||||
assert sampled_action not in env.action_space
|
|
||||||
assert sampled_action in wrapped_env.action_space
|
|
||||||
|
|
||||||
_, _, _, _, info = wrapped_env.step(sampled_action)
|
|
||||||
assert np.all(info["action"] in env.action_space)
|
|
||||||
assert np.all(info["action"] == np.array([0, 2, 3.5]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_rescale_action_wrapper():
|
|
||||||
"""Test that the action is rescale within a min / max bound."""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
step_func=_record_action_step_func,
|
|
||||||
action_space=Box(np.array([0, 1]), np.array([1, 3])),
|
|
||||||
)
|
|
||||||
wrapped_env = RescaleActionV0(
|
|
||||||
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
|
|
||||||
)
|
|
||||||
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
|
|
||||||
|
|
||||||
for sample_action, expected_action in (
|
|
||||||
(
|
|
||||||
np.array([0.0, 0.5], dtype=np.float32),
|
|
||||||
np.array([0.5, 2.0], dtype=np.float32),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.array([-5.0, 0.0], dtype=np.float32),
|
|
||||||
np.array([0.0, 1.0], dtype=np.float32),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.array([5.0, 1.0], dtype=np.float32),
|
|
||||||
np.array([1.0, 3.0], dtype=np.float32),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
assert sample_action in wrapped_env.action_space
|
|
||||||
|
|
||||||
_, _, _, _, info = wrapped_env.step(sample_action)
|
|
||||||
assert np.all(info["action"] == expected_action)
|
|
||||||
|
@@ -1,250 +1,26 @@
|
|||||||
"""Test suite for lambda observation wrappers: """
|
"""Test suite for lambda observation wrappers."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
from gymnasium.experimental.wrappers import LambdaObservationV0
|
||||||
from gymnasium.experimental.wrappers import (
|
from gymnasium.spaces import Box
|
||||||
DtypeObservationV0,
|
from tests.experimental.wrappers.utils import (
|
||||||
FilterObservationV0,
|
check_obs,
|
||||||
FlattenObservationV0,
|
record_action_as_obs_step,
|
||||||
GrayscaleObservationV0,
|
record_obs_reset,
|
||||||
LambdaObservationV0,
|
|
||||||
RescaleObservationV0,
|
|
||||||
ReshapeObservationV0,
|
|
||||||
ResizeObservationV0,
|
|
||||||
)
|
)
|
||||||
from gymnasium.spaces import Box, Dict, Tuple
|
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
SEED = 42
|
|
||||||
|
|
||||||
|
|
||||||
def _record_random_obs_reset(self: gym.Env, seed=None, options=None):
|
|
||||||
obs = self.observation_space.sample()
|
|
||||||
return obs, {"obs": obs}
|
|
||||||
|
|
||||||
|
|
||||||
def _record_random_obs_step(self: gym.Env, action):
|
|
||||||
obs = self.observation_space.sample()
|
|
||||||
return obs, 0, False, False, {"obs": obs}
|
|
||||||
|
|
||||||
|
|
||||||
def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}):
|
|
||||||
return options["obs"], {"obs": options["obs"]}
|
|
||||||
|
|
||||||
|
|
||||||
def _record_action_obs_step(self: gym.Env, action):
|
|
||||||
return action, 0, False, False, {"obs": action}
|
|
||||||
|
|
||||||
|
|
||||||
def _check_obs(
|
|
||||||
env: gym.Env,
|
|
||||||
wrapped_env: gym.Wrapper,
|
|
||||||
transformed_obs,
|
|
||||||
original_obs,
|
|
||||||
strict: bool = True,
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
transformed_obs in wrapped_env.observation_space
|
|
||||||
), f"{transformed_obs}, {wrapped_env.observation_space}"
|
|
||||||
assert (
|
|
||||||
original_obs in env.observation_space
|
|
||||||
), f"{original_obs}, {env.observation_space}"
|
|
||||||
|
|
||||||
if strict:
|
|
||||||
assert (
|
|
||||||
transformed_obs not in env.observation_space
|
|
||||||
), f"{transformed_obs}, {env.observation_space}"
|
|
||||||
assert (
|
|
||||||
original_obs not in wrapped_env.observation_space
|
|
||||||
), f"{original_obs}, {wrapped_env.observation_space}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_lambda_observation_wrapper():
|
def test_lambda_observation_wrapper():
|
||||||
"""Tests lambda observation that the function is applied to both the reset and step observation."""
|
"""Tests lambda observation that the function is applied to both the reset and step observation."""
|
||||||
env = GenericTestEnv(
|
env = GenericTestEnv(
|
||||||
reset_func=_record_action_obs_reset, step_func=_record_action_obs_step
|
reset_func=record_obs_reset, step_func=record_action_as_obs_step
|
||||||
)
|
)
|
||||||
wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3))
|
wrapped_env = LambdaObservationV0(env, lambda _obs: _obs + 2, Box(2, 3))
|
||||||
|
|
||||||
obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)})
|
obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)})
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32))
|
obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32))
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
|
||||||
def test_filter_observation_wrapper():
|
|
||||||
"""Tests ``FilterObservation`` that the right keys are filtered."""
|
|
||||||
dict_env = GenericTestEnv(
|
|
||||||
observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
|
|
||||||
reset_func=_record_random_obs_reset,
|
|
||||||
step_func=_record_random_obs_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
|
||||||
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
|
||||||
_check_obs(dict_env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
|
||||||
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
|
||||||
_check_obs(dict_env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
# Test tuple environments
|
|
||||||
tuple_env = GenericTestEnv(
|
|
||||||
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
|
|
||||||
reset_func=_record_random_obs_reset,
|
|
||||||
step_func=_record_random_obs_step,
|
|
||||||
)
|
|
||||||
wrapped_env = FilterObservationV0(tuple_env, (2,))
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
assert len(obs) == 1 and len(info["obs"]) == 3
|
|
||||||
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
assert len(obs) == 1 and len(info["obs"]) == 3
|
|
||||||
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_flatten_observation_wrapper():
|
|
||||||
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
|
|
||||||
reset_func=_record_random_obs_reset,
|
|
||||||
step_func=_record_random_obs_step,
|
|
||||||
)
|
|
||||||
print(env.observation_space)
|
|
||||||
wrapped_env = FlattenObservationV0(env)
|
|
||||||
print(wrapped_env.observation_space)
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_grayscale_observation_wrapper():
|
|
||||||
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
|
|
||||||
reset_func=_record_random_obs_reset,
|
|
||||||
step_func=_record_random_obs_step,
|
|
||||||
)
|
|
||||||
wrapped_env = GrayscaleObservationV0(env)
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
assert obs.shape == (25, 25)
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
# Keep_dim
|
|
||||||
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
assert obs.shape == (25, 25, 1)
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_resize_observation_wrapper():
|
|
||||||
"""Test the ``ResizeObservation`` that the observation has changed size"""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
|
|
||||||
reset_func=_record_random_obs_reset,
|
|
||||||
step_func=_record_random_obs_step,
|
|
||||||
)
|
|
||||||
wrapped_env = ResizeObservationV0(env, (25, 25))
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_reshape_observation_wrapper():
|
|
||||||
"""Test the ``ReshapeObservation`` wrapper."""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
observation_space=Box(0, 1, shape=(2, 3, 2)),
|
|
||||||
reset_func=_record_random_obs_reset,
|
|
||||||
step_func=_record_random_obs_step,
|
|
||||||
)
|
|
||||||
wrapped_env = ReshapeObservationV0(env, (6, 2))
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
assert obs.shape == (6, 2)
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"])
|
|
||||||
assert obs.shape == (6, 2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rescale_observation():
|
|
||||||
"""Test the ``RescaleObservation`` wrapper"""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
observation_space=Box(
|
|
||||||
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
|
|
||||||
),
|
|
||||||
reset_func=_record_action_obs_reset,
|
|
||||||
step_func=_record_action_obs_step,
|
|
||||||
)
|
|
||||||
wrapped_env = RescaleObservationV0(
|
|
||||||
env,
|
|
||||||
min_obs=np.array([-5, 0], dtype=np.float32),
|
|
||||||
max_obs=np.array([5, 1], dtype=np.float32),
|
|
||||||
)
|
|
||||||
assert wrapped_env.observation_space == Box(
|
|
||||||
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
for sample_obs, expected_obs in (
|
|
||||||
(
|
|
||||||
np.array([0.5, 2.0], dtype=np.float32),
|
|
||||||
np.array([0.0, 0.5], dtype=np.float32),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.array([0.0, 1.0], dtype=np.float32),
|
|
||||||
np.array([-5.0, 0.0], dtype=np.float32),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
np.array([1.0, 3.0], dtype=np.float32),
|
|
||||||
np.array([5.0, 1.0], dtype=np.float32),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
assert sample_obs in env.observation_space
|
|
||||||
assert expected_obs in wrapped_env.observation_space
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset(options={"obs": sample_obs})
|
|
||||||
assert np.all(obs == expected_obs)
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(sample_obs)
|
|
||||||
assert np.all(obs == expected_obs)
|
|
||||||
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dtype_observation():
|
|
||||||
"""Test ``DtypeObservation`` that the"""
|
|
||||||
env = GenericTestEnv(
|
|
||||||
reset_func=_record_random_obs_reset, step_func=_record_random_obs_step
|
|
||||||
)
|
|
||||||
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
|
|
||||||
|
|
||||||
obs, info = wrapped_env.reset()
|
|
||||||
assert obs.dtype != info["obs"].dtype
|
|
||||||
assert obs.dtype == np.uint8
|
|
||||||
|
|
||||||
obs, _, _, _, info = wrapped_env.step(None)
|
|
||||||
assert obs.dtype != info["obs"].dtype
|
|
||||||
assert obs.dtype == np.uint8
|
|
||||||
|
@@ -4,14 +4,8 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.error import InvalidBound
|
from gymnasium.experimental.wrappers import LambdaRewardV0
|
||||||
from gymnasium.experimental.wrappers import ClipRewardV0, LambdaRewardV0
|
from tests.experimental.wrappers.utils import DISCRETE_ACTION, ENV_ID, NUM_ENVS, SEED
|
||||||
|
|
||||||
|
|
||||||
ENV_ID = "CartPole-v1"
|
|
||||||
DISCRETE_ACTION = 0
|
|
||||||
NUM_ENVS = 3
|
|
||||||
SEED = 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -54,57 +48,3 @@ def test_lambda_reward_within_vector(reward_fn, expected_reward):
|
|||||||
_, rew, _, _, _ = env.step(actions)
|
_, rew, _, _, _ = env.step(actions)
|
||||||
|
|
||||||
assert np.alltrue(rew == expected_reward)
|
assert np.alltrue(rew == expected_reward)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("lower_bound", "upper_bound", "expected_reward"),
|
|
||||||
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
|
|
||||||
)
|
|
||||||
def test_clip_reward(lower_bound, upper_bound, expected_reward):
|
|
||||||
"""Test reward clipping.
|
|
||||||
Test if reward is correctly clipped
|
|
||||||
accordingly to the input args.
|
|
||||||
"""
|
|
||||||
env = gym.make(ENV_ID)
|
|
||||||
env = ClipRewardV0(env, lower_bound, upper_bound)
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
|
|
||||||
|
|
||||||
assert rew == expected_reward
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("lower_bound", "upper_bound", "expected_reward"),
|
|
||||||
[(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)],
|
|
||||||
)
|
|
||||||
def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward):
|
|
||||||
"""Test reward clipping in vectorized environment.
|
|
||||||
Test if reward is correctly clipped
|
|
||||||
accordingly to the input args in a vectorized environment.
|
|
||||||
"""
|
|
||||||
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
|
|
||||||
|
|
||||||
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
|
|
||||||
env = ClipRewardV0(env, lower_bound, upper_bound)
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
|
|
||||||
_, rew, _, _, _ = env.step(actions)
|
|
||||||
|
|
||||||
assert np.alltrue(rew == expected_reward)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("lower_bound", "upper_bound"),
|
|
||||||
[(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))],
|
|
||||||
)
|
|
||||||
def test_clip_reward_incorrect_params(lower_bound, upper_bound):
|
|
||||||
"""Test reward clipping with incorrect params.
|
|
||||||
Test whether passing wrong params to clip_rewards
|
|
||||||
correctly raise an exception.
|
|
||||||
clip_rewards should raise an exception if, both low and upper
|
|
||||||
bound of reward are `None` or if upper bound is lower than lower bound.
|
|
||||||
"""
|
|
||||||
env = gym.make(ENV_ID)
|
|
||||||
|
|
||||||
with pytest.raises(InvalidBound):
|
|
||||||
env = ClipRewardV0(env, lower_bound, upper_bound)
|
|
||||||
|
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for NormalizeObservationV0."""
|
1
tests/experimental/wrappers/test_normalize_reward.py
Normal file
1
tests/experimental/wrappers/test_normalize_reward.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for NormalizeRewardV0."""
|
1
tests/experimental/wrappers/test_numpy_to_torch.py
Normal file
1
tests/experimental/wrappers/test_numpy_to_torch.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for NumpyToTorchV0."""
|
1
tests/experimental/wrappers/test_order_enforcing.py
Normal file
1
tests/experimental/wrappers/test_order_enforcing.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for OrderEnforcingV0."""
|
1
tests/experimental/wrappers/test_passive_env_checker.py
Normal file
1
tests/experimental/wrappers/test_passive_env_checker.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for PassiveEnvCheckerV0."""
|
1
tests/experimental/wrappers/test_pixel_observation.py
Normal file
1
tests/experimental/wrappers/test_pixel_observation.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for PixelObservationV0."""
|
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for RecordEpisodeStatisticsV0."""
|
1
tests/experimental/wrappers/test_record_video.py
Normal file
1
tests/experimental/wrappers/test_record_video.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for RecordVideoV0."""
|
1
tests/experimental/wrappers/test_render_collection.py
Normal file
1
tests/experimental/wrappers/test_render_collection.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test suite for RenderCollectionV0."""
|
38
tests/experimental/wrappers/test_rescale_action.py
Normal file
38
tests/experimental/wrappers/test_rescale_action.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Test suite for RescaleActionV0."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import RescaleActionV0
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from tests.experimental.wrappers.utils import record_action_step
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescale_action_wrapper():
|
||||||
|
"""Test that the action is rescale within a min / max bound."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
step_func=record_action_step,
|
||||||
|
action_space=Box(np.array([0, 1]), np.array([1, 3])),
|
||||||
|
)
|
||||||
|
wrapped_env = RescaleActionV0(
|
||||||
|
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
|
||||||
|
)
|
||||||
|
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
|
||||||
|
|
||||||
|
for sample_action, expected_action in (
|
||||||
|
(
|
||||||
|
np.array([0.0, 0.5], dtype=np.float32),
|
||||||
|
np.array([0.5, 2.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([-5.0, 0.0], dtype=np.float32),
|
||||||
|
np.array([0.0, 1.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([5.0, 1.0], dtype=np.float32),
|
||||||
|
np.array([1.0, 3.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert sample_action in wrapped_env.action_space
|
||||||
|
|
||||||
|
_, _, _, _, info = wrapped_env.step(sample_action)
|
||||||
|
assert np.all(info["action"] == expected_action)
|
55
tests/experimental/wrappers/test_rescale_observation.py
Normal file
55
tests/experimental/wrappers/test_rescale_observation.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Test suite for RescaleObservationV0."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import RescaleObservationV0
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from tests.experimental.wrappers.utils import (
|
||||||
|
check_obs,
|
||||||
|
record_action_as_obs_step,
|
||||||
|
record_obs_reset,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescale_observation():
|
||||||
|
"""Test the ``RescaleObservation`` wrapper."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(
|
||||||
|
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
|
||||||
|
),
|
||||||
|
reset_func=record_obs_reset,
|
||||||
|
step_func=record_action_as_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = RescaleObservationV0(
|
||||||
|
env,
|
||||||
|
min_obs=np.array([-5, 0], dtype=np.float32),
|
||||||
|
max_obs=np.array([5, 1], dtype=np.float32),
|
||||||
|
)
|
||||||
|
assert wrapped_env.observation_space == Box(
|
||||||
|
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample_obs, expected_obs in (
|
||||||
|
(
|
||||||
|
np.array([0.5, 2.0], dtype=np.float32),
|
||||||
|
np.array([0.0, 0.5], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.0, 1.0], dtype=np.float32),
|
||||||
|
np.array([-5.0, 0.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([1.0, 3.0], dtype=np.float32),
|
||||||
|
np.array([5.0, 1.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert sample_obs in env.observation_space
|
||||||
|
assert expected_obs in wrapped_env.observation_space
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset(options={"obs": sample_obs})
|
||||||
|
assert np.all(obs == expected_obs)
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(sample_obs)
|
||||||
|
assert np.all(obs == expected_obs)
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
27
tests/experimental/wrappers/test_reshape_observation.py
Normal file
27
tests/experimental/wrappers/test_reshape_observation.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Test suite for ReshapeObservationv0."""
|
||||||
|
from gymnasium.experimental.wrappers import ReshapeObservationV0
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from tests.experimental.wrappers.utils import (
|
||||||
|
check_obs,
|
||||||
|
record_random_obs_reset,
|
||||||
|
record_random_obs_step,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_reshape_observation_wrapper():
|
||||||
|
"""Test the ``ReshapeObservation`` wrapper."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(0, 1, shape=(2, 3, 2)),
|
||||||
|
reset_func=record_random_obs_reset,
|
||||||
|
step_func=record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = ReshapeObservationV0(env, (6, 2))
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (6, 2)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (6, 2)
|
27
tests/experimental/wrappers/test_resize_observation.py
Normal file
27
tests/experimental/wrappers/test_resize_observation.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Test suite for ResizeObservationV0."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import ResizeObservationV0
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from tests.experimental.wrappers.utils import (
|
||||||
|
check_obs,
|
||||||
|
record_random_obs_reset,
|
||||||
|
record_random_obs_step,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_observation_wrapper():
|
||||||
|
"""Test the ``ResizeObservation`` that the observation has changed size."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
|
||||||
|
reset_func=record_random_obs_reset,
|
||||||
|
step_func=record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = ResizeObservationV0(env, (25, 25))
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
check_obs(env, wrapped_env, obs, info["obs"])
|
@@ -1,43 +1,34 @@
|
|||||||
"""Test suite for StickyActionV0."""
|
"""Test suite for StickyActionV0."""
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.error import InvalidProbability
|
from gymnasium.error import InvalidProbability
|
||||||
from gymnasium.experimental.wrappers import StickyActionV0
|
from gymnasium.experimental.wrappers import StickyActionV0
|
||||||
|
from tests.experimental.wrappers.utils import NUM_STEPS, record_action_as_obs_step
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
SEED = 42
|
|
||||||
|
|
||||||
DELAY = 3
|
|
||||||
NUM_STEPS = 10
|
|
||||||
|
|
||||||
|
|
||||||
def step_fn(self, action):
|
|
||||||
return action
|
|
||||||
|
|
||||||
|
|
||||||
def test_sticky_action():
|
def test_sticky_action():
|
||||||
|
"""Tests the sticky action wrapper."""
|
||||||
env = StickyActionV0(
|
env = StickyActionV0(
|
||||||
GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5
|
GenericTestEnv(step_func=record_action_as_obs_step),
|
||||||
|
repeat_action_probability=0.5,
|
||||||
)
|
)
|
||||||
env.reset(seed=SEED)
|
|
||||||
env.action_space.seed(SEED)
|
|
||||||
|
|
||||||
previous_action = None
|
previous_action = None
|
||||||
for _ in range(NUM_STEPS):
|
for _ in range(NUM_STEPS):
|
||||||
input_action = env.action_space.sample()
|
input_action = env.action_space.sample()
|
||||||
executed_action = env.step(input_action)
|
executed_action, _, _, _, _ = env.step(input_action)
|
||||||
|
|
||||||
if executed_action != input_action:
|
assert np.all(executed_action == input_action) or np.all(
|
||||||
assert executed_action == previous_action
|
executed_action == previous_action
|
||||||
else:
|
)
|
||||||
assert executed_action == input_action
|
previous_action = executed_action
|
||||||
|
|
||||||
previous_action = input_action
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
||||||
def test_sticky_action_raise(repeat_action_probability):
|
def test_sticky_action_raise(repeat_action_probability):
|
||||||
|
"""Tests the stick action wrapper with probabilities that should raise an error."""
|
||||||
with pytest.raises(InvalidProbability):
|
with pytest.raises(InvalidProbability):
|
||||||
StickyActionV0(
|
StickyActionV0(
|
||||||
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
@@ -1,19 +1,10 @@
|
|||||||
"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation."""
|
"""Test suite for TimeAwareObservationV0."""
|
||||||
|
|
||||||
import numpy as np
|
from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0
|
|
||||||
from gymnasium.spaces import Box, Dict, Tuple
|
from gymnasium.spaces import Box, Dict, Tuple
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
NUM_STEPS = 20
|
|
||||||
SEED = 0
|
|
||||||
|
|
||||||
DELAY = 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_time_aware_observation_wrapper():
|
def test_time_aware_observation_wrapper():
|
||||||
"""Tests the time aware observation wrapper."""
|
"""Tests the time aware observation wrapper."""
|
||||||
# Test the environment observation space with Dict, Tuple and other
|
# Test the environment observation space with Dict, Tuple and other
|
||||||
@@ -60,30 +51,3 @@ def test_time_aware_observation_wrapper():
|
|||||||
reset_obs, _ = wrapped_env.reset()
|
reset_obs, _ = wrapped_env.reset()
|
||||||
step_obs, _, _, _, _ = wrapped_env.step(None)
|
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||||
assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01
|
assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01
|
||||||
|
|
||||||
|
|
||||||
def test_delay_observation_wrapper():
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
env.action_space.seed(SEED)
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
|
|
||||||
undelayed_observations = []
|
|
||||||
for _ in range(NUM_STEPS):
|
|
||||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
||||||
undelayed_observations.append(obs)
|
|
||||||
|
|
||||||
env = DelayObservationV0(env, delay=DELAY)
|
|
||||||
env.action_space.seed(SEED)
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
|
|
||||||
delayed_observations = []
|
|
||||||
for i in range(NUM_STEPS):
|
|
||||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
||||||
delayed_observations.append(obs)
|
|
||||||
if i < DELAY - 1:
|
|
||||||
assert np.all(obs == 0)
|
|
||||||
|
|
||||||
undelayed_observations = np.array(undelayed_observations)
|
|
||||||
delayed_observations = np.array(delayed_observations)
|
|
||||||
|
|
||||||
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])
|
|
69
tests/experimental/wrappers/utils.py
Normal file
69
tests/experimental/wrappers/utils.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Utility functions for testing the experimental wrappers."""
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
SEED = 42
|
||||||
|
ENV_ID = "CartPole-v1"
|
||||||
|
DISCRETE_ACTION = 0
|
||||||
|
NUM_ENVS = 3
|
||||||
|
NUM_STEPS = 20
|
||||||
|
DELAY = 3
|
||||||
|
|
||||||
|
|
||||||
|
def record_obs_reset(self: gym.Env, seed=None, options: dict = None):
|
||||||
|
"""Records and uses an observation passed through options."""
|
||||||
|
return options["obs"], {"obs": options["obs"]}
|
||||||
|
|
||||||
|
|
||||||
|
def record_random_obs_reset(self: gym.Env, seed=None, options=None):
|
||||||
|
"""Records random observation generated by the environment."""
|
||||||
|
obs = self.observation_space.sample()
|
||||||
|
return obs, {"obs": obs}
|
||||||
|
|
||||||
|
|
||||||
|
def record_action_step(self: gym.Env, action):
|
||||||
|
"""Records the actions passed to the environment."""
|
||||||
|
return 0, 0, False, False, {"action": action}
|
||||||
|
|
||||||
|
|
||||||
|
def record_random_obs_step(self: gym.Env, action):
|
||||||
|
"""Records the observation generated by the environment."""
|
||||||
|
obs = self.observation_space.sample()
|
||||||
|
return obs, 0, False, False, {"obs": obs}
|
||||||
|
|
||||||
|
|
||||||
|
def record_action_as_obs_step(self: gym.Env, action):
|
||||||
|
"""Uses the action as the observation."""
|
||||||
|
return action, 0, False, False, {"obs": action}
|
||||||
|
|
||||||
|
|
||||||
|
def check_obs(
|
||||||
|
env: gym.Env,
|
||||||
|
wrapped_env: gym.Wrapper,
|
||||||
|
transformed_obs,
|
||||||
|
original_obs,
|
||||||
|
strict: bool = True,
|
||||||
|
):
|
||||||
|
"""Checks that the original and transformed observations using the environment and wrapped environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The base environment
|
||||||
|
wrapped_env: The wrapped environment
|
||||||
|
transformed_obs: The transformed observation by the wrapped environment
|
||||||
|
original_obs: The original observation by the base environment.
|
||||||
|
strict: If to check that the observations aren't contained in the other environment.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
transformed_obs in wrapped_env.observation_space
|
||||||
|
), f"{transformed_obs}, {wrapped_env.observation_space}"
|
||||||
|
assert (
|
||||||
|
original_obs in env.observation_space
|
||||||
|
), f"{original_obs}, {env.observation_space}"
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
assert (
|
||||||
|
transformed_obs not in env.observation_space
|
||||||
|
), f"{transformed_obs}, {env.observation_space}"
|
||||||
|
assert (
|
||||||
|
original_obs not in wrapped_env.observation_space
|
||||||
|
), f"{original_obs}, {wrapped_env.observation_space}"
|
@@ -23,36 +23,44 @@ from tests.testing_env import GenericTestEnv
|
|||||||
|
|
||||||
|
|
||||||
class ArgumentEnv(Env):
|
class ArgumentEnv(Env):
|
||||||
|
"""Testing environment that records the number of times the environment is created."""
|
||||||
|
|
||||||
observation_space = spaces.Box(low=0, high=1, shape=(1,))
|
observation_space = spaces.Box(low=0, high=1, shape=(1,))
|
||||||
action_space = spaces.Box(low=0, high=1, shape=(1,))
|
action_space = spaces.Box(low=0, high=1, shape=(1,))
|
||||||
calls = 0
|
calls = 0
|
||||||
|
|
||||||
def __init__(self, arg):
|
def __init__(self, arg: Any):
|
||||||
|
"""Constructor."""
|
||||||
self.calls += 1
|
self.calls += 1
|
||||||
self.arg = arg
|
self.arg = arg
|
||||||
|
|
||||||
|
|
||||||
class UnittestEnv(Env):
|
class UnittestEnv(Env):
|
||||||
|
"""Example testing environment."""
|
||||||
|
|
||||||
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
||||||
action_space = spaces.Discrete(3)
|
action_space = spaces.Discrete(3)
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
|
"""Resets the environment."""
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
return self.observation_space.sample(), {"info": "dummy"}
|
return self.observation_space.sample(), {"info": "dummy"}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""Steps through the environment."""
|
||||||
observation = self.observation_space.sample() # Dummy observation
|
observation = self.observation_space.sample() # Dummy observation
|
||||||
return (observation, 0.0, False, {})
|
return observation, 0.0, False, {}
|
||||||
|
|
||||||
|
|
||||||
class UnknownSpacesEnv(Env):
|
class UnknownSpacesEnv(Env):
|
||||||
"""This environment defines its observation & action spaces only
|
"""This environment defines its observation & action spaces only after the first call to reset.
|
||||||
after the first call to reset. Although this pattern is sometimes
|
|
||||||
necessary when implementing a new environment (e.g. if it depends
|
Although this pattern is sometimes necessary when implementing a new environment (e.g. if it depends
|
||||||
on external resources), it is not encouraged.
|
on external resources), it is not encouraged.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
|
"""Resets the environment."""
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
||||||
@@ -61,25 +69,27 @@ class UnknownSpacesEnv(Env):
|
|||||||
return self.observation_space.sample(), {} # Dummy observation with info
|
return self.observation_space.sample(), {} # Dummy observation with info
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""Steps through the environment."""
|
||||||
observation = self.observation_space.sample() # Dummy observation
|
observation = self.observation_space.sample() # Dummy observation
|
||||||
return (observation, 0.0, False, {})
|
return observation, 0.0, False, {}
|
||||||
|
|
||||||
|
|
||||||
class OldStyleEnv(Env):
|
class OldStyleEnv(Env):
|
||||||
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)"""
|
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)."""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
"""Resets the environment."""
|
||||||
super().reset()
|
super().reset()
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""Steps through the environment."""
|
||||||
return 0, 0, False, {}
|
return 0, 0, False, {}
|
||||||
|
|
||||||
|
|
||||||
class NewPropertyWrapper(Wrapper):
|
class NewPropertyWrapper(Wrapper):
|
||||||
|
"""Wrapper that tests setting a property."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env,
|
env,
|
||||||
@@ -88,6 +98,15 @@ class NewPropertyWrapper(Wrapper):
|
|||||||
reward_range=None,
|
reward_range=None,
|
||||||
metadata=None,
|
metadata=None,
|
||||||
):
|
):
|
||||||
|
"""New property wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
observation_space: The observation space
|
||||||
|
action_space: The action space
|
||||||
|
reward_range: The reward range
|
||||||
|
metadata: The environment metadata
|
||||||
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
if observation_space is not None:
|
if observation_space is not None:
|
||||||
# Only set the observation space if not None to test property forwarding
|
# Only set the observation space if not None to test property forwarding
|
||||||
@@ -101,6 +120,7 @@ class NewPropertyWrapper(Wrapper):
|
|||||||
|
|
||||||
|
|
||||||
def test_env_instantiation():
|
def test_env_instantiation():
|
||||||
|
"""Tests the environment instantiation using ArgumentEnv."""
|
||||||
# This looks like a pretty trivial, but given our usage of
|
# This looks like a pretty trivial, but given our usage of
|
||||||
# __new__, it's worth having.
|
# __new__, it's worth having.
|
||||||
env = ArgumentEnv("arg")
|
env = ArgumentEnv("arg")
|
||||||
@@ -129,6 +149,7 @@ properties = [
|
|||||||
@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv])
|
@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv])
|
||||||
@pytest.mark.parametrize("props", properties)
|
@pytest.mark.parametrize("props", properties)
|
||||||
def test_wrapper_property_forwarding(class_, props):
|
def test_wrapper_property_forwarding(class_, props):
|
||||||
|
"""Tests wrapper property forwarding."""
|
||||||
env = class_()
|
env = class_()
|
||||||
env = NewPropertyWrapper(env, **props)
|
env = NewPropertyWrapper(env, **props)
|
||||||
|
|
||||||
@@ -147,6 +168,7 @@ def test_wrapper_property_forwarding(class_, props):
|
|||||||
|
|
||||||
|
|
||||||
def test_compatibility_with_old_style_env():
|
def test_compatibility_with_old_style_env():
|
||||||
|
"""Test compatibility with old style environment."""
|
||||||
env = OldStyleEnv()
|
env = OldStyleEnv()
|
||||||
env = OrderEnforcing(env)
|
env = OrderEnforcing(env)
|
||||||
env = TimeLimit(env)
|
env = TimeLimit(env)
|
||||||
@@ -158,13 +180,17 @@ def test_compatibility_with_old_style_env():
|
|||||||
|
|
||||||
|
|
||||||
class ExampleEnv(Env):
|
class ExampleEnv(Env):
|
||||||
|
"""Example testing environment."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""Constructor for example environment."""
|
||||||
self.observation_space = Box(0, 1)
|
self.observation_space = Box(0, 1)
|
||||||
self.action_space = Box(0, 1)
|
self.action_space = Box(0, 1)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: ActType
|
self, action: ActType
|
||||||
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
|
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
|
||||||
|
"""Steps through the environment."""
|
||||||
return 0, 0, False, False, {}
|
return 0, 0, False, False, {}
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
@@ -173,10 +199,12 @@ class ExampleEnv(Env):
|
|||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Tuple[ObsType, dict]:
|
) -> Tuple[ObsType, dict]:
|
||||||
|
"""Resets the environment."""
|
||||||
return 0, {}
|
return 0, {}
|
||||||
|
|
||||||
|
|
||||||
def test_gymnasium_env():
|
def test_gymnasium_env():
|
||||||
|
"""Tests a gymnasium environment."""
|
||||||
env = ExampleEnv()
|
env = ExampleEnv()
|
||||||
|
|
||||||
assert env.metadata == {"render_modes": []}
|
assert env.metadata == {"render_modes": []}
|
||||||
@@ -187,7 +215,10 @@ def test_gymnasium_env():
|
|||||||
|
|
||||||
|
|
||||||
class ExampleWrapper(Wrapper):
|
class ExampleWrapper(Wrapper):
|
||||||
|
"""An example testing wrapper."""
|
||||||
|
|
||||||
def __init__(self, env: Env[ObsType, ActType]):
|
def __init__(self, env: Env[ObsType, ActType]):
|
||||||
|
"""Constructor that sets the reward."""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
self.new_reward = 3
|
self.new_reward = 3
|
||||||
@@ -195,11 +226,13 @@ class ExampleWrapper(Wrapper):
|
|||||||
def reset(
|
def reset(
|
||||||
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
|
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[WrapperObsType, Dict[str, Any]]:
|
) -> Tuple[WrapperObsType, Dict[str, Any]]:
|
||||||
|
"""Resets the environment ."""
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: WrapperActType
|
self, action: WrapperActType
|
||||||
) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]:
|
) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]:
|
||||||
|
"""Steps through the environment."""
|
||||||
obs, reward, termination, truncation, info = self.env.step(action)
|
obs, reward, termination, truncation, info = self.env.step(action)
|
||||||
return obs, self.new_reward, termination, truncation, info
|
return obs, self.new_reward, termination, truncation, info
|
||||||
|
|
||||||
@@ -209,6 +242,7 @@ class ExampleWrapper(Wrapper):
|
|||||||
|
|
||||||
|
|
||||||
def test_gymnasium_wrapper():
|
def test_gymnasium_wrapper():
|
||||||
|
"""Tests the gymnasium wrapper works as expected."""
|
||||||
env = ExampleEnv()
|
env = ExampleEnv()
|
||||||
wrapper_env = ExampleWrapper(env)
|
wrapper_env = ExampleWrapper(env)
|
||||||
|
|
||||||
@@ -250,21 +284,31 @@ def test_gymnasium_wrapper():
|
|||||||
|
|
||||||
|
|
||||||
class ExampleRewardWrapper(RewardWrapper):
|
class ExampleRewardWrapper(RewardWrapper):
|
||||||
|
"""Example reward wrapper for testing."""
|
||||||
|
|
||||||
def reward(self, reward: SupportsFloat) -> SupportsFloat:
|
def reward(self, reward: SupportsFloat) -> SupportsFloat:
|
||||||
|
"""Reward function."""
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
class ExampleObservationWrapper(ObservationWrapper):
|
class ExampleObservationWrapper(ObservationWrapper):
|
||||||
|
"""Example observation wrapper for testing."""
|
||||||
|
|
||||||
def observation(self, observation: ObsType) -> ObsType:
|
def observation(self, observation: ObsType) -> ObsType:
|
||||||
|
"""Observation function."""
|
||||||
return np.array([1])
|
return np.array([1])
|
||||||
|
|
||||||
|
|
||||||
class ExampleActionWrapper(ActionWrapper):
|
class ExampleActionWrapper(ActionWrapper):
|
||||||
|
"""Example action wrapper for testing."""
|
||||||
|
|
||||||
def action(self, action: ActType) -> ActType:
|
def action(self, action: ActType) -> ActType:
|
||||||
|
"""Action function."""
|
||||||
return np.array([1])
|
return np.array([1])
|
||||||
|
|
||||||
|
|
||||||
def test_wrapper_types():
|
def test_wrapper_types():
|
||||||
|
"""Tests the observation, action and reward wrapper examples."""
|
||||||
env = GenericTestEnv()
|
env = GenericTestEnv()
|
||||||
|
|
||||||
reward_env = ExampleRewardWrapper(env)
|
reward_env = ExampleRewardWrapper(env)
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
"""Provides a generic testing environment for use in tests with custom reset, step and render functions."""
|
"""Provides a generic testing environment for use in tests with custom reset, step and render functions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import types
|
import types
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
@@ -11,21 +13,21 @@ from gymnasium.envs.registration import EnvSpec
|
|||||||
def basic_reset_func(
|
def basic_reset_func(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
options: Optional[dict] = None,
|
options: dict | None = None,
|
||||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
) -> ObsType | tuple[ObsType, dict]:
|
||||||
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
||||||
super(GenericTestEnv, self).reset(seed=seed)
|
super(GenericTestEnv, self).reset(seed=seed)
|
||||||
self.observation_space.seed(seed)
|
self.observation_space.seed(seed)
|
||||||
return self.observation_space.sample(), {"options": options}
|
return self.observation_space.sample(), {"options": options}
|
||||||
|
|
||||||
|
|
||||||
def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
def new_step_func(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
|
||||||
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
|
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
|
||||||
return self.observation_space.sample(), 0, False, False, {}
|
return self.observation_space.sample(), 0, False, False, {}
|
||||||
|
|
||||||
|
|
||||||
def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
def old_step_func(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
||||||
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
|
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
|
||||||
return self.observation_space.sample(), 0, False, {}
|
return self.observation_space.sample(), 0, False, {}
|
||||||
|
|
||||||
@@ -46,12 +48,24 @@ class GenericTestEnv(gym.Env):
|
|||||||
reset_func: callable = basic_reset_func,
|
reset_func: callable = basic_reset_func,
|
||||||
step_func: callable = new_step_func,
|
step_func: callable = new_step_func,
|
||||||
render_func: callable = basic_render_func,
|
render_func: callable = basic_render_func,
|
||||||
metadata: Dict[str, Any] = {"render_modes": []},
|
metadata: dict[str, Any] = {"render_modes": []},
|
||||||
render_mode: Optional[str] = None,
|
render_mode: str | None = None,
|
||||||
spec: EnvSpec = EnvSpec(
|
spec: EnvSpec = EnvSpec(
|
||||||
"TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100
|
"TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
"""Generic testing environment constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_space: The environment action space
|
||||||
|
observation_space: The environment observation space
|
||||||
|
reset_func: The environment reset function
|
||||||
|
step_func: The environment step function
|
||||||
|
render_func: The environment render function
|
||||||
|
metadata: The environment metadata
|
||||||
|
render_mode: The render mode of the environment
|
||||||
|
spec: The environment spec
|
||||||
|
"""
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
self.spec = spec
|
self.spec = spec
|
||||||
@@ -71,14 +85,17 @@ class GenericTestEnv(gym.Env):
|
|||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
options: Optional[dict] = None,
|
options: dict | None = None,
|
||||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
) -> ObsType | tuple[ObsType, dict]:
|
||||||
|
"""Resets the environment."""
|
||||||
# If you need a default working reset function, use `basic_reset_fn` above
|
# If you need a default working reset function, use `basic_reset_fn` above
|
||||||
raise NotImplementedError("TestingEnv reset_fn is not set.")
|
raise NotImplementedError("TestingEnv reset_fn is not set.")
|
||||||
|
|
||||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict[str, Any]]:
|
||||||
|
"""Steps through the environment."""
|
||||||
raise NotImplementedError("TestingEnv step_fn is not set.")
|
raise NotImplementedError("TestingEnv step_fn is not set.")
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
"""Renders the environment."""
|
||||||
raise NotImplementedError("testingEnv render_fn is not set.")
|
raise NotImplementedError("testingEnv render_fn is not set.")
|
||||||
|
Reference in New Issue
Block a user