Move dev_wrappers and functional to experimental (#159)

This commit is contained in:
Mark Towers
2022-11-29 23:37:53 +00:00
committed by GitHub
parent ae75ad2e44
commit f24fa1426c
30 changed files with 435 additions and 88 deletions

View File

@@ -15,7 +15,7 @@ repos:
hooks: hooks:
- id: flake8 - id: flake8
args: args:
- '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402' - '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402 gymnasium/experimental/wrappers/__init__.py:E402'
- --ignore=E203,W503,E741 - --ignore=E203,W503,E741
- --max-complexity=30 - --max-complexity=30
- --max-line-length=456 - --max-line-length=456

214
docs/api/experimental.md Normal file
View File

@@ -0,0 +1,214 @@
---
title: Experimental
---
# Experimental
```{toctree}
:hidden:
experimental/functional
experimental/wrappers
experimental/vector
experimental/vector_wrappers
```
## Functional Environments
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
## Wrappers
Gymnasium already contains a large collection of wrappers, but we believe that the wrappers can be improved to
* Support arbitrarily complex observation / action spaces. As RL has advanced, action and observation spaces are becoming more complex and the current wrappers were not implemented with these spaces in mind.
* Support for numpy, jax and pytorch data. With hardware accelerated environments, i.e. Brax, written in Jax and similar pytorch based programs, numpy is not the only game in town anymore. Therefore, these upgrades will use Jumpy for calling numpy, jax and torch depending on the data.
* More wrappers. Projects like Supersuit aimed to bring more wrappers for RL however wrappers can be moved into Gymnasium.
* Versioning. Like environments, the implementation details of wrapper can cause changes agent performance. Therefore, we propose adding version numbers with all wrappers.
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
### Lambda Observation Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
- Tree structure
* - :class:`wrappers.TransformObservation`
- :class:`experimental.wrappers.LambdaObservationV0`
- VectorLambdaObservation
- No
* - :class:`wrappers.FilterObservation`
- FilterObservation
- VectorFilterObservation (*)
- Yes
* - :class:`wrappers.FlattenObservation`
- FlattenObservation
- VectorFlattenObservation (*)
- No
* - :class:`wrappers.GrayScaleObservation`
- GrayscaleObservation
- VectorGrayscaleObservation (*)
- Yes
* - :class:`wrappers.PixelObservationWrapper`
- PixelObservation
- VectorPixelObservation (*)
- No
* - :class:`wrappers.ResizeObservation`
- ResizeObservation
- VectorResizeObservation (*)
- Yes
* - Not Implemented
- ReshapeObservation
- VectorReshapeObservation (*)
- Yes
* - Not Implemented
- RescaleObservation
- VectorRescaleObservation (*)
- Yes
* - Not Implemented
- DtypeObservation
- VectorDtypeObservation (*)
- Yes
* - :class:`NormalizeObservation`
- NormalizeObservation
- VectorNormalizeObservation
- No
* - :class:`TimeAwareObservation`
- TimeAwareObservation
- VectorTimeAwareObservation
- No
* - :class:`FrameStack`
- FrameStackObservation
- VectorFrameStackObservation
- No
* - Not Implemented
- DelayObservation
- VectorDelayObservation
- No
* - :class:`AtariPreprocessing`
- AtariPreprocessing
- Not Implemented
- No
```
### Lambda Action Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
- Tree structure
* - Not Implemented
- :class:`experimental.wrappers.LambdaActionV0`
- VectorLambdaAction
- No
* - :class:`wrappers.ClipAction`
- ClipAction
- VectorClipAction (*)
- Yes
* - :class:`wrappers.RescaleAction`
- RescaleAction
- VectorRescaleAction (*)
- Yes
* - Not Implemented
- NanAction
- VectorNanAction (*)
- Yes
* - Not Implemented
- StickyAction
- VectorStickyAction
- No
```
### Lambda Reward Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
* - :class:`wrappers.TransformReward`
- :class:`experimental.wrappers.LambdaRewardV0`
- VectorLambdaReward
* - Not Implemented
- :class:`experimental.wrappers.ClipRewardV0`
- VectorClipReward (*)
* - Not Implemented
- RescaleReward
- VectorRescaleReward (*)
* - :class:`wrappers.NormalizeReward`
- NormalizeReward
- VectorNormalizeReward
```
### Common Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
* - :class:`wrappers.AutoResetWrapper`
- AutoReset
- VectorAutoReset
* - :class:`wrappers.PassiveEnvChecker`
- PassiveEnvChecker
- VectorPassiveEnvChecker
* - :class:`wrappers.OrderEnforcing`
- OrderEnforcing
- VectorOrderEnforcing (*)
* - :class:`wrappers.EnvCompatibility`
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
- Not Implemented
* - :class:`RecordEpisodeStatistics`
- RecordEpisodeStatistics
- VectorRecordEpisodeStatistics
* - :class:`RenderCollection`
- RenderCollection
- VectorRenderCollection
* - :class:`HumanRendering`
- HumanRendering
- Not Implemented
* - Not Implemented
- JaxToNumpy
- VectorJaxToNumpy
* - Not Implemented
- JaxToTorch
- VectorJaxToTorch
```
### Vector Only Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
* - :class:`wrappers.VectorListInfo`
- VectorListInfo
```
## Vector Environment
These changes will be made in v0.28
## Wrappers for Vector Environments
These changes will be made in v0.28

View File

@@ -0,0 +1,36 @@
---
title: Functional
---
# Functional Environment
## gymnasium.experimental.FuncEnv
```{eval-rst}
.. autoclass:: gymnasium.experimental.FuncEnv
.. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.transition
.. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.reward
.. autofunction:: gymnasium.experimental.FuncEnv.terminal
.. autofunction:: gymnasium.experimental.FuncEnv.state_info
.. autofunction:: gymnasium.experimental.FuncEnv.step_info
.. autofunction:: gymnasium.experimental.FuncEnv.transform
.. autofunction:: gymnasium.experimental.FuncEnv.render_image
.. autofunction:: gymnasium.experimental.FuncEnv.render_init
.. autofunction:: gymnasium.experimental.FuncEnv.render_close
```
## gymnasium.experimental.func2env.FunctionalJaxCompatibilityEnv
```{eval-rst}
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
```

View File

@@ -0,0 +1,15 @@
---
title: Vector
---
# Vectorizing Environment
## gymnasium.experimental.VectorEnv
## gymnasium.experimental.vector.AsyncVectorEnv
## gymnasium.experimental.vector.SyncVectorEnv
## Custom Vector environments
## EnvPool

View File

@@ -0,0 +1,15 @@
---
title: Vector Wrappers
---
# Vector Environment Wrappers
## Vector Lambda Observation Wrappers
## Vector Lambda Action Wrappers
## Vector Lambda Reward Wrappers
## Vector Common Wrappers
## Vector Only Wrappers

View File

@@ -0,0 +1,26 @@
# Wrappers
## Lambda Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
```
## Lambda Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
```
## Lambda Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
```
## Common Wrappers
```{eval-rst}
```

View File

@@ -48,6 +48,7 @@ api/spaces
api/wrappers api/wrappers
api/vector api/vector
api/utils api/utils
api/experimental
``` ```
```{toctree} ```{toctree}

View File

@@ -10,10 +10,8 @@ from gymnasium.core import (
) )
from gymnasium.spaces.space import Space from gymnasium.spaces.space import Space
from gymnasium.envs.registration import make, spec, register, registry, pprint_registry from gymnasium.envs.registration import make, spec, register, registry, pprint_registry
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental
import os
import sys
__all__ = [ __all__ = [
# core classes # core classes
@@ -37,6 +35,7 @@ __all__ = [
"wrappers", "wrappers",
"error", "error",
"logger", "logger",
"experimental",
] ]
__version__ = "0.26.3" __version__ = "0.26.3"
@@ -45,6 +44,9 @@ __version__ = "0.26.3"
# pygame # pygame
# DSP is far more benign (and should probably be the default in SDL anyways) # DSP is far more benign (and should probably be the default in SDL anyways)
import os
import sys
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
os.environ["SDL_AUDIODRIVER"] = "dsp" os.environ["SDL_AUDIODRIVER"] = "dsp"

View File

@@ -1,4 +0,0 @@
"""Root __init__ of the gym dev_wrappers."""
from typing import TypeVar
ArgType = TypeVar("ArgType")

View File

@@ -1,2 +1,2 @@
from gymnasium.envs.phys2d.cartpole import CartPoleF from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumF from gymnasium.envs.phys2d.pendulum import PendulumFunctional

View File

@@ -10,41 +10,42 @@ import numpy as np
from jax.random import PRNGKey from jax.random import PRNGKey
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.phys2d.conversion import JaxEnv
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.experimental.func_jax_env import FunctionalJaxEnv
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821 RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]): class CartPoleFunctional(
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
):
"""Cartpole but in jax and functional. """Cartpole but in jax and functional.
Example usage: Example usage:
```
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0) >>> import jax
>>> import jax.numpy as jnp
env = CartPole({"x_init": 0.5}) >>> key = jax.random.PRNGKey(0)
state = env.initial(key)
print(state)
print(env.step(state, 0))
env.transform(jax.jit) >>> env = CartPole({"x_init": 0.5})
>>> state = env.initial(key)
>>> print(state)
>>> print(env.step(state, 0))
state = env.initial(key) >>> env.transform(jax.jit)
print(state)
print(env.step(state, 0))
vkey = jax.random.split(key, 10) >>> state = env.initial(key)
env.transform(jax.vmap) >>> print(state)
vstate = env.initial(vkey) >>> print(env.step(state, 0))
print(vstate)
print(env.step(vstate, jnp.array([0 for _ in range(10)]))) >>> vkey = jax.random.split(key, 10)
``` >>> env.transform(jax.vmap)
>>> vstate = env.initial(vkey)
>>> print(vstate)
>>> print(env.step(vstate, jnp.array([0 for _ in range(10)])))
""" """
gravity = 9.8 gravity = 9.8
@@ -232,13 +233,13 @@ class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateT
pygame.quit() pygame.quit()
class CartPoleJaxEnv(JaxEnv, EzPickle): class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
metadata = {"render_modes": ["rgb_array"], "render_fps": 50} metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(self, render_mode: Optional[str] = None, **kwargs): def __init__(self, render_mode: Optional[str] = None, **kwargs):
EzPickle.__init__(self, render_mode=render_mode, **kwargs) EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = CartPoleF(**kwargs) env = CartPoleFunctional(**kwargs)
env.transform(jax.jit) env.transform(jax.jit)
action_space = env.action_space action_space = env.action_space
observation_space = env.observation_space observation_space = env.observation_space

