mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 09:37:29 +00:00
Experimental wrapper changes (#517)
This commit is contained in:
@@ -54,27 +54,28 @@ _wrapper_to_class = {
|
|||||||
"LambdaActionV0": "lambda_action",
|
"LambdaActionV0": "lambda_action",
|
||||||
"ClipActionV0": "lambda_action",
|
"ClipActionV0": "lambda_action",
|
||||||
"RescaleActionV0": "lambda_action",
|
"RescaleActionV0": "lambda_action",
|
||||||
# lambda_observations.py
|
# lambda_observation.py
|
||||||
"LambdaObservationV0": "lambda_observations",
|
"LambdaObservationV0": "lambda_observation",
|
||||||
"FilterObservationV0": "lambda_observations",
|
"FilterObservationV0": "lambda_observation",
|
||||||
"FlattenObservationV0": "lambda_observations",
|
"FlattenObservationV0": "lambda_observation",
|
||||||
"GrayscaleObservationV0": "lambda_observations",
|
"GrayscaleObservationV0": "lambda_observation",
|
||||||
"ResizeObservationV0": "lambda_observations",
|
"ResizeObservationV0": "lambda_observation",
|
||||||
"ReshapeObservationV0": "lambda_observations",
|
"ReshapeObservationV0": "lambda_observation",
|
||||||
"RescaleObservationV0": "lambda_observations",
|
"RescaleObservationV0": "lambda_observation",
|
||||||
"DtypeObservationV0": "lambda_observations",
|
"DtypeObservationV0": "lambda_observation",
|
||||||
"PixelObservationV0": "lambda_observations",
|
"PixelObservationV0": "lambda_observation",
|
||||||
"NormalizeObservationV0": "lambda_observations",
|
|
||||||
# lambda_reward.py
|
# lambda_reward.py
|
||||||
"ClipRewardV0": "lambda_reward",
|
"ClipRewardV0": "lambda_reward",
|
||||||
"LambdaRewardV0": "lambda_reward",
|
"LambdaRewardV0": "lambda_reward",
|
||||||
"NormalizeRewardV1": "lambda_reward",
|
|
||||||
# stateful_action
|
# stateful_action
|
||||||
"StickyActionV0": "stateful_action",
|
"StickyActionV0": "stateful_action",
|
||||||
# stateful_observation
|
# stateful_observation
|
||||||
"TimeAwareObservationV0": "stateful_observation",
|
"TimeAwareObservationV0": "stateful_observation",
|
||||||
"DelayObservationV0": "stateful_observation",
|
"DelayObservationV0": "stateful_observation",
|
||||||
"FrameStackObservationV0": "stateful_observation",
|
"FrameStackObservationV0": "stateful_observation",
|
||||||
|
"NormalizeObservationV0": "stateful_observation",
|
||||||
|
# stateful_reward
|
||||||
|
"NormalizeRewardV1": "stateful_reward",
|
||||||
# atari_preprocessing
|
# atari_preprocessing
|
||||||
"AtariPreprocessingV0": "atari_preprocessing",
|
"AtariPreprocessingV0": "atari_preprocessing",
|
||||||
# common
|
# common
|
||||||
@@ -86,18 +87,10 @@ _wrapper_to_class = {
|
|||||||
"RenderCollectionV0": "rendering",
|
"RenderCollectionV0": "rendering",
|
||||||
"RecordVideoV0": "rendering",
|
"RecordVideoV0": "rendering",
|
||||||
"HumanRenderingV0": "rendering",
|
"HumanRenderingV0": "rendering",
|
||||||
# jax_to_numpy
|
# data converters
|
||||||
"JaxToNumpyV0": "jax_to_numpy",
|
"JaxToNumpyV0": "jax_to_numpy",
|
||||||
# "jax_to_numpy": "jax_to_numpy",
|
|
||||||
# "numpy_to_jax": "jax_to_numpy",
|
|
||||||
# jax_to_torch
|
|
||||||
"JaxToTorchV0": "jax_to_torch",
|
"JaxToTorchV0": "jax_to_torch",
|
||||||
# "jax_to_torch": "jax_to_torch",
|
|
||||||
# "torch_to_jax": "jax_to_torch",
|
|
||||||
# numpy_to_torch
|
|
||||||
"NumpyToTorchV0": "numpy_to_torch",
|
"NumpyToTorchV0": "numpy_to_torch",
|
||||||
# "torch_to_numpy": "numpy_to_torch",
|
|
||||||
# "numpy_to_torch": "numpy_to_torch",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -99,10 +99,10 @@ class JaxToNumpyV0(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||||
"""Wraps an environment such that the input and outputs are numpy arrays.
|
"""Wraps a jax environment such that the input and outputs are numpy arrays.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: the environment to wrap
|
env: the jax environment to wrap
|
||||||
"""
|
"""
|
||||||
if jnp is None:
|
if jnp is None:
|
||||||
raise DependencyNotInstalled(
|
raise DependencyNotInstalled(
|
||||||
@@ -120,7 +120,7 @@ class JaxToNumpyV0(
|
|||||||
action: the action to perform as a numpy array
|
action: the action to perform as a numpy array
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing the next observation, reward, termination, truncation, and extra info.
|
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
||||||
"""
|
"""
|
||||||
jax_action = numpy_to_jax(action)
|
jax_action = numpy_to_jax(action)
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
@@ -39,7 +39,7 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["jax_to_torch", "torch_to_jax", "JaxToTorchV0"]
|
__all__ = ["JaxToTorchV0", "jax_to_torch", "torch_to_jax", "Device"]
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
@@ -114,7 +114,7 @@ def _jax_iterable_to_torch(
|
|||||||
|
|
||||||
|
|
||||||
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
"""Wraps a Jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
|
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
"""A collection of wrappers that all use the LambdaAction class.
|
"""A collection of wrappers that all use the LambdaAction class.
|
||||||
|
|
||||||
* ``LambdaAction`` - Transforms the actions based on a function
|
* ``LambdaActionV0`` - Transforms the actions based on a function
|
||||||
* ``ClipAction`` - Clips the action within a bounds
|
* ``ClipActionV0`` - Clips the action within a bounds
|
||||||
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
|
* ``RescaleActionV0`` - Rescales the action within a minimum and maximum actions
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -34,8 +34,8 @@ class LambdaActionV0(
|
|||||||
"""Initialize LambdaAction.
|
"""Initialize LambdaAction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The gymnasium environment
|
env: The environment to wrap
|
||||||
func: Function to apply to ``step`` ``action``
|
func: Function to apply to the :meth:`step`'s ``action``
|
||||||
action_space: The updated action space of the wrapper given the function.
|
action_space: The updated action space of the wrapper given the function.
|
||||||
"""
|
"""
|
||||||
gym.utils.RecordConstructorArgs.__init__(
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
@@ -75,7 +75,7 @@ class ClipActionV0(
|
|||||||
"""A wrapper for clipping continuous actions within the valid bound.
|
"""A wrapper for clipping continuous actions within the valid bound.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to wrap
|
||||||
"""
|
"""
|
||||||
assert isinstance(env.action_space, Box)
|
assert isinstance(env.action_space, Box)
|
||||||
|
|
||||||
@@ -125,10 +125,10 @@ class RescaleActionV0(
|
|||||||
min_action: float | int | np.ndarray,
|
min_action: float | int | np.ndarray,
|
||||||
max_action: float | int | np.ndarray,
|
max_action: float | int | np.ndarray,
|
||||||
):
|
):
|
||||||
"""Initializes the :class:`RescaleAction` wrapper.
|
"""Constructor for the Rescale Action wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to wrap
|
||||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||||
"""
|
"""
|
||||||
|
@@ -9,7 +9,6 @@
|
|||||||
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
|
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
|
||||||
* ``DtypeObservationV0`` - Convert an observation to a dtype
|
* ``DtypeObservationV0`` - Convert an observation to a dtype
|
||||||
* ``PixelObservationV0`` - Allows the observation to the rendered frame
|
* ``PixelObservationV0`` - Allows the observation to the rendered frame
|
||||||
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -27,7 +26,6 @@ import gymnasium as gym
|
|||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
|
||||||
|
|
||||||
|
|
||||||
class LambdaObservationV0(
|
class LambdaObservationV0(
|
||||||
@@ -37,7 +35,7 @@ class LambdaObservationV0(
|
|||||||
"""Transforms an observation via a function provided to the wrapper.
|
"""Transforms an observation via a function provided to the wrapper.
|
||||||
|
|
||||||
The function :attr:`func` will be applied to all observations.
|
The function :attr:`func` will be applied to all observations.
|
||||||
If the observations from :attr:`func` are outside the bounds of the `env` spaces, provide a :attr:`observation_space`.
|
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import gymnasium as gym
|
>>> import gymnasium as gym
|
||||||
@@ -60,8 +58,8 @@ class LambdaObservationV0(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
|
func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`.
|
||||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
|
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
|
||||||
"""
|
"""
|
||||||
gym.utils.RecordConstructorArgs.__init__(
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
self, func=func, observation_space=observation_space
|
self, func=func, observation_space=observation_space
|
||||||
@@ -82,7 +80,7 @@ class FilterObservationV0(
|
|||||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
gym.utils.RecordConstructorArgs,
|
gym.utils.RecordConstructorArgs,
|
||||||
):
|
):
|
||||||
"""Filter Dict observation space by the keys.
|
"""Filters Dict or Tuple observation space by the keys or indexes.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import gymnasium as gym
|
>>> import gymnasium as gym
|
||||||
@@ -103,7 +101,12 @@ class FilterObservationV0(
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
|
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
|
||||||
):
|
):
|
||||||
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
"""Constructor for the filter observation wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
|
||||||
|
"""
|
||||||
assert isinstance(filter_keys, Sequence)
|
assert isinstance(filter_keys, Sequence)
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
||||||
|
|
||||||
@@ -177,7 +180,7 @@ class FilterObservationV0(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}"
|
f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||||
@@ -204,7 +207,11 @@ class FlattenObservationV0(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
"""
|
||||||
gym.utils.RecordConstructorArgs.__init__(self)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
LambdaObservationV0.__init__(
|
LambdaObservationV0.__init__(
|
||||||
self,
|
self,
|
||||||
@@ -237,7 +244,12 @@ class GrayscaleObservationV0(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
|
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
|
||||||
"""Constructor for an RGB image based environments to make the image grayscale."""
|
"""Constructor for an RGB image based environments to make the image grayscale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
|
||||||
|
"""
|
||||||
assert isinstance(env.observation_space, spaces.Box)
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
assert (
|
assert (
|
||||||
len(env.observation_space.shape) == 3
|
len(env.observation_space.shape) == 3
|
||||||
@@ -301,7 +313,12 @@ class ResizeObservationV0(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
|
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
|
||||||
"""Constructor that requires an image environment observation space with a shape."""
|
"""Constructor that requires an image environment observation space with a shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
shape: The resized observation shape
|
||||||
|
"""
|
||||||
assert isinstance(env.observation_space, spaces.Box)
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
assert len(env.observation_space.shape) in [2, 3]
|
assert len(env.observation_space.shape) in [2, 3]
|
||||||
assert np.all(env.observation_space.low == 0) and np.all(
|
assert np.all(env.observation_space.low == 0) and np.all(
|
||||||
@@ -323,7 +340,10 @@ class ResizeObservationV0(
|
|||||||
self.shape: Final[tuple[int, ...]] = tuple(shape)
|
self.shape: Final[tuple[int, ...]] = tuple(shape)
|
||||||
|
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=self.shape + env.observation_space.shape[2:],
|
||||||
|
dtype=np.uint8,
|
||||||
)
|
)
|
||||||
|
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||||
@@ -353,7 +373,12 @@ class ReshapeObservationV0(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
|
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
|
||||||
"""Constructor for env with Box observation space that has a shape product equal to the new shape product."""
|
"""Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
shape: The reshaped observation space
|
||||||
|
"""
|
||||||
assert isinstance(env.observation_space, spaces.Box)
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
assert np.product(shape) == np.product(env.observation_space.shape)
|
assert np.product(shape) == np.product(env.observation_space.shape)
|
||||||
|
|
||||||
@@ -401,7 +426,13 @@ class RescaleObservationV0(
|
|||||||
min_obs: np.floating | np.integer | np.ndarray,
|
min_obs: np.floating | np.integer | np.ndarray,
|
||||||
max_obs: np.floating | np.integer | np.ndarray,
|
max_obs: np.floating | np.integer | np.ndarray,
|
||||||
):
|
):
|
||||||
"""Constructor that requires the env observation spaces to be a :class:`Box`."""
|
"""Constructor that requires the env observation spaces to be a :class:`Box`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
min_obs: The new minimum observation bound
|
||||||
|
max_obs: The new maximum observation bound
|
||||||
|
"""
|
||||||
assert isinstance(env.observation_space, spaces.Box)
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
assert not np.any(env.observation_space.low == np.inf) and not np.any(
|
assert not np.any(env.observation_space.low == np.inf) and not np.any(
|
||||||
env.observation_space.high == np.inf
|
env.observation_space.high == np.inf
|
||||||
@@ -452,10 +483,19 @@ class DtypeObservationV0(
|
|||||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
gym.utils.RecordConstructorArgs,
|
gym.utils.RecordConstructorArgs,
|
||||||
):
|
):
|
||||||
"""Observation wrapper for transforming the dtype of an observation."""
|
"""Observation wrapper for transforming the dtype of an observation.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
||||||
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
|
"""Constructor for Dtype observation wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
dtype: The new dtype of the observation
|
||||||
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
env.observation_space,
|
env.observation_space,
|
||||||
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
||||||
@@ -505,7 +545,7 @@ class PixelObservationV0(
|
|||||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
gym.utils.RecordConstructorArgs,
|
gym.utils.RecordConstructorArgs,
|
||||||
):
|
):
|
||||||
"""Augment observations by pixel values.
|
"""Includes the rendered observations to the environment's observations.
|
||||||
|
|
||||||
Observations of this wrapper will be dictionaries of images.
|
Observations of this wrapper will be dictionaries of images.
|
||||||
You can also choose to add the observation of the base environment to this dictionary.
|
You can also choose to add the observation of the base environment to this dictionary.
|
||||||
@@ -522,13 +562,13 @@ class PixelObservationV0(
|
|||||||
pixels_key: str = "pixels",
|
pixels_key: str = "pixels",
|
||||||
obs_key: str = "state",
|
obs_key: str = "state",
|
||||||
):
|
):
|
||||||
"""Initializes a new pixel Wrapper.
|
"""Constructor of the pixel observation wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to wrap.
|
env: The environment to wrap.
|
||||||
pixels_only (bool): If `True` (default), the original observation returned
|
pixels_only (bool): If ``True`` (default), the original observation returned
|
||||||
by the wrapped environment will be discarded, and a dictionary
|
by the wrapped environment will be discarded, and a dictionary
|
||||||
observation will only include pixels. If `False`, the
|
observation will only include pixels. If ``False``, the
|
||||||
observation dictionary will contain both the original
|
observation dictionary will contain both the original
|
||||||
observations and the pixel observations.
|
observations and the pixel observations.
|
||||||
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
||||||
@@ -571,51 +611,3 @@ class PixelObservationV0(
|
|||||||
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
|
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
|
||||||
observation_space=obs_space,
|
observation_space=obs_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeObservationV0(
|
|
||||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
|
||||||
gym.utils.RecordConstructorArgs,
|
|
||||||
):
|
|
||||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
|
||||||
|
|
||||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
|
||||||
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
|
|
||||||
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
|
|
||||||
newly instantiated or the policy was changed recently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
|
|
||||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env (Env): The environment to apply the wrapper
|
|
||||||
epsilon: A stability parameter that is used when scaling the observations.
|
|
||||||
"""
|
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
|
||||||
gym.ObservationWrapper.__init__(self, env)
|
|
||||||
|
|
||||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self._update_running_mean = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def update_running_mean(self) -> bool:
|
|
||||||
"""Property to freeze/continue the running mean calculation of the observation statistics."""
|
|
||||||
return self._update_running_mean
|
|
||||||
|
|
||||||
@update_running_mean.setter
|
|
||||||
def update_running_mean(self, setting: bool):
|
|
||||||
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
|
|
||||||
self._update_running_mean = setting
|
|
||||||
|
|
||||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
|
||||||
"""Normalises the observation using the running mean and variance of the observations."""
|
|
||||||
if self._update_running_mean:
|
|
||||||
self.obs_rms.update(observation)
|
|
||||||
return (observation - self.obs_rms.mean) / np.sqrt(
|
|
||||||
self.obs_rms.var + self.epsilon
|
|
||||||
)
|
|
@@ -1,18 +1,17 @@
|
|||||||
"""A collection of wrappers for modifying the reward.
|
"""A collection of wrappers for modifying the reward.
|
||||||
|
|
||||||
* ``LambdaReward`` - Transforms the reward by a function
|
* ``LambdaRewardV0`` - Transforms the reward by a function
|
||||||
* ``ClipReward`` - Clips the reward between a minimum and maximum value
|
* ``ClipRewardV0`` - Clips the reward between a minimum and maximum value
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, SupportsFloat
|
from typing import Callable, SupportsFloat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from gymnasium.error import InvalidBound
|
from gymnasium.error import InvalidBound
|
||||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
|
||||||
|
|
||||||
|
|
||||||
class LambdaRewardV0(
|
class LambdaRewardV0(
|
||||||
@@ -39,7 +38,7 @@ class LambdaRewardV0(
|
|||||||
"""Initialize LambdaRewardV0 wrapper.
|
"""Initialize LambdaRewardV0 wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to wrap
|
||||||
func: (Callable): The function to apply to reward
|
func: (Callable): The function to apply to reward
|
||||||
"""
|
"""
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, func=func)
|
gym.utils.RecordConstructorArgs.__init__(self, func=func)
|
||||||
@@ -79,7 +78,7 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
|
|||||||
"""Initialize ClipRewardsV0 wrapper.
|
"""Initialize ClipRewardsV0 wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to wrap
|
||||||
min_reward (Union[float, np.ndarray]): lower bound to apply
|
min_reward (Union[float, np.ndarray]): lower bound to apply
|
||||||
max_reward (Union[float, np.ndarray]): higher bound to apply
|
max_reward (Union[float, np.ndarray]): higher bound to apply
|
||||||
"""
|
"""
|
||||||
@@ -98,72 +97,3 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
|
|||||||
LambdaRewardV0.__init__(
|
LambdaRewardV0.__init__(
|
||||||
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
|
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeRewardV1(
|
|
||||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
|
||||||
):
|
|
||||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
|
||||||
|
|
||||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
|
||||||
|
|
||||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
|
|
||||||
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
|
|
||||||
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
|
|
||||||
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
|
|
||||||
instantiated or the policy was changed recently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
env: gym.Env[ObsType, ActType],
|
|
||||||
gamma: float = 0.99,
|
|
||||||
epsilon: float = 1e-8,
|
|
||||||
):
|
|
||||||
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env (env): The environment to apply the wrapper
|
|
||||||
epsilon (float): A stability parameter
|
|
||||||
gamma (float): The discount factor that is used in the exponential moving average.
|
|
||||||
"""
|
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
|
||||||
gym.Wrapper.__init__(self, env)
|
|
||||||
|
|
||||||
self.rewards_running_means = RunningMeanStd(shape=())
|
|
||||||
self.discounted_reward: np.array = np.array([0.0])
|
|
||||||
self.gamma = gamma
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self._update_running_mean = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def update_running_mean(self) -> bool:
|
|
||||||
"""Property to freeze/continue the running mean calculation of the reward statistics."""
|
|
||||||
return self._update_running_mean
|
|
||||||
|
|
||||||
@update_running_mean.setter
|
|
||||||
def update_running_mean(self, setting: bool):
|
|
||||||
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
|
|
||||||
self._update_running_mean = setting
|
|
||||||
|
|
||||||
def step(
|
|
||||||
self, action: ActType
|
|
||||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
||||||
"""Steps through the environment, normalizing the reward returned."""
|
|
||||||
obs, reward, terminated, truncated, info = super().step(action)
|
|
||||||
self.discounted_reward = self.discounted_reward * self.gamma * (
|
|
||||||
1 - terminated
|
|
||||||
) + float(reward)
|
|
||||||
return obs, self.normalize(float(reward)), terminated, truncated, info
|
|
||||||
|
|
||||||
def normalize(self, reward: SupportsFloat):
|
|
||||||
"""Normalizes the rewards with the running mean rewards and their variance."""
|
|
||||||
if self._update_running_mean:
|
|
||||||
self.rewards_running_means.update(self.discounted_reward)
|
|
||||||
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)
|
|
||||||
|
@@ -111,13 +111,13 @@ class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
def step(
|
def step(
|
||||||
self, action: WrapperActType
|
self, action: WrapperActType
|
||||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
"""Performs the given action within the environment.
|
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action: The action to perform as a PyTorch Tensor
|
action: A PyTorch-based action
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The next observation, reward, termination, truncation, and extra info
|
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||||
"""
|
"""
|
||||||
jax_action = torch_to_numpy(action)
|
jax_action = torch_to_numpy(action)
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
|
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
|
||||||
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
|
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
|
||||||
* ``FrameStackObservationV0`` - Frame stack the observations
|
* ``FrameStackObservationV0`` - Frame stack the observations
|
||||||
|
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ from gymnasium.experimental.vector.utils import (
|
|||||||
concatenate,
|
concatenate,
|
||||||
create_empty_array,
|
create_empty_array,
|
||||||
)
|
)
|
||||||
from gymnasium.experimental.wrappers.utils import create_zero_array
|
from gymnasium.experimental.wrappers.utils import RunningMeanStd, create_zero_array
|
||||||
from gymnasium.spaces import Box, Dict, Tuple
|
from gymnasium.spaces import Box, Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
@@ -382,3 +383,51 @@ class FrameStackObservationV0(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return updated_obs, info
|
return updated_obs, info
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeObservationV0(
|
||||||
|
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
|
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||||
|
|
||||||
|
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
||||||
|
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
|
||||||
|
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
|
||||||
|
newly instantiated or the policy was changed recently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
|
||||||
|
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): The environment to apply the wrapper
|
||||||
|
epsilon: A stability parameter that is used when scaling the observations.
|
||||||
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
|
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self._update_running_mean = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def update_running_mean(self) -> bool:
|
||||||
|
"""Property to freeze/continue the running mean calculation of the observation statistics."""
|
||||||
|
return self._update_running_mean
|
||||||
|
|
||||||
|
@update_running_mean.setter
|
||||||
|
def update_running_mean(self, setting: bool):
|
||||||
|
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
|
||||||
|
self._update_running_mean = setting
|
||||||
|
|
||||||
|
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||||
|
"""Normalises the observation using the running mean and variance of the observations."""
|
||||||
|
if self._update_running_mean:
|
||||||
|
self.obs_rms.update(observation)
|
||||||
|
return (observation - self.obs_rms.mean) / np.sqrt(
|
||||||
|
self.obs_rms.var + self.epsilon
|
||||||
|
)
|
||||||
|
82
gymnasium/experimental/wrappers/stateful_reward.py
Normal file
82
gymnasium/experimental/wrappers/stateful_reward.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""A collection of wrappers for modifying the reward with an internal state.
|
||||||
|
|
||||||
|
* ``NormalizeRewardV1`` - Normalizes the rewards to a mean and standard deviation
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, SupportsFloat
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeRewardV1(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
|
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
|
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||||
|
|
||||||
|
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
|
||||||
|
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
|
||||||
|
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
|
||||||
|
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
|
||||||
|
instantiated or the policy was changed recently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env[ObsType, ActType],
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epsilon: float = 1e-8,
|
||||||
|
):
|
||||||
|
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (env): The environment to apply the wrapper
|
||||||
|
epsilon (float): A stability parameter
|
||||||
|
gamma (float): The discount factor that is used in the exponential moving average.
|
||||||
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
self.rewards_running_means = RunningMeanStd(shape=())
|
||||||
|
self.discounted_reward: np.array = np.array([0.0])
|
||||||
|
self.gamma = gamma
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self._update_running_mean = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def update_running_mean(self) -> bool:
|
||||||
|
"""Property to freeze/continue the running mean calculation of the reward statistics."""
|
||||||
|
return self._update_running_mean
|
||||||
|
|
||||||
|
@update_running_mean.setter
|
||||||
|
def update_running_mean(self, setting: bool):
|
||||||
|
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
|
||||||
|
self._update_running_mean = setting
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: ActType
|
||||||
|
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
|
"""Steps through the environment, normalizing the reward returned."""
|
||||||
|
obs, reward, terminated, truncated, info = super().step(action)
|
||||||
|
self.discounted_reward = self.discounted_reward * self.gamma * (
|
||||||
|
1 - terminated
|
||||||
|
) + float(reward)
|
||||||
|
return obs, self.normalize(float(reward)), terminated, truncated, info
|
||||||
|
|
||||||
|
def normalize(self, reward: SupportsFloat):
|
||||||
|
"""Normalizes the rewards with the running mean rewards and their variance."""
|
||||||
|
if self._update_running_mean:
|
||||||
|
self.rewards_running_means.update(self.discounted_reward)
|
||||||
|
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)
|
@@ -1,12 +1,44 @@
|
|||||||
"""Wrappers for vector environments."""
|
"""Wrappers for vector environments."""
|
||||||
|
# pyright: reportUnsupportedDunderAll=false
|
||||||
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
|
import importlib
|
||||||
VectorRecordEpisodeStatistics,
|
|
||||||
)
|
|
||||||
from gymnasium.experimental.wrappers.vector.vector_list_info import VectorListInfo
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"VectorRecordEpisodeStatistics",
|
# --- Vector only wrappers
|
||||||
"VectorListInfo",
|
"VectoriseLambdaObservationV0",
|
||||||
|
"VectoriseLambdaActionV0",
|
||||||
|
"VectoriseLambdaRewardV0",
|
||||||
|
"DictInfoToListV0",
|
||||||
|
# --- Observation wrappers ---
|
||||||
|
"LambdaObservationV0",
|
||||||
|
"FilterObservationV0",
|
||||||
|
"FlattenObservationV0",
|
||||||
|
"GrayscaleObservationV0",
|
||||||
|
"ResizeObservationV0",
|
||||||
|
"ReshapeObservationV0",
|
||||||
|
"RescaleObservationV0",
|
||||||
|
"DtypeObservationV0",
|
||||||
|
"PixelObservationV0",
|
||||||
|
"NormalizeObservationV0",
|
||||||
|
# "TimeAwareObservationV0",
|
||||||
|
# "FrameStackObservationV0",
|
||||||
|
# "DelayObservationV0",
|
||||||
|
# --- Action Wrappers ---
|
||||||
|
"LambdaActionV0",
|
||||||
|
"ClipActionV0",
|
||||||
|
"RescaleActionV0",
|
||||||
|
# --- Reward wrappers ---
|
||||||
|
"LambdaRewardV0",
|
||||||
|
"ClipRewardV0",
|
||||||
|
"NormalizeRewardV1",
|
||||||
|
# --- Common ---
|
||||||
|
"RecordEpisodeStatisticsV0",
|
||||||
|
# --- Rendering ---
|
||||||
|
# "RenderCollectionV0",
|
||||||
|
# "RecordVideoV0",
|
||||||
|
# "HumanRenderingV0",
|
||||||
|
# --- Conversion ---
|
||||||
|
"JaxToNumpyV0",
|
||||||
|
"JaxToTorchV0",
|
||||||
|
"NumpyToTorchV0",
|
||||||
]
|
]
|
||||||
|
@@ -1,11 +1,13 @@
|
|||||||
"""Wrapper that converts the info format for vec envs into the list format."""
|
"""Wrapper that converts the info format for vec envs into the list format."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List
|
from typing import Any
|
||||||
|
|
||||||
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
||||||
|
|
||||||
|
|
||||||
class VectorListInfo(VectorWrapper):
|
class DictInfoToListV0(VectorWrapper):
|
||||||
"""Converts infos of vectorized environments from dict to List[dict].
|
"""Converts infos of vectorized environments from dict to List[dict].
|
||||||
|
|
||||||
This wrapper converts the info format of a
|
This wrapper converts the info format of a
|
||||||
@@ -15,17 +17,16 @@ class VectorListInfo(VectorWrapper):
|
|||||||
operation on info like `RecordEpisodeStatistics` this
|
operation on info like `RecordEpisodeStatistics` this
|
||||||
need to be the outermost wrapper.
|
need to be the outermost wrapper.
|
||||||
|
|
||||||
i.e. `VectorListInfo(RecordEpisodeStatistics(envs))`
|
i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))``
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> # actual
|
>>> import numpy as np
|
||||||
>>> { # doctest: +SKIP
|
>>> dict_info = {
|
||||||
... "k": np.array([0., 0., 0.5, 0.3]),
|
... "k": np.array([0., 0., 0.5, 0.3]),
|
||||||
... "_k": np.array([False, False, True, True])
|
... "_k": np.array([False, False, True, True])
|
||||||
... }
|
... }
|
||||||
>>> # classic
|
>>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
|
||||||
>>> [{}, {}, {k: 0.5}, {k: 0.3}] # doctest: +SKIP
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: VectorEnv):
|
def __init__(self, env: VectorEnv):
|
||||||
@@ -36,20 +37,28 @@ class VectorListInfo(VectorWrapper):
|
|||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
def step(self, action):
|
def step(
|
||||||
|
self, actions: ActType
|
||||||
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
|
||||||
"""Steps through the environment, convert dict info to list."""
|
"""Steps through the environment, convert dict info to list."""
|
||||||
observation, reward, terminated, truncated, infos = self.env.step(action)
|
observation, reward, terminated, truncated, infos = self.env.step(actions)
|
||||||
list_info = self._convert_info_to_list(infos)
|
list_info = self._convert_info_to_list(infos)
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, list_info
|
return observation, reward, terminated, truncated, list_info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | list[int] | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[ObsType, list[dict[str, Any]]]:
|
||||||
"""Resets the environment using kwargs."""
|
"""Resets the environment using kwargs."""
|
||||||
obs, infos = self.env.reset(**kwargs)
|
obs, infos = self.env.reset(seed=seed, options=options)
|
||||||
list_info = self._convert_info_to_list(infos)
|
list_info = self._convert_info_to_list(infos)
|
||||||
|
|
||||||
return obs, list_info
|
return obs, list_info
|
||||||
|
|
||||||
def _convert_info_to_list(self, infos: dict) -> List[dict]:
|
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
|
||||||
"""Convert the dict info to list.
|
"""Convert the dict info to list.
|
||||||
|
|
||||||
Convert the dict info of the vectorized environment
|
Convert the dict info of the vectorized environment
|
||||||
@@ -72,36 +81,3 @@ class VectorListInfo(VectorWrapper):
|
|||||||
if has_info:
|
if has_info:
|
||||||
list_info[i][k] = infos[k][i]
|
list_info[i][k] = infos[k][i]
|
||||||
return list_info
|
return list_info
|
||||||
|
|
||||||
def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]:
|
|
||||||
"""Process episode statistics.
|
|
||||||
|
|
||||||
`RecordEpisodeStatistics` wrapper add extra
|
|
||||||
information to the info. This information are in
|
|
||||||
the form of a dict of dict. This method process these
|
|
||||||
information and add them to the info.
|
|
||||||
`RecordEpisodeStatistics` info contains the keys
|
|
||||||
"r", "l", "t" which represents "cumulative reward",
|
|
||||||
"episode length", "elapsed time since instantiation of wrapper".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
infos (dict): infos coming from `RecordEpisodeStatistics`.
|
|
||||||
list_info (list): info of the current vectorized environment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list_info (list): updated info.
|
|
||||||
|
|
||||||
"""
|
|
||||||
episode_statistics = infos.pop("episode", False)
|
|
||||||
if not episode_statistics:
|
|
||||||
return list_info
|
|
||||||
|
|
||||||
episode_statistics_mask = infos.pop("_episode")
|
|
||||||
for i, has_info in enumerate(episode_statistics_mask):
|
|
||||||
if has_info:
|
|
||||||
list_info[i]["episode"] = {}
|
|
||||||
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
|
|
||||||
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
|
|
||||||
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
|
|
||||||
|
|
||||||
return list_info
|
|
76
gymnasium/experimental/wrappers/vector/jax_to_numpy.py
Normal file
76
gymnasium/experimental/wrappers/vector/jax_to_numpy.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Vector wrapper for converting between NumPy and Jax."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||||
|
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
||||||
|
|
||||||
|
|
||||||
|
class JaxToNumpyV0(VectorWrapper):
|
||||||
|
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
A vectorised version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
|
||||||
|
|
||||||
|
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: VectorEnv):
|
||||||
|
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: the vector jax environment to wrap
|
||||||
|
"""
|
||||||
|
if jnp is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"jax is not installed, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, actions: ActType
|
||||||
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||||
|
"""Transforms the action to a jax array .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: the action to perform as a numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
||||||
|
"""
|
||||||
|
jax_actions = numpy_to_jax(actions)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
|
||||||
|
|
||||||
|
return (
|
||||||
|
jax_to_numpy(obs),
|
||||||
|
jax_to_numpy(reward),
|
||||||
|
jax_to_numpy(terminated),
|
||||||
|
jax_to_numpy(truncated),
|
||||||
|
jax_to_numpy(info),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | list[int] | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning numpy-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Numpy-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = numpy_to_jax(options)
|
||||||
|
|
||||||
|
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
73
gymnasium/experimental/wrappers/vector/jax_to_torch.py
Normal file
73
gymnasium/experimental/wrappers/vector/jax_to_torch.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Vector wrapper class for converting between PyTorch and Jax."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||||
|
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_torch import (
|
||||||
|
Device,
|
||||||
|
jax_to_torch,
|
||||||
|
torch_to_jax,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JaxToTorchV0(VectorWrapper):
|
||||||
|
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
|
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: VectorEnv, device: Device | None = None):
|
||||||
|
"""Vector wrapper to change inputs and outputs to PyTorch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Jax-based vector environment to wrap
|
||||||
|
device: The device the torch Tensors should be moved to
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
self.device: Device | None = device
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, actions: ActType
|
||||||
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||||
|
"""Performs the given action within the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: The action to perform as a PyTorch Tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
|
||||||
|
"""
|
||||||
|
jax_action = torch_to_jax(actions)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
|
||||||
|
return (
|
||||||
|
jax_to_torch(obs, self.device),
|
||||||
|
jax_to_torch(reward, self.device),
|
||||||
|
jax_to_torch(terminated, self.device),
|
||||||
|
jax_to_torch(truncated, self.device),
|
||||||
|
jax_to_torch(info, self.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | list[int] | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning PyTorch-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTorch-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = torch_to_jax(options)
|
||||||
|
|
||||||
|
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
70
gymnasium/experimental/wrappers/vector/numpy_to_torch.py
Normal file
70
gymnasium/experimental/wrappers/vector/numpy_to_torch.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""Wrapper for converting NumPy environments to PyTorch."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||||
|
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_torch import Device
|
||||||
|
from gymnasium.experimental.wrappers.numpy_to_torch import (
|
||||||
|
numpy_to_torch,
|
||||||
|
torch_to_numpy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyToTorchV0(VectorWrapper):
|
||||||
|
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors."""
|
||||||
|
|
||||||
|
def __init__(self, env: VectorEnv, device: Device | None = None):
|
||||||
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Jax-based vector environment to wrap
|
||||||
|
device: The device the torch Tensors should be moved to
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
self.device: Device | None = device
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, actions: ActType
|
||||||
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||||
|
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: A PyTorch-based action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||||
|
"""
|
||||||
|
jax_action = torch_to_numpy(actions)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
|
||||||
|
return (
|
||||||
|
numpy_to_torch(obs, self.device),
|
||||||
|
numpy_to_torch(reward, self.device),
|
||||||
|
numpy_to_torch(terminated, self.device),
|
||||||
|
numpy_to_torch(truncated, self.device),
|
||||||
|
numpy_to_torch(info, self.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | list[int] | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning PyTorch-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTorch-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = torch_to_numpy(options)
|
||||||
|
|
||||||
|
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
@@ -1,14 +1,16 @@
|
|||||||
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
||||||
|
|
||||||
|
|
||||||
class VectorRecordEpisodeStatistics(VectorWrapper):
|
class RecordEpisodeStatisticsV0(VectorWrapper):
|
||||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||||
@@ -32,9 +34,9 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
|||||||
>>> infos = { # doctest: +SKIP
|
>>> infos = { # doctest: +SKIP
|
||||||
... ...
|
... ...
|
||||||
... "episode": {
|
... "episode": {
|
||||||
... "r": "<array of cumulative reward>",
|
... "r": "<array of cumulative reward for each done sub-environment>",
|
||||||
... "l": "<array of episode length>",
|
... "l": "<array of episode length for each done sub-environment>",
|
||||||
... "t": "<array of elapsed time since beginning of episode>"
|
... "t": "<array of elapsed time since beginning of episode for each done sub-environment>"
|
||||||
... },
|
... },
|
||||||
... "_episode": "<boolean array of length num-envs>"
|
... "_episode": "<boolean array of length num-envs>"
|
||||||
... }
|
... }
|
||||||
@@ -55,30 +57,35 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
|||||||
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
|
||||||
self.episode_count = 0
|
self.episode_count = 0
|
||||||
self.episode_start_times: np.ndarray = None
|
|
||||||
self.episode_returns: Optional[np.ndarray] = None
|
self.episode_start_times: np.ndarray = np.zeros(())
|
||||||
self.episode_lengths: Optional[np.ndarray] = None
|
self.episode_returns: np.ndarray = np.zeros(())
|
||||||
|
self.episode_lengths: np.ndarray = np.zeros(())
|
||||||
|
|
||||||
self.return_queue = deque(maxlen=deque_size)
|
self.return_queue = deque(maxlen=deque_size)
|
||||||
self.length_queue = deque(maxlen=deque_size)
|
self.length_queue = deque(maxlen=deque_size)
|
||||||
self.is_vector_env = True
|
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: int | list[int] | None = None,
|
||||||
options: Optional[dict] = None,
|
options: dict | None = None,
|
||||||
):
|
):
|
||||||
"""Resets the environment using kwargs and resets the episode returns and lengths."""
|
"""Resets the environment using kwargs and resets the episode returns and lengths."""
|
||||||
obs, info = super().reset(seed=seed, options=options)
|
obs, info = super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
self.episode_start_times = np.full(
|
self.episode_start_times = np.full(
|
||||||
self.num_envs, time.perf_counter(), dtype=np.float32
|
self.num_envs, time.perf_counter(), dtype=np.float32
|
||||||
)
|
)
|
||||||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
||||||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
||||||
|
|
||||||
return obs, info
|
return obs, info
|
||||||
|
|
||||||
def step(self, action):
|
def step(
|
||||||
|
self, actions: ActType
|
||||||
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||||
"""Steps through the environment, recording the episode statistics."""
|
"""Steps through the environment, recording the episode statistics."""
|
||||||
(
|
(
|
||||||
observations,
|
observations,
|
||||||
@@ -86,14 +93,18 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
|||||||
terminations,
|
terminations,
|
||||||
truncations,
|
truncations,
|
||||||
infos,
|
infos,
|
||||||
) = self.env.step(action)
|
) = self.env.step(actions)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
infos, dict
|
infos, dict
|
||||||
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
||||||
|
|
||||||
self.episode_returns += rewards
|
self.episode_returns += rewards
|
||||||
self.episode_lengths += 1
|
self.episode_lengths += 1
|
||||||
|
|
||||||
dones = np.logical_or(terminations, truncations)
|
dones = np.logical_or(terminations, truncations)
|
||||||
num_dones = np.sum(dones)
|
num_dones = np.sum(dones)
|
||||||
|
|
||||||
if num_dones:
|
if num_dones:
|
||||||
if "episode" in infos or "_episode" in infos:
|
if "episode" in infos or "_episode" in infos:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -109,14 +120,18 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
|||||||
0.0,
|
0.0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if self.is_vector_env:
|
infos["_episode"] = dones
|
||||||
infos["_episode"] = np.where(dones, True, False)
|
|
||||||
self.return_queue.extend(self.episode_returns[dones])
|
|
||||||
self.length_queue.extend(self.episode_lengths[dones])
|
|
||||||
self.episode_count += num_dones
|
self.episode_count += num_dones
|
||||||
|
|
||||||
|
for i in np.where(dones):
|
||||||
|
self.return_queue.extend(self.episode_returns[i])
|
||||||
|
self.length_queue.extend(self.episode_lengths[i])
|
||||||
|
|
||||||
self.episode_lengths[dones] = 0
|
self.episode_lengths[dones] = 0
|
||||||
self.episode_returns[dones] = 0
|
self.episode_returns[dones] = 0
|
||||||
self.episode_start_times[dones] = time.perf_counter()
|
self.episode_start_times[dones] = time.perf_counter()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
observations,
|
observations,
|
||||||
rewards,
|
rewards,
|
||||||
|
Reference in New Issue
Block a user