mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
Move dev_wrappers and functional to experimental (#159)
This commit is contained in:
@@ -15,7 +15,7 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
args:
|
||||
- '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402'
|
||||
- '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402 gymnasium/experimental/wrappers/__init__.py:E402'
|
||||
- --ignore=E203,W503,E741
|
||||
- --max-complexity=30
|
||||
- --max-line-length=456
|
||||
|
214
docs/api/experimental.md
Normal file
214
docs/api/experimental.md
Normal file
@@ -0,0 +1,214 @@
|
||||
---
|
||||
title: Experimental
|
||||
---
|
||||
|
||||
# Experimental
|
||||
|
||||
```{toctree}
|
||||
:hidden:
|
||||
experimental/functional
|
||||
experimental/wrappers
|
||||
experimental/vector
|
||||
experimental/vector_wrappers
|
||||
```
|
||||
|
||||
## Functional Environments
|
||||
|
||||
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
|
||||
|
||||
## Wrappers
|
||||
|
||||
Gymnasium already contains a large collection of wrappers, but we believe that the wrappers can be improved to
|
||||
|
||||
* Support arbitrarily complex observation / action spaces. As RL has advanced, action and observation spaces are becoming more complex and the current wrappers were not implemented with these spaces in mind.
|
||||
* Support for numpy, jax and pytorch data. With hardware accelerated environments, i.e. Brax, written in Jax and similar pytorch based programs, numpy is not the only game in town anymore. Therefore, these upgrades will use Jumpy for calling numpy, jax and torch depending on the data.
|
||||
* More wrappers. Projects like Supersuit aimed to bring more wrappers for RL however wrappers can be moved into Gymnasium.
|
||||
* Versioning. Like environments, the implementation details of wrapper can cause changes agent performance. Therefore, we propose adding version numbers with all wrappers.
|
||||
|
||||
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
|
||||
|
||||
### Lambda Observation Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
- Vector version
|
||||
- Tree structure
|
||||
* - :class:`wrappers.TransformObservation`
|
||||
- :class:`experimental.wrappers.LambdaObservationV0`
|
||||
- VectorLambdaObservation
|
||||
- No
|
||||
* - :class:`wrappers.FilterObservation`
|
||||
- FilterObservation
|
||||
- VectorFilterObservation (*)
|
||||
- Yes
|
||||
* - :class:`wrappers.FlattenObservation`
|
||||
- FlattenObservation
|
||||
- VectorFlattenObservation (*)
|
||||
- No
|
||||
* - :class:`wrappers.GrayScaleObservation`
|
||||
- GrayscaleObservation
|
||||
- VectorGrayscaleObservation (*)
|
||||
- Yes
|
||||
* - :class:`wrappers.PixelObservationWrapper`
|
||||
- PixelObservation
|
||||
- VectorPixelObservation (*)
|
||||
- No
|
||||
* - :class:`wrappers.ResizeObservation`
|
||||
- ResizeObservation
|
||||
- VectorResizeObservation (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- ReshapeObservation
|
||||
- VectorReshapeObservation (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- RescaleObservation
|
||||
- VectorRescaleObservation (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- DtypeObservation
|
||||
- VectorDtypeObservation (*)
|
||||
- Yes
|
||||
* - :class:`NormalizeObservation`
|
||||
- NormalizeObservation
|
||||
- VectorNormalizeObservation
|
||||
- No
|
||||
* - :class:`TimeAwareObservation`
|
||||
- TimeAwareObservation
|
||||
- VectorTimeAwareObservation
|
||||
- No
|
||||
* - :class:`FrameStack`
|
||||
- FrameStackObservation
|
||||
- VectorFrameStackObservation
|
||||
- No
|
||||
* - Not Implemented
|
||||
- DelayObservation
|
||||
- VectorDelayObservation
|
||||
- No
|
||||
* - :class:`AtariPreprocessing`
|
||||
- AtariPreprocessing
|
||||
- Not Implemented
|
||||
- No
|
||||
```
|
||||
|
||||
### Lambda Action Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
- Vector version
|
||||
- Tree structure
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.LambdaActionV0`
|
||||
- VectorLambdaAction
|
||||
- No
|
||||
* - :class:`wrappers.ClipAction`
|
||||
- ClipAction
|
||||
- VectorClipAction (*)
|
||||
- Yes
|
||||
* - :class:`wrappers.RescaleAction`
|
||||
- RescaleAction
|
||||
- VectorRescaleAction (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- NanAction
|
||||
- VectorNanAction (*)
|
||||
- Yes
|
||||
* - Not Implemented
|
||||
- StickyAction
|
||||
- VectorStickyAction
|
||||
- No
|
||||
```
|
||||
|
||||
### Lambda Reward Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
- Vector version
|
||||
* - :class:`wrappers.TransformReward`
|
||||
- :class:`experimental.wrappers.LambdaRewardV0`
|
||||
- VectorLambdaReward
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.ClipRewardV0`
|
||||
- VectorClipReward (*)
|
||||
* - Not Implemented
|
||||
- RescaleReward
|
||||
- VectorRescaleReward (*)
|
||||
* - :class:`wrappers.NormalizeReward`
|
||||
- NormalizeReward
|
||||
- VectorNormalizeReward
|
||||
```
|
||||
|
||||
### Common Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
- Vector version
|
||||
* - :class:`wrappers.AutoResetWrapper`
|
||||
- AutoReset
|
||||
- VectorAutoReset
|
||||
* - :class:`wrappers.PassiveEnvChecker`
|
||||
- PassiveEnvChecker
|
||||
- VectorPassiveEnvChecker
|
||||
* - :class:`wrappers.OrderEnforcing`
|
||||
- OrderEnforcing
|
||||
- VectorOrderEnforcing (*)
|
||||
* - :class:`wrappers.EnvCompatibility`
|
||||
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
||||
- Not Implemented
|
||||
* - :class:`RecordEpisodeStatistics`
|
||||
- RecordEpisodeStatistics
|
||||
- VectorRecordEpisodeStatistics
|
||||
* - :class:`RenderCollection`
|
||||
- RenderCollection
|
||||
- VectorRenderCollection
|
||||
* - :class:`HumanRendering`
|
||||
- HumanRendering
|
||||
- Not Implemented
|
||||
* - Not Implemented
|
||||
- JaxToNumpy
|
||||
- VectorJaxToNumpy
|
||||
* - Not Implemented
|
||||
- JaxToTorch
|
||||
- VectorJaxToTorch
|
||||
```
|
||||
|
||||
### Vector Only Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
* - :class:`wrappers.VectorListInfo`
|
||||
- VectorListInfo
|
||||
```
|
||||
|
||||
## Vector Environment
|
||||
|
||||
These changes will be made in v0.28
|
||||
|
||||
## Wrappers for Vector Environments
|
||||
|
||||
These changes will be made in v0.28
|
36
docs/api/experimental/functional.md
Normal file
36
docs/api/experimental/functional.md
Normal file
@@ -0,0 +1,36 @@
|
||||
---
|
||||
title: Functional
|
||||
---
|
||||
|
||||
# Functional Environment
|
||||
|
||||
## gymnasium.experimental.FuncEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.FuncEnv
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.initial
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.transition
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.observation
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.initial
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.observation
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.reward
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.terminal
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.state_info
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.step_info
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.transform
|
||||
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.render_image
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.render_init
|
||||
.. autofunction:: gymnasium.experimental.FuncEnv.render_close
|
||||
```
|
||||
|
||||
## gymnasium.experimental.func2env.FunctionalJaxCompatibilityEnv
|
||||
|
||||
```{eval-rst}
|
||||
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
|
||||
```
|
15
docs/api/experimental/vector.md
Normal file
15
docs/api/experimental/vector.md
Normal file
@@ -0,0 +1,15 @@
|
||||
---
|
||||
title: Vector
|
||||
---
|
||||
|
||||
# Vectorizing Environment
|
||||
|
||||
## gymnasium.experimental.VectorEnv
|
||||
|
||||
## gymnasium.experimental.vector.AsyncVectorEnv
|
||||
|
||||
## gymnasium.experimental.vector.SyncVectorEnv
|
||||
|
||||
## Custom Vector environments
|
||||
|
||||
## EnvPool
|
15
docs/api/experimental/vector_wrappers.md
Normal file
15
docs/api/experimental/vector_wrappers.md
Normal file
@@ -0,0 +1,15 @@
|
||||
---
|
||||
title: Vector Wrappers
|
||||
---
|
||||
|
||||
# Vector Environment Wrappers
|
||||
|
||||
## Vector Lambda Observation Wrappers
|
||||
|
||||
## Vector Lambda Action Wrappers
|
||||
|
||||
## Vector Lambda Reward Wrappers
|
||||
|
||||
## Vector Common Wrappers
|
||||
|
||||
## Vector Only Wrappers
|
26
docs/api/experimental/wrappers.md
Normal file
26
docs/api/experimental/wrappers.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Wrappers
|
||||
|
||||
## Lambda Observation Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
|
||||
```
|
||||
|
||||
## Lambda Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
|
||||
```
|
||||
|
||||
## Lambda Reward Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||
```
|
||||
|
||||
## Common Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
```
|
@@ -48,6 +48,7 @@ api/spaces
|
||||
api/wrappers
|
||||
api/vector
|
||||
api/utils
|
||||
api/experimental
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
|
@@ -10,10 +10,8 @@ from gymnasium.core import (
|
||||
)
|
||||
from gymnasium.spaces.space import Space
|
||||
from gymnasium.envs.registration import make, spec, register, registry, pprint_registry
|
||||
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
|
||||
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__all__ = [
|
||||
# core classes
|
||||
@@ -37,6 +35,7 @@ __all__ = [
|
||||
"wrappers",
|
||||
"error",
|
||||
"logger",
|
||||
"experimental",
|
||||
]
|
||||
__version__ = "0.26.3"
|
||||
|
||||
@@ -45,6 +44,9 @@ __version__ = "0.26.3"
|
||||
# pygame
|
||||
# DSP is far more benign (and should probably be the default in SDL anyways)
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
if sys.platform.startswith("linux"):
|
||||
os.environ["SDL_AUDIODRIVER"] = "dsp"
|
||||
|
||||
|
@@ -1,4 +0,0 @@
|
||||
"""Root __init__ of the gym dev_wrappers."""
|
||||
from typing import TypeVar
|
||||
|
||||
ArgType = TypeVar("ArgType")
|
@@ -1,2 +1,2 @@
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleF
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumF
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
||||
|
@@ -10,41 +10,42 @@ import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.phys2d.conversion import JaxEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.func_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle
|
||||
|
||||
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
|
||||
|
||||
|
||||
class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
|
||||
class CartPoleFunctional(
|
||||
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
|
||||
):
|
||||
"""Cartpole but in jax and functional.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
|
||||
env = CartPole({"x_init": 0.5})
|
||||
state = env.initial(key)
|
||||
print(state)
|
||||
print(env.step(state, 0))
|
||||
>>> key = jax.random.PRNGKey(0)
|
||||
|
||||
env.transform(jax.jit)
|
||||
>>> env = CartPole({"x_init": 0.5})
|
||||
>>> state = env.initial(key)
|
||||
>>> print(state)
|
||||
>>> print(env.step(state, 0))
|
||||
|
||||
state = env.initial(key)
|
||||
print(state)
|
||||
print(env.step(state, 0))
|
||||
>>> env.transform(jax.jit)
|
||||
|
||||
vkey = jax.random.split(key, 10)
|
||||
env.transform(jax.vmap)
|
||||
vstate = env.initial(vkey)
|
||||
print(vstate)
|
||||
print(env.step(vstate, jnp.array([0 for _ in range(10)])))
|
||||
```
|
||||
>>> state = env.initial(key)
|
||||
>>> print(state)
|
||||
>>> print(env.step(state, 0))
|
||||
|
||||
>>> vkey = jax.random.split(key, 10)
|
||||
>>> env.transform(jax.vmap)
|
||||
>>> vstate = env.initial(vkey)
|
||||
>>> print(vstate)
|
||||
>>> print(env.step(vstate, jnp.array([0 for _ in range(10)])))
|
||||
"""
|
||||
|
||||
gravity = 9.8
|
||||
@@ -232,13 +233,13 @@ class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateT
|
||||
pygame.quit()
|
||||
|
||||
|
||||
class CartPoleJaxEnv(JaxEnv, EzPickle):
|
||||
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
|
||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||
env = CartPoleF(**kwargs)
|
||||
env = CartPoleFunctional(**kwargs)
|
||||
env.transform(jax.jit)
|
||||
action_space = env.action_space
|
||||
observation_space = env.observation_space
|
||||
|
@@ -10,15 +10,17 @@ import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.phys2d.conversion import JaxEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.func_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle
|
||||
|
||||
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
|
||||
|
||||
|
||||
class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
|
||||
class PendulumFunctional(
|
||||
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
|
||||
):
|
||||
"""Pendulum but in jax and functional."""
|
||||
|
||||
max_speed = 8
|
||||
@@ -180,13 +182,13 @@ class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateT
|
||||
pygame.quit()
|
||||
|
||||
|
||||
class PendulumJaxEnv(JaxEnv, EzPickle):
|
||||
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
||||
|
||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||
env = PendulumF(**kwargs)
|
||||
env = PendulumFunctional(**kwargs)
|
||||
env.transform(jax.jit)
|
||||
action_space = env.action_space
|
||||
observation_space = env.observation_space
|
||||
|
12
gymnasium/experimental/__init__.py
Normal file
12
gymnasium/experimental/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Root __init__ of the gym dev_wrappers."""
|
||||
|
||||
|
||||
from gymnasium.experimental.functional import FuncEnv
|
||||
|
||||
__all__ = [
|
||||
# Functional
|
||||
"FuncEnv",
|
||||
"functional",
|
||||
# Wrapper
|
||||
"wrappers",
|
||||
]
|
@@ -1,4 +1,7 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
"""Functional to Environment compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
@@ -7,14 +10,12 @@ import numpy as np
|
||||
import gymnasium as gym
|
||||
from gymnasium import Space
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
|
||||
class JaxEnv(gym.Env):
|
||||
"""
|
||||
A conversion layer for numpy-based environments.
|
||||
"""
|
||||
class FunctionalJaxEnv(gym.Env):
|
||||
"""A conversion layer for jax-based environments."""
|
||||
|
||||
state: StateType
|
||||
rng: jrng.PRNGKey
|
||||
@@ -24,20 +25,24 @@ class JaxEnv(gym.Env):
|
||||
func_env: FuncEnv,
|
||||
observation_space: Space,
|
||||
action_space: Space,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
render_mode: Optional[str] = None,
|
||||
reward_range: Tuple[float, float] = (-float("inf"), float("inf")),
|
||||
spec: Optional[EnvSpec] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
render_mode: str | None = None,
|
||||
reward_range: tuple[float, float] = (-float("inf"), float("inf")),
|
||||
spec: EnvSpec | None = None,
|
||||
):
|
||||
"""Initialize the environment from a FuncEnv."""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata = {"render_mode": []}
|
||||
|
||||
self.func_env = func_env
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
self.metadata = metadata
|
||||
self.render_mode = render_mode
|
||||
self.reward_range = reward_range
|
||||
|
||||
self.spec = spec
|
||||
|
||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||
@@ -52,7 +57,8 @@ class JaxEnv(gym.Env):
|
||||
|
||||
self.rng = jrng.PRNGKey(seed)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(self, *, seed: int | None = None, options: dict | None = None):
|
||||
"""Resets the environment using the seed."""
|
||||
super().reset(seed=seed)
|
||||
if seed is not None:
|
||||
self.rng = jrng.PRNGKey(seed)
|
||||
@@ -68,6 +74,7 @@ class JaxEnv(gym.Env):
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment using the action."""
|
||||
if self._is_box_action_space:
|
||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
@@ -90,6 +97,7 @@ class JaxEnv(gym.Env):
|
||||
return observation, float(reward), bool(terminated), False, info
|
||||
|
||||
def render(self):
|
||||
"""Returns the render state if `render_mode` is "rgb_array"."""
|
||||
if self.render_mode == "rgb_array":
|
||||
self.render_state, image = self.func_env.render_image(
|
||||
self.state, self.render_state
|
||||
@@ -99,15 +107,16 @@ class JaxEnv(gym.Env):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
"""Closes the environments and render state if set."""
|
||||
if self.render_state is not None:
|
||||
self.func_env.render_close(self.render_state)
|
||||
self.render_state = None
|
||||
|
||||
|
||||
def _convert_jax_to_numpy(element: Any):
|
||||
"""
|
||||
Convert a jax observation/action to a numpy array, or a numpy-based container.
|
||||
Currently required because all tests assume that stuff is in numpy arrays, hopefully will be removed soon.
|
||||
"""Convert a jax observation/action to a numpy array, or a numpy-based container.
|
||||
|
||||
Requires as all tests assume that data is in numpy arrays, to be removed soon.
|
||||
"""
|
||||
if isinstance(element, jnp.ndarray):
|
||||
return np.asarray(element)
|
@@ -1,6 +1,7 @@
|
||||
"""Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -35,7 +36,7 @@ class FuncEnv(
|
||||
we intend to flesh it out and officially expose it to end users.
|
||||
"""
|
||||
|
||||
def __init__(self, options: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, options: dict[str, Any] | None = None):
|
||||
"""Initialize the environment constants."""
|
||||
self.__dict__.update(options or {})
|
||||
|
||||
@@ -43,14 +44,14 @@ class FuncEnv(
|
||||
"""Initial state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def observation(self, state: StateType) -> ObsType:
|
||||
"""Observation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
|
||||
"""Transition."""
|
||||
raise NotImplementedError
|
||||
|
||||
def observation(self, state: StateType) -> ObsType:
|
||||
"""Observation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reward(
|
||||
self, state: StateType, action: ActType, next_state: StateType
|
||||
) -> RewardType:
|
||||
@@ -83,7 +84,7 @@ class FuncEnv(
|
||||
|
||||
def render_image(
|
||||
self, state: StateType, render_state: RenderStateType
|
||||
) -> Tuple[RenderStateType, np.ndarray]:
|
||||
) -> tuple[RenderStateType, np.ndarray]:
|
||||
"""Show the state."""
|
||||
raise NotImplementedError
|
||||
|
21
gymnasium/experimental/wrappers/__init__.py
Normal file
21
gymnasium/experimental/wrappers/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Experimental Wrappers."""
|
||||
# isort: skip_file
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
ArgType = TypeVar("ArgType")
|
||||
|
||||
from gymnasium.experimental.wrappers.lambda_action import LambdaActionV0
|
||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||
|
||||
__all__ = [
|
||||
"ArgType",
|
||||
# Lambda Action
|
||||
"LambdaActionV0",
|
||||
# Lambda Observation
|
||||
"LambdaObservationV0",
|
||||
# Lambda Reward
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
]
|
@@ -4,7 +4,7 @@ from typing import Any, Callable
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType
|
||||
from gymnasium.dev_wrappers import ArgType
|
||||
from gymnasium.experimental.wrappers import ArgType
|
||||
|
||||
|
||||
class LambdaActionV0(gym.ActionWrapper):
|
@@ -4,10 +4,10 @@ from typing import Any, Callable
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.dev_wrappers import ArgType
|
||||
from gymnasium.experimental.wrappers import ArgType
|
||||
|
||||
|
||||
class LambdaObservationsV0(gym.ObservationWrapper):
|
||||
class LambdaObservationV0(gym.ObservationWrapper):
|
||||
"""Lambda observation wrapper where a function is provided that is applied to the observation."""
|
||||
|
||||
def __init__(
|
@@ -5,8 +5,8 @@ from typing import Any, Callable, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.dev_wrappers import ArgType
|
||||
from gymnasium.error import InvalidBound
|
||||
from gymnasium.experimental.wrappers import ArgType
|
||||
|
||||
|
||||
class LambdaRewardV0(gym.RewardWrapper):
|
||||
@@ -14,7 +14,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import LambdaRewardV0
|
||||
>>> from gymnasium.experimental.wrappers import LambdaRewardV0
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = LambdaRewardV0(env, lambda r: 2 * r + 1)
|
||||
>>> _ = env.reset()
|
||||
@@ -47,14 +47,14 @@ class LambdaRewardV0(gym.RewardWrapper):
|
||||
return self.func(reward)
|
||||
|
||||
|
||||
class ClipRewardsV0(LambdaRewardV0):
|
||||
class ClipRewardV0(LambdaRewardV0):
|
||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
|
||||
|
||||
Example with an upper and lower bound:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import ClipRewardsV0
|
||||
>>> from gymnasium.experimental.wrappers import ClipRewardV0
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = ClipRewardsV0(env, 0, 0.5)
|
||||
>>> env = ClipRewardV0(env, 0, 0.5)
|
||||
>>> env.reset()
|
||||
>>> _, rew, _, _, _ = env.step(1)
|
||||
>>> rew
|
@@ -1,8 +1,4 @@
|
||||
"""Module of wrapper classes."""
|
||||
from gymnasium import error
|
||||
from gymnasium.dev_wrappers.lambda_action import LambdaActionV0
|
||||
from gymnasium.dev_wrappers.lambda_observations import LambdaObservationsV0
|
||||
from gymnasium.dev_wrappers.lambda_reward import ClipRewardsV0, LambdaRewardV0
|
||||
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
|
||||
from gymnasium.wrappers.autoreset import AutoResetWrapper
|
||||
from gymnasium.wrappers.clip_action import ClipAction
|
||||
|
@@ -2,10 +2,10 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.functional import FuncEnv
|
||||
from gymnasium.experimental import FuncEnv
|
||||
|
||||
|
||||
class TestEnv(FuncEnv):
|
||||
class GenericTestFuncEnv(FuncEnv):
|
||||
def __init__(self, options: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(options)
|
||||
|
||||
@@ -26,7 +26,7 @@ class TestEnv(FuncEnv):
|
||||
|
||||
|
||||
def test_api():
|
||||
env = TestEnv()
|
||||
env = GenericTestFuncEnv()
|
||||
state = env.initial(None)
|
||||
obs = env.observation(state)
|
||||
assert state.shape == (2,)
|
@@ -4,11 +4,11 @@ import jax.random as jrng
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleF # noqa: E402
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumF # noqa: E402
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||
def test_normal(env_class):
|
||||
env = env_class()
|
||||
rng = jrng.PRNGKey(0)
|
||||
@@ -40,7 +40,7 @@ def test_normal(env_class):
|
||||
state = next_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||
def test_jit(env_class):
|
||||
env = env_class()
|
||||
rng = jrng.PRNGKey(0)
|
||||
@@ -73,7 +73,7 @@ def test_jit(env_class):
|
||||
state = next_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||
def test_vmap(env_class):
|
||||
env = env_class()
|
||||
num_envs = 10
|
@@ -4,8 +4,8 @@ import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.error import InvalidAction
|
||||
from gymnasium.experimental.wrappers import LambdaActionV0
|
||||
from gymnasium.spaces import Box
|
||||
from gymnasium.wrappers import LambdaActionV0
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
NUM_ENVS = 3
|
@@ -3,8 +3,8 @@
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import LambdaObservationV0
|
||||
from gymnasium.spaces import Box
|
||||
from gymnasium.wrappers import LambdaObservationsV0
|
||||
|
||||
NUM_ENVS = 3
|
||||
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
|
||||
@@ -25,7 +25,7 @@ def test_lambda_observation_v0():
|
||||
observation_shift = 1
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = LambdaObservationsV0(
|
||||
wrapped_env = LambdaObservationV0(
|
||||
env, lambda observation: observation + observation_shift
|
||||
)
|
||||
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION)
|
||||
@@ -48,7 +48,7 @@ def test_lambda_observation_v0_within_vector():
|
||||
observation_shift = 1
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = LambdaObservationsV0(
|
||||
wrapped_env = LambdaObservationV0(
|
||||
env, lambda observation: observation + observation_shift
|
||||
)
|
||||
wrapped_obs, _, _, _, _ = wrapped_env.step(
|
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.error import InvalidBound
|
||||
from gymnasium.wrappers import ClipRewardsV0, LambdaRewardV0
|
||||
from gymnasium.experimental.wrappers import ClipRewardV0, LambdaRewardV0
|
||||
|
||||
ENV_ID = "CartPole-v1"
|
||||
DISCRETE_ACTION = 0
|
||||
@@ -65,7 +65,7 @@ def test_clip_reward(lower_bound, upper_bound, expected_reward):
|
||||
accordingly to the input args.
|
||||
"""
|
||||
env = gym.make(ENV_ID)
|
||||
env = ClipRewardsV0(env, lower_bound, upper_bound)
|
||||
env = ClipRewardV0(env, lower_bound, upper_bound)
|
||||
env.reset(seed=SEED)
|
||||
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
|
||||
|
||||
@@ -84,7 +84,7 @@ def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward):
|
||||
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
|
||||
|
||||
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
|
||||
env = ClipRewardsV0(env, lower_bound, upper_bound)
|
||||
env = ClipRewardV0(env, lower_bound, upper_bound)
|
||||
env.reset(seed=SEED)
|
||||
|
||||
_, rew, _, _, _ = env.step(actions)
|
||||
@@ -106,4 +106,4 @@ def test_clip_reward_incorrect_params(lower_bound, upper_bound):
|
||||
env = gym.make(ENV_ID)
|
||||
|
||||
with pytest.raises(InvalidBound):
|
||||
env = ClipRewardsV0(env, lower_bound, upper_bound)
|
||||
env = ClipRewardV0(env, lower_bound, upper_bound)
|
Reference in New Issue
Block a user