mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Update experimental wrappers (#176)
This commit is contained in:
@@ -22,9 +22,9 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: codespell
|
- id: codespell
|
||||||
args:
|
args:
|
||||||
- --ignore-words-list=nd,reacher,thist,ths, ure, referenc,wile
|
- --ignore-words-list=nd,reacher,thist,ths,ure,referenc,wile
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 6.0.0
|
rev: 5.0.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
args:
|
args:
|
||||||
|
@@ -27,8 +27,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
|
|
||||||
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
|
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
|
||||||
|
|
||||||
### Lambda Observation Wrappers
|
### Observation Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. py:currentmodule:: gymnasium
|
.. py:currentmodule:: gymnasium
|
||||||
|
|
||||||
@@ -44,61 +43,60 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
- VectorLambdaObservation
|
- VectorLambdaObservation
|
||||||
- No
|
- No
|
||||||
* - :class:`wrappers.FilterObservation`
|
* - :class:`wrappers.FilterObservation`
|
||||||
- :class:`experimental.wrappers.FilterObservation`
|
- :class:`experimental.wrappers.FilterObservationV0`
|
||||||
- VectorFilterObservation (*)
|
- VectorFilterObservation (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - :class:`wrappers.FlattenObservation`
|
* - :class:`wrappers.FlattenObservation`
|
||||||
- `:class:`experimental.wrappers.FlattenObservation`
|
- :class:`experimental.wrappers.FlattenObservationV0`
|
||||||
- VectorFlattenObservation (*)
|
- VectorFlattenObservation (*)
|
||||||
- No
|
- No
|
||||||
* - :class:`wrappers.GrayScaleObservation`
|
* - :class:`wrappers.GrayScaleObservation`
|
||||||
- `:class:`experimental.wrappers.GrayscaleObservation`
|
- :class:`experimental.wrappers.GrayscaleObservationV0`
|
||||||
- VectorGrayscaleObservation (*)
|
- VectorGrayscaleObservation (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - :class:`wrappers.ResizeObservation`
|
* - :class:`wrappers.ResizeObservation`
|
||||||
- :class:`experimental.wrappers.ResizeObservation`
|
- :class:`experimental.wrappers.ResizeObservationV0`
|
||||||
- VectorResizeObservation (*)
|
- VectorResizeObservation (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- :class:`experimental.wrappers.ReshapeObservation`
|
- :class:`experimental.wrappers.ReshapeObservationV0`
|
||||||
- VectorReshapeObservation (*)
|
- VectorReshapeObservation (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- :class:`experimental.wrappers.RescaleObservation`
|
- :class:`experimental.wrappers.RescaleObservationV0`
|
||||||
- VectorRescaleObservation (*)
|
- VectorRescaleObservation (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- :class:`experimental.wrappers.DtypeObservation`
|
- :class:`experimental.wrappers.DtypeObservationV0`
|
||||||
- VectorDtypeObservation (*)
|
- VectorDtypeObservation (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - :class:`wrappers.PixelObservationWrapper`
|
* - :class:`wrappers.PixelObservationWrapper`
|
||||||
- PixelObservation
|
- PixelObservation
|
||||||
- VectorPixelObservation
|
- VectorPixelObservation
|
||||||
- No
|
- No
|
||||||
* - :class:`NormalizeObservation`
|
* - :class:`wrappers.NormalizeObservation`
|
||||||
- NormalizeObservation
|
- NormalizeObservation
|
||||||
- VectorNormalizeObservation
|
- VectorNormalizeObservation
|
||||||
- No
|
- No
|
||||||
* - :class:`TimeAwareObservation`
|
* - :class:`wrappers.TimeAwareObservation`
|
||||||
- TimeAwareObservation
|
- :class:`experimental.wrappers.TimeAwareObservationV0`
|
||||||
- VectorTimeAwareObservation
|
- VectorTimeAwareObservation
|
||||||
- No
|
- No
|
||||||
* - :class:`FrameStack`
|
* - :class:`wrappers.FrameStack`
|
||||||
- FrameStackObservation
|
- FrameStackObservation
|
||||||
- VectorFrameStackObservation
|
- VectorFrameStackObservation
|
||||||
- No
|
- No
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- DelayObservation
|
- :class:`experimental.wrappers.DelayObservationV0`
|
||||||
- VectorDelayObservation
|
- VectorDelayObservation
|
||||||
- No
|
- No
|
||||||
* - :class:`AtariPreprocessing`
|
* - :class:`wrappers.AtariPreprocessing`
|
||||||
- AtariPreprocessing
|
- AtariPreprocessing
|
||||||
- Not Implemented
|
- Not Implemented
|
||||||
- No
|
- No
|
||||||
```
|
```
|
||||||
|
|
||||||
### Lambda Action Wrappers
|
### Action Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. py:currentmodule:: gymnasium
|
.. py:currentmodule:: gymnasium
|
||||||
|
|
||||||
@@ -114,25 +112,20 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
- VectorLambdaAction
|
- VectorLambdaAction
|
||||||
- No
|
- No
|
||||||
* - :class:`wrappers.ClipAction`
|
* - :class:`wrappers.ClipAction`
|
||||||
- ClipAction
|
- :class:`experimental.wrappers.ClipActionV0`
|
||||||
- VectorClipAction (*)
|
- VectorClipAction (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - :class:`wrappers.RescaleAction`
|
* - :class:`wrappers.RescaleAction`
|
||||||
- RescaleAction
|
- :class:`experimental.wrappers.RescaleActionV0`
|
||||||
- VectorRescaleAction (*)
|
- VectorRescaleAction (*)
|
||||||
- Yes
|
- Yes
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- NanAction
|
- :class:`experimental.wrappers.StickyActionV0`
|
||||||
- VectorNanAction (*)
|
|
||||||
- Yes
|
|
||||||
* - Not Implemented
|
|
||||||
- StickyAction
|
|
||||||
- VectorStickyAction
|
- VectorStickyAction
|
||||||
- No
|
- No
|
||||||
```
|
```
|
||||||
|
|
||||||
### Lambda Reward Wrappers
|
### Reward Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. py:currentmodule:: gymnasium
|
.. py:currentmodule:: gymnasium
|
||||||
|
|
||||||
@@ -175,7 +168,7 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
- VectorPassiveEnvChecker
|
- VectorPassiveEnvChecker
|
||||||
* - :class:`wrappers.OrderEnforcing`
|
* - :class:`wrappers.OrderEnforcing`
|
||||||
- OrderEnforcing
|
- OrderEnforcing
|
||||||
- VectorOrderEnforcing (*)
|
- VectorOrderEnforcing
|
||||||
* - :class:`wrappers.EnvCompatibility`
|
* - :class:`wrappers.EnvCompatibility`
|
||||||
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
||||||
- Not Implemented
|
- Not Implemented
|
||||||
@@ -189,10 +182,10 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
- HumanRendering
|
- HumanRendering
|
||||||
- Not Implemented
|
- Not Implemented
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- :class:`experimental.wrappers.JaxToNumpy`
|
- :class:`experimental.wrappers.JaxToNumpyV0`
|
||||||
- VectorJaxToNumpy (*)
|
- VectorJaxToNumpy (*)
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- :class:`experimental.wrappers.JaxToTorch`
|
- :class:`experimental.wrappers.JaxToTorchV0`
|
||||||
- VectorJaxToTorch (*)
|
- VectorJaxToTorch (*)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@@ -12,9 +12,6 @@ title: Functional
|
|||||||
.. autofunction:: gymnasium.experimental.FuncEnv.initial
|
.. autofunction:: gymnasium.experimental.FuncEnv.initial
|
||||||
.. autofunction:: gymnasium.experimental.FuncEnv.transition
|
.. autofunction:: gymnasium.experimental.FuncEnv.transition
|
||||||
|
|
||||||
.. autofunction:: gymnasium.experimental.FuncEnv.observation
|
|
||||||
.. autofunction:: gymnasium.experimental.FuncEnv.initial
|
|
||||||
|
|
||||||
.. autofunction:: gymnasium.experimental.FuncEnv.observation
|
.. autofunction:: gymnasium.experimental.FuncEnv.observation
|
||||||
.. autofunction:: gymnasium.experimental.FuncEnv.reward
|
.. autofunction:: gymnasium.experimental.FuncEnv.reward
|
||||||
.. autofunction:: gymnasium.experimental.FuncEnv.terminal
|
.. autofunction:: gymnasium.experimental.FuncEnv.terminal
|
||||||
@@ -33,4 +30,8 @@ title: Functional
|
|||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
|
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
|
||||||
|
|
||||||
|
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.reset
|
||||||
|
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.step
|
||||||
|
... autofunction:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv.render
|
||||||
```
|
```
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
# Wrappers
|
# Wrappers
|
||||||
|
|
||||||
## Lambda Observation Wrappers
|
## Observation Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
|
||||||
@@ -11,24 +11,6 @@
|
|||||||
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
|
||||||
```
|
|
||||||
|
|
||||||
## Lambda Action Wrappers
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
|
|
||||||
```
|
|
||||||
|
|
||||||
## Lambda Reward Wrappers
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
|
||||||
```
|
|
||||||
|
|
||||||
## Observation Wrappers
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
||||||
```
|
```
|
||||||
@@ -36,11 +18,22 @@
|
|||||||
## Action Wrappers
|
## Action Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0
|
||||||
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Reward Wrappers
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||||
|
```
|
||||||
|
|
||||||
## Common Wrappers
|
## Common Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
|
||||||
```
|
```
|
||||||
|
@@ -27,7 +27,7 @@ __all__ = [
|
|||||||
"register",
|
"register",
|
||||||
"registry",
|
"registry",
|
||||||
"pprint_registry",
|
"pprint_registry",
|
||||||
# root files
|
# module folders
|
||||||
"envs",
|
"envs",
|
||||||
"spaces",
|
"spaces",
|
||||||
"utils",
|
"utils",
|
||||||
|
@@ -1,2 +1,2 @@
|
|||||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
|
||||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
"""Root __init__ of the gym experimental wrappers."""
|
"""Root __init__ of the gym experimental wrappers."""
|
||||||
|
|
||||||
|
|
||||||
|
from gymnasium.experimental import functional, wrappers
|
||||||
from gymnasium.experimental.functional import FuncEnv
|
from gymnasium.experimental.functional import FuncEnv
|
||||||
|
|
||||||
|
|
||||||
@@ -8,6 +9,8 @@ __all__ = [
|
|||||||
# Functional
|
# Functional
|
||||||
"FuncEnv",
|
"FuncEnv",
|
||||||
"functional",
|
"functional",
|
||||||
# Wrapper
|
# Wrappers
|
||||||
"wrappers",
|
"wrappers",
|
||||||
|
# Vector
|
||||||
|
# "vector",
|
||||||
]
|
]
|
||||||
|
@@ -10,31 +10,59 @@ from gymnasium.experimental.wrappers.lambda_action import (
|
|||||||
ClipActionV0,
|
ClipActionV0,
|
||||||
RescaleActionV0,
|
RescaleActionV0,
|
||||||
)
|
)
|
||||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
from gymnasium.experimental.wrappers.lambda_observations import (
|
||||||
|
LambdaObservationV0,
|
||||||
|
FilterObservationV0,
|
||||||
|
FlattenObservationV0,
|
||||||
|
GrayscaleObservationV0,
|
||||||
|
ResizeObservationV0,
|
||||||
|
ReshapeObservationV0,
|
||||||
|
RescaleObservationV0,
|
||||||
|
DtypeObservationV0,
|
||||||
|
)
|
||||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||||
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
|
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
|
||||||
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
|
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
|
||||||
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
|
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
||||||
from gymnasium.experimental.wrappers.time_aware_observation import (
|
from gymnasium.experimental.wrappers.stateful_observation import (
|
||||||
TimeAwareObservationV0,
|
TimeAwareObservationV0,
|
||||||
|
DelayObservationV0,
|
||||||
)
|
)
|
||||||
from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ArgType",
|
# --- Observation wrappers ---
|
||||||
# Lambda Action
|
"LambdaObservationV0",
|
||||||
|
"FilterObservationV0",
|
||||||
|
"FlattenObservationV0",
|
||||||
|
"GrayscaleObservationV0",
|
||||||
|
"ResizeObservationV0",
|
||||||
|
"ReshapeObservationV0",
|
||||||
|
"RescaleObservationV0",
|
||||||
|
"DtypeObservationV0",
|
||||||
|
# "PixelObservationV0",
|
||||||
|
# "NormalizeObservationV0",
|
||||||
|
"TimeAwareObservationV0",
|
||||||
|
# "FrameStackV0",
|
||||||
|
"DelayObservationV0",
|
||||||
|
# "AtariPreprocessingV0"
|
||||||
|
# --- Action Wrappers ---
|
||||||
"LambdaActionV0",
|
"LambdaActionV0",
|
||||||
"StickyActionV0",
|
|
||||||
"ClipActionV0",
|
"ClipActionV0",
|
||||||
"RescaleActionV0",
|
"RescaleActionV0",
|
||||||
# Lambda Observation
|
# "NanAction",
|
||||||
"LambdaObservationV0",
|
"StickyActionV0",
|
||||||
"DelayObservationV0",
|
# --- Reward wrappers ---
|
||||||
"TimeAwareObservationV0",
|
|
||||||
# Lambda Reward
|
|
||||||
"LambdaRewardV0",
|
"LambdaRewardV0",
|
||||||
"ClipRewardV0",
|
"ClipRewardV0",
|
||||||
# Jax conversion wrappers
|
# "RescaleRewardV0",
|
||||||
|
# "NormalizeRewardV0",
|
||||||
|
# --- Common ---
|
||||||
|
# "AutoReset",
|
||||||
|
# "PassiveEnvChecker",
|
||||||
|
# "OrderEnforcing",
|
||||||
|
# "RecordEpisodeStatistics",
|
||||||
|
# "RenderCollection",
|
||||||
|
# "HumanRendering",
|
||||||
"JaxToNumpyV0",
|
"JaxToNumpyV0",
|
||||||
"JaxToTorchV0",
|
"JaxToTorchV0",
|
||||||
]
|
]
|
||||||
|
@@ -1,35 +0,0 @@
|
|||||||
"""Wrapper for delaying the returned observation."""
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import jumpy as jp
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.core import ObsType
|
|
||||||
|
|
||||||
|
|
||||||
class DelayObservationV0(gym.ObservationWrapper):
|
|
||||||
"""Wrapper which adds a delay to the returned observation."""
|
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, delay: int):
|
|
||||||
"""Initialize the DelayObservation wrapper.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env (Env): the wrapped environment
|
|
||||||
delay (int): number of timesteps for delaying the observation.
|
|
||||||
Before reaching the `delay` number of timesteps,
|
|
||||||
returned observation is an array of zeros with the
|
|
||||||
same shape of the observation space.
|
|
||||||
"""
|
|
||||||
super().__init__(env)
|
|
||||||
self.delay = delay
|
|
||||||
self.observation_queue = deque()
|
|
||||||
|
|
||||||
def observation(self, observation: ObsType) -> ObsType:
|
|
||||||
"""Return the delayed observation."""
|
|
||||||
self.observation_queue.append(observation)
|
|
||||||
|
|
||||||
if len(self.observation_queue) > self.delay:
|
|
||||||
return self.observation_queue.popleft()
|
|
||||||
|
|
||||||
return jp.zeros_like(observation)
|
|
@@ -1,13 +1,19 @@
|
|||||||
"""Lambda action wrapper which apply a function to the provided action."""
|
"""A collection of wrappers that all use the LambdaAction class.
|
||||||
from typing import Any, Callable, Union
|
|
||||||
|
* ``LambdaAction`` - Transforms the actions based on a function
|
||||||
|
* ``ClipAction`` - Clips the action within a bounds
|
||||||
|
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import jumpy as jp
|
import jumpy as jp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
from gymnasium.core import ActType, WrapperActType
|
||||||
from gymnasium.core import ActType
|
from gymnasium.spaces import Box, Space
|
||||||
from gymnasium.experimental.wrappers import ArgType
|
|
||||||
|
|
||||||
|
|
||||||
class LambdaActionV0(gym.ActionWrapper):
|
class LambdaActionV0(gym.ActionWrapper):
|
||||||
@@ -16,19 +22,23 @@ class LambdaActionV0(gym.ActionWrapper):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
func: Callable[[ArgType], Any],
|
func: Callable[[WrapperActType], ActType],
|
||||||
|
action_space: Space | None,
|
||||||
):
|
):
|
||||||
"""Initialize LambdaAction.
|
"""Initialize LambdaAction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The gymnasium environment
|
env: The gymnasium environment
|
||||||
func (Callable): function to apply to action
|
func: Function to apply to ``step`` ``action``
|
||||||
|
action_space: The updated action space of the wrapper given the function.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
if action_space is not None:
|
||||||
|
self.action_space = action_space
|
||||||
|
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
def action(self, action: ActType) -> Any:
|
def action(self, action: WrapperActType) -> ActType:
|
||||||
"""Apply function to action."""
|
"""Apply function to action."""
|
||||||
return self.func(action)
|
return self.func(action)
|
||||||
|
|
||||||
@@ -53,14 +63,19 @@ class ClipActionV0(LambdaActionV0):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
assert isinstance(env.action_space, spaces.Box)
|
assert isinstance(env.action_space, Box)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
env,
|
env,
|
||||||
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
|
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
|
||||||
|
Box(
|
||||||
|
-np.inf,
|
||||||
|
np.inf,
|
||||||
|
shape=env.action_space.shape,
|
||||||
|
dtype=env.action_space.dtype,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape)
|
|
||||||
|
|
||||||
|
|
||||||
class RescaleActionV0(LambdaActionV0):
|
class RescaleActionV0(LambdaActionV0):
|
||||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||||
@@ -86,8 +101,8 @@ class RescaleActionV0(LambdaActionV0):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
min_action: Union[float, int, np.ndarray],
|
min_action: float | int | np.ndarray,
|
||||||
max_action: Union[float, int, np.ndarray],
|
max_action: float | int | np.ndarray,
|
||||||
):
|
):
|
||||||
"""Initializes the :class:`RescaleAction` wrapper.
|
"""Initializes the :class:`RescaleAction` wrapper.
|
||||||
|
|
||||||
@@ -96,28 +111,44 @@ class RescaleActionV0(LambdaActionV0):
|
|||||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(env.action_space, Box)
|
||||||
env.action_space, spaces.Box
|
assert not np.any(env.action_space.low == np.inf) and not np.any(
|
||||||
), f"expected Box action space, got {type(env.action_space)}"
|
env.action_space.high == np.inf
|
||||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
|
||||||
|
|
||||||
low = env.action_space.low
|
|
||||||
high = env.action_space.high
|
|
||||||
|
|
||||||
self.min_action = np.full(
|
|
||||||
env.action_space.shape, min_action, dtype=env.action_space.dtype
|
|
||||||
)
|
)
|
||||||
self.max_action = np.full(
|
|
||||||
env.action_space.shape, max_action, dtype=env.action_space.dtype
|
if not isinstance(min_action, np.ndarray):
|
||||||
|
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
|
||||||
|
type(max_action), np.floating
|
||||||
|
)
|
||||||
|
min_action = np.full(env.action_space.shape, min_action)
|
||||||
|
|
||||||
|
assert min_action.shape == env.action_space.shape
|
||||||
|
assert not np.any(min_action == np.inf)
|
||||||
|
|
||||||
|
if not isinstance(max_action, np.ndarray):
|
||||||
|
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
|
||||||
|
type(max_action), np.floating
|
||||||
|
)
|
||||||
|
max_action = np.full(env.action_space.shape, max_action)
|
||||||
|
assert max_action.shape == env.action_space.shape
|
||||||
|
assert not np.any(max_action == np.inf)
|
||||||
|
|
||||||
|
assert isinstance(env.action_space, Box)
|
||||||
|
assert np.all(np.less_equal(min_action, max_action))
|
||||||
|
|
||||||
|
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||||
|
gradient = (env.action_space.high - env.action_space.low) / (
|
||||||
|
max_action - min_action
|
||||||
)
|
)
|
||||||
|
intercept = gradient * -min_action + env.action_space.low
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
env,
|
env,
|
||||||
lambda action: jp.clip(
|
lambda action: gradient * action + intercept,
|
||||||
low
|
Box(
|
||||||
+ (high - low)
|
low=min_action,
|
||||||
* ((action - self.min_action) / (self.max_action - self.min_action)),
|
high=max_action,
|
||||||
low,
|
shape=env.action_space.shape,
|
||||||
high,
|
dtype=env.action_space.dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@@ -1,17 +1,27 @@
|
|||||||
"""Lambda observation wrappers which apply a function to the observation."""
|
"""A collection of observation wrappers using a lambda function.
|
||||||
|
|
||||||
|
* ``LambdaObservation`` - Transforms the observation with a function
|
||||||
|
* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
|
||||||
|
* ``FlattenObservation`` - Flattens the observations
|
||||||
|
* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation
|
||||||
|
* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation)
|
||||||
|
* ``ReshapeObservation`` - Reshapes an array-based observation
|
||||||
|
* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
|
||||||
|
* ``DtypeObservation`` - Convert a observation dtype
|
||||||
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Sequence
|
from typing import Any, Callable, Sequence
|
||||||
|
from typing_extensions import Final
|
||||||
|
|
||||||
import jumpy as jp
|
import jumpy as jp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.spaces import utils
|
from gymnasium.spaces import Box, utils
|
||||||
|
|
||||||
|
|
||||||
class LambdaObservationV0(gym.ObservationWrapper):
|
class LambdaObservationV0(gym.ObservationWrapper):
|
||||||
@@ -71,32 +81,82 @@ class FilterObservationV0(LambdaObservationV0):
|
|||||||
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {})
|
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, filter_keys: Sequence[str]):
|
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
|
||||||
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
||||||
if not isinstance(env.observation_space, spaces.Dict):
|
assert isinstance(filter_keys, Sequence)
|
||||||
|
|
||||||
|
# Filters for dictionary space
|
||||||
|
if isinstance(env.observation_space, spaces.Dict):
|
||||||
|
assert all(isinstance(key, str) for key in filter_keys)
|
||||||
|
|
||||||
|
if any(
|
||||||
|
key not in env.observation_space.spaces.keys() for key in filter_keys
|
||||||
|
):
|
||||||
|
missing_keys = [
|
||||||
|
key
|
||||||
|
for key in filter_keys
|
||||||
|
if key not in env.observation_space.spaces.keys()
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
"All the `filter_keys` must be included in the observation space.\n"
|
||||||
|
f"Filter keys: {filter_keys}\n"
|
||||||
|
f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
|
||||||
|
f"Missing keys: {missing_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_observation_space = spaces.Dict(
|
||||||
|
{key: env.observation_space[key] for key in filter_keys}
|
||||||
|
)
|
||||||
|
if len(new_observation_space) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The observation space is empty due to filtering all keys."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
env,
|
||||||
|
lambda obs: {key: obs[key] for key in filter_keys},
|
||||||
|
new_observation_space,
|
||||||
|
)
|
||||||
|
# Filter for tuple observation
|
||||||
|
elif isinstance(env.observation_space, spaces.Tuple):
|
||||||
|
assert all(isinstance(key, int) for key in filter_keys)
|
||||||
|
assert len(set(filter_keys)) == len(
|
||||||
|
filter_keys
|
||||||
|
), f"Duplicate keys exist, filter_keys: {filter_keys}"
|
||||||
|
|
||||||
|
if any(
|
||||||
|
0 < key and key >= len(env.observation_space) for key in filter_keys
|
||||||
|
):
|
||||||
|
missing_index = [
|
||||||
|
key
|
||||||
|
for key in filter_keys
|
||||||
|
if 0 < key and key >= len(env.observation_space)
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
"All the `filter_keys` must be included in the length of the observation space.\n"
|
||||||
|
f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
|
||||||
|
f"missing indexes: {missing_index}"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_observation_spaces = spaces.Tuple(
|
||||||
|
env.observation_space[key] for key in filter_keys
|
||||||
|
)
|
||||||
|
if len(new_observation_spaces) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The observation space is empty due to filtering all keys."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
env,
|
||||||
|
lambda obs: tuple(obs[key] for key in filter_keys),
|
||||||
|
new_observation_spaces,
|
||||||
|
)
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"FilterObservation wrapper is only usable with dict observations, actual type: {type(env.observation_space)}"
|
f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if any(key not in env.observation_space.keys() for key in filter_keys):
|
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||||
missing_keys = [
|
|
||||||
key for key in filter_keys if key not in env.observation_space.keys()
|
|
||||||
]
|
|
||||||
raise ValueError(
|
|
||||||
"All the filter_keys must be included in the original observation space.\n"
|
|
||||||
f"Filter keys: {filter_keys}\n"
|
|
||||||
f"Observation keys: {list(env.observation_space.keys())}\n"
|
|
||||||
f"Missing keys: {missing_keys}"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_observation_space = spaces.Dict(
|
|
||||||
{key: env.observation_space[key] for key in filter_keys}
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
env,
|
|
||||||
lambda obs: {key: obs[key] for key in filter_keys},
|
|
||||||
new_observation_space,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlattenObservationV0(LambdaObservationV0):
|
class FlattenObservationV0(LambdaObservationV0):
|
||||||
@@ -117,9 +177,10 @@ class FlattenObservationV0(LambdaObservationV0):
|
|||||||
|
|
||||||
def __init__(self, env: gym.Env):
|
def __init__(self, env: gym.Env):
|
||||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
||||||
flattened_space = utils.flatten_space(env.observation_space)
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
env, lambda obs: utils.flatten(flattened_space, obs), flattened_space
|
env,
|
||||||
|
lambda obs: utils.flatten(env.observation_space, obs),
|
||||||
|
utils.flatten_space(env.observation_space),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -154,7 +215,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
|||||||
and env.observation_space.dtype == np.uint8
|
and env.observation_space.dtype == np.uint8
|
||||||
)
|
)
|
||||||
|
|
||||||
self.keep_dim = keep_dim
|
self.keep_dim: Final[bool] = keep_dim
|
||||||
if keep_dim:
|
if keep_dim:
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=0,
|
low=0,
|
||||||
@@ -167,7 +228,8 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
|||||||
lambda obs: jp.expand_dims(
|
lambda obs: jp.expand_dims(
|
||||||
jp.sum(
|
jp.sum(
|
||||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||||
)
|
).astype(np.uint8),
|
||||||
|
axis=-1,
|
||||||
),
|
),
|
||||||
new_observation_space,
|
new_observation_space,
|
||||||
)
|
)
|
||||||
@@ -179,7 +241,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
|||||||
env,
|
env,
|
||||||
lambda obs: jp.sum(
|
lambda obs: jp.sum(
|
||||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||||
),
|
).astype(np.uint8),
|
||||||
new_observation_space,
|
new_observation_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,7 +277,7 @@ class ResizeObservationV0(LambdaObservationV0):
|
|||||||
"opencv is not install, run `pip install gymnasium[other]`"
|
"opencv is not install, run `pip install gymnasium[other]`"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.shape = tuple(shape)
|
self.shape: Final[tuple[int, ...]] = tuple(shape)
|
||||||
|
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
||||||
@@ -237,7 +299,7 @@ class ReshapeObservationV0(LambdaObservationV0):
|
|||||||
|
|
||||||
assert isinstance(shape, tuple)
|
assert isinstance(shape, tuple)
|
||||||
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
|
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
|
||||||
assert all(x > 0 for x in shape)
|
assert all(x > 0 or x == -1 for x in shape)
|
||||||
|
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=np.reshape(np.ravel(env.observation_space.low), shape),
|
low=np.reshape(np.ravel(env.observation_space.low), shape),
|
||||||
@@ -245,9 +307,8 @@ class ReshapeObservationV0(LambdaObservationV0):
|
|||||||
shape=shape,
|
shape=shape,
|
||||||
dtype=env.observation_space.dtype,
|
dtype=env.observation_space.dtype,
|
||||||
)
|
)
|
||||||
super().__init__(
|
self.shape = shape
|
||||||
env, lambda obs: jp.reshape(obs, self.shape), new_observation_space
|
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RescaleObservationV0(LambdaObservationV0):
|
class RescaleObservationV0(LambdaObservationV0):
|
||||||
@@ -256,18 +317,23 @@ class RescaleObservationV0(LambdaObservationV0):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
min_obs: tuple[np.floating, np.integer, np.ndarray],
|
min_obs: np.floating | np.integer | np.ndarray,
|
||||||
max_obs: tuple[np.floating, np.integer, np.ndarray],
|
max_obs: np.floating | np.integer | np.ndarray,
|
||||||
):
|
):
|
||||||
"""Constructor that requires the env observation spaces to be a :class:`Box`."""
|
"""Constructor that requires the env observation spaces to be a :class:`Box`."""
|
||||||
assert isinstance(env.observation_space, spaces.Box)
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
|
assert not np.any(env.observation_space.low == np.inf) and not np.any(
|
||||||
|
env.observation_space.high == np.inf
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(min_obs, np.ndarray):
|
if not isinstance(min_obs, np.ndarray):
|
||||||
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
|
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
|
||||||
type(max_obs), np.floating
|
type(max_obs), np.floating
|
||||||
)
|
)
|
||||||
min_obs = np.full(env.observation_space.shape, min_obs)
|
min_obs = np.full(env.observation_space.shape, min_obs)
|
||||||
assert min_obs.shape == env.observation_space.shape
|
assert (
|
||||||
|
min_obs.shape == env.observation_space.shape
|
||||||
|
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
|
||||||
assert not np.any(min_obs == np.inf)
|
assert not np.any(min_obs == np.inf)
|
||||||
|
|
||||||
if not isinstance(max_obs, np.ndarray):
|
if not isinstance(max_obs, np.ndarray):
|
||||||
@@ -278,52 +344,66 @@ class RescaleObservationV0(LambdaObservationV0):
|
|||||||
assert max_obs.shape == env.observation_space.shape
|
assert max_obs.shape == env.observation_space.shape
|
||||||
assert not np.any(max_obs == np.inf)
|
assert not np.any(max_obs == np.inf)
|
||||||
|
|
||||||
env_low = env.observation_space.low
|
self.min_obs = min_obs
|
||||||
env_high = env.observation_space.high
|
self.max_obs = max_obs
|
||||||
|
|
||||||
|
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||||
|
gradient = (max_obs - min_obs) / (
|
||||||
|
env.observation_space.high - env.observation_space.low
|
||||||
|
)
|
||||||
|
intercept = gradient * -env.observation_space.low + min_obs
|
||||||
|
|
||||||
new_observation_space = spaces.Box(low=min_obs, high=max_obs)
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
env,
|
env,
|
||||||
lambda obs: env_low
|
lambda obs: gradient * obs + intercept,
|
||||||
+ (env_high - env_low) * ((obs - min_obs) / (max_obs - min_obs)),
|
Box(
|
||||||
new_observation_space,
|
low=min_obs,
|
||||||
|
high=max_obs,
|
||||||
|
shape=env.observation_space.shape,
|
||||||
|
dtype=env.observation_space.dtype,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DtypeObservationV0(LambdaObservationV0):
|
class DtypeObservationV0(LambdaObservationV0):
|
||||||
"""Observation wrapper for transforming the dtype of an observation."""
|
"""Observation wrapper for transforming the dtype of an observation."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, dtype: npt.DTypeLike):
|
def __init__(self, env: gym.Env, dtype: Any):
|
||||||
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
|
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
env.observation_space,
|
env.observation_space,
|
||||||
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
||||||
)
|
)
|
||||||
|
|
||||||
dtype = np.dtype(dtype)
|
self.dtype = dtype
|
||||||
if isinstance(env.observation_space, spaces.Box):
|
if isinstance(env.observation_space, spaces.Box):
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=env.observation_space.low,
|
low=env.observation_space.low,
|
||||||
high=env.observation_space.high,
|
high=env.observation_space.high,
|
||||||
shape=env.observation_space.shape,
|
shape=env.observation_space.shape,
|
||||||
dtype=dtype.__name__,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
elif isinstance(env.observation_space, spaces.Discrete):
|
elif isinstance(env.observation_space, spaces.Discrete):
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=env.observation_space.start,
|
low=env.observation_space.start,
|
||||||
high=env.observation_space.start + env.observation_space.n,
|
high=env.observation_space.start + env.observation_space.n,
|
||||||
shape=(),
|
shape=(),
|
||||||
dtype=dtype.__name__,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
elif isinstance(env.observation_space, spaces.MultiDiscrete):
|
elif isinstance(env.observation_space, spaces.MultiDiscrete):
|
||||||
new_observation_space = spaces.MultiDiscrete(
|
new_observation_space = spaces.MultiDiscrete(
|
||||||
env.observation_space.nvec, dtype=dtype.__name__
|
env.observation_space.nvec, dtype=dtype
|
||||||
)
|
)
|
||||||
elif isinstance(env.observation_space, spaces.MultiBinary):
|
elif isinstance(env.observation_space, spaces.MultiBinary):
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=0, high=1, shape=env.observation_space.shape, dtype=dtype.__name__
|
low=0,
|
||||||
|
high=1,
|
||||||
|
shape=env.observation_space.shape,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError(
|
||||||
|
"DtypeObservation is only compatible with value / array-based observations."
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
||||||
|
@@ -1,12 +1,17 @@
|
|||||||
"""Lambda reward wrappers which apply a function to the reward."""
|
"""A collection of wrappers for modifying the reward.
|
||||||
|
|
||||||
from typing import Any, Callable, Optional, Union
|
* ``LambdaReward`` - Transforms the reward by a function
|
||||||
|
* ``ClipReward`` - Clips the reward between a minimum and maximum value
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Callable, SupportsFloat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.error import InvalidBound
|
from gymnasium.error import InvalidBound
|
||||||
from gymnasium.experimental.wrappers import ArgType
|
|
||||||
|
|
||||||
|
|
||||||
class LambdaRewardV0(gym.RewardWrapper):
|
class LambdaRewardV0(gym.RewardWrapper):
|
||||||
@@ -26,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
func: Callable[[ArgType], Any],
|
func: Callable[[SupportsFloat], SupportsFloat],
|
||||||
):
|
):
|
||||||
"""Initialize LambdaRewardV0 wrapper.
|
"""Initialize LambdaRewardV0 wrapper.
|
||||||
|
|
||||||
@@ -38,7 +43,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
|||||||
|
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
def reward(self, reward: Union[float, int, np.ndarray]) -> Any:
|
def reward(self, reward: SupportsFloat) -> SupportsFloat:
|
||||||
"""Apply function to reward.
|
"""Apply function to reward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -64,8 +69,8 @@ class ClipRewardV0(LambdaRewardV0):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
min_reward: Optional[Union[float, np.ndarray]] = None,
|
min_reward: float | np.ndarray | None = None,
|
||||||
max_reward: Optional[Union[float, np.ndarray]] = None,
|
max_reward: float | np.ndarray | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize ClipRewardsV0 wrapper.
|
"""Initialize ClipRewardsV0 wrapper.
|
||||||
|
|
||||||
|
@@ -6,70 +6,90 @@ import numbers
|
|||||||
from collections import abc
|
from collections import abc
|
||||||
from typing import Any, Iterable, Mapping, SupportsFloat
|
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||||
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium import Env, Wrapper
|
from gymnasium import Env, Wrapper
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax.numpy as jnp
|
||||||
|
except ImportError:
|
||||||
|
# We handle the error internal to the relative functions
|
||||||
|
jnp = None
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def numpy_to_jax(value: Any) -> Any:
|
def numpy_to_jax(value: Any) -> Any:
|
||||||
"""Converts a value to a Jax DeviceArray."""
|
"""Converts a value to a Jax DeviceArray."""
|
||||||
raise Exception(
|
if jnp is None:
|
||||||
f"No conversion for Numpy to Jax registered for type: {type(value)}"
|
raise DependencyNotInstalled(
|
||||||
)
|
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(numbers.Number)
|
if jnp is not None:
|
||||||
@numpy_to_jax.register(np.ndarray)
|
|
||||||
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
|
|
||||||
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
|
||||||
return jnp.array(value)
|
|
||||||
|
|
||||||
|
@numpy_to_jax.register(numbers.Number)
|
||||||
|
@numpy_to_jax.register(np.ndarray)
|
||||||
|
def _number_ndarray_numpy_to_jax(
|
||||||
|
value: np.ndarray | numbers.Number,
|
||||||
|
) -> jnp.DeviceArray:
|
||||||
|
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||||
|
assert jnp is not None
|
||||||
|
return jnp.array(value)
|
||||||
|
|
||||||
@numpy_to_jax.register(abc.Mapping)
|
@numpy_to_jax.register(abc.Mapping)
|
||||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
@numpy_to_jax.register(abc.Iterable)
|
||||||
@numpy_to_jax.register(abc.Iterable)
|
def _iterable_numpy_to_jax(
|
||||||
def _iterable_numpy_to_jax(
|
value: Iterable[np.ndarray | Any],
|
||||||
value: Iterable[np.ndarray | Any],
|
) -> Iterable[jnp.DeviceArray | Any]:
|
||||||
) -> Iterable[jnp.DeviceArray | Any]:
|
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
||||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
return type(value)(numpy_to_jax(v) for v in value)
|
||||||
return type(value)(numpy_to_jax(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def jax_to_numpy(value: Any) -> Any:
|
def jax_to_numpy(value: Any) -> Any:
|
||||||
"""Converts a value to a numpy array."""
|
"""Converts a value to a numpy array."""
|
||||||
raise Exception(
|
if jnp is None:
|
||||||
f"No conversion for Jax to Numpy registered for type: {type(value)}"
|
raise DependencyNotInstalled(
|
||||||
)
|
"Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@jax_to_numpy.register(jnp.DeviceArray)
|
if jnp is not None:
|
||||||
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
|
||||||
"""Converts a Jax DeviceArray to a numpy array."""
|
|
||||||
return np.array(value)
|
|
||||||
|
|
||||||
|
@jax_to_numpy.register(jnp.DeviceArray)
|
||||||
|
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||||
|
"""Converts a Jax DeviceArray to a numpy array."""
|
||||||
|
return np.array(value)
|
||||||
|
|
||||||
@jax_to_numpy.register(abc.Mapping)
|
@jax_to_numpy.register(abc.Mapping)
|
||||||
def _mapping_jax_to_numpy(
|
def _mapping_jax_to_numpy(
|
||||||
value: Mapping[str, jnp.DeviceArray | Any]
|
value: Mapping[str, jnp.DeviceArray | Any]
|
||||||
) -> Mapping[str, np.ndarray | Any]:
|
) -> Mapping[str, np.ndarray | Any]:
|
||||||
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
@jax_to_numpy.register(abc.Iterable)
|
||||||
@jax_to_numpy.register(abc.Iterable)
|
def _iterable_jax_to_numpy(
|
||||||
def _iterable_jax_to_numpy(
|
value: Iterable[np.ndarray | Any],
|
||||||
value: Iterable[np.ndarray | Any],
|
) -> Iterable[jnp.DeviceArray | Any]:
|
||||||
) -> Iterable[jnp.DeviceArray | Any]:
|
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
||||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
return type(value)(jax_to_numpy(v) for v in value)
|
||||||
return type(value)(jax_to_numpy(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
class JaxToNumpyV0(Wrapper):
|
class JaxToNumpyV0(Wrapper):
|
||||||
@@ -88,6 +108,10 @@ class JaxToNumpyV0(Wrapper):
|
|||||||
Args:
|
Args:
|
||||||
env: the environment to wrap
|
env: the environment to wrap
|
||||||
"""
|
"""
|
||||||
|
if jnp is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Jax is not installed, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
|
56
gymnasium/experimental/wrappers/stateful_action.py
Normal file
56
gymnasium/experimental/wrappers/stateful_action.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""A collection of stateful action wrappers.
|
||||||
|
|
||||||
|
* StickyAction - There is a probability that the action is taken again
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, SupportsFloat
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import InvalidProbability
|
||||||
|
|
||||||
|
|
||||||
|
class StickyActionV0(gym.Wrapper):
|
||||||
|
"""Wrapper which adds a probability of repeating the previous action.
|
||||||
|
|
||||||
|
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||||
|
in Section 5.2 on page 12.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, repeat_action_probability: float):
|
||||||
|
"""Initialize StickyAction wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): the wrapped environment
|
||||||
|
repeat_action_probability (int | float): a probability of repeating the old action.
|
||||||
|
"""
|
||||||
|
if not 0 <= repeat_action_probability < 1:
|
||||||
|
raise InvalidProbability(
|
||||||
|
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(env)
|
||||||
|
self.repeat_action_probability = repeat_action_probability
|
||||||
|
self.last_action: WrapperActType | None = None
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Reset the environment."""
|
||||||
|
self.last_action = None
|
||||||
|
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Execute the action."""
|
||||||
|
if (
|
||||||
|
self.last_action is not None
|
||||||
|
and self.np_random.uniform() < self.repeat_action_probability
|
||||||
|
):
|
||||||
|
action = self.last_action
|
||||||
|
|
||||||
|
self.last_action = action
|
||||||
|
return action
|
200
gymnasium/experimental/wrappers/stateful_observation.py
Normal file
200
gymnasium/experimental/wrappers/stateful_observation.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""A collection of stateful observation wrappers.
|
||||||
|
|
||||||
|
* DelayObservation - A wrapper for delaying the returned observation
|
||||||
|
* TimeAwareObservation - A wrapper for adding time aware observations to environment observation
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, SupportsFloat
|
||||||
|
from typing_extensions import Final
|
||||||
|
|
||||||
|
import jumpy as jp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import gymnasium.spaces as spaces
|
||||||
|
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||||
|
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class DelayObservationV0(gym.ObservationWrapper):
|
||||||
|
"""Wrapper which adds a delay to the returned observation."""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, delay: int):
|
||||||
|
"""Initialize the DelayObservation wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): the wrapped environment
|
||||||
|
delay (int): number of timesteps for delaying the observation.
|
||||||
|
Before reaching the `delay` number of timesteps,
|
||||||
|
returned observation is an array of zeros with the
|
||||||
|
same shape of the observation space.
|
||||||
|
"""
|
||||||
|
assert isinstance(env.observation_space, (Box, MultiBinary, MultiDiscrete))
|
||||||
|
assert 0 < delay
|
||||||
|
|
||||||
|
self.delay: Final[int] = delay
|
||||||
|
self.observation_queue: Final[deque] = deque()
|
||||||
|
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment, clearing the observation queue."""
|
||||||
|
self.observation_queue.clear()
|
||||||
|
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
def observation(self, observation: ObsType) -> ObsType:
|
||||||
|
"""Return the delayed observation."""
|
||||||
|
self.observation_queue.append(observation)
|
||||||
|
|
||||||
|
if len(self.observation_queue) > self.delay:
|
||||||
|
return self.observation_queue.popleft()
|
||||||
|
|
||||||
|
return jp.zeros_like(observation)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||||
|
"""Augment the observation with time information of the episode.
|
||||||
|
|
||||||
|
Time can be represented as a normalized value between [0,1]
|
||||||
|
or by the number of timesteps remaining before truncation occurs.
|
||||||
|
|
||||||
|
For environments with ``Dict`` or ``Tuple`` observation spaces, by default,
|
||||||
|
the time information is automatically added in the key `"time"` and
|
||||||
|
as the final element in the tuple.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import gymnasium as gym
|
||||||
|
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
||||||
|
>>> env = gym.make('CartPole-v1')
|
||||||
|
>>> env = TimeAwareObservationV0(env)
|
||||||
|
>>> env.observation_space
|
||||||
|
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
|
||||||
|
>>> _ = env.reset()
|
||||||
|
>>> env.step(env.action_space.sample())[0]
|
||||||
|
OrderedDict([('obs',
|
||||||
|
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
|
||||||
|
... ('time', array([0.002]))])
|
||||||
|
|
||||||
|
Flatten observation space example:
|
||||||
|
>>> env = gym.make('CartPole-v1')
|
||||||
|
>>> env = TimeAwareObservationV0(env, flatten=True)
|
||||||
|
>>> env.observation_space
|
||||||
|
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
|
||||||
|
>>> _ = env.reset()
|
||||||
|
>>> env.step(env.action_space.sample())[0]
|
||||||
|
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
flatten: bool = False,
|
||||||
|
normalize_time: bool = True,
|
||||||
|
*,
|
||||||
|
dict_time_key: str = "time",
|
||||||
|
):
|
||||||
|
"""Initialize :class:`TimeAwareObservationV0`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to apply the wrapper
|
||||||
|
flatten: Flatten the observation to a `Box` of a single dimension
|
||||||
|
normalize_time: if `True` return time in the range [0,1]
|
||||||
|
otherwise return time as remaining timesteps before truncation
|
||||||
|
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self.flatten: Final[bool] = flatten
|
||||||
|
self.normalize_time: Final[bool] = normalize_time
|
||||||
|
|
||||||
|
if hasattr(env, "_max_episode_steps"):
|
||||||
|
self.max_timesteps = getattr(env, "_max_episode_steps")
|
||||||
|
elif env.spec is not None and env.spec.max_episode_steps is not None:
|
||||||
|
self.max_timesteps = env.spec.max_episode_steps
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timesteps: int = 0
|
||||||
|
|
||||||
|
# Find the normalized time space
|
||||||
|
if self.normalize_time:
|
||||||
|
self._time_preprocess_func = lambda time: time / self.max_timesteps
|
||||||
|
time_space = Box(0.0, 1.0)
|
||||||
|
else:
|
||||||
|
self._time_preprocess_func = lambda time: self.max_timesteps - time
|
||||||
|
time_space = Box(0, self.max_timesteps, dtype=np.int32)
|
||||||
|
|
||||||
|
# Find the observation space
|
||||||
|
if isinstance(env.observation_space, Dict):
|
||||||
|
assert dict_time_key not in env.observation_space.keys()
|
||||||
|
observation_space = Dict(
|
||||||
|
{dict_time_key: time_space}, **env.observation_space.spaces
|
||||||
|
)
|
||||||
|
self._append_data_func = lambda obs, time: {**obs, dict_time_key: time}
|
||||||
|
elif isinstance(env.observation_space, Tuple):
|
||||||
|
observation_space = Tuple(env.observation_space.spaces + (time_space,))
|
||||||
|
self._append_data_func = lambda obs, time: obs + (time,)
|
||||||
|
else:
|
||||||
|
observation_space = Dict(obs=env.observation_space, time=time_space)
|
||||||
|
self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
|
||||||
|
|
||||||
|
# If to flatten the observation space
|
||||||
|
if self.flatten:
|
||||||
|
self.observation_space = spaces.flatten_space(observation_space)
|
||||||
|
self._obs_postprocess_func = lambda obs: spaces.flatten(
|
||||||
|
observation_space, obs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.observation_space = observation_space
|
||||||
|
self._obs_postprocess_func = lambda obs: obs
|
||||||
|
|
||||||
|
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||||
|
"""Adds to the observation with the current time information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: The observation to add the time step to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The observation with the time information appended to
|
||||||
|
"""
|
||||||
|
return self._obs_postprocess_func(
|
||||||
|
self._append_data_func(
|
||||||
|
observation, self._time_preprocess_func(self.timesteps)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: ActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
|
"""Steps through the environment, incrementing the time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to take
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The environment's step using the action.
|
||||||
|
"""
|
||||||
|
self.timesteps += 1
|
||||||
|
return super().step(action)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Reset the environment setting the time to zero.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed to reset the environment
|
||||||
|
options: The options used to reset the environment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reset environment
|
||||||
|
"""
|
||||||
|
self.timesteps = 0
|
||||||
|
|
||||||
|
return super().reset(seed=seed, options=options)
|
@@ -1,40 +0,0 @@
|
|||||||
"""Wrapper which adds a probability of repeating the previous executed action."""
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.core import ActType
|
|
||||||
from gymnasium.error import InvalidProbability
|
|
||||||
|
|
||||||
|
|
||||||
class StickyActionV0(gym.ActionWrapper):
|
|
||||||
"""Wrapper which adds a probability of repeating the previous action."""
|
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]):
|
|
||||||
"""Initialize StickyAction wrapper.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env (Env): the wrapped environment
|
|
||||||
repeat_action_probability (int | float): a proability of repeating the old action.
|
|
||||||
"""
|
|
||||||
if not 0 <= repeat_action_probability < 1:
|
|
||||||
raise InvalidProbability(
|
|
||||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
|
||||||
)
|
|
||||||
super().__init__(env)
|
|
||||||
self.repeat_action_probability = repeat_action_probability
|
|
||||||
self.old_action = None
|
|
||||||
|
|
||||||
def action(self, action: ActType):
|
|
||||||
"""Execute the action."""
|
|
||||||
if (
|
|
||||||
self.old_action is not None
|
|
||||||
and self.np_random.uniform() < self.repeat_action_probability
|
|
||||||
):
|
|
||||||
action = self.old_action
|
|
||||||
self.old_action = action
|
|
||||||
return action
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
|
||||||
"""Reset the environment."""
|
|
||||||
self.old_action = None
|
|
||||||
return super().reset(**kwargs)
|
|
@@ -1,113 +0,0 @@
|
|||||||
"""Wrapper for adding time aware observations to environment observation."""
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
import gymnasium.spaces as spaces
|
|
||||||
from gymnasium.core import ActType, ObsType
|
|
||||||
from gymnasium.spaces import Box, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class TimeAwareObservationV0(gym.ObservationWrapper):
|
|
||||||
"""Augment the observation with time information of the episode.
|
|
||||||
|
|
||||||
Time can be represented as a normalized value between [0,1]
|
|
||||||
or by the number of timesteps remaining before truncation occurs.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> import gym
|
|
||||||
>>> from gym.wrappers import TimeAwareObservationV0
|
|
||||||
>>> env = gym.make('CartPole-v1')
|
|
||||||
>>> env = TimeAwareObservationV0(env)
|
|
||||||
>>> env.observation_space
|
|
||||||
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
|
|
||||||
>>> _ = env.reset()
|
|
||||||
>>> env.step(env.action_space.sample())[0]
|
|
||||||
OrderedDict([('obs',
|
|
||||||
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
|
|
||||||
... ('time', array([0.002]))])
|
|
||||||
|
|
||||||
Flatten observation space example:
|
|
||||||
>>> env = gym.make('CartPole-v1')
|
|
||||||
>>> env = TimeAwareObservationV0(env, flatten=True)
|
|
||||||
>>> env.observation_space
|
|
||||||
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
|
|
||||||
>>> _ = env.reset()
|
|
||||||
>>> env.step(env.action_space.sample())[0]
|
|
||||||
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, flatten=False, normalize_time=True):
|
|
||||||
"""Initialize :class:`TimeAwareObservationV0`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env: The environment to apply the wrapper
|
|
||||||
flatten: Flatten the observation to a `Box` of a single dimension
|
|
||||||
normalize_time: if `True` return time in the range [0,1]
|
|
||||||
otherwise return time as remaining timesteps before truncation
|
|
||||||
"""
|
|
||||||
super().__init__(env)
|
|
||||||
self.flatten = flatten
|
|
||||||
self.normalize_time = normalize_time
|
|
||||||
self.max_timesteps = getattr(env, "_max_episode_steps")
|
|
||||||
|
|
||||||
if self.normalize_time:
|
|
||||||
self._get_time_observation = lambda: self.timesteps / self.max_timesteps
|
|
||||||
time_space = Box(0, 1)
|
|
||||||
else:
|
|
||||||
self._get_time_observation = lambda: self.max_timesteps - self.timesteps
|
|
||||||
time_space = Box(0, self.max_timesteps)
|
|
||||||
|
|
||||||
self.time_aware_observation_space = Dict(
|
|
||||||
obs=env.observation_space, time=time_space
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.flatten:
|
|
||||||
self.observation_space = spaces.flatten_space(
|
|
||||||
self.time_aware_observation_space
|
|
||||||
)
|
|
||||||
self._observation_postprocess = lambda observation: spaces.flatten(
|
|
||||||
self.time_aware_observation_space, observation
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.observation_space = self.time_aware_observation_space
|
|
||||||
self._observation_postprocess = lambda observation: observation
|
|
||||||
|
|
||||||
def observation(self, observation: ObsType):
|
|
||||||
"""Adds to the observation with the current time information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
observation: The observation to add the time step to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The observation with the time information appended to
|
|
||||||
"""
|
|
||||||
time_observation = self._get_time_observation()
|
|
||||||
observation = OrderedDict(obs=observation, time=time_observation)
|
|
||||||
|
|
||||||
return self._observation_postprocess(observation)
|
|
||||||
|
|
||||||
def step(self, action: ActType):
|
|
||||||
"""Steps through the environment, incrementing the time step.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: The action to take
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The environment's step using the action.
|
|
||||||
"""
|
|
||||||
self.timesteps += 1
|
|
||||||
observation, reward, terminated, truncated, info = super().step(action)
|
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
|
||||||
"""Reset the environment setting the time to zero.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Kwargs to apply to env.reset()
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The reset environment
|
|
||||||
"""
|
|
||||||
self.timesteps = 0
|
|
||||||
return super().reset(**kwargs)
|
|
@@ -14,92 +14,121 @@ import numbers
|
|||||||
from collections import abc
|
from collections import abc
|
||||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||||
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from jax import dlpack as jax_dlpack
|
|
||||||
|
|
||||||
from gymnasium import Env, Wrapper
|
from gymnasium import Env, Wrapper
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
|
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import dlpack as jax_dlpack
|
||||||
|
except ImportError:
|
||||||
|
jnp, jax_dlpack = None, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from torch.utils import dlpack as torch_dlpack
|
from torch.utils import dlpack as torch_dlpack
|
||||||
|
|
||||||
|
Device = Union[str, torch.device]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise DependencyNotInstalled("torch is not installed, run `pip install torch`")
|
torch, torch_dlpack, Device = None, None, None
|
||||||
|
|
||||||
|
|
||||||
Device = Union[str, torch.device]
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def torch_to_jax(value: Any) -> Any:
|
def torch_to_jax(value: Any) -> Any:
|
||||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||||
raise Exception(
|
if torch is None:
|
||||||
f"No conversion for PyTorch to Jax registered for type: {type(value)}"
|
raise DependencyNotInstalled(
|
||||||
)
|
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
|
||||||
|
)
|
||||||
|
elif jnp is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch_to_jax.register(numbers.Number)
|
if torch is not None and jnp is not None:
|
||||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
|
||||||
return jnp.array(value)
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(numbers.Number)
|
||||||
|
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||||
|
"""Convert a python number (int, float, complex) to a jax array."""
|
||||||
|
assert jnp is not None
|
||||||
|
return jnp.array(value)
|
||||||
|
|
||||||
@torch_to_jax.register(torch.Tensor)
|
@torch_to_jax.register(torch.Tensor)
|
||||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||||
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
assert torch_dlpack is not None and jax_dlpack is not None
|
||||||
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
|
tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||||
return tensor
|
value
|
||||||
|
)
|
||||||
|
tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
tensor
|
||||||
|
)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
@torch_to_jax.register(abc.Mapping)
|
||||||
|
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||||
|
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||||
|
|
||||||
@torch_to_jax.register(abc.Mapping)
|
@torch_to_jax.register(abc.Iterable)
|
||||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
return type(value)(torch_to_jax(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
@torch_to_jax.register(abc.Iterable)
|
|
||||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
|
||||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
|
||||||
return type(value)(torch_to_jax(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
raise Exception(
|
if torch is None:
|
||||||
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}"
|
raise DependencyNotInstalled(
|
||||||
)
|
"Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`"
|
||||||
|
)
|
||||||
|
elif jnp is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@jax_to_torch.register(jnp.DeviceArray)
|
if torch is not None and jnp is not None:
|
||||||
def _devicearray_jax_to_torch(
|
|
||||||
value: jnp.DeviceArray, device: Device | None = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
|
||||||
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
tensor = torch_dlpack.from_dlpack(dlpack)
|
|
||||||
if device:
|
|
||||||
return tensor.to(device=device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
@jax_to_torch.register(jnp.DeviceArray)
|
||||||
|
def _devicearray_jax_to_torch(
|
||||||
|
value: jnp.DeviceArray, device: Device | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
|
assert jax_dlpack is not None and torch_dlpack is not None
|
||||||
|
dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
value
|
||||||
|
)
|
||||||
|
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||||
|
if device:
|
||||||
|
return tensor.to(device=device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
@jax_to_torch.register(abc.Mapping)
|
@jax_to_torch.register(abc.Mapping)
|
||||||
def _jax_mapping_to_torch(
|
def _jax_mapping_to_torch(
|
||||||
value: Mapping[str, Any], device: Device | None = None
|
value: Mapping[str, Any], device: Device | None = None
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||||
|
|
||||||
|
@jax_to_torch.register(abc.Iterable)
|
||||||
@jax_to_torch.register(abc.Iterable)
|
def _jax_iterable_to_torch(
|
||||||
def _jax_iterable_to_torch(
|
value: Iterable[Any], device: Device | None = None
|
||||||
value: Iterable[Any], device: Device | None = None
|
) -> Iterable[Any]:
|
||||||
) -> Iterable[Any]:
|
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
return type(value)(jax_to_torch(v, device) for v in value)
|
||||||
return type(value)(jax_to_torch(v, device) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
class JaxToTorchV0(Wrapper):
|
class JaxToTorchV0(Wrapper):
|
||||||
@@ -107,7 +136,8 @@ class JaxToTorchV0(Wrapper):
|
|||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
|
|
||||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
Note:
|
||||||
|
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env, device: Device | None = None):
|
def __init__(self, env: Env, device: Device | None = None):
|
||||||
@@ -117,6 +147,15 @@ class JaxToTorchV0(Wrapper):
|
|||||||
env: The Jax-based environment to wrap
|
env: The Jax-based environment to wrap
|
||||||
device: The device the torch Tensors should be moved to
|
device: The device the torch Tensors should be moved to
|
||||||
"""
|
"""
|
||||||
|
if torch is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Torch is not installed, run `pip install torch`"
|
||||||
|
)
|
||||||
|
elif jnp is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Jax is not installed, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.device: Device | None = device
|
self.device: Device | None = device
|
||||||
|
|
||||||
|
@@ -76,6 +76,7 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
|
|||||||
* ``None`` The length will be randomly drawn from a geometric distribution
|
* ``None`` The length will be randomly drawn from a geometric distribution
|
||||||
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
|
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
|
||||||
* ``int`` for a fixed length sample
|
* ``int`` for a fixed length sample
|
||||||
|
|
||||||
The second element of the mask tuple `sample` mask specifies a mask that is applied when
|
The second element of the mask tuple `sample` mask specifies a mask that is applied when
|
||||||
sampling elements from the base space. The mask is applied for each feature space sample.
|
sampling elements from the base space. The mask is applied for each feature space sample.
|
||||||
|
|
||||||
|
@@ -29,7 +29,7 @@ dependencies = [
|
|||||||
"jax-jumpy >=0.2.0",
|
"jax-jumpy >=0.2.0",
|
||||||
"cloudpickle >=1.2.0",
|
"cloudpickle >=1.2.0",
|
||||||
"importlib-metadata >=4.8.0; python_version < '3.10'",
|
"importlib-metadata >=4.8.0; python_version < '3.10'",
|
||||||
"typing-extensions >=4.3.0; python_version == '3.7'",
|
"typing-extensions >=4.3.0",
|
||||||
"gymnasium-notices >=0.0.1",
|
"gymnasium-notices >=0.0.1",
|
||||||
"shimmy >=0.1.0,<1.0",
|
"shimmy >=0.1.0,<1.0",
|
||||||
]
|
]
|
||||||
|
@@ -19,7 +19,7 @@ from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
|||||||
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
|
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
|
||||||
from tests.envs.utils import all_testing_env_specs
|
from tests.envs.utils import all_testing_env_specs
|
||||||
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
||||||
from tests.testing_env import GenericTestEnv, old_step_fn
|
from tests.testing_env import GenericTestEnv, old_step_func
|
||||||
from tests.wrappers.utils import has_wrapper
|
from tests.wrappers.utils import has_wrapper
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +155,7 @@ def test_make_disable_env_checker():
|
|||||||
def test_apply_api_compatibility():
|
def test_apply_api_compatibility():
|
||||||
gym.register(
|
gym.register(
|
||||||
"testing-old-env",
|
"testing-old-env",
|
||||||
lambda: GenericTestEnv(step_fn=old_step_fn),
|
lambda: GenericTestEnv(step_func=old_step_func),
|
||||||
apply_api_compatibility=True,
|
apply_api_compatibility=True,
|
||||||
max_episode_steps=3,
|
max_episode_steps=3,
|
||||||
)
|
)
|
||||||
|
@@ -1,48 +0,0 @@
|
|||||||
"""Test suite for LambdaActionV0."""
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.experimental.wrappers import ClipActionV0
|
|
||||||
|
|
||||||
|
|
||||||
SEED = 42
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("env", "action_unclipped_env", "action_clipped_env"),
|
|
||||||
(
|
|
||||||
[
|
|
||||||
# MountainCar action space: Box(-1.0, 1.0, (1,), float32)
|
|
||||||
gym.make("MountainCarContinuous-v0"),
|
|
||||||
np.array([1]),
|
|
||||||
np.array([1.5]),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
|
||||||
gym.make("BipedalWalker-v3"),
|
|
||||||
np.array([1, 1, 1, 1]),
|
|
||||||
np.array([10, 10, 10, 10]),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
|
||||||
gym.make("BipedalWalker-v3"),
|
|
||||||
np.array([0.5, 0.5, 1, 1]),
|
|
||||||
np.array([0.5, 0.5, 10, 10]),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env):
|
|
||||||
"""Tests if actions out of bound are correctly clipped.
|
|
||||||
|
|
||||||
Tests whether out of bound actions for the wrapped
|
|
||||||
environments are correctly clipped.
|
|
||||||
"""
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
obs, _, _, _, _ = env.step(action_unclipped_env)
|
|
||||||
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
wrapped_env = ClipActionV0(env)
|
|
||||||
wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env)
|
|
||||||
|
|
||||||
assert np.alltrue(obs == wrapped_obs)
|
|
@@ -1,37 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.experimental.wrappers import DelayObservationV0
|
|
||||||
|
|
||||||
|
|
||||||
SEED = 42
|
|
||||||
|
|
||||||
DELAY = 3
|
|
||||||
NUM_STEPS = 5
|
|
||||||
|
|
||||||
|
|
||||||
def test_delay_observation():
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
env.action_space.seed(SEED)
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
|
|
||||||
undelayed_observations = []
|
|
||||||
for _ in range(NUM_STEPS):
|
|
||||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
||||||
undelayed_observations.append(obs)
|
|
||||||
|
|
||||||
env.action_space.seed(SEED)
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
env = DelayObservationV0(env, delay=DELAY)
|
|
||||||
|
|
||||||
delayed_observations = []
|
|
||||||
for i in range(NUM_STEPS):
|
|
||||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
||||||
if i < DELAY - 1:
|
|
||||||
assert np.all(obs == 0)
|
|
||||||
delayed_observations.append(obs)
|
|
||||||
|
|
||||||
assert np.alltrue(
|
|
||||||
np.array(delayed_observations[DELAY:])
|
|
||||||
== np.array(undelayed_observations[: DELAY - 1])
|
|
||||||
)
|
|
@@ -1,60 +1,78 @@
|
|||||||
"""Test suite for LambdaActionV0."""
|
"""Test suit for lambda action wrappers: LambdaAction, ClipAction, RescaleAction."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
import gymnasium as gym
|
from gymnasium.experimental.wrappers import (
|
||||||
from gymnasium.error import InvalidAction
|
ClipActionV0,
|
||||||
from gymnasium.experimental.wrappers import LambdaActionV0
|
LambdaActionV0,
|
||||||
|
RescaleActionV0,
|
||||||
|
)
|
||||||
from gymnasium.spaces import Box
|
from gymnasium.spaces import Box
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
NUM_ENVS = 3
|
SEED = 42
|
||||||
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
|
|
||||||
|
|
||||||
|
|
||||||
def generic_step_fn(self, action):
|
def _record_action_step_func(self, action):
|
||||||
return 0, 0, False, False, {"action": action}
|
return 0, 0, False, False, {"action": action}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_lambda_action_wrapper():
|
||||||
("env", "func", "action", "expected"),
|
"""Tests LambdaAction through checking that the action taken is transformed by function."""
|
||||||
[
|
env = GenericTestEnv(step_func=_record_action_step_func)
|
||||||
(
|
wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3))
|
||||||
GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn),
|
|
||||||
lambda action: action + 2,
|
|
||||||
1,
|
|
||||||
3,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_lambda_action_v0(env, func, action, expected):
|
|
||||||
"""Tests lambda action.
|
|
||||||
Tests if function is correctly applied to environment's action.
|
|
||||||
"""
|
|
||||||
wrapped_env = LambdaActionV0(env, func)
|
|
||||||
_, _, _, _, info = wrapped_env.step(action)
|
|
||||||
executed_action = info["action"]
|
|
||||||
|
|
||||||
assert executed_action == expected
|
sampled_action = wrapped_env.action_space.sample()
|
||||||
|
assert sampled_action not in env.action_space
|
||||||
|
|
||||||
|
_, _, _, _, info = wrapped_env.step(sampled_action)
|
||||||
|
assert info["action"] in env.action_space
|
||||||
|
assert sampled_action - 2 == info["action"]
|
||||||
|
|
||||||
|
|
||||||
def test_lambda_action_v0_within_vector():
|
def test_clip_action_wrapper():
|
||||||
"""Tests lambda action in vectorized environments.
|
"""Test that the action is correctly clipped to the base environment action space."""
|
||||||
Tests if function is correctly applied to environment's action
|
env = GenericTestEnv(
|
||||||
in vectorized environment.
|
action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])),
|
||||||
"""
|
step_func=_record_action_step_func,
|
||||||
env = gym.vector.make(
|
|
||||||
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
|
|
||||||
)
|
)
|
||||||
action = np.ones(NUM_ENVS, dtype=np.float64)
|
wrapped_env = ClipActionV0(env)
|
||||||
|
|
||||||
wrapped_env = LambdaActionV0(env, lambda action: action.astype(int))
|
sampled_action = np.array([-1, 5, 3.5], dtype=np.float32)
|
||||||
wrapped_env.reset()
|
assert sampled_action not in env.action_space
|
||||||
|
assert sampled_action in wrapped_env.action_space
|
||||||
|
|
||||||
wrapped_env.step(action)
|
_, _, _, _, info = wrapped_env.step(sampled_action)
|
||||||
|
assert np.all(info["action"] in env.action_space)
|
||||||
|
assert np.all(info["action"] == np.array([0, 2, 3.5]))
|
||||||
|
|
||||||
# unwrapped env should raise exception because it does not
|
|
||||||
# support float actions
|
def test_rescale_action_wrapper():
|
||||||
with pytest.raises(InvalidAction):
|
"""Test that the action is rescale within a min / max bound."""
|
||||||
env.step(action)
|
env = GenericTestEnv(
|
||||||
|
step_func=_record_action_step_func,
|
||||||
|
action_space=Box(np.array([0, 1]), np.array([1, 3])),
|
||||||
|
)
|
||||||
|
wrapped_env = RescaleActionV0(
|
||||||
|
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
|
||||||
|
)
|
||||||
|
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
|
||||||
|
|
||||||
|
for sample_action, expected_action in (
|
||||||
|
(
|
||||||
|
np.array([0.0, 0.5], dtype=np.float32),
|
||||||
|
np.array([0.5, 2.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([-5.0, 0.0], dtype=np.float32),
|
||||||
|
np.array([0.0, 1.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([5.0, 1.0], dtype=np.float32),
|
||||||
|
np.array([1.0, 3.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert sample_action in wrapped_env.action_space
|
||||||
|
|
||||||
|
_, _, _, _, info = wrapped_env.step(sample_action)
|
||||||
|
assert np.all(info["action"] == expected_action)
|
||||||
|
@@ -1,59 +1,250 @@
|
|||||||
"""Test suite for LambdaObservationV0."""
|
"""Test suite for lambda observation wrappers: """
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.experimental.wrappers import LambdaObservationV0
|
from gymnasium.experimental.wrappers import (
|
||||||
from gymnasium.spaces import Box
|
DtypeObservationV0,
|
||||||
|
FilterObservationV0,
|
||||||
|
FlattenObservationV0,
|
||||||
|
GrayscaleObservationV0,
|
||||||
|
LambdaObservationV0,
|
||||||
|
RescaleObservationV0,
|
||||||
|
ReshapeObservationV0,
|
||||||
|
ResizeObservationV0,
|
||||||
|
)
|
||||||
|
from gymnasium.spaces import Box, Dict, Tuple
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
NUM_ENVS = 3
|
|
||||||
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
|
|
||||||
|
|
||||||
SEED = 42
|
SEED = 42
|
||||||
DISCRETE_ACTION = 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_lambda_observation_v0():
|
def _record_random_obs_reset(self: gym.Env, seed=None, options=None):
|
||||||
"""Tests lambda observation.
|
obs = self.observation_space.sample()
|
||||||
|
return obs, {"obs": obs}
|
||||||
|
|
||||||
Tests if function is correctly applied to environment's observation.
|
|
||||||
"""
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
obs, _, _, _, _ = env.step(DISCRETE_ACTION)
|
|
||||||
|
|
||||||
observation_shift = 1
|
def _record_random_obs_step(self: gym.Env, action):
|
||||||
|
obs = self.observation_space.sample()
|
||||||
|
return obs, 0, False, False, {"obs": obs}
|
||||||
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
wrapped_env = LambdaObservationV0(
|
def _record_action_obs_reset(self: gym.Env, seed=None, options: dict = {}):
|
||||||
env, lambda observation: observation + observation_shift, None
|
return options["obs"], {"obs": options["obs"]}
|
||||||
|
|
||||||
|
|
||||||
|
def _record_action_obs_step(self: gym.Env, action):
|
||||||
|
return action, 0, False, False, {"obs": action}
|
||||||
|
|
||||||
|
|
||||||
|
def _check_obs(
|
||||||
|
env: gym.Env,
|
||||||
|
wrapped_env: gym.Wrapper,
|
||||||
|
transformed_obs,
|
||||||
|
original_obs,
|
||||||
|
strict: bool = True,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
transformed_obs in wrapped_env.observation_space
|
||||||
|
), f"{transformed_obs}, {wrapped_env.observation_space}"
|
||||||
|
assert (
|
||||||
|
original_obs in env.observation_space
|
||||||
|
), f"{original_obs}, {env.observation_space}"
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
assert (
|
||||||
|
transformed_obs not in env.observation_space
|
||||||
|
), f"{transformed_obs}, {env.observation_space}"
|
||||||
|
assert (
|
||||||
|
original_obs not in wrapped_env.observation_space
|
||||||
|
), f"{original_obs}, {wrapped_env.observation_space}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lambda_observation_wrapper():
|
||||||
|
"""Tests lambda observation that the function is applied to both the reset and step observation."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
reset_func=_record_action_obs_reset, step_func=_record_action_obs_step
|
||||||
)
|
)
|
||||||
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION)
|
wrapped_env = LambdaObservationV0(env, lambda obs: obs + 2, Box(2, 3))
|
||||||
|
|
||||||
assert np.alltrue(wrapped_obs == obs + observation_shift)
|
obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)})
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(np.array([1], dtype=np.float32))
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
|
||||||
def test_lambda_observation_v0_within_vector():
|
def test_filter_observation_wrapper():
|
||||||
"""Tests lambda observation in vectorized environments.
|
"""Tests ``FilterObservation`` that the right keys are filtered."""
|
||||||
|
dict_env = GenericTestEnv(
|
||||||
Tests if function is correctly applied to environment's observation
|
observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)),
|
||||||
in vectorized environment.
|
reset_func=_record_random_obs_reset,
|
||||||
"""
|
step_func=_record_random_obs_step,
|
||||||
env = gym.vector.make(
|
|
||||||
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
|
|
||||||
)
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
obs, _, _, _, _ = env.step(np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)]))
|
|
||||||
|
|
||||||
observation_shift = 1
|
|
||||||
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
wrapped_env = LambdaObservationV0(
|
|
||||||
env, lambda observation: observation + observation_shift, None
|
|
||||||
)
|
|
||||||
wrapped_obs, _, _, _, _ = wrapped_env.step(
|
|
||||||
np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.alltrue(wrapped_obs == obs + observation_shift)
|
wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3"))
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
||||||
|
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
||||||
|
_check_obs(dict_env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
assert list(obs.keys()) == ["arm_1", "arm_3"]
|
||||||
|
assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"]
|
||||||
|
_check_obs(dict_env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
# Test tuple environments
|
||||||
|
tuple_env = GenericTestEnv(
|
||||||
|
observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))),
|
||||||
|
reset_func=_record_random_obs_reset,
|
||||||
|
step_func=_record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = FilterObservationV0(tuple_env, (2,))
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert len(obs) == 1 and len(info["obs"]) == 3
|
||||||
|
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
assert len(obs) == 1 and len(info["obs"]) == 3
|
||||||
|
_check_obs(tuple_env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_flatten_observation_wrapper():
|
||||||
|
"""Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)),
|
||||||
|
reset_func=_record_random_obs_reset,
|
||||||
|
step_func=_record_random_obs_step,
|
||||||
|
)
|
||||||
|
print(env.observation_space)
|
||||||
|
wrapped_env = FlattenObservationV0(env)
|
||||||
|
print(wrapped_env.observation_space)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_grayscale_observation_wrapper():
|
||||||
|
"""Tests the ``GrayscaleObservation`` that the observation is grayscale."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8),
|
||||||
|
reset_func=_record_random_obs_reset,
|
||||||
|
step_func=_record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = GrayscaleObservationV0(env)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (25, 25)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
# Keep_dim
|
||||||
|
wrapped_env = GrayscaleObservationV0(env, keep_dim=True)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (25, 25, 1)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_observation_wrapper():
|
||||||
|
"""Test the ``ResizeObservation`` that the observation has changed size"""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
|
||||||
|
reset_func=_record_random_obs_reset,
|
||||||
|
step_func=_record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = ResizeObservationV0(env, (25, 25))
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_reshape_observation_wrapper():
|
||||||
|
"""Test the ``ReshapeObservation`` wrapper."""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(0, 1, shape=(2, 3, 2)),
|
||||||
|
reset_func=_record_random_obs_reset,
|
||||||
|
step_func=_record_random_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = ReshapeObservationV0(env, (6, 2))
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (6, 2)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"])
|
||||||
|
assert obs.shape == (6, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescale_observation():
|
||||||
|
"""Test the ``RescaleObservation`` wrapper"""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
observation_space=Box(
|
||||||
|
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
|
||||||
|
),
|
||||||
|
reset_func=_record_action_obs_reset,
|
||||||
|
step_func=_record_action_obs_step,
|
||||||
|
)
|
||||||
|
wrapped_env = RescaleObservationV0(
|
||||||
|
env,
|
||||||
|
min_obs=np.array([-5, 0], dtype=np.float32),
|
||||||
|
max_obs=np.array([5, 1], dtype=np.float32),
|
||||||
|
)
|
||||||
|
assert wrapped_env.observation_space == Box(
|
||||||
|
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample_obs, expected_obs in (
|
||||||
|
(
|
||||||
|
np.array([0.5, 2.0], dtype=np.float32),
|
||||||
|
np.array([0.0, 0.5], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([0.0, 1.0], dtype=np.float32),
|
||||||
|
np.array([-5.0, 0.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.array([1.0, 3.0], dtype=np.float32),
|
||||||
|
np.array([5.0, 1.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert sample_obs in env.observation_space
|
||||||
|
assert expected_obs in wrapped_env.observation_space
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset(options={"obs": sample_obs})
|
||||||
|
assert np.all(obs == expected_obs)
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(sample_obs)
|
||||||
|
assert np.all(obs == expected_obs)
|
||||||
|
_check_obs(env, wrapped_env, obs, info["obs"], strict=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dtype_observation():
|
||||||
|
"""Test ``DtypeObservation`` that the"""
|
||||||
|
env = GenericTestEnv(
|
||||||
|
reset_func=_record_random_obs_reset, step_func=_record_random_obs_step
|
||||||
|
)
|
||||||
|
wrapped_env = DtypeObservationV0(env, dtype=np.uint8)
|
||||||
|
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert obs.dtype != info["obs"].dtype
|
||||||
|
assert obs.dtype == np.uint8
|
||||||
|
|
||||||
|
obs, _, _, _, info = wrapped_env.step(None)
|
||||||
|
assert obs.dtype != info["obs"].dtype
|
||||||
|
assert obs.dtype == np.uint8
|
||||||
|
@@ -55,7 +55,7 @@ def jax_step_func(self, action):
|
|||||||
|
|
||||||
|
|
||||||
def test_jax_to_numpy():
|
def test_jax_to_numpy():
|
||||||
jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
|
||||||
|
|
||||||
# Check that the reset and step for jax environment are as expected
|
# Check that the reset and step for jax environment are as expected
|
||||||
obs, info = jax_env.reset()
|
obs, info = jax_env.reset()
|
||||||
|
@@ -1,52 +0,0 @@
|
|||||||
"""Test suite for RescaleActionV0."""
|
|
||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.experimental.wrappers import RescaleActionV0
|
|
||||||
|
|
||||||
|
|
||||||
SEED = 42
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("env", "low", "high", "action", "scaled_action"),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
|
||||||
gym.make("BipedalWalker-v3"),
|
|
||||||
-0.5,
|
|
||||||
0.5,
|
|
||||||
np.array([1, 1, 1, 1]),
|
|
||||||
np.array([0.5, 0.5, 0.5, 0.5]),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
|
||||||
gym.make("BipedalWalker-v3"),
|
|
||||||
-0.5,
|
|
||||||
0.5,
|
|
||||||
jax.numpy.array([1, 1, 1, 1]),
|
|
||||||
jax.numpy.array([0.5, 0.5, 0.5, 0.5]),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
|
||||||
gym.make("BipedalWalker-v3"),
|
|
||||||
np.array([-0.5, -0.5, -1, -1], dtype=np.float32),
|
|
||||||
np.array([0.5, 0.5, 1, 1], dtype=np.float32),
|
|
||||||
jax.numpy.array([1, 1, 1, 1]),
|
|
||||||
jax.numpy.array([0.5, 0.5, 1, 1]),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_rescale_actions_v0_box(env, low, high, action, scaled_action):
|
|
||||||
"""Test action rescaling."""
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
obs, _, _, _, _ = env.step(action)
|
|
||||||
|
|
||||||
env.reset(seed=SEED)
|
|
||||||
wrapped_env = RescaleActionV0(env, low, high)
|
|
||||||
|
|
||||||
obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action)
|
|
||||||
|
|
||||||
assert np.alltrue(obs == obs_scaled)
|
|
@@ -17,7 +17,9 @@ def step_fn(self, action):
|
|||||||
|
|
||||||
|
|
||||||
def test_sticky_action():
|
def test_sticky_action():
|
||||||
env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5)
|
env = StickyActionV0(
|
||||||
|
GenericTestEnv(step_func=step_fn), repeat_action_probability=0.5
|
||||||
|
)
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
env.action_space.seed(SEED)
|
env.action_space.seed(SEED)
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ def test_sticky_action():
|
|||||||
previous_action = input_action
|
previous_action = input_action
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5])
|
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
||||||
def test_sticky_action_raise(repeat_action_probability):
|
def test_sticky_action_raise(repeat_action_probability):
|
||||||
with pytest.raises(InvalidProbability):
|
with pytest.raises(InvalidProbability):
|
||||||
StickyActionV0(
|
StickyActionV0(
|
89
tests/experimental/wrappers/test_stateful_observation.py
Normal file
89
tests/experimental/wrappers/test_stateful_observation.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Test suite for stateful observation wrappers: TimeAwareObservation, DelayObservation."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.experimental.wrappers import DelayObservationV0, TimeAwareObservationV0
|
||||||
|
from gymnasium.spaces import Box, Dict, Tuple
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
NUM_STEPS = 20
|
||||||
|
SEED = 0
|
||||||
|
|
||||||
|
DELAY = 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_aware_observation_wrapper():
|
||||||
|
"""Tests the time aware observation wrapper."""
|
||||||
|
# Test the environment observation space with Dict, Tuple and other
|
||||||
|
env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3)))
|
||||||
|
wrapped_env = TimeAwareObservationV0(env)
|
||||||
|
assert isinstance(wrapped_env.observation_space, Dict)
|
||||||
|
reset_obs, _ = wrapped_env.reset()
|
||||||
|
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||||
|
assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}"
|
||||||
|
|
||||||
|
env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3))))
|
||||||
|
wrapped_env = TimeAwareObservationV0(env)
|
||||||
|
assert isinstance(wrapped_env.observation_space, Tuple)
|
||||||
|
reset_obs, _ = wrapped_env.reset()
|
||||||
|
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||||
|
assert len(reset_obs) == 3 and len(step_obs) == 3
|
||||||
|
|
||||||
|
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||||
|
wrapped_env = TimeAwareObservationV0(env)
|
||||||
|
assert isinstance(wrapped_env.observation_space, Dict)
|
||||||
|
reset_obs, _ = wrapped_env.reset()
|
||||||
|
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||||
|
assert isinstance(reset_obs, dict) and isinstance(step_obs, dict)
|
||||||
|
assert "obs" in reset_obs and "obs" in step_obs
|
||||||
|
assert "time" in reset_obs and "time" in step_obs
|
||||||
|
|
||||||
|
# Tests the flatten parameter
|
||||||
|
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||||
|
wrapped_env = TimeAwareObservationV0(env, flatten=True)
|
||||||
|
assert isinstance(wrapped_env.observation_space, Box)
|
||||||
|
reset_obs, _ = wrapped_env.reset()
|
||||||
|
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||||
|
assert reset_obs.shape == (2,) and step_obs.shape == (2,)
|
||||||
|
|
||||||
|
# Tests the normalize_time parameter
|
||||||
|
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||||
|
wrapped_env = TimeAwareObservationV0(env, normalize_time=False)
|
||||||
|
reset_obs, _ = wrapped_env.reset()
|
||||||
|
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||||
|
assert reset_obs["time"] == 100 and step_obs["time"] == 99
|
||||||
|
|
||||||
|
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||||
|
wrapped_env = TimeAwareObservationV0(env, normalize_time=True)
|
||||||
|
reset_obs, _ = wrapped_env.reset()
|
||||||
|
step_obs, _, _, _, _ = wrapped_env.step(None)
|
||||||
|
assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_delay_observation_wrapper():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
|
undelayed_observations = []
|
||||||
|
for _ in range(NUM_STEPS):
|
||||||
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
undelayed_observations.append(obs)
|
||||||
|
|
||||||
|
env = DelayObservationV0(env, delay=DELAY)
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
|
delayed_observations = []
|
||||||
|
for i in range(NUM_STEPS):
|
||||||
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
delayed_observations.append(obs)
|
||||||
|
if i < DELAY - 1:
|
||||||
|
assert np.all(obs == 0)
|
||||||
|
|
||||||
|
undelayed_observations = np.array(undelayed_observations)
|
||||||
|
delayed_observations = np.array(delayed_observations)
|
||||||
|
|
||||||
|
assert np.all(delayed_observations[DELAY:] == undelayed_observations[:-DELAY])
|
@@ -1,99 +0,0 @@
|
|||||||
"""Test suite for TimeAwareobservationV0."""
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
|
||||||
from gymnasium.spaces import Box, Dict
|
|
||||||
|
|
||||||
|
|
||||||
NUM_STEPS = 20
|
|
||||||
SEED = 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"env",
|
|
||||||
[
|
|
||||||
gym.make("CartPole-v1", disable_env_checker=True),
|
|
||||||
gym.make("CarRacing-v2", disable_env_checker=True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_time_aware_observation_creation(env):
|
|
||||||
"""Test TimeAwareObservationV0 wrapper creation.
|
|
||||||
|
|
||||||
This test checks if wrapped env with TimeAwareObservationV0
|
|
||||||
is correctly created.
|
|
||||||
"""
|
|
||||||
wrapped_env = TimeAwareObservationV0(env)
|
|
||||||
obs, _ = wrapped_env.reset()
|
|
||||||
|
|
||||||
assert isinstance(wrapped_env.observation_space, Dict)
|
|
||||||
assert isinstance(obs, OrderedDict)
|
|
||||||
assert np.all(obs["time"] == 0)
|
|
||||||
assert env.observation_space == wrapped_env.observation_space["obs"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("normalize_time", [True, False])
|
|
||||||
@pytest.mark.parametrize("flatten", [False, True])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"env",
|
|
||||||
[
|
|
||||||
gym.make("CartPole-v1", disable_env_checker=True),
|
|
||||||
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_time_aware_observation_step(env, flatten, normalize_time):
|
|
||||||
"""Test TimeAwareObservationV0 step.
|
|
||||||
|
|
||||||
This test checks if wrapped env with TimeAwareObservationV0
|
|
||||||
steps correctly.
|
|
||||||
"""
|
|
||||||
env.action_space.seed(SEED)
|
|
||||||
max_timesteps = env._max_episode_steps
|
|
||||||
|
|
||||||
wrapped_env = TimeAwareObservationV0(
|
|
||||||
env, flatten=flatten, normalize_time=normalize_time
|
|
||||||
)
|
|
||||||
wrapped_env.reset(seed=SEED)
|
|
||||||
|
|
||||||
for timestep in range(1, NUM_STEPS):
|
|
||||||
action = env.action_space.sample()
|
|
||||||
observation, _, terminated, _, _ = wrapped_env.step(action)
|
|
||||||
|
|
||||||
expected_time_obs = (
|
|
||||||
timestep / max_timesteps if normalize_time else max_timesteps - timestep
|
|
||||||
)
|
|
||||||
|
|
||||||
if flatten:
|
|
||||||
assert np.allclose(observation[-1], expected_time_obs)
|
|
||||||
else:
|
|
||||||
assert np.allclose(observation["time"], expected_time_obs)
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"env",
|
|
||||||
[
|
|
||||||
gym.make("CartPole-v1", disable_env_checker=True),
|
|
||||||
gym.make("CarRacing-v2", disable_env_checker=True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_time_aware_observation_creation_flatten(env):
|
|
||||||
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
|
|
||||||
|
|
||||||
This test checks if wrapped env with TimeAwareObservationV0
|
|
||||||
is correctly created when the `flatten` parameter is set to `True`.
|
|
||||||
When flattened, the observation space should be a 1 dimension `Box`
|
|
||||||
with time appended to the end.
|
|
||||||
"""
|
|
||||||
wrapped_env = TimeAwareObservationV0(env, flatten=True)
|
|
||||||
obs, _ = wrapped_env.reset()
|
|
||||||
|
|
||||||
assert isinstance(wrapped_env.observation_space, Box)
|
|
||||||
assert isinstance(obs, np.ndarray)
|
|
||||||
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]
|
|
@@ -9,6 +9,7 @@ from tests.testing_env import GenericTestEnv
|
|||||||
|
|
||||||
|
|
||||||
def torch_data_equivalence(data_1, data_2) -> bool:
|
def torch_data_equivalence(data_1, data_2) -> bool:
|
||||||
|
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
|
||||||
if type(data_1) == type(data_2):
|
if type(data_1) == type(data_2):
|
||||||
if isinstance(data_1, dict):
|
if isinstance(data_1, dict):
|
||||||
return data_1.keys() == data_2.keys() and all(
|
return data_1.keys() == data_2.keys() and all(
|
||||||
@@ -56,14 +57,15 @@ def torch_data_equivalence(data_1, data_2) -> bool:
|
|||||||
)
|
)
|
||||||
def test_roundtripping(value, expected_value):
|
def test_roundtripping(value, expected_value):
|
||||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||||
assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value)
|
roundtripped_value = jax_to_torch(torch_to_jax(value))
|
||||||
|
assert torch_data_equivalence(roundtripped_value, expected_value)
|
||||||
|
|
||||||
|
|
||||||
def jax_reset_func(self, seed=None, options=None):
|
def _jax_reset_func(self, seed=None, options=None):
|
||||||
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
||||||
|
|
||||||
|
|
||||||
def jax_step_func(self, action):
|
def _jax_step_func(self, action):
|
||||||
assert isinstance(action, jnp.DeviceArray), type(action)
|
assert isinstance(action, jnp.DeviceArray), type(action)
|
||||||
return (
|
return (
|
||||||
jnp.array([1, 2, 3]),
|
jnp.array([1, 2, 3]),
|
||||||
@@ -75,7 +77,7 @@ def jax_step_func(self, action):
|
|||||||
|
|
||||||
|
|
||||||
def test_jax_to_torch():
|
def test_jax_to_torch():
|
||||||
env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
|
||||||
|
|
||||||
# Check that the reset and step for jax environment are as expected
|
# Check that the reset and step for jax environment are as expected
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
|
@@ -278,7 +278,7 @@ def test_wrapper_types():
|
|||||||
obs, _, _, _, _ = observation_env.step(0)
|
obs, _, _, _, _ = observation_env.step(0)
|
||||||
assert obs == np.array([1])
|
assert obs == np.array([1])
|
||||||
|
|
||||||
env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {}))
|
env = GenericTestEnv(step_func=lambda self, action: (action, 0, False, False, {}))
|
||||||
action_env = ExampleActionWrapper(env)
|
action_env = ExampleActionWrapper(env)
|
||||||
obs, _, _, _, _ = action_env.step(0)
|
obs, _, _, _, _ = action_env.step(0)
|
||||||
assert obs == np.array([1])
|
assert obs == np.array([1])
|
||||||
|
@@ -8,7 +8,7 @@ from gymnasium.core import ActType, ObsType
|
|||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
def basic_reset_fn(
|
def basic_reset_func(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
@@ -20,17 +20,17 @@ def basic_reset_fn(
|
|||||||
return self.observation_space.sample(), {"options": options}
|
return self.observation_space.sample(), {"options": options}
|
||||||
|
|
||||||
|
|
||||||
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
def new_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||||
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
|
"""A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
|
||||||
return self.observation_space.sample(), 0, False, False, {}
|
return self.observation_space.sample(), 0, False, False, {}
|
||||||
|
|
||||||
|
|
||||||
def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
def old_step_func(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||||
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
|
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
|
||||||
return self.observation_space.sample(), 0, False, {}
|
return self.observation_space.sample(), 0, False, {}
|
||||||
|
|
||||||
|
|
||||||
def basic_render_fn(self):
|
def basic_render_func(self):
|
||||||
"""Basic render fn that does nothing."""
|
"""Basic render fn that does nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -43,12 +43,14 @@ class GenericTestEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||||
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||||
reset_fn: callable = basic_reset_fn,
|
reset_func: callable = basic_reset_func,
|
||||||
step_fn: callable = new_step_fn,
|
step_func: callable = new_step_func,
|
||||||
render_fn: callable = basic_render_fn,
|
render_func: callable = basic_render_func,
|
||||||
metadata: Dict[str, Any] = {"render_modes": []},
|
metadata: Dict[str, Any] = {"render_modes": []},
|
||||||
render_mode: Optional[str] = None,
|
render_mode: Optional[str] = None,
|
||||||
spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"),
|
spec: EnvSpec = EnvSpec(
|
||||||
|
"TestingEnv-v0", "testing-env-no-entry-point", max_episode_steps=100
|
||||||
|
),
|
||||||
):
|
):
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
@@ -59,12 +61,12 @@ class GenericTestEnv(gym.Env):
|
|||||||
if action_space is not None:
|
if action_space is not None:
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
|
|
||||||
if reset_fn is not None:
|
if reset_func is not None:
|
||||||
self.reset = types.MethodType(reset_fn, self)
|
self.reset = types.MethodType(reset_func, self)
|
||||||
if step_fn is not None:
|
if step_func is not None:
|
||||||
self.step = types.MethodType(step_fn, self)
|
self.step = types.MethodType(step_func, self)
|
||||||
if render_fn is not None:
|
if render_func is not None:
|
||||||
self.render = types.MethodType(render_fn, self)
|
self.render = types.MethodType(render_func, self)
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
|
@@ -112,10 +112,10 @@ def test_check_reset_seed(test, func: callable, message: str):
|
|||||||
with pytest.warns(
|
with pytest.warns(
|
||||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
):
|
):
|
||||||
check_reset_seed(GenericTestEnv(reset_fn=func))
|
check_reset_seed(GenericTestEnv(reset_func=func))
|
||||||
else:
|
else:
|
||||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
check_reset_seed(GenericTestEnv(reset_fn=func))
|
check_reset_seed(GenericTestEnv(reset_func=func))
|
||||||
|
|
||||||
|
|
||||||
def _deprecated_return_info(
|
def _deprecated_return_info(
|
||||||
@@ -179,7 +179,7 @@ def test_check_reset_return_type(test, func: callable, message: str):
|
|||||||
"""Tests the check `env.reset()` function has a correct return type."""
|
"""Tests the check `env.reset()` function has a correct return type."""
|
||||||
|
|
||||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
check_reset_return_type(GenericTestEnv(reset_fn=func))
|
check_reset_return_type(GenericTestEnv(reset_func=func))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -198,7 +198,7 @@ def test_check_reset_return_info_deprecation(test, func: callable, message: str)
|
|||||||
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
|
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
|
||||||
|
|
||||||
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
|
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
|
||||||
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func))
|
check_reset_return_info_deprecation(GenericTestEnv(reset_func=func))
|
||||||
|
|
||||||
|
|
||||||
def test_check_seed_deprecation():
|
def test_check_seed_deprecation():
|
||||||
@@ -236,7 +236,7 @@ def test_check_reset_options():
|
|||||||
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
|
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {})))
|
check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {})))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@@ -303,11 +303,11 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
|
|||||||
with pytest.warns(
|
with pytest.warns(
|
||||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
):
|
):
|
||||||
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
|
||||||
assert len(caught_warnings) == 0
|
assert len(caught_warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
@@ -383,11 +383,11 @@ def test_passive_env_step_checker(
|
|||||||
with pytest.warns(
|
with pytest.warns(
|
||||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
):
|
):
|
||||||
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
|
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
|
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
|
||||||
assert len(caught_warnings) == 0, caught_warnings
|
assert len(caught_warnings) == 0, caught_warnings
|
||||||
|
|
||||||
|
|
||||||
@@ -416,7 +416,7 @@ def test_passive_env_step_checker(
|
|||||||
GenericTestEnv(
|
GenericTestEnv(
|
||||||
metadata={"render_modes": ["Testing mode"], "render_fps": None},
|
metadata={"render_modes": ["Testing mode"], "render_fps": None},
|
||||||
render_mode="Testing mode",
|
render_mode="Testing mode",
|
||||||
render_fn=lambda self: 0,
|
render_func=lambda self: 0,
|
||||||
),
|
),
|
||||||
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
|
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
|
||||||
],
|
],
|
||||||
|
@@ -21,7 +21,7 @@ IRRELEVANT_KEY = 1
|
|||||||
PlayableEnv = partial(
|
PlayableEnv = partial(
|
||||||
GenericTestEnv,
|
GenericTestEnv,
|
||||||
metadata={"render_modes": ["rgb_array"]},
|
metadata={"render_modes": ["rgb_array"]},
|
||||||
render_fn=lambda self: np.ones((10, 10, 3)),
|
render_func=lambda self: np.ones((10, 10, 3)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -82,8 +82,8 @@ def test_final_obs_info(vectoriser):
|
|||||||
return GenericTestEnv(
|
return GenericTestEnv(
|
||||||
action_space=Discrete(4),
|
action_space=Discrete(4),
|
||||||
observation_space=Discrete(4),
|
observation_space=Discrete(4),
|
||||||
reset_fn=reset_fn,
|
reset_func=reset_fn,
|
||||||
step_fn=lambda self, action: (
|
step_func=lambda self, action: (
|
||||||
action if action < 3 else 0,
|
action if action < 3 else 0,
|
||||||
0,
|
0,
|
||||||
action >= 3,
|
action >= 3,
|
||||||
|
@@ -3,7 +3,7 @@ import pytest
|
|||||||
|
|
||||||
from gymnasium.spaces import Box, Discrete
|
from gymnasium.spaces import Box, Discrete
|
||||||
from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility
|
from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility
|
||||||
from tests.testing_env import GenericTestEnv, old_step_fn
|
from tests.testing_env import GenericTestEnv, old_step_func
|
||||||
|
|
||||||
|
|
||||||
class AleTesting:
|
class AleTesting:
|
||||||
@@ -34,7 +34,7 @@ class AtariTestingEnv(GenericTestEnv):
|
|||||||
low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1
|
low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1
|
||||||
),
|
),
|
||||||
action_space=Discrete(3, seed=1),
|
action_space=Discrete(3, seed=1),
|
||||||
step_fn=old_step_fn,
|
step_func=old_step_func,
|
||||||
)
|
)
|
||||||
self.ale = AleTesting()
|
self.ale = AleTesting()
|
||||||
|
|
||||||
|
@@ -68,8 +68,8 @@ def _step_failure(self, action):
|
|||||||
|
|
||||||
def test_api_failures():
|
def test_api_failures():
|
||||||
env = GenericTestEnv(
|
env = GenericTestEnv(
|
||||||
reset_fn=_reset_failure,
|
reset_func=_reset_failure,
|
||||||
step_fn=_step_failure,
|
step_func=_step_failure,
|
||||||
metadata={"render_modes": "error"},
|
metadata={"render_modes": "error"},
|
||||||
)
|
)
|
||||||
env = PassiveEnvChecker(env)
|
env = PassiveEnvChecker(env)
|
||||||
|
Reference in New Issue
Block a user