View File

@@ -10,15 +10,17 @@ import numpy as np
from jax.random import PRNGKey from jax.random import PRNGKey
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.phys2d.conversion import JaxEnv
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.experimental.func_jax_env import FunctionalJaxEnv
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821 RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]): class PendulumFunctional(
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
):
"""Pendulum but in jax and functional.""" """Pendulum but in jax and functional."""
max_speed = 8 max_speed = 8
@@ -180,13 +182,13 @@ class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateT
pygame.quit() pygame.quit()
class PendulumJaxEnv(JaxEnv, EzPickle): class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
metadata = {"render_modes": ["rgb_array"], "render_fps": 30} metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
def __init__(self, render_mode: Optional[str] = None, **kwargs): def __init__(self, render_mode: Optional[str] = None, **kwargs):
EzPickle.__init__(self, render_mode=render_mode, **kwargs) EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = PendulumF(**kwargs) env = PendulumFunctional(**kwargs)
env.transform(jax.jit) env.transform(jax.jit)
action_space = env.action_space action_space = env.action_space
observation_space = env.observation_space observation_space = env.observation_space

View File

@@ -0,0 +1,12 @@
"""Root __init__ of the gym dev_wrappers."""
from gymnasium.experimental.functional import FuncEnv
__all__ = [
# Functional
"FuncEnv",
"functional",
# Wrapper
"wrappers",
]

