Move dev_wrappers and functional to experimental (#159)

This commit is contained in:
Mark Towers
2022-11-29 23:37:53 +00:00
committed by GitHub
parent ae75ad2e44
commit f24fa1426c
30 changed files with 435 additions and 88 deletions

View File

@@ -15,7 +15,7 @@ repos:
hooks:
- id: flake8
args:
- '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402'
- '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402 gymnasium/experimental/wrappers/__init__.py:E402'
- --ignore=E203,W503,E741
- --max-complexity=30
- --max-line-length=456

214
docs/api/experimental.md Normal file
View File

@@ -0,0 +1,214 @@
---
title: Experimental
---
# Experimental
```{toctree}
:hidden:
experimental/functional
experimental/wrappers
experimental/vector
experimental/vector_wrappers
```
## Functional Environments
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
## Wrappers
Gymnasium already contains a large collection of wrappers, but we believe that the wrappers can be improved to
* Support arbitrarily complex observation / action spaces. As RL has advanced, action and observation spaces are becoming more complex and the current wrappers were not implemented with these spaces in mind.
* Support for numpy, jax and pytorch data. With hardware accelerated environments, i.e. Brax, written in Jax and similar pytorch based programs, numpy is not the only game in town anymore. Therefore, these upgrades will use Jumpy for calling numpy, jax and torch depending on the data.
* More wrappers. Projects like Supersuit aimed to bring more wrappers for RL however wrappers can be moved into Gymnasium.
* Versioning. Like environments, the implementation details of wrapper can cause changes agent performance. Therefore, we propose adding version numbers with all wrappers.
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided.
### Lambda Observation Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
- Tree structure
* - :class:`wrappers.TransformObservation`
- :class:`experimental.wrappers.LambdaObservationV0`
- VectorLambdaObservation
- No
* - :class:`wrappers.FilterObservation`
- FilterObservation
- VectorFilterObservation (*)
- Yes
* - :class:`wrappers.FlattenObservation`
- FlattenObservation
- VectorFlattenObservation (*)
- No
* - :class:`wrappers.GrayScaleObservation`
- GrayscaleObservation
- VectorGrayscaleObservation (*)
- Yes
* - :class:`wrappers.PixelObservationWrapper`
- PixelObservation
- VectorPixelObservation (*)
- No
* - :class:`wrappers.ResizeObservation`
- ResizeObservation
- VectorResizeObservation (*)
- Yes
* - Not Implemented
- ReshapeObservation
- VectorReshapeObservation (*)
- Yes
* - Not Implemented
- RescaleObservation
- VectorRescaleObservation (*)
- Yes
* - Not Implemented
- DtypeObservation
- VectorDtypeObservation (*)
- Yes
* - :class:`NormalizeObservation`
- NormalizeObservation
- VectorNormalizeObservation
- No
* - :class:`TimeAwareObservation`
- TimeAwareObservation
- VectorTimeAwareObservation
- No
* - :class:`FrameStack`
- FrameStackObservation
- VectorFrameStackObservation
- No
* - Not Implemented
- DelayObservation
- VectorDelayObservation
- No
* - :class:`AtariPreprocessing`
- AtariPreprocessing
- Not Implemented
- No
```
### Lambda Action Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
- Tree structure
* - Not Implemented
- :class:`experimental.wrappers.LambdaActionV0`
- VectorLambdaAction
- No
* - :class:`wrappers.ClipAction`
- ClipAction
- VectorClipAction (*)
- Yes
* - :class:`wrappers.RescaleAction`
- RescaleAction
- VectorRescaleAction (*)
- Yes
* - Not Implemented
- NanAction
- VectorNanAction (*)
- Yes
* - Not Implemented
- StickyAction
- VectorStickyAction
- No
```
### Lambda Reward Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
* - :class:`wrappers.TransformReward`
- :class:`experimental.wrappers.LambdaRewardV0`
- VectorLambdaReward
* - Not Implemented
- :class:`experimental.wrappers.ClipRewardV0`
- VectorClipReward (*)
* - Not Implemented
- RescaleReward
- VectorRescaleReward (*)
* - :class:`wrappers.NormalizeReward`
- NormalizeReward
- VectorNormalizeReward
```
### Common Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
- Vector version
* - :class:`wrappers.AutoResetWrapper`
- AutoReset
- VectorAutoReset
* - :class:`wrappers.PassiveEnvChecker`
- PassiveEnvChecker
- VectorPassiveEnvChecker
* - :class:`wrappers.OrderEnforcing`
- OrderEnforcing
- VectorOrderEnforcing (*)
* - :class:`wrappers.EnvCompatibility`
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
- Not Implemented
* - :class:`RecordEpisodeStatistics`
- RecordEpisodeStatistics
- VectorRecordEpisodeStatistics
* - :class:`RenderCollection`
- RenderCollection
- VectorRenderCollection
* - :class:`HumanRendering`
- HumanRendering
- Not Implemented
* - Not Implemented
- JaxToNumpy
- VectorJaxToNumpy
* - Not Implemented
- JaxToTorch
- VectorJaxToTorch
```
### Vector Only Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
* - :class:`wrappers.VectorListInfo`
- VectorListInfo
```
## Vector Environment
These changes will be made in v0.28
## Wrappers for Vector Environments
These changes will be made in v0.28

View File

@@ -0,0 +1,36 @@
---
title: Functional
---
# Functional Environment
## gymnasium.experimental.FuncEnv
```{eval-rst}
.. autoclass:: gymnasium.experimental.FuncEnv
.. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.transition
.. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.initial
.. autofunction:: gymnasium.experimental.FuncEnv.observation
.. autofunction:: gymnasium.experimental.FuncEnv.reward
.. autofunction:: gymnasium.experimental.FuncEnv.terminal
.. autofunction:: gymnasium.experimental.FuncEnv.state_info
.. autofunction:: gymnasium.experimental.FuncEnv.step_info
.. autofunction:: gymnasium.experimental.FuncEnv.transform
.. autofunction:: gymnasium.experimental.FuncEnv.render_image
.. autofunction:: gymnasium.experimental.FuncEnv.render_init
.. autofunction:: gymnasium.experimental.FuncEnv.render_close
```
## gymnasium.experimental.func2env.FunctionalJaxCompatibilityEnv
```{eval-rst}
... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv
```

View File

@@ -0,0 +1,15 @@
---
title: Vector
---
# Vectorizing Environment
## gymnasium.experimental.VectorEnv
## gymnasium.experimental.vector.AsyncVectorEnv
## gymnasium.experimental.vector.SyncVectorEnv
## Custom Vector environments
## EnvPool

View File

@@ -0,0 +1,15 @@
---
title: Vector Wrappers
---
# Vector Environment Wrappers
## Vector Lambda Observation Wrappers
## Vector Lambda Action Wrappers
## Vector Lambda Reward Wrappers
## Vector Common Wrappers
## Vector Only Wrappers

View File

@@ -0,0 +1,26 @@
# Wrappers
## Lambda Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
```
## Lambda Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
```
## Lambda Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
```
## Common Wrappers
```{eval-rst}
```

View File

@@ -48,6 +48,7 @@ api/spaces
api/wrappers
api/vector
api/utils
api/experimental
```
```{toctree}

View File

@@ -10,10 +10,8 @@ from gymnasium.core import (
)
from gymnasium.spaces.space import Space
from gymnasium.envs.registration import make, spec, register, registry, pprint_registry
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental
import os
import sys
__all__ = [
# core classes
@@ -37,6 +35,7 @@ __all__ = [
"wrappers",
"error",
"logger",
"experimental",
]
__version__ = "0.26.3"
@@ -45,6 +44,9 @@ __version__ = "0.26.3"
# pygame
# DSP is far more benign (and should probably be the default in SDL anyways)
import os
import sys
if sys.platform.startswith("linux"):
os.environ["SDL_AUDIODRIVER"] = "dsp"

View File

@@ -1,4 +0,0 @@
"""Root __init__ of the gym dev_wrappers."""
from typing import TypeVar
ArgType = TypeVar("ArgType")

View File

@@ -1,2 +1,2 @@
from gymnasium.envs.phys2d.cartpole import CartPoleF
from gymnasium.envs.phys2d.pendulum import PendulumF
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumFunctional

View File

@@ -10,41 +10,42 @@ import numpy as np
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.envs.phys2d.conversion import JaxEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.func_jax_env import FunctionalJaxEnv
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
class CartPoleFunctional(
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
):
"""Cartpole but in jax and functional.
Example usage:
```
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
>>> import jax
>>> import jax.numpy as jnp
env = CartPole({"x_init": 0.5})
state = env.initial(key)
print(state)
print(env.step(state, 0))
>>> key = jax.random.PRNGKey(0)
env.transform(jax.jit)
>>> env = CartPole({"x_init": 0.5})
>>> state = env.initial(key)
>>> print(state)
>>> print(env.step(state, 0))
state = env.initial(key)
print(state)
print(env.step(state, 0))
>>> env.transform(jax.jit)
vkey = jax.random.split(key, 10)
env.transform(jax.vmap)
vstate = env.initial(vkey)
print(vstate)
print(env.step(vstate, jnp.array([0 for _ in range(10)])))
```
>>> state = env.initial(key)
>>> print(state)
>>> print(env.step(state, 0))
>>> vkey = jax.random.split(key, 10)
>>> env.transform(jax.vmap)
>>> vstate = env.initial(vkey)
>>> print(vstate)
>>> print(env.step(vstate, jnp.array([0 for _ in range(10)])))
"""
gravity = 9.8
@@ -232,13 +233,13 @@ class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateT
pygame.quit()
class CartPoleJaxEnv(JaxEnv, EzPickle):
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = CartPoleF(**kwargs)
env = CartPoleFunctional(**kwargs)
env.transform(jax.jit)
action_space = env.action_space
observation_space = env.observation_space

View File

@@ -10,15 +10,17 @@ import numpy as np
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.envs.phys2d.conversion import JaxEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.func_jax_env import FunctionalJaxEnv
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
class PendulumFunctional(
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
):
"""Pendulum but in jax and functional."""
max_speed = 8
@@ -180,13 +182,13 @@ class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateT
pygame.quit()
class PendulumJaxEnv(JaxEnv, EzPickle):
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = PendulumF(**kwargs)
env = PendulumFunctional(**kwargs)
env.transform(jax.jit)
action_space = env.action_space
observation_space = env.observation_space

View File

@@ -0,0 +1,12 @@
"""Root __init__ of the gym dev_wrappers."""
from gymnasium.experimental.functional import FuncEnv
__all__ = [
# Functional
"FuncEnv",
"functional",
# Wrapper
"wrappers",
]

View File

@@ -1,4 +1,7 @@
from typing import Any, Dict, Optional, Tuple
"""Functional to Environment compatibility."""
from __future__ import annotations
from typing import Any
import jax.numpy as jnp
import jax.random as jrng
@@ -7,14 +10,12 @@ import numpy as np
import gymnasium as gym
from gymnasium import Space
from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
class JaxEnv(gym.Env):
"""
A conversion layer for numpy-based environments.
"""
class FunctionalJaxEnv(gym.Env):
"""A conversion layer for jax-based environments."""
state: StateType
rng: jrng.PRNGKey
@@ -24,20 +25,24 @@ class JaxEnv(gym.Env):
func_env: FuncEnv,
observation_space: Space,
action_space: Space,
metadata: Optional[Dict[str, Any]] = None,
render_mode: Optional[str] = None,
reward_range: Tuple[float, float] = (-float("inf"), float("inf")),
spec: Optional[EnvSpec] = None,
metadata: dict[str, Any] | None = None,
render_mode: str | None = None,
reward_range: tuple[float, float] = (-float("inf"), float("inf")),
spec: EnvSpec | None = None,
):
"""Initialize the environment from a FuncEnv."""
if metadata is None:
metadata = {}
metadata = {"render_mode": []}
self.func_env = func_env
self.observation_space = observation_space
self.action_space = action_space
self.metadata = metadata
self.render_mode = render_mode
self.reward_range = reward_range
self.spec = spec
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
@@ -52,7 +57,8 @@ class JaxEnv(gym.Env):
self.rng = jrng.PRNGKey(seed)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(self, *, seed: int | None = None, options: dict | None = None):
"""Resets the environment using the seed."""
super().reset(seed=seed)
if seed is not None:
self.rng = jrng.PRNGKey(seed)
@@ -68,6 +74,7 @@ class JaxEnv(gym.Env):
return obs, info
def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
@@ -90,6 +97,7 @@ class JaxEnv(gym.Env):
return observation, float(reward), bool(terminated), False, info
def render(self):
"""Returns the render state if `render_mode` is "rgb_array"."""
if self.render_mode == "rgb_array":
self.render_state, image = self.func_env.render_image(
self.state, self.render_state
@@ -99,15 +107,16 @@ class JaxEnv(gym.Env):
raise NotImplementedError
def close(self):
"""Closes the environments and render state if set."""
if self.render_state is not None:
self.func_env.render_close(self.render_state)
self.render_state = None
def _convert_jax_to_numpy(element: Any):
"""
Convert a jax observation/action to a numpy array, or a numpy-based container.
Currently required because all tests assume that stuff is in numpy arrays, hopefully will be removed soon.
"""Convert a jax observation/action to a numpy array, or a numpy-based container.
Requires as all tests assume that data is in numpy arrays, to be removed soon.
"""
if isinstance(element, jnp.ndarray):
return np.asarray(element)

View File

@@ -1,6 +1,7 @@
"""Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments."""
from __future__ import annotations
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
from typing import Any, Callable, Generic, TypeVar
import numpy as np
@@ -35,7 +36,7 @@ class FuncEnv(
we intend to flesh it out and officially expose it to end users.
"""
def __init__(self, options: Optional[Dict[str, Any]] = None):
def __init__(self, options: dict[str, Any] | None = None):
"""Initialize the environment constants."""
self.__dict__.update(options or {})
@@ -43,14 +44,14 @@ class FuncEnv(
"""Initial state."""
raise NotImplementedError
def observation(self, state: StateType) -> ObsType:
"""Observation."""
raise NotImplementedError
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
"""Transition."""
raise NotImplementedError
def observation(self, state: StateType) -> ObsType:
"""Observation."""
raise NotImplementedError
def reward(
self, state: StateType, action: ActType, next_state: StateType
) -> RewardType:
@@ -83,7 +84,7 @@ class FuncEnv(
def render_image(
self, state: StateType, render_state: RenderStateType
) -> Tuple[RenderStateType, np.ndarray]:
) -> tuple[RenderStateType, np.ndarray]:
"""Show the state."""
raise NotImplementedError

View File

@@ -0,0 +1,21 @@
"""Experimental Wrappers."""
# isort: skip_file
from typing import TypeVar
ArgType = TypeVar("ArgType")
from gymnasium.experimental.wrappers.lambda_action import LambdaActionV0
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
__all__ = [
"ArgType",
# Lambda Action
"LambdaActionV0",
# Lambda Observation
"LambdaObservationV0",
# Lambda Reward
"LambdaRewardV0",
"ClipRewardV0",
]

View File

@@ -4,7 +4,7 @@ from typing import Any, Callable
import gymnasium as gym
from gymnasium.core import ActType
from gymnasium.dev_wrappers import ArgType
from gymnasium.experimental.wrappers import ArgType
class LambdaActionV0(gym.ActionWrapper):

View File

@@ -4,10 +4,10 @@ from typing import Any, Callable
import gymnasium as gym
from gymnasium.core import ObsType
from gymnasium.dev_wrappers import ArgType
from gymnasium.experimental.wrappers import ArgType
class LambdaObservationsV0(gym.ObservationWrapper):
class LambdaObservationV0(gym.ObservationWrapper):
"""Lambda observation wrapper where a function is provided that is applied to the observation."""
def __init__(

View File

@@ -5,8 +5,8 @@ from typing import Any, Callable, Optional, Union
import numpy as np
import gymnasium as gym
from gymnasium.dev_wrappers import ArgType
from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers import ArgType
class LambdaRewardV0(gym.RewardWrapper):
@@ -14,7 +14,7 @@ class LambdaRewardV0(gym.RewardWrapper):
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import LambdaRewardV0
>>> from gymnasium.experimental.wrappers import LambdaRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = LambdaRewardV0(env, lambda r: 2 * r + 1)
>>> _ = env.reset()
@@ -47,14 +47,14 @@ class LambdaRewardV0(gym.RewardWrapper):
return self.func(reward)
class ClipRewardsV0(LambdaRewardV0):
class ClipRewardV0(LambdaRewardV0):
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
Example with an upper and lower bound:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import ClipRewardsV0
>>> from gymnasium.experimental.wrappers import ClipRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = ClipRewardsV0(env, 0, 0.5)
>>> env = ClipRewardV0(env, 0, 0.5)
>>> env.reset()
>>> _, rew, _, _, _ = env.step(1)
>>> rew

View File

@@ -1,8 +1,4 @@
"""Module of wrapper classes."""
from gymnasium import error
from gymnasium.dev_wrappers.lambda_action import LambdaActionV0
from gymnasium.dev_wrappers.lambda_observations import LambdaObservationsV0
from gymnasium.dev_wrappers.lambda_reward import ClipRewardsV0, LambdaRewardV0
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
from gymnasium.wrappers.autoreset import AutoResetWrapper
from gymnasium.wrappers.clip_action import ClipAction

View File

@@ -2,10 +2,10 @@ from typing import Any, Dict, Optional
import numpy as np
from gymnasium.functional import FuncEnv
from gymnasium.experimental import FuncEnv
class TestEnv(FuncEnv):
class GenericTestFuncEnv(FuncEnv):
def __init__(self, options: Optional[Dict[str, Any]] = None):
super().__init__(options)
@@ -26,7 +26,7 @@ class TestEnv(FuncEnv):
def test_api():
env = TestEnv()
env = GenericTestFuncEnv()
state = env.initial(None)
obs = env.observation(state)
assert state.shape == (2,)

View File

@@ -4,11 +4,11 @@ import jax.random as jrng
import numpy as np
import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleF # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumF # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_normal(env_class):
env = env_class()
rng = jrng.PRNGKey(0)
@@ -40,7 +40,7 @@ def test_normal(env_class):
state = next_state
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_jit(env_class):
env = env_class()
rng = jrng.PRNGKey(0)
@@ -73,7 +73,7 @@ def test_jit(env_class):
state = next_state
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_vmap(env_class):
env = env_class()
num_envs = 10

View File

@@ -4,8 +4,8 @@ import pytest
import gymnasium as gym
from gymnasium.error import InvalidAction
from gymnasium.experimental.wrappers import LambdaActionV0
from gymnasium.spaces import Box
from gymnasium.wrappers import LambdaActionV0
from tests.testing_env import GenericTestEnv
NUM_ENVS = 3

View File

@@ -3,8 +3,8 @@
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import LambdaObservationV0
from gymnasium.spaces import Box
from gymnasium.wrappers import LambdaObservationsV0
NUM_ENVS = 3
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
@@ -25,7 +25,7 @@ def test_lambda_observation_v0():
observation_shift = 1
env.reset(seed=SEED)
wrapped_env = LambdaObservationsV0(
wrapped_env = LambdaObservationV0(
env, lambda observation: observation + observation_shift
)
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION)
@@ -48,7 +48,7 @@ def test_lambda_observation_v0_within_vector():
observation_shift = 1
env.reset(seed=SEED)
wrapped_env = LambdaObservationsV0(
wrapped_env = LambdaObservationV0(
env, lambda observation: observation + observation_shift
)
wrapped_obs, _, _, _, _ = wrapped_env.step(

View File

@@ -5,7 +5,7 @@ import pytest
import gymnasium as gym
from gymnasium.error import InvalidBound
from gymnasium.wrappers import ClipRewardsV0, LambdaRewardV0
from gymnasium.experimental.wrappers import ClipRewardV0, LambdaRewardV0
ENV_ID = "CartPole-v1"
DISCRETE_ACTION = 0
@@ -65,7 +65,7 @@ def test_clip_reward(lower_bound, upper_bound, expected_reward):
accordingly to the input args.
"""
env = gym.make(ENV_ID)
env = ClipRewardsV0(env, lower_bound, upper_bound)
env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
@@ -84,7 +84,7 @@ def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward):
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
env = ClipRewardsV0(env, lower_bound, upper_bound)
env = ClipRewardV0(env, lower_bound, upper_bound)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(actions)
@@ -106,4 +106,4 @@ def test_clip_reward_incorrect_params(lower_bound, upper_bound):
env = gym.make(ENV_ID)
with pytest.raises(InvalidBound):
env = ClipRewardsV0(env, lower_bound, upper_bound)
env = ClipRewardV0(env, lower_bound, upper_bound)