Add vector wrappers for lambda observation, action and reward wrappers (#444)

This commit is contained in:
Mark Towers
2023-07-13 12:33:16 +01:00
committed by GitHub
parent ee067c721b
commit aa6d0c8787
15 changed files with 677 additions and 51 deletions

View File

@@ -8,32 +8,46 @@ title: Vector Wrappers
.. autoclass:: gymnasium.experimental.vector.VectorWrapper
```
## Vector Lambda Observation Wrappers
## Vector Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorObservationWrapper
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.FilterObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.FlattenObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.GrayscaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ResizeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ReshapeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.DtypeObservationV0
```
## Vector Lambda Action Wrappers
## Vector Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorActionWrapper
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaActionV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipActionV0
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleActionV0
```
## Vector Lambda Reward Wrappers
## Vector Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorRewardWrapper
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipRewardV0
```
## Vector Common Wrappers
## More Vector Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorRecordEpisodeStatistics
```
## Vector Only Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorListInfo
.. autoclass:: gymnasium.experimental.wrappers.vector.RecordEpisodeStatisticsV0
.. autoclass:: gymnasium.experimental.wrappers.vector.DictInfoToListV0
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaActionV0
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToNumpyV0
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToTorchV0
.. autoclass:: gymnasium.experimental.wrappers.vector.NumpyToTorchV0
```

View File

@@ -20,7 +20,8 @@ from gymnasium.envs.registration import (
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
from gymnasium import envs
from gymnasium import experimental, spaces, utils, vector, wrappers, error, logger
from gymnasium import spaces, utils, vector, wrappers, error, logger
from gymnasium import experimental
__all__ = [

View File

@@ -1,22 +1,25 @@
"""Root __init__ of the gym experimental wrappers."""
from gymnasium.experimental import functional, wrappers
from gymnasium.experimental.functional import FuncEnv
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
from gymnasium.experimental import functional, vector, wrappers
# from gymnasium.experimental.functional import FuncEnv
# from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
# from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
# from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
__all__ = [
# Functional
"FuncEnv",
# "FuncEnv",
"functional",
# Vector
"VectorEnv",
"VectorWrapper",
"SyncVectorEnv",
"AsyncVectorEnv",
# "VectorEnv",
# "VectorWrapper",
# "SyncVectorEnv",
# "AsyncVectorEnv",
# wrappers
"wrappers",
"vector",
]

View File

@@ -397,7 +397,7 @@ class VectorObservationWrapper(VectorWrapper):
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.observation(obs), info
return self.vector_observation(obs), info
def step(
self, actions: ActType
@@ -405,24 +405,43 @@ class VectorObservationWrapper(VectorWrapper):
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return (
self.observation(observation),
self.vector_observation(observation),
reward,
termination,
truncation,
info,
self.update_final_obs(info),
)
def observation(self, observation: ObsType) -> ObsType:
"""Defines the observation transformation.
def vector_observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation.
Args:
observation (object): the observation from the environment
observation: A vector observation from the environment
Returns:
observation (object): the transformed observation
the transformed observation
"""
raise NotImplementedError
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info
class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""

View File

@@ -8,35 +8,57 @@ from gymnasium.experimental.wrappers.vector.dict_info_to_list import DictInfoToL
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_action import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
VectorizeLambdaActionV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_observation import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
VectorizeLambdaObservationV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_reward import (
ClipRewardV0,
LambdaRewardV0,
VectorizeLambdaRewardV0,
)
__all__ = [
# --- Vector only wrappers
# "VectoriseLambdaObservationV0",
# "VectoriseLambdaActionV0",
# "VectoriseLambdaRewardV0",
"VectorizeLambdaObservationV0",
"VectorizeLambdaActionV0",
"VectorizeLambdaRewardV0",
"DictInfoToListV0",
# --- Observation wrappers ---
# "LambdaObservationV0",
# "FilterObservationV0",
# "FlattenObservationV0",
# "GrayscaleObservationV0",
# "ResizeObservationV0",
# "ReshapeObservationV0",
# "RescaleObservationV0",
# "DtypeObservationV0",
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
# "PixelObservationV0",
# "NormalizeObservationV0",
# "TimeAwareObservationV0",
# "FrameStackObservationV0",
# "DelayObservationV0",
# --- Action Wrappers ---
# "LambdaActionV0",
# "ClipActionV0",
# "RescaleActionV0",
"LambdaActionV0",
"ClipActionV0",
"RescaleActionV0",
# --- Reward wrappers ---
# "LambdaRewardV0",
# "ClipRewardV0",
"LambdaRewardV0",
"ClipRewardV0",
# "NormalizeRewardV1",
# --- Common ---
"RecordEpisodeStatisticsV0",

View File

@@ -7,7 +7,7 @@ import jax.numpy as jnp
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental import VectorEnv, VectorWrapper
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
@@ -19,7 +19,7 @@ class JaxToNumpyV0(VectorWrapper):
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
Notes:
A vectorised version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
A vectorized version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
"""

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental import VectorEnv, VectorWrapper
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import (
Device,

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental import VectorEnv, VectorWrapper
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import Device
from gymnasium.experimental.wrappers.numpy_to_torch import (

View File

@@ -0,0 +1,143 @@
"""Vectorizes action wrappers to work for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable
import numpy as np
from gymnasium import Space
from gymnasium.core import ActType, Env
from gymnasium.experimental.vector import VectorActionWrapper, VectorEnv
from gymnasium.experimental.wrappers import lambda_action
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
class LambdaActionV0(VectorActionWrapper):
"""Transforms an action via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector actions.
If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, provide an :attr:`action_space`.
"""
def __init__(
self,
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
"""
super().__init__(env)
if action_space is not None:
self.action_space = action_space
self.func = func
def actions(self, actions: ActType) -> ActType:
"""Applies the :attr:`func` to the actions."""
return self.func(actions)
class VectorizeLambdaActionV0(VectorActionWrapper):
"""Vectorizes a single-agent lambda action wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, action_space: Space):
"""Constructor for the fake environment."""
self.action_space = action_space
def __init__(
self, env: VectorEnv, wrapper: type[lambda_action.LambdaActionV0], **kwargs: Any
):
"""Constructor for the vectorized lambda action wrapper.
Args:
env: The vector environment to wrap
wrapper: The wrapper to vectorize
**kwargs: Arguments for the LambdaActionV0 wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_action_space), **kwargs
)
self.single_action_space = self.wrapper.action_space
self.action_space = batch_space(self.single_action_space, self.num_envs)
self.same_out = self.action_space == self.env.action_space
self.out = create_empty_array(self.single_action_space, self.num_envs)
def actions(self, actions: ActType) -> ActType:
"""Applies the wrapper to each of the action.
Args:
actions: The actions to apply the function to
Returns:
The updated actions using the wrapper func
"""
if self.same_out:
return concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
actions,
)
else:
return deepcopy(
concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
self.out,
)
)
class ClipActionV0(VectorizeLambdaActionV0):
"""Clip the continuous action within the valid :class:`Box` observation space bound."""
def __init__(self, env: VectorEnv):
"""Constructor for the Clip Action wrapper.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_action.ClipActionV0)
class RescaleActionV0(VectorizeLambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]."""
def __init__(
self,
env: VectorEnv,
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Initializes the :class:`RescaleAction` wrapper.
Args:
env (Env): The vector environment to wrap
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
super().__init__(
env,
lambda_action.RescaleActionV0,
min_action=min_action,
max_action=max_action,
)

View File

@@ -0,0 +1,222 @@
"""Vectorizes observation wrappers to works for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable, Sequence
import numpy as np
from gymnasium import Space
from gymnasium.core import Env, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorObservationWrapper
from gymnasium.experimental.vector.utils import batch_space, concatenate, iterate
from gymnasium.experimental.wrappers import lambda_observation
from gymnasium.vector.utils import create_empty_array
class LambdaObservationV0(VectorObservationWrapper):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector observations.
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
"""
def __init__(
self,
env: VectorEnv,
vector_func: Callable[[ObsType], Any],
single_func: Callable[[ObsType], Any],
observation_space: Space | None = None,
):
"""Constructor for the lambda observation wrapper.
Args:
env: The vector environment to wrap
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
single_func: A function that will transform an individual observation.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
super().__init__(env)
if observation_space is not None:
self.observation_space = observation_space
self.vector_func = vector_func
self.single_func = single_func
def vector_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.vector_func(observation)
def single_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the single observation."""
return self.single_func(observation)
class VectorizeLambdaObservationV0(VectorObservationWrapper):
"""Vectori`es a single-agent lambda observation wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, observation_space: Space):
"""Constructor for the fake environment."""
self.observation_space = observation_space
def __init__(
self,
env: VectorEnv,
wrapper: type[lambda_observation.LambdaObservationV0],
**kwargs: Any,
):
"""Constructor for the vectorized lambda observation wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_observation_space), **kwargs
)
self.single_observation_space = self.wrapper.observation_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.same_out = self.observation_space == self.env.observation_space
self.out = create_empty_array(self.single_observation_space, self.num_envs)
def vector_observation(self, observation: ObsType) -> ObsType:
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
if self.same_out:
return concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
observation,
)
else:
return deepcopy(
concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
self.out,
)
)
def single_observation(self, observation: ObsType) -> ObsType:
"""Transforms a single observation using the wrapper transformation function."""
return self.wrapper.func(observation)
class FilterObservationV0(VectorizeLambdaObservationV0):
"""Vector wrapper for filtering dict or tuple observation spaces."""
def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):
"""Constructor for the filter observation wrapper.
Args:
env: The vector environment to wrap
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
"""
super().__init__(
env, lambda_observation.FilterObservationV0, filter_keys=filter_keys
)
class FlattenObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that flattens the observation."""
def __init__(self, env: VectorEnv):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_observation.FlattenObservationV0)
class GrayscaleObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that converts an RGB image to grayscale."""
def __init__(self, env: VectorEnv, keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale.
Args:
env: The vector environment to wrap
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
"""
super().__init__(
env, lambda_observation.GrayscaleObservationV0, keep_dim=keep_dim
)
class ResizeObservationV0(VectorizeLambdaObservationV0):
"""Resizes image observations using OpenCV to shape."""
def __init__(self, env: VectorEnv, shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape.
Args:
env: The vector environment to wrap
shape: The resized observation shape
"""
super().__init__(env, lambda_observation.ResizeObservationV0, shape=shape)
class ReshapeObservationV0(VectorizeLambdaObservationV0):
"""Reshapes array based observations to shapes."""
def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]):
"""Constructor for env with Box observation space that has a shape product equal to the new shape product.
Args:
env: The vector environment to wrap
shape: The reshaped observation space
"""
super().__init__(env, lambda_observation.ReshapeObservationV0, shape=shape)
class RescaleObservationV0(VectorizeLambdaObservationV0):
"""Linearly rescales observation to between a minimum and maximum value."""
def __init__(
self,
env: VectorEnv,
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`.
Args:
env: The vector environment to wrap
min_obs: The new minimum observation bound
max_obs: The new maximum observation bound
"""
super().__init__(
env,
lambda_observation.RescaleObservationV0,
min_obs=min_obs,
max_obs=max_obs,
)
class DtypeObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: VectorEnv, dtype: Any):
"""Constructor for Dtype observation wrapper.
Args:
env: The vector environment to wrap
dtype: The new dtype of the observation
"""
super().__init__(env, lambda_observation.DtypeObservationV0, dtype=dtype)

View File

@@ -0,0 +1,78 @@
"""Vectorizes reward function to work with `VectorEnv`."""
from __future__ import annotations
from typing import Any, Callable
import numpy as np
from gymnasium import Env
from gymnasium.experimental.vector import VectorEnv, VectorRewardWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers import lambda_reward
class LambdaRewardV0(VectorRewardWrapper):
"""A reward wrapper that allows a custom function to modify the step reward."""
def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]):
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The vector environment to wrap
func: (Callable): The function to apply to reward
"""
super().__init__(env)
self.func = func
def reward(self, reward: ArrayType) -> ArrayType:
"""Apply function to reward."""
return self.func(reward)
class VectorizeLambdaRewardV0(VectorRewardWrapper):
"""Vectorizes a single-agent lambda reward wrapper for vector environments."""
def __init__(
self, env: VectorEnv, wrapper: type[lambda_reward.LambdaRewardV0], **kwargs: Any
):
"""Constructor for the vectorized lambda reward wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(Env(), **kwargs)
def reward(self, reward: ArrayType) -> ArrayType:
"""Iterates over the reward updating each with the wrapper func."""
for i, r in enumerate(reward):
reward[i] = self.wrapper.func(r)
return reward
class ClipRewardV0(VectorizeLambdaRewardV0):
"""A wrapper that clips the rewards for an environment between an upper and lower bound."""
def __init__(
self,
env: VectorEnv,
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Constructor for ClipReward wrapper.
Args:
env: The vector environment to wrap
min_reward: The min reward for each step
max_reward: the max reward for each step
"""
super().__init__(
env,
lambda_reward.ClipRewardV0,
min_reward=min_reward,
max_reward=max_reward,
)

View File

@@ -3,7 +3,7 @@
import pytest
import gymnasium as gym
from gymnasium.experimental import AsyncVectorEnv, SyncVectorEnv
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv
from gymnasium.wrappers import TimeLimit, TransformObservation
from tests.wrappers.utils import has_wrapper

View File

@@ -5,7 +5,7 @@ from typing import Any
import numpy as np
from gymnasium.experimental import FuncEnv
from gymnasium.experimental.functional import FuncEnv
class GenericTestFuncEnv(FuncEnv):

View File

@@ -0,0 +1 @@
"""Testing suite for `gymnasium.experimental.wrappers.vector`."""

View File

@@ -0,0 +1,123 @@
"""Tests that the vectorised wrappers operate identically in `VectorEnv(Wrapper)` and `VectorWrapper(VectorEnv)`.
The exception is the data converter wrappers (`JaxToTorch`, `JaxToNumpy` and `NumpyToJax`)
"""
from __future__ import annotations
from typing import Any
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental import wrappers
from gymnasium.experimental.vector import VectorEnv
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils.env_checker import data_equivalence
from tests.testing_env import GenericTestEnv
@pytest.fixture
def custom_environments():
gym.register(
"CustomDictEnv-v0",
lambda: GenericTestEnv(
observation_space=Dict({"a": Box(0, 1), "b": Discrete(5)})
),
)
yield
del gym.registry["CustomDictEnv-v0"]
@pytest.mark.parametrize("num_envs", (1, 3))
@pytest.mark.parametrize(
"env_id, wrapper_name, kwargs",
(
("CustomDictEnv-v0", "FilterObservationV0", {"filter_keys": ["a"]}),
("CartPole-v1", "FlattenObservationV0", {}),
("CarRacing-v2", "GrayscaleObservationV0", {}),
# ("CarRacing-v2", "ResizeObservationV0", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservationV0", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservationV0", {"min_obs": 0, "max_obs": 1}),
("CartPole-v1", "DtypeObservationV0", {"dtype": np.int32}),
# ("CartPole-v1", "PixelObservationV0", {}),
# ("CartPole-v1", "NormalizeObservationV0", {}),
# ("CartPole-v1", "TimeAwareObservationV0", {}),
# ("CartPole-v1", "FrameStackObservationV0", {}),
# ("CartPole-v1", "DelayObservationV0", {}),
("MountainCarContinuous-v0", "ClipActionV0", {}),
(
"MountainCarContinuous-v0",
"RescaleActionV0",
{"min_action": 1, "max_action": 2},
),
# ("CartPole-v1", "StickyActionV0", {}),
("CartPole-v1", "ClipRewardV0", {"min_reward": 0.25, "max_reward": 0.75}),
# ("CartPole-v1", "NormalizeRewardV1", {}),
),
)
def test_vector_wrapper_equivalence(
env_id: str,
wrapper_name: str,
kwargs: dict[str, Any],
num_envs: int,
custom_environments,
vectorization_mode: str = "sync",
num_steps: int = 50,
):
vector_wrapper = getattr(wrappers.vector, wrapper_name)
wrapper_vector_env: VectorEnv = vector_wrapper(
gym.make_vec(
id=env_id, num_envs=num_envs, vectorization_mode=vectorization_mode
),
**kwargs,
)
env_wrapper = getattr(wrappers, wrapper_name)
vector_wrapper_env = gym.make_vec(
id=env_id,
num_envs=num_envs,
vectorization_mode=vectorization_mode,
wrappers=(lambda env: env_wrapper(env, **kwargs),),
)
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
assert (
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
)
assert (
wrapper_vector_env.single_observation_space
== vector_wrapper_env.single_observation_space
)
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
for _ in range(num_steps):
action = wrapper_vector_env.action_space.sample()
wrapper_vector_step_returns = wrapper_vector_env.step(action)
vector_wrapper_step_returns = vector_wrapper_env.step(action)
for wrapper_vector_return, vector_wrapper_return in zip(
wrapper_vector_step_returns, vector_wrapper_step_returns
):
assert data_equivalence(wrapper_vector_return, vector_wrapper_return)
wrapper_vector_env.close()
vector_wrapper_env.close()
# ("CartPole-v1", "LambdaObservationV0", {"func": lambda obs: obs + 1}),
# ("CartPole-v1", "LambdaActionV0", {"func": lambda action: action + 1}),
# ("CartPole-v1", "LambdaRewardV0", {"func": lambda reward: reward + 1}),
# (vector.JaxToNumpyV0, {}, {}),
# (vector.JaxToTorchV0, {}, {}),
# (vector.NumpyToTorchV0, {}, {}),
# ("CartPole-v1", "RecordEpisodeStatisticsV0", {}), # for the time taken in info, this is not equivalent for two instances