mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Add vector wrappers for lambda observation, action and reward wrappers (#444)
This commit is contained in:
@@ -8,32 +8,46 @@ title: Vector Wrappers
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorWrapper
|
||||
```
|
||||
|
||||
## Vector Lambda Observation Wrappers
|
||||
## Vector Observation Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorObservationWrapper
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.FilterObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.FlattenObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.GrayscaleObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ResizeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ReshapeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.DtypeObservationV0
|
||||
```
|
||||
|
||||
## Vector Lambda Action Wrappers
|
||||
## Vector Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorActionWrapper
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleActionV0
|
||||
```
|
||||
|
||||
## Vector Lambda Reward Wrappers
|
||||
## Vector Reward Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorRewardWrapper
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipRewardV0
|
||||
```
|
||||
|
||||
## Vector Common Wrappers
|
||||
## More Vector Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorRecordEpisodeStatistics
|
||||
```
|
||||
|
||||
## Vector Only Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorListInfo
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.RecordEpisodeStatisticsV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.DictInfoToListV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToNumpyV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToTorchV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.NumpyToTorchV0
|
||||
```
|
||||
|
@@ -20,7 +20,8 @@ from gymnasium.envs.registration import (
|
||||
|
||||
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
|
||||
from gymnasium import envs
|
||||
from gymnasium import experimental, spaces, utils, vector, wrappers, error, logger
|
||||
from gymnasium import spaces, utils, vector, wrappers, error, logger
|
||||
from gymnasium import experimental
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@@ -1,22 +1,25 @@
|
||||
"""Root __init__ of the gym experimental wrappers."""
|
||||
|
||||
|
||||
from gymnasium.experimental import functional, wrappers
|
||||
from gymnasium.experimental.functional import FuncEnv
|
||||
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental import functional, vector, wrappers
|
||||
|
||||
|
||||
# from gymnasium.experimental.functional import FuncEnv
|
||||
# from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||
# from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||
# from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Functional
|
||||
"FuncEnv",
|
||||
# "FuncEnv",
|
||||
"functional",
|
||||
# Vector
|
||||
"VectorEnv",
|
||||
"VectorWrapper",
|
||||
"SyncVectorEnv",
|
||||
"AsyncVectorEnv",
|
||||
# "VectorEnv",
|
||||
# "VectorWrapper",
|
||||
# "SyncVectorEnv",
|
||||
# "AsyncVectorEnv",
|
||||
# wrappers
|
||||
"wrappers",
|
||||
"vector",
|
||||
]
|
||||
|
@@ -397,7 +397,7 @@ class VectorObservationWrapper(VectorWrapper):
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
return self.observation(obs), info
|
||||
return self.vector_observation(obs), info
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
@@ -405,24 +405,43 @@ class VectorObservationWrapper(VectorWrapper):
|
||||
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return (
|
||||
self.observation(observation),
|
||||
self.vector_observation(observation),
|
||||
reward,
|
||||
termination,
|
||||
truncation,
|
||||
info,
|
||||
self.update_final_obs(info),
|
||||
)
|
||||
|
||||
def observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the observation transformation.
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the vector observation transformation.
|
||||
|
||||
Args:
|
||||
observation (object): the observation from the environment
|
||||
observation: A vector observation from the environment
|
||||
|
||||
Returns:
|
||||
observation (object): the transformed observation
|
||||
the transformed observation
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the single observation transformation.
|
||||
|
||||
Args:
|
||||
observation: A single observation from the environment
|
||||
|
||||
Returns:
|
||||
The transformed observation
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Updates the `final_obs` in the info using `single_observation`."""
|
||||
if "final_observation" in info:
|
||||
for i, obs in enumerate(info["final_observation"]):
|
||||
if obs is not None:
|
||||
info["final_observation"][i] = self.single_observation(obs)
|
||||
return info
|
||||
|
||||
|
||||
class VectorActionWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
|
||||
|
@@ -8,35 +8,57 @@ from gymnasium.experimental.wrappers.vector.dict_info_to_list import DictInfoToL
|
||||
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
|
||||
RecordEpisodeStatisticsV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.vector.vectorize_action import (
|
||||
ClipActionV0,
|
||||
LambdaActionV0,
|
||||
RescaleActionV0,
|
||||
VectorizeLambdaActionV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.vector.vectorize_observation import (
|
||||
DtypeObservationV0,
|
||||
FilterObservationV0,
|
||||
FlattenObservationV0,
|
||||
GrayscaleObservationV0,
|
||||
LambdaObservationV0,
|
||||
RescaleObservationV0,
|
||||
ReshapeObservationV0,
|
||||
ResizeObservationV0,
|
||||
VectorizeLambdaObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.vector.vectorize_reward import (
|
||||
ClipRewardV0,
|
||||
LambdaRewardV0,
|
||||
VectorizeLambdaRewardV0,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# --- Vector only wrappers
|
||||
# "VectoriseLambdaObservationV0",
|
||||
# "VectoriseLambdaActionV0",
|
||||
# "VectoriseLambdaRewardV0",
|
||||
"VectorizeLambdaObservationV0",
|
||||
"VectorizeLambdaActionV0",
|
||||
"VectorizeLambdaRewardV0",
|
||||
"DictInfoToListV0",
|
||||
# --- Observation wrappers ---
|
||||
# "LambdaObservationV0",
|
||||
# "FilterObservationV0",
|
||||
# "FlattenObservationV0",
|
||||
# "GrayscaleObservationV0",
|
||||
# "ResizeObservationV0",
|
||||
# "ReshapeObservationV0",
|
||||
# "RescaleObservationV0",
|
||||
# "DtypeObservationV0",
|
||||
"LambdaObservationV0",
|
||||
"FilterObservationV0",
|
||||
"FlattenObservationV0",
|
||||
"GrayscaleObservationV0",
|
||||
"ResizeObservationV0",
|
||||
"ReshapeObservationV0",
|
||||
"RescaleObservationV0",
|
||||
"DtypeObservationV0",
|
||||
# "PixelObservationV0",
|
||||
# "NormalizeObservationV0",
|
||||
# "TimeAwareObservationV0",
|
||||
# "FrameStackObservationV0",
|
||||
# "DelayObservationV0",
|
||||
# --- Action Wrappers ---
|
||||
# "LambdaActionV0",
|
||||
# "ClipActionV0",
|
||||
# "RescaleActionV0",
|
||||
"LambdaActionV0",
|
||||
"ClipActionV0",
|
||||
"RescaleActionV0",
|
||||
# --- Reward wrappers ---
|
||||
# "LambdaRewardV0",
|
||||
# "ClipRewardV0",
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
# "NormalizeRewardV1",
|
||||
# --- Common ---
|
||||
"RecordEpisodeStatisticsV0",
|
||||
|
@@ -7,7 +7,7 @@ import jax.numpy as jnp
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
||||
|
||||
@@ -19,7 +19,7 @@ class JaxToNumpyV0(VectorWrapper):
|
||||
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
||||
|
||||
Notes:
|
||||
A vectorised version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
|
||||
A vectorized version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
|
||||
|
||||
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
|
||||
"""
|
||||
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import (
|
||||
Device,
|
||||
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import Device
|
||||
from gymnasium.experimental.wrappers.numpy_to_torch import (
|
||||
|
143
gymnasium/experimental/wrappers/vector/vectorize_action.py
Normal file
143
gymnasium/experimental/wrappers/vector/vectorize_action.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Vectorizes action wrappers to work for `VectorEnv`."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Space
|
||||
from gymnasium.core import ActType, Env
|
||||
from gymnasium.experimental.vector import VectorActionWrapper, VectorEnv
|
||||
from gymnasium.experimental.wrappers import lambda_action
|
||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||
|
||||
|
||||
class LambdaActionV0(VectorActionWrapper):
|
||||
"""Transforms an action via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all vector actions.
|
||||
If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, provide an :attr:`action_space`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
func: Callable[[ActType], Any],
|
||||
action_space: Space | None = None,
|
||||
):
|
||||
"""Constructor for the lambda action wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
|
||||
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
if action_space is not None:
|
||||
self.action_space = action_space
|
||||
|
||||
self.func = func
|
||||
|
||||
def actions(self, actions: ActType) -> ActType:
|
||||
"""Applies the :attr:`func` to the actions."""
|
||||
return self.func(actions)
|
||||
|
||||
|
||||
class VectorizeLambdaActionV0(VectorActionWrapper):
|
||||
"""Vectorizes a single-agent lambda action wrapper for vector environments."""
|
||||
|
||||
class VectorizedEnv(Env):
|
||||
"""Fake single-agent environment uses for the single-agent wrapper."""
|
||||
|
||||
def __init__(self, action_space: Space):
|
||||
"""Constructor for the fake environment."""
|
||||
self.action_space = action_space
|
||||
|
||||
def __init__(
|
||||
self, env: VectorEnv, wrapper: type[lambda_action.LambdaActionV0], **kwargs: Any
|
||||
):
|
||||
"""Constructor for the vectorized lambda action wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
wrapper: The wrapper to vectorize
|
||||
**kwargs: Arguments for the LambdaActionV0 wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.wrapper = wrapper(
|
||||
self.VectorizedEnv(self.env.single_action_space), **kwargs
|
||||
)
|
||||
self.single_action_space = self.wrapper.action_space
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
self.same_out = self.action_space == self.env.action_space
|
||||
self.out = create_empty_array(self.single_action_space, self.num_envs)
|
||||
|
||||
def actions(self, actions: ActType) -> ActType:
|
||||
"""Applies the wrapper to each of the action.
|
||||
|
||||
Args:
|
||||
actions: The actions to apply the function to
|
||||
|
||||
Returns:
|
||||
The updated actions using the wrapper func
|
||||
"""
|
||||
if self.same_out:
|
||||
return concatenate(
|
||||
self.single_action_space,
|
||||
tuple(
|
||||
self.wrapper.func(action)
|
||||
for action in iterate(self.action_space, actions)
|
||||
),
|
||||
actions,
|
||||
)
|
||||
else:
|
||||
return deepcopy(
|
||||
concatenate(
|
||||
self.single_action_space,
|
||||
tuple(
|
||||
self.wrapper.func(action)
|
||||
for action in iterate(self.action_space, actions)
|
||||
),
|
||||
self.out,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ClipActionV0(VectorizeLambdaActionV0):
|
||||
"""Clip the continuous action within the valid :class:`Box` observation space bound."""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
"""Constructor for the Clip Action wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
"""
|
||||
super().__init__(env, lambda_action.ClipActionV0)
|
||||
|
||||
|
||||
class RescaleActionV0(VectorizeLambdaActionV0):
|
||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
min_action: float | int | np.ndarray,
|
||||
max_action: float | int | np.ndarray,
|
||||
):
|
||||
"""Initializes the :class:`RescaleAction` wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The vector environment to wrap
|
||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
super().__init__(
|
||||
env,
|
||||
lambda_action.RescaleActionV0,
|
||||
min_action=min_action,
|
||||
max_action=max_action,
|
||||
)
|
222
gymnasium/experimental/wrappers/vector/vectorize_observation.py
Normal file
222
gymnasium/experimental/wrappers/vector/vectorize_observation.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Vectorizes observation wrappers to works for `VectorEnv`."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Space
|
||||
from gymnasium.core import Env, ObsType
|
||||
from gymnasium.experimental.vector import VectorEnv, VectorObservationWrapper
|
||||
from gymnasium.experimental.vector.utils import batch_space, concatenate, iterate
|
||||
from gymnasium.experimental.wrappers import lambda_observation
|
||||
from gymnasium.vector.utils import create_empty_array
|
||||
|
||||
|
||||
class LambdaObservationV0(VectorObservationWrapper):
|
||||
"""Transforms an observation via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all vector observations.
|
||||
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
vector_func: Callable[[ObsType], Any],
|
||||
single_func: Callable[[ObsType], Any],
|
||||
observation_space: Space | None = None,
|
||||
):
|
||||
"""Constructor for the lambda observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
|
||||
single_func: A function that will transform an individual observation.
|
||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
if observation_space is not None:
|
||||
self.observation_space = observation_space
|
||||
|
||||
self.vector_func = vector_func
|
||||
self.single_func = single_func
|
||||
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Apply function to the vector observation."""
|
||||
return self.vector_func(observation)
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Apply function to the single observation."""
|
||||
return self.single_func(observation)
|
||||
|
||||
|
||||
class VectorizeLambdaObservationV0(VectorObservationWrapper):
|
||||
"""Vectori`es a single-agent lambda observation wrapper for vector environments."""
|
||||
|
||||
class VectorizedEnv(Env):
|
||||
"""Fake single-agent environment uses for the single-agent wrapper."""
|
||||
|
||||
def __init__(self, observation_space: Space):
|
||||
"""Constructor for the fake environment."""
|
||||
self.observation_space = observation_space
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
wrapper: type[lambda_observation.LambdaObservationV0],
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Constructor for the vectorized lambda observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap.
|
||||
wrapper: The wrapper to vectorize
|
||||
**kwargs: Keyword argument for the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.wrapper = wrapper(
|
||||
self.VectorizedEnv(self.env.single_observation_space), **kwargs
|
||||
)
|
||||
self.single_observation_space = self.wrapper.observation_space
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
|
||||
self.same_out = self.observation_space == self.env.observation_space
|
||||
self.out = create_empty_array(self.single_observation_space, self.num_envs)
|
||||
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
|
||||
if self.same_out:
|
||||
return concatenate(
|
||||
self.single_observation_space,
|
||||
tuple(
|
||||
self.wrapper.func(obs)
|
||||
for obs in iterate(self.observation_space, observation)
|
||||
),
|
||||
observation,
|
||||
)
|
||||
else:
|
||||
return deepcopy(
|
||||
concatenate(
|
||||
self.single_observation_space,
|
||||
tuple(
|
||||
self.wrapper.func(obs)
|
||||
for obs in iterate(self.observation_space, observation)
|
||||
),
|
||||
self.out,
|
||||
)
|
||||
)
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Transforms a single observation using the wrapper transformation function."""
|
||||
return self.wrapper.func(observation)
|
||||
|
||||
|
||||
class FilterObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Vector wrapper for filtering dict or tuple observation spaces."""
|
||||
|
||||
def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):
|
||||
"""Constructor for the filter observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
|
||||
"""
|
||||
super().__init__(
|
||||
env, lambda_observation.FilterObservationV0, filter_keys=filter_keys
|
||||
)
|
||||
|
||||
|
||||
class FlattenObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Observation wrapper that flattens the observation."""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
"""
|
||||
super().__init__(env, lambda_observation.FlattenObservationV0)
|
||||
|
||||
|
||||
class GrayscaleObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Observation wrapper that converts an RGB image to grayscale."""
|
||||
|
||||
def __init__(self, env: VectorEnv, keep_dim: bool = False):
|
||||
"""Constructor for an RGB image based environments to make the image grayscale.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
|
||||
"""
|
||||
super().__init__(
|
||||
env, lambda_observation.GrayscaleObservationV0, keep_dim=keep_dim
|
||||
)
|
||||
|
||||
|
||||
class ResizeObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Resizes image observations using OpenCV to shape."""
|
||||
|
||||
def __init__(self, env: VectorEnv, shape: tuple[int, ...]):
|
||||
"""Constructor that requires an image environment observation space with a shape.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
shape: The resized observation shape
|
||||
"""
|
||||
super().__init__(env, lambda_observation.ResizeObservationV0, shape=shape)
|
||||
|
||||
|
||||
class ReshapeObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Reshapes array based observations to shapes."""
|
||||
|
||||
def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]):
|
||||
"""Constructor for env with Box observation space that has a shape product equal to the new shape product.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
shape: The reshaped observation space
|
||||
"""
|
||||
super().__init__(env, lambda_observation.ReshapeObservationV0, shape=shape)
|
||||
|
||||
|
||||
class RescaleObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Linearly rescales observation to between a minimum and maximum value."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
min_obs: np.floating | np.integer | np.ndarray,
|
||||
max_obs: np.floating | np.integer | np.ndarray,
|
||||
):
|
||||
"""Constructor that requires the env observation spaces to be a :class:`Box`.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
min_obs: The new minimum observation bound
|
||||
max_obs: The new maximum observation bound
|
||||
"""
|
||||
super().__init__(
|
||||
env,
|
||||
lambda_observation.RescaleObservationV0,
|
||||
min_obs=min_obs,
|
||||
max_obs=max_obs,
|
||||
)
|
||||
|
||||
|
||||
class DtypeObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Observation wrapper for transforming the dtype of an observation."""
|
||||
|
||||
def __init__(self, env: VectorEnv, dtype: Any):
|
||||
"""Constructor for Dtype observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
dtype: The new dtype of the observation
|
||||
"""
|
||||
super().__init__(env, lambda_observation.DtypeObservationV0, dtype=dtype)
|
78
gymnasium/experimental/wrappers/vector/vectorize_reward.py
Normal file
78
gymnasium/experimental/wrappers/vector/vectorize_reward.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Vectorizes reward function to work with `VectorEnv`."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env
|
||||
from gymnasium.experimental.vector import VectorEnv, VectorRewardWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers import lambda_reward
|
||||
|
||||
|
||||
class LambdaRewardV0(VectorRewardWrapper):
|
||||
"""A reward wrapper that allows a custom function to modify the step reward."""
|
||||
|
||||
def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]):
|
||||
"""Initialize LambdaRewardV0 wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The vector environment to wrap
|
||||
func: (Callable): The function to apply to reward
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.func = func
|
||||
|
||||
def reward(self, reward: ArrayType) -> ArrayType:
|
||||
"""Apply function to reward."""
|
||||
return self.func(reward)
|
||||
|
||||
|
||||
class VectorizeLambdaRewardV0(VectorRewardWrapper):
|
||||
"""Vectorizes a single-agent lambda reward wrapper for vector environments."""
|
||||
|
||||
def __init__(
|
||||
self, env: VectorEnv, wrapper: type[lambda_reward.LambdaRewardV0], **kwargs: Any
|
||||
):
|
||||
"""Constructor for the vectorized lambda reward wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap.
|
||||
wrapper: The wrapper to vectorize
|
||||
**kwargs: Keyword argument for the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.wrapper = wrapper(Env(), **kwargs)
|
||||
|
||||
def reward(self, reward: ArrayType) -> ArrayType:
|
||||
"""Iterates over the reward updating each with the wrapper func."""
|
||||
for i, r in enumerate(reward):
|
||||
reward[i] = self.wrapper.func(r)
|
||||
return reward
|
||||
|
||||
|
||||
class ClipRewardV0(VectorizeLambdaRewardV0):
|
||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
min_reward: float | np.ndarray | None = None,
|
||||
max_reward: float | np.ndarray | None = None,
|
||||
):
|
||||
"""Constructor for ClipReward wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
min_reward: The min reward for each step
|
||||
max_reward: the max reward for each step
|
||||
"""
|
||||
super().__init__(
|
||||
env,
|
||||
lambda_reward.ClipRewardV0,
|
||||
min_reward=min_reward,
|
||||
max_reward=max_reward,
|
||||
)
|
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental import AsyncVectorEnv, SyncVectorEnv
|
||||
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv
|
||||
from gymnasium.wrappers import TimeLimit, TransformObservation
|
||||
from tests.wrappers.utils import has_wrapper
|
||||
|
||||
|
@@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.experimental import FuncEnv
|
||||
from gymnasium.experimental.functional import FuncEnv
|
||||
|
||||
|
||||
class GenericTestFuncEnv(FuncEnv):
|
||||
|
1
tests/experimental/wrappers/vector/__init__.py
Normal file
1
tests/experimental/wrappers/vector/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Testing suite for `gymnasium.experimental.wrappers.vector`."""
|
123
tests/experimental/wrappers/vector/test_vector_wrappers.py
Normal file
123
tests/experimental/wrappers/vector/test_vector_wrappers.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tests that the vectorised wrappers operate identically in `VectorEnv(Wrapper)` and `VectorWrapper(VectorEnv)`.
|
||||
|
||||
The exception is the data converter wrappers (`JaxToTorch`, `JaxToNumpy` and `NumpyToJax`)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental import wrappers
|
||||
from gymnasium.experimental.vector import VectorEnv
|
||||
from gymnasium.spaces import Box, Dict, Discrete
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_environments():
|
||||
gym.register(
|
||||
"CustomDictEnv-v0",
|
||||
lambda: GenericTestEnv(
|
||||
observation_space=Dict({"a": Box(0, 1), "b": Discrete(5)})
|
||||
),
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
del gym.registry["CustomDictEnv-v0"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_envs", (1, 3))
|
||||
@pytest.mark.parametrize(
|
||||
"env_id, wrapper_name, kwargs",
|
||||
(
|
||||
("CustomDictEnv-v0", "FilterObservationV0", {"filter_keys": ["a"]}),
|
||||
("CartPole-v1", "FlattenObservationV0", {}),
|
||||
("CarRacing-v2", "GrayscaleObservationV0", {}),
|
||||
# ("CarRacing-v2", "ResizeObservationV0", {"shape": (35, 45)}),
|
||||
("CarRacing-v2", "ReshapeObservationV0", {"shape": (96, 48, 6)}),
|
||||
("CartPole-v1", "RescaleObservationV0", {"min_obs": 0, "max_obs": 1}),
|
||||
("CartPole-v1", "DtypeObservationV0", {"dtype": np.int32}),
|
||||
# ("CartPole-v1", "PixelObservationV0", {}),
|
||||
# ("CartPole-v1", "NormalizeObservationV0", {}),
|
||||
# ("CartPole-v1", "TimeAwareObservationV0", {}),
|
||||
# ("CartPole-v1", "FrameStackObservationV0", {}),
|
||||
# ("CartPole-v1", "DelayObservationV0", {}),
|
||||
("MountainCarContinuous-v0", "ClipActionV0", {}),
|
||||
(
|
||||
"MountainCarContinuous-v0",
|
||||
"RescaleActionV0",
|
||||
{"min_action": 1, "max_action": 2},
|
||||
),
|
||||
# ("CartPole-v1", "StickyActionV0", {}),
|
||||
("CartPole-v1", "ClipRewardV0", {"min_reward": 0.25, "max_reward": 0.75}),
|
||||
# ("CartPole-v1", "NormalizeRewardV1", {}),
|
||||
),
|
||||
)
|
||||
def test_vector_wrapper_equivalence(
|
||||
env_id: str,
|
||||
wrapper_name: str,
|
||||
kwargs: dict[str, Any],
|
||||
num_envs: int,
|
||||
custom_environments,
|
||||
vectorization_mode: str = "sync",
|
||||
num_steps: int = 50,
|
||||
):
|
||||
vector_wrapper = getattr(wrappers.vector, wrapper_name)
|
||||
wrapper_vector_env: VectorEnv = vector_wrapper(
|
||||
gym.make_vec(
|
||||
id=env_id, num_envs=num_envs, vectorization_mode=vectorization_mode
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
env_wrapper = getattr(wrappers, wrapper_name)
|
||||
vector_wrapper_env = gym.make_vec(
|
||||
id=env_id,
|
||||
num_envs=num_envs,
|
||||
vectorization_mode=vectorization_mode,
|
||||
wrappers=(lambda env: env_wrapper(env, **kwargs),),
|
||||
)
|
||||
|
||||
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
|
||||
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
|
||||
assert (
|
||||
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
|
||||
)
|
||||
assert (
|
||||
wrapper_vector_env.single_observation_space
|
||||
== vector_wrapper_env.single_observation_space
|
||||
)
|
||||
|
||||
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
|
||||
|
||||
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
|
||||
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
|
||||
|
||||
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
|
||||
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
|
||||
|
||||
for _ in range(num_steps):
|
||||
action = wrapper_vector_env.action_space.sample()
|
||||
wrapper_vector_step_returns = wrapper_vector_env.step(action)
|
||||
vector_wrapper_step_returns = vector_wrapper_env.step(action)
|
||||
|
||||
for wrapper_vector_return, vector_wrapper_return in zip(
|
||||
wrapper_vector_step_returns, vector_wrapper_step_returns
|
||||
):
|
||||
assert data_equivalence(wrapper_vector_return, vector_wrapper_return)
|
||||
|
||||
wrapper_vector_env.close()
|
||||
vector_wrapper_env.close()
|
||||
|
||||
|
||||
# ("CartPole-v1", "LambdaObservationV0", {"func": lambda obs: obs + 1}),
|
||||
# ("CartPole-v1", "LambdaActionV0", {"func": lambda action: action + 1}),
|
||||
# ("CartPole-v1", "LambdaRewardV0", {"func": lambda reward: reward + 1}),
|
||||
# (vector.JaxToNumpyV0, {}, {}),
|
||||
# (vector.JaxToTorchV0, {}, {}),
|
||||
# (vector.NumpyToTorchV0, {}, {}),
|
||||
# ("CartPole-v1", "RecordEpisodeStatisticsV0", {}), # for the time taken in info, this is not equivalent for two instances
|
Reference in New Issue
Block a user