View File

@@ -1,4 +1,7 @@
from typing import Any, Dict, Optional, Tuple """Functional to Environment compatibility."""
from __future__ import annotations
from typing import Any
import jax.numpy as jnp import jax.numpy as jnp
import jax.random as jrng import jax.random as jrng
@@ -7,14 +10,12 @@ import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium import Space from gymnasium import Space
from gymnasium.envs.registration import EnvSpec from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding from gymnasium.utils import seeding
class JaxEnv(gym.Env): class FunctionalJaxEnv(gym.Env):
""" """A conversion layer for jax-based environments."""
A conversion layer for numpy-based environments.
"""
state: StateType state: StateType
rng: jrng.PRNGKey rng: jrng.PRNGKey
@@ -24,20 +25,24 @@ class JaxEnv(gym.Env):
func_env: FuncEnv, func_env: FuncEnv,
observation_space: Space, observation_space: Space,
action_space: Space, action_space: Space,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
render_mode: Optional[str] = None, render_mode: str | None = None,
reward_range: Tuple[float, float] = (-float("inf"), float("inf")), reward_range: tuple[float, float] = (-float("inf"), float("inf")),
spec: Optional[EnvSpec] = None, spec: EnvSpec | None = None,
): ):
"""Initialize the environment from a FuncEnv.""" """Initialize the environment from a FuncEnv."""
if metadata is None: if metadata is None:
metadata = {} metadata = {"render_mode": []}
self.func_env = func_env self.func_env = func_env
self.observation_space = observation_space self.observation_space = observation_space
self.action_space = action_space self.action_space = action_space
self.metadata = metadata self.metadata = metadata
self.render_mode = render_mode self.render_mode = render_mode
self.reward_range = reward_range self.reward_range = reward_range
self.spec = spec self.spec = spec
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
@@ -52,7 +57,8 @@ class JaxEnv(gym.Env):
self.rng = jrng.PRNGKey(seed) self.rng = jrng.PRNGKey(seed)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: int | None = None, options: dict | None = None):
"""Resets the environment using the seed."""
super().reset(seed=seed) super().reset(seed=seed)
if seed is not None: if seed is not None:
self.rng = jrng.PRNGKey(seed) self.rng = jrng.PRNGKey(seed)
@@ -68,6 +74,7 @@ class JaxEnv(gym.Env):
return obs, info return obs, info
def step(self, action: ActType): def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space: if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high) action = np.clip(action, self.action_space.low, self.action_space.high)
@@ -90,6 +97,7 @@ class JaxEnv(gym.Env):
return observation, float(reward), bool(terminated), False, info return observation, float(reward), bool(terminated), False, info
def render(self): def render(self):
"""Returns the render state if `render_mode` is "rgb_array"."""
if self.render_mode == "rgb_array": if self.render_mode == "rgb_array":
self.render_state, image = self.func_env.render_image( self.render_state, image = self.func_env.render_image(
self.state, self.render_state self.state, self.render_state
@@ -99,15 +107,16 @@ class JaxEnv(gym.Env):
raise NotImplementedError raise NotImplementedError
def close(self): def close(self):
"""Closes the environments and render state if set."""
if self.render_state is not None: if self.render_state is not None:
self.func_env.render_close(self.render_state) self.func_env.render_close(self.render_state)
self.render_state = None self.render_state = None
def _convert_jax_to_numpy(element: Any): def _convert_jax_to_numpy(element: Any):
""" """Convert a jax observation/action to a numpy array, or a numpy-based container.
Convert a jax observation/action to a numpy array, or a numpy-based container.
Currently required because all tests assume that stuff is in numpy arrays, hopefully will be removed soon. Requires as all tests assume that data is in numpy arrays, to be removed soon.
""" """
if isinstance(element, jnp.ndarray): if isinstance(element, jnp.ndarray):
return np.asarray(element) return np.asarray(element)

View File

@@ -1,6 +1,7 @@
"""Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments.""" """Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments."""
from __future__ import annotations
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar from typing import Any, Callable, Generic, TypeVar
import numpy as np import numpy as np
@@ -35,7 +36,7 @@ class FuncEnv(
we intend to flesh it out and officially expose it to end users. we intend to flesh it out and officially expose it to end users.
""" """
def __init__(self, options: Optional[Dict[str, Any]] = None): def __init__(self, options: dict[str, Any] | None = None):
"""Initialize the environment constants.""" """Initialize the environment constants."""
self.__dict__.update(options or {}) self.__dict__.update(options or {})
@@ -43,14 +44,14 @@ class FuncEnv(
"""Initial state.""" """Initial state."""
raise NotImplementedError raise NotImplementedError
def observation(self, state: StateType) -> ObsType:
"""Observation."""
raise NotImplementedError
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType: def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
"""Transition.""" """Transition."""
raise NotImplementedError raise NotImplementedError
def observation(self, state: StateType) -> ObsType:
"""Observation."""
raise NotImplementedError
def reward( def reward(
self, state: StateType, action: ActType, next_state: StateType self, state: StateType, action: ActType, next_state: StateType
) -> RewardType: ) -> RewardType:
@@ -83,7 +84,7 @@ class FuncEnv(
def render_image( def render_image(
self, state: StateType, render_state: RenderStateType self, state: StateType, render_state: RenderStateType
) -> Tuple[RenderStateType, np.ndarray]: ) -> tuple[RenderStateType, np.ndarray]:
"""Show the state.""" """Show the state."""
raise NotImplementedError raise NotImplementedError

View File

@@ -0,0 +1,21 @@
"""Experimental Wrappers."""
# isort: skip_file
from typing import TypeVar
ArgType = TypeVar("ArgType")
from gymnasium.experimental.wrappers.lambda_action import LambdaActionV0
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
__all__ = [
"ArgType",
# Lambda Action
"LambdaActionV0",
# Lambda Observation
"LambdaObservationV0",
# Lambda Reward
"LambdaRewardV0",
"ClipRewardV0",
]

View File

@@ -4,7 +4,7 @@ from typing import Any, Callable
import gymnasium as gym import gymnasium as gym
from gymnasium.core import ActType from gymnasium.core import ActType
from gymnasium.dev_wrappers import ArgType from gymnasium.experimental.wrappers import ArgType
class LambdaActionV0(gym.ActionWrapper): class LambdaActionV0(gym.ActionWrapper):

View File

@@ -4,10 +4,10 @@ from typing import Any, Callable
import gymnasium as gym import gymnasium as gym
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium.dev_wrappers import ArgType from gymnasium.experimental.wrappers import ArgType
class LambdaObservationsV0(gym.ObservationWrapper): class LambdaObservationV0(gym.ObservationWrapper):
"""Lambda observation wrapper where a function is provided that is applied to the observation.""" """Lambda observation wrapper where a function is provided that is applied to the observation."""
def __init__( def __init__(

View File

@@ -5,8 +5,8 @@ from typing import Any, Callable, Optional, Union
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.dev_wrappers import ArgType
from gymnasium.error import InvalidBound from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers import ArgType
class LambdaRewardV0(gym.RewardWrapper): class LambdaRewardV0(gym.RewardWrapper):
@@ -14,7 +14,7 @@ class LambdaRewardV0(gym.RewardWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.wrappers import LambdaRewardV0 >>> from gymnasium.experimental.wrappers import LambdaRewardV0
>>> env = gym.make("CartPole-v1") >>> env = gym.make("CartPole-v1")
>>> env = LambdaRewardV0(env, lambda r: 2 * r + 1) >>> env = LambdaRewardV0(env, lambda r: 2 * r + 1)
>>> _ = env.reset() >>> _ = env.reset()
@@ -47,14 +47,14 @@ class LambdaRewardV0(gym.RewardWrapper):
return self.func(reward) return self.func(reward)
class ClipRewardsV0(LambdaRewardV0): class ClipRewardV0(LambdaRewardV0):
"""A wrapper that clips the rewards for an environment between an upper and lower bound. """A wrapper that clips the rewards for an environment between an upper and lower bound.
Example with an upper and lower bound: Example with an upper and lower bound:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.wrappers import ClipRewardsV0 >>> from gymnasium.experimental.wrappers import ClipRewardV0
>>> env = gym.make("CartPole-v1") >>> env = gym.make("CartPole-v1")
>>> env = ClipRewardsV0(env, 0, 0.5) >>> env = ClipRewardV0(env, 0, 0.5)
>>> env.reset() >>> env.reset()
>>> _, rew, _, _, _ = env.step(1) >>> _, rew, _, _, _ = env.step(1)
>>> rew >>> rew

View File

@@ -1,8 +1,4 @@
"""Module of wrapper classes.""" """Module of wrapper classes."""
from gymnasium import error
from gymnasium.dev_wrappers.lambda_action import LambdaActionV0
from gymnasium.dev_wrappers.lambda_observations import LambdaObservationsV0
from gymnasium.dev_wrappers.lambda_reward import ClipRewardsV0, LambdaRewardV0
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
from gymnasium.wrappers.autoreset import AutoResetWrapper from gymnasium.wrappers.autoreset import AutoResetWrapper
from gymnasium.wrappers.clip_action import ClipAction from gymnasium.wrappers.clip_action import ClipAction

View File

@@ -2,10 +2,10 @@ from typing import Any, Dict, Optional
import numpy as np import numpy as np
from gymnasium.functional import FuncEnv from gymnasium.experimental import FuncEnv
class TestEnv(FuncEnv): class GenericTestFuncEnv(FuncEnv):
def __init__(self, options: Optional[Dict[str, Any]] = None): def __init__(self, options: Optional[Dict[str, Any]] = None):
super().__init__(options) super().__init__(options)
@@ -26,7 +26,7 @@ class TestEnv(FuncEnv):
def test_api(): def test_api():
env = TestEnv() env = GenericTestFuncEnv()
state = env.initial(None) state = env.initial(None)
obs = env.observation(state) obs = env.observation(state)
assert state.shape == (2,) assert state.shape == (2,)

View File

@@ -4,11 +4,11 @@ import jax.random as jrng
import numpy as np import numpy as np
import pytest import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleF # noqa: E402 from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumF # noqa: E402 from gymnasium.envs.phys2d.pendulum import PendulumFunctional
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF]) @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_normal(env_class): def test_normal(env_class):
env = env_class() env = env_class()
rng = jrng.PRNGKey(0) rng = jrng.PRNGKey(0)
@@ -40,7 +40,7 @@ def test_normal(env_class):
state = next_state state = next_state
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF]) @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_jit(env_class): def test_jit(env_class):
env = env_class() env = env_class()
rng = jrng.PRNGKey(0) rng = jrng.PRNGKey(0)
@@ -73,7 +73,7 @@ def test_jit(env_class):
state = next_state state = next_state
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF]) @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_vmap(env_class): def test_vmap(env_class):
env = env_class() env = env_class()
num_envs = 10 num_envs = 10

View File

@@ -4,8 +4,8 @@ import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium.error import InvalidAction from gymnasium.error import InvalidAction
from gymnasium.experimental.wrappers import LambdaActionV0
from gymnasium.spaces import Box from gymnasium.spaces import Box
from gymnasium.wrappers import LambdaActionV0
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
NUM_ENVS = 3 NUM_ENVS = 3

View File

@@ -3,8 +3,8 @@
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.experimental.wrappers import LambdaObservationV0
from gymnasium.spaces import Box from gymnasium.spaces import Box
from gymnasium.wrappers import LambdaObservationsV0
NUM_ENVS = 3 NUM_ENVS = 3
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64) BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
@@ -25,7 +25,7 @@ def test_lambda_observation_v0():
observation_shift = 1 observation_shift = 1
env.reset(seed=SEED) env.reset(seed=SEED)
wrapped_env = LambdaObservationsV0( wrapped_env = LambdaObservationV0(
env, lambda observation: observation + observation_shift env, lambda observation: observation + observation_shift
) )
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION) wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION)
@@ -48,7 +48,7 @@ def test_lambda_observation_v0_within_vector():
observation_shift = 1 observation_shift = 1
env.reset(seed=SEED) env.reset(seed=SEED)
wrapped_env = LambdaObservationsV0( wrapped_env = LambdaObservationV0(
env, lambda observation: observation + observation_shift env, lambda observation: observation + observation_shift
) )
wrapped_obs, _, _, _, _ = wrapped_env.step( wrapped_obs, _, _, _, _ = wrapped_env.step(

View File

@@ -5,7 +5,7 @@ import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium.error import InvalidBound from gymnasium.error import InvalidBound
from gymnasium.wrappers import ClipRewardsV0, LambdaRewardV0 from gymnasium.experimental.wrappers import ClipRewardV0, LambdaRewardV0
ENV_ID = "CartPole-v1" ENV_ID = "CartPole-v1"
DISCRETE_ACTION = 0 DISCRETE_ACTION = 0
@@ -65,7 +65,7 @@ def test_clip_reward(lower_bound, upper_bound, expected_reward):
accordingly to the input args. accordingly to the input args.
""" """
env = gym.make(ENV_ID) env = gym.make(ENV_ID)
env = ClipRewardsV0(env, lower_bound, upper_bound) env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED) env.reset(seed=SEED)
_, rew, _, _, _ = env.step(DISCRETE_ACTION) _, rew, _, _, _ = env.step(DISCRETE_ACTION)
@@ -84,7 +84,7 @@ def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward):
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)] actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS) env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
env = ClipRewardsV0(env, lower_bound, upper_bound) env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED) env.reset(seed=SEED)
_, rew, _, _, _ = env.step(actions) _, rew, _, _, _ = env.step(actions)
@@ -106,4 +106,4 @@ def test_clip_reward_incorrect_params(lower_bound, upper_bound):
env = gym.make(ENV_ID) env = gym.make(ENV_ID)
with pytest.raises(InvalidBound): with pytest.raises(InvalidBound):
env = ClipRewardsV0(env, lower_bound, upper_bound) env = ClipRewardV0(env, lower_bound, upper_bound)