Experimental wrapper changes (#517)

This commit is contained in:
Mark Towers
2023-05-23 15:46:04 +01:00
committed by GitHub
parent 22a00c2a75
commit 5bf6c1e93f
15 changed files with 541 additions and 253 deletions

View File

@@ -54,27 +54,28 @@ _wrapper_to_class = {
"LambdaActionV0": "lambda_action",
"ClipActionV0": "lambda_action",
"RescaleActionV0": "lambda_action",
# lambda_observations.py
"LambdaObservationV0": "lambda_observations",
"FilterObservationV0": "lambda_observations",
"FlattenObservationV0": "lambda_observations",
"GrayscaleObservationV0": "lambda_observations",
"ResizeObservationV0": "lambda_observations",
"ReshapeObservationV0": "lambda_observations",
"RescaleObservationV0": "lambda_observations",
"DtypeObservationV0": "lambda_observations",
"PixelObservationV0": "lambda_observations",
"NormalizeObservationV0": "lambda_observations",
# lambda_observation.py
"LambdaObservationV0": "lambda_observation",
"FilterObservationV0": "lambda_observation",
"FlattenObservationV0": "lambda_observation",
"GrayscaleObservationV0": "lambda_observation",
"ResizeObservationV0": "lambda_observation",
"ReshapeObservationV0": "lambda_observation",
"RescaleObservationV0": "lambda_observation",
"DtypeObservationV0": "lambda_observation",
"PixelObservationV0": "lambda_observation",
# lambda_reward.py
"ClipRewardV0": "lambda_reward",
"LambdaRewardV0": "lambda_reward",
"NormalizeRewardV1": "lambda_reward",
# stateful_action
"StickyActionV0": "stateful_action",
# stateful_observation
"TimeAwareObservationV0": "stateful_observation",
"DelayObservationV0": "stateful_observation",
"FrameStackObservationV0": "stateful_observation",
"NormalizeObservationV0": "stateful_observation",
# stateful_reward
"NormalizeRewardV1": "stateful_reward",
# atari_preprocessing
"AtariPreprocessingV0": "atari_preprocessing",
# common
@@ -86,18 +87,10 @@ _wrapper_to_class = {
"RenderCollectionV0": "rendering",
"RecordVideoV0": "rendering",
"HumanRenderingV0": "rendering",
# jax_to_numpy
# data converters
"JaxToNumpyV0": "jax_to_numpy",
# "jax_to_numpy": "jax_to_numpy",
# "numpy_to_jax": "jax_to_numpy",
# jax_to_torch
"JaxToTorchV0": "jax_to_torch",
# "jax_to_torch": "jax_to_torch",
# "torch_to_jax": "jax_to_torch",
# numpy_to_torch
"NumpyToTorchV0": "numpy_to_torch",
# "torch_to_numpy": "numpy_to_torch",
# "numpy_to_torch": "numpy_to_torch",
}

View File

@@ -99,10 +99,10 @@ class JaxToNumpyV0(
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Wraps an environment such that the input and outputs are numpy arrays.
"""Wraps a jax environment such that the input and outputs are numpy arrays.
Args:
env: the environment to wrap
env: the jax environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
@@ -120,7 +120,7 @@ class JaxToNumpyV0(
action: the action to perform as a numpy array
Returns:
A tuple containing the next observation, reward, termination, truncation, and extra info.
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_action = numpy_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)

View File

@@ -39,7 +39,7 @@ except ImportError:
)
__all__ = ["jax_to_torch", "torch_to_jax", "JaxToTorchV0"]
__all__ = ["JaxToTorchV0", "jax_to_torch", "torch_to_jax", "Device"]
@functools.singledispatch
@@ -114,7 +114,7 @@ def _jax_iterable_to_torch(
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
"""Wraps a Jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.

View File

@@ -1,8 +1,8 @@
"""A collection of wrappers that all use the LambdaAction class.
* ``LambdaAction`` - Transforms the actions based on a function
* ``ClipAction`` - Clips the action within a bounds
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
* ``LambdaActionV0`` - Transforms the actions based on a function
* ``ClipActionV0`` - Clips the action within a bounds
* ``RescaleActionV0`` - Rescales the action within a minimum and maximum actions
"""
from __future__ import annotations
@@ -34,8 +34,8 @@ class LambdaActionV0(
"""Initialize LambdaAction.
Args:
env: The gymnasium environment
func: Function to apply to ``step`` ``action``
env: The environment to wrap
func: Function to apply to the :meth:`step`'s ``action``
action_space: The updated action space of the wrapper given the function.
"""
gym.utils.RecordConstructorArgs.__init__(
@@ -75,7 +75,7 @@ class ClipActionV0(
"""A wrapper for clipping continuous actions within the valid bound.
Args:
env: The environment to apply the wrapper
env: The environment to wrap
"""
assert isinstance(env.action_space, Box)
@@ -125,10 +125,10 @@ class RescaleActionV0(
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Initializes the :class:`RescaleAction` wrapper.
"""Constructor for the Rescale Action wrapper.
Args:
env (Env): The environment to apply the wrapper
env (Env): The environment to wrap
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""

View File

@@ -9,7 +9,6 @@
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservationV0`` - Convert an observation to a dtype
* ``PixelObservationV0`` - Allows the observation to the rendered frame
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
"""
from __future__ import annotations
@@ -27,7 +26,6 @@ import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.utils import RunningMeanStd
class LambdaObservationV0(
@@ -37,7 +35,7 @@ class LambdaObservationV0(
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all observations.
If the observations from :attr:`func` are outside the bounds of the `env` spaces, provide a :attr:`observation_space`.
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
Example:
>>> import gymnasium as gym
@@ -60,8 +58,8 @@ class LambdaObservationV0(
Args:
env: The environment to wrap
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
gym.utils.RecordConstructorArgs.__init__(
self, func=func, observation_space=observation_space
@@ -82,7 +80,7 @@ class FilterObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Filter Dict observation space by the keys.
"""Filters Dict or Tuple observation space by the keys or indexes.
Example:
>>> import gymnasium as gym
@@ -103,7 +101,12 @@ class FilterObservationV0(
def __init__(
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
):
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
"""Constructor for the filter observation wrapper.
Args:
env: The environment to wrap
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
"""
assert isinstance(filter_keys, Sequence)
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
@@ -177,7 +180,7 @@ class FilterObservationV0(
)
else:
raise ValueError(
f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}"
f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}"
)
self.filter_keys: Final[Sequence[str | int]] = filter_keys
@@ -204,7 +207,11 @@ class FlattenObservationV0(
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
Args:
env: The environment to wrap
"""
gym.utils.RecordConstructorArgs.__init__(self)
LambdaObservationV0.__init__(
self,
@@ -237,7 +244,12 @@ class GrayscaleObservationV0(
"""
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale."""
"""Constructor for an RGB image based environments to make the image grayscale.
Args:
env: The environment to wrap
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
"""
assert isinstance(env.observation_space, spaces.Box)
assert (
len(env.observation_space.shape) == 3
@@ -301,7 +313,12 @@ class ResizeObservationV0(
"""
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape."""
"""Constructor that requires an image environment observation space with a shape.
Args:
env: The environment to wrap
shape: The resized observation shape
"""
assert isinstance(env.observation_space, spaces.Box)
assert len(env.observation_space.shape) in [2, 3]
assert np.all(env.observation_space.low == 0) and np.all(
@@ -323,7 +340,10 @@ class ResizeObservationV0(
self.shape: Final[tuple[int, ...]] = tuple(shape)
new_observation_space = spaces.Box(
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
low=0,
high=255,
shape=self.shape + env.observation_space.shape[2:],
dtype=np.uint8,
)
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
@@ -353,7 +373,12 @@ class ReshapeObservationV0(
"""
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
"""Constructor for env with Box observation space that has a shape product equal to the new shape product."""
"""Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product.
Args:
env: The environment to wrap
shape: The reshaped observation space
"""
assert isinstance(env.observation_space, spaces.Box)
assert np.product(shape) == np.product(env.observation_space.shape)
@@ -401,7 +426,13 @@ class RescaleObservationV0(
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`."""
"""Constructor that requires the env observation spaces to be a :class:`Box`.
Args:
env: The environment to wrap
min_obs: The new minimum observation bound
max_obs: The new maximum observation bound
"""
assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
@@ -452,10 +483,19 @@ class DtypeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper for transforming the dtype of an observation."""
"""Observation wrapper for transforming the dtype of an observation.
Note:
This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces
"""
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
"""Constructor for Dtype observation wrapper.
Args:
env: The environment to wrap
dtype: The new dtype of the observation
"""
assert isinstance(
env.observation_space,
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
@@ -505,7 +545,7 @@ class PixelObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Augment observations by pixel values.
"""Includes the rendered observations to the environment's observations.
Observations of this wrapper will be dictionaries of images.
You can also choose to add the observation of the base environment to this dictionary.
@@ -522,13 +562,13 @@ class PixelObservationV0(
pixels_key: str = "pixels",
obs_key: str = "state",
):
"""Initializes a new pixel Wrapper.
"""Constructor of the pixel observation wrapper.
Args:
env: The environment to wrap.
pixels_only (bool): If `True` (default), the original observation returned
pixels_only (bool): If ``True`` (default), the original observation returned
by the wrapped environment will be discarded, and a dictionary
observation will only include pixels. If `False`, the
observation will only include pixels. If ``False``, the
observation dictionary will contain both the original
observations and the pixel observations.
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
@@ -571,51 +611,3 @@ class PixelObservationV0(
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
observation_space=obs_space,
)
class NormalizeObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.ObservationWrapper.__init__(self, env)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
self._update_running_mean = True
@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the observation statistics."""
return self._update_running_mean
@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting
def observation(self, observation: ObsType) -> WrapperObsType:
"""Normalises the observation using the running mean and variance of the observations."""
if self._update_running_mean:
self.obs_rms.update(observation)
return (observation - self.obs_rms.mean) / np.sqrt(
self.obs_rms.var + self.epsilon
)

View File

@@ -1,18 +1,17 @@
"""A collection of wrappers for modifying the reward.
* ``LambdaReward`` - Transforms the reward by a function
* ``ClipReward`` - Clips the reward between a minimum and maximum value
* ``LambdaRewardV0`` - Transforms the reward by a function
* ``ClipRewardV0`` - Clips the reward between a minimum and maximum value
"""
from __future__ import annotations
from typing import Any, Callable, SupportsFloat
from typing import Callable, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers.utils import RunningMeanStd
class LambdaRewardV0(
@@ -39,7 +38,7 @@ class LambdaRewardV0(
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The environment to apply the wrapper
env (Env): The environment to wrap
func: (Callable): The function to apply to reward
"""
gym.utils.RecordConstructorArgs.__init__(self, func=func)
@@ -79,7 +78,7 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
"""Initialize ClipRewardsV0 wrapper.
Args:
env (Env): The environment to apply the wrapper
env (Env): The environment to wrap
min_reward (Union[float, np.ndarray]): lower bound to apply
max_reward (Union[float, np.ndarray]): higher bound to apply
"""
@@ -98,72 +97,3 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
LambdaRewardV0.__init__(
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
)
class NormalizeRewardV1(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently.
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
gamma: float = 0.99,
epsilon: float = 1e-8,
):
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
Args:
env (env): The environment to apply the wrapper
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
"""
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.rewards_running_means = RunningMeanStd(shape=())
self.discounted_reward: np.array = np.array([0.0])
self.gamma = gamma
self.epsilon = epsilon
self._update_running_mean = True
@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the reward statistics."""
return self._update_running_mean
@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
self._update_running_mean = setting
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, normalizing the reward returned."""
obs, reward, terminated, truncated, info = super().step(action)
self.discounted_reward = self.discounted_reward * self.gamma * (
1 - terminated
) + float(reward)
return obs, self.normalize(float(reward)), terminated, truncated, info
def normalize(self, reward: SupportsFloat):
"""Normalizes the rewards with the running mean rewards and their variance."""
if self._update_running_mean:
self.rewards_running_means.update(self.discounted_reward)
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)

View File

@@ -111,13 +111,13 @@ class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
Args:
action: The action to perform as a PyTorch Tensor
action: A PyTorch-based action
Returns:
The next observation, reward, termination, truncation, and extra info
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)

View File

@@ -3,6 +3,7 @@
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
* ``FrameStackObservationV0`` - Frame stack the observations
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
"""
from __future__ import annotations
@@ -21,7 +22,7 @@ from gymnasium.experimental.vector.utils import (
concatenate,
create_empty_array,
)
from gymnasium.experimental.wrappers.utils import create_zero_array
from gymnasium.experimental.wrappers.utils import RunningMeanStd, create_zero_array
from gymnasium.spaces import Box, Dict, Tuple
@@ -382,3 +383,51 @@ class FrameStackObservationV0(
)
)
return updated_obs, info
class NormalizeObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.ObservationWrapper.__init__(self, env)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
self._update_running_mean = True
@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the observation statistics."""
return self._update_running_mean
@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting
def observation(self, observation: ObsType) -> WrapperObsType:
"""Normalises the observation using the running mean and variance of the observations."""
if self._update_running_mean:
self.obs_rms.update(observation)
return (observation - self.obs_rms.mean) / np.sqrt(
self.obs_rms.var + self.epsilon
)

View File

@@ -0,0 +1,82 @@
"""A collection of wrappers for modifying the reward with an internal state.
* ``NormalizeRewardV1`` - Normalizes the rewards to a mean and standard deviation
"""
from __future__ import annotations
from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.wrappers.utils import RunningMeanStd
class NormalizeRewardV1(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently.
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
gamma: float = 0.99,
epsilon: float = 1e-8,
):
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
Args:
env (env): The environment to apply the wrapper
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
"""
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.rewards_running_means = RunningMeanStd(shape=())
self.discounted_reward: np.array = np.array([0.0])
self.gamma = gamma
self.epsilon = epsilon
self._update_running_mean = True
@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the reward statistics."""
return self._update_running_mean
@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
self._update_running_mean = setting
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, normalizing the reward returned."""
obs, reward, terminated, truncated, info = super().step(action)
self.discounted_reward = self.discounted_reward * self.gamma * (
1 - terminated
) + float(reward)
return obs, self.normalize(float(reward)), terminated, truncated, info
def normalize(self, reward: SupportsFloat):
"""Normalizes the rewards with the running mean rewards and their variance."""
if self._update_running_mean:
self.rewards_running_means.update(self.discounted_reward)
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)

View File

@@ -1,12 +1,44 @@
"""Wrappers for vector environments."""
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
VectorRecordEpisodeStatistics,
)
from gymnasium.experimental.wrappers.vector.vector_list_info import VectorListInfo
# pyright: reportUnsupportedDunderAll=false
import importlib
__all__ = [
"VectorRecordEpisodeStatistics",
"VectorListInfo",
# --- Vector only wrappers
"VectoriseLambdaObservationV0",
"VectoriseLambdaActionV0",
"VectoriseLambdaRewardV0",
"DictInfoToListV0",
# --- Observation wrappers ---
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
"PixelObservationV0",
"NormalizeObservationV0",
# "TimeAwareObservationV0",
# "FrameStackObservationV0",
# "DelayObservationV0",
# --- Action Wrappers ---
"LambdaActionV0",
"ClipActionV0",
"RescaleActionV0",
# --- Reward wrappers ---
"LambdaRewardV0",
"ClipRewardV0",
"NormalizeRewardV1",
# --- Common ---
"RecordEpisodeStatisticsV0",
# --- Rendering ---
# "RenderCollectionV0",
# "RecordVideoV0",
# "HumanRenderingV0",
# --- Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]

View File

@@ -1,11 +1,13 @@
"""Wrapper that converts the info format for vec envs into the list format."""
from __future__ import annotations
from typing import List
from typing import Any
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
class VectorListInfo(VectorWrapper):
class DictInfoToListV0(VectorWrapper):
"""Converts infos of vectorized environments from dict to List[dict].
This wrapper converts the info format of a
@@ -15,17 +17,16 @@ class VectorListInfo(VectorWrapper):
operation on info like `RecordEpisodeStatistics` this
need to be the outermost wrapper.
i.e. `VectorListInfo(RecordEpisodeStatistics(envs))`
i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))``
Example::
>>> # actual
>>> { # doctest: +SKIP
>>> import numpy as np
>>> dict_info = {
... "k": np.array([0., 0., 0.5, 0.3]),
... "_k": np.array([False, False, True, True])
... }
>>> # classic
>>> [{}, {}, {k: 0.5}, {k: 0.3}] # doctest: +SKIP
>>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
"""
def __init__(self, env: VectorEnv):
@@ -36,20 +37,28 @@ class VectorListInfo(VectorWrapper):
"""
super().__init__(env)
def step(self, action):
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
"""Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(action)
observation, reward, terminated, truncated, infos = self.env.step(actions)
list_info = self._convert_info_to_list(infos)
return observation, reward, terminated, truncated, list_info
def reset(self, **kwargs):
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, list[dict[str, Any]]]:
"""Resets the environment using kwargs."""
obs, infos = self.env.reset(**kwargs)
obs, infos = self.env.reset(seed=seed, options=options)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> List[dict]:
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
@@ -72,36 +81,3 @@ class VectorListInfo(VectorWrapper):
if has_info:
list_info[i][k] = infos[k][i]
return list_info
def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]:
"""Process episode statistics.
`RecordEpisodeStatistics` wrapper add extra
information to the info. This information are in
the form of a dict of dict. This method process these
information and add them to the info.
`RecordEpisodeStatistics` info contains the keys
"r", "l", "t" which represents "cumulative reward",
"episode length", "elapsed time since instantiation of wrapper".
Args:
infos (dict): infos coming from `RecordEpisodeStatistics`.
list_info (list): info of the current vectorized environment.
Returns:
list_info (list): updated info.
"""
episode_statistics = infos.pop("episode", False)
if not episode_statistics:
return list_info
episode_statistics_mask = infos.pop("_episode")
for i, has_info in enumerate(episode_statistics_mask):
if has_info:
list_info[i]["episode"] = {}
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
return list_info

View File

@@ -0,0 +1,76 @@
"""Vector wrapper for converting between NumPy and Jax."""
from __future__ import annotations
from typing import Any
import jax.numpy as jnp
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
class JaxToNumpyV0(VectorWrapper):
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
Notes:
A vectorised version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
"""
def __init__(self, env: VectorEnv):
"""Wraps an environment such that the input and outputs are numpy arrays.
Args:
env: the vector jax environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
"jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Transforms the action to a jax array .
Args:
actions: the action to perform as a numpy array
Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_actions = numpy_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
return (
jax_to_numpy(obs),
jax_to_numpy(reward),
jax_to_numpy(terminated),
jax_to_numpy(truncated),
jax_to_numpy(info),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)
return jax_to_numpy(self.env.reset(seed=seed, options=options))

View File

@@ -0,0 +1,73 @@
"""Vector wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import (
Device,
jax_to_torch,
torch_to_jax,
)
class JaxToTorchV0(VectorWrapper):
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
"""
def __init__(self, env: VectorEnv, device: Device | None = None):
"""Vector wrapper to change inputs and outputs to PyTorch tensors.
Args:
env: The Jax-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Performs the given action within the environment.
Args:
actions: The action to perform as a PyTorch Tensor
Returns:
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
jax_to_torch(reward, self.device),
jax_to_torch(terminated, self.device),
jax_to_torch(truncated, self.device),
jax_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@@ -0,0 +1,70 @@
"""Wrapper for converting NumPy environments to PyTorch."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import Device
from gymnasium.experimental.wrappers.numpy_to_torch import (
numpy_to_torch,
torch_to_numpy,
)
class NumpyToTorchV0(VectorWrapper):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors."""
def __init__(self, env: VectorEnv, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
Args:
action: A PyTorch-based action
Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
numpy_to_torch(obs, self.device),
numpy_to_torch(reward, self.device),
numpy_to_torch(terminated, self.device),
numpy_to_torch(truncated, self.device),
numpy_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_numpy(options)
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@@ -1,14 +1,16 @@
"""Wrapper that tracks the cumulative rewards and episode lengths."""
from __future__ import annotations
import time
from collections import deque
from typing import List, Optional, Union
import numpy as np
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
class VectorRecordEpisodeStatistics(VectorWrapper):
class RecordEpisodeStatisticsV0(VectorWrapper):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
@@ -32,9 +34,9 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
>>> infos = { # doctest: +SKIP
... ...
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... "r": "<array of cumulative reward for each done sub-environment>",
... "l": "<array of episode length for each done sub-environment>",
... "t": "<array of elapsed time since beginning of episode for each done sub-environment>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
@@ -55,30 +57,35 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_count = 0
self.episode_start_times: np.ndarray = None
self.episode_returns: Optional[np.ndarray] = None
self.episode_lengths: Optional[np.ndarray] = None
self.episode_start_times: np.ndarray = np.zeros(())
self.episode_returns: np.ndarray = np.zeros(())
self.episode_lengths: np.ndarray = np.zeros(())
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = True
def reset(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_times = np.full(
self.num_envs, time.perf_counter(), dtype=np.float32
)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs, info
def step(self, action):
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment, recording the episode statistics."""
(
observations,
@@ -86,14 +93,18 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
terminations,
truncations,
infos,
) = self.env.step(action)
) = self.env.step(actions)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
dones = np.logical_or(terminations, truncations)
num_dones = np.sum(dones)
if num_dones:
if "episode" in infos or "_episode" in infos:
raise ValueError(
@@ -109,14 +120,18 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
0.0,
),
}
if self.is_vector_env:
infos["_episode"] = np.where(dones, True, False)
self.return_queue.extend(self.episode_returns[dones])
self.length_queue.extend(self.episode_lengths[dones])
infos["_episode"] = dones
self.episode_count += num_dones
for i in np.where(dones):
self.return_queue.extend(self.episode_returns[i])
self.length_queue.extend(self.episode_lengths[i])
self.episode_lengths[dones] = 0
self.episode_returns[dones] = 0
self.episode_start_times[dones] = time.perf_counter()
return (
observations,
rewards,