mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 01:27:29 +00:00
Experimental wrapper changes (#517)
This commit is contained in:
@@ -54,27 +54,28 @@ _wrapper_to_class = {
|
||||
"LambdaActionV0": "lambda_action",
|
||||
"ClipActionV0": "lambda_action",
|
||||
"RescaleActionV0": "lambda_action",
|
||||
# lambda_observations.py
|
||||
"LambdaObservationV0": "lambda_observations",
|
||||
"FilterObservationV0": "lambda_observations",
|
||||
"FlattenObservationV0": "lambda_observations",
|
||||
"GrayscaleObservationV0": "lambda_observations",
|
||||
"ResizeObservationV0": "lambda_observations",
|
||||
"ReshapeObservationV0": "lambda_observations",
|
||||
"RescaleObservationV0": "lambda_observations",
|
||||
"DtypeObservationV0": "lambda_observations",
|
||||
"PixelObservationV0": "lambda_observations",
|
||||
"NormalizeObservationV0": "lambda_observations",
|
||||
# lambda_observation.py
|
||||
"LambdaObservationV0": "lambda_observation",
|
||||
"FilterObservationV0": "lambda_observation",
|
||||
"FlattenObservationV0": "lambda_observation",
|
||||
"GrayscaleObservationV0": "lambda_observation",
|
||||
"ResizeObservationV0": "lambda_observation",
|
||||
"ReshapeObservationV0": "lambda_observation",
|
||||
"RescaleObservationV0": "lambda_observation",
|
||||
"DtypeObservationV0": "lambda_observation",
|
||||
"PixelObservationV0": "lambda_observation",
|
||||
# lambda_reward.py
|
||||
"ClipRewardV0": "lambda_reward",
|
||||
"LambdaRewardV0": "lambda_reward",
|
||||
"NormalizeRewardV1": "lambda_reward",
|
||||
# stateful_action
|
||||
"StickyActionV0": "stateful_action",
|
||||
# stateful_observation
|
||||
"TimeAwareObservationV0": "stateful_observation",
|
||||
"DelayObservationV0": "stateful_observation",
|
||||
"FrameStackObservationV0": "stateful_observation",
|
||||
"NormalizeObservationV0": "stateful_observation",
|
||||
# stateful_reward
|
||||
"NormalizeRewardV1": "stateful_reward",
|
||||
# atari_preprocessing
|
||||
"AtariPreprocessingV0": "atari_preprocessing",
|
||||
# common
|
||||
@@ -86,18 +87,10 @@ _wrapper_to_class = {
|
||||
"RenderCollectionV0": "rendering",
|
||||
"RecordVideoV0": "rendering",
|
||||
"HumanRenderingV0": "rendering",
|
||||
# jax_to_numpy
|
||||
# data converters
|
||||
"JaxToNumpyV0": "jax_to_numpy",
|
||||
# "jax_to_numpy": "jax_to_numpy",
|
||||
# "numpy_to_jax": "jax_to_numpy",
|
||||
# jax_to_torch
|
||||
"JaxToTorchV0": "jax_to_torch",
|
||||
# "jax_to_torch": "jax_to_torch",
|
||||
# "torch_to_jax": "jax_to_torch",
|
||||
# numpy_to_torch
|
||||
"NumpyToTorchV0": "numpy_to_torch",
|
||||
# "torch_to_numpy": "numpy_to_torch",
|
||||
# "numpy_to_torch": "numpy_to_torch",
|
||||
}
|
||||
|
||||
|
||||
|
@@ -99,10 +99,10 @@ class JaxToNumpyV0(
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||
"""Wraps a jax environment such that the input and outputs are numpy arrays.
|
||||
|
||||
Args:
|
||||
env: the environment to wrap
|
||||
env: the jax environment to wrap
|
||||
"""
|
||||
if jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
@@ -120,7 +120,7 @@ class JaxToNumpyV0(
|
||||
action: the action to perform as a numpy array
|
||||
|
||||
Returns:
|
||||
A tuple containing the next observation, reward, termination, truncation, and extra info.
|
||||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
||||
"""
|
||||
jax_action = numpy_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
@@ -39,7 +39,7 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["jax_to_torch", "torch_to_jax", "JaxToTorchV0"]
|
||||
__all__ = ["JaxToTorchV0", "jax_to_torch", "torch_to_jax", "Device"]
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
@@ -114,7 +114,7 @@ def _jax_iterable_to_torch(
|
||||
|
||||
|
||||
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
"""Wraps a Jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
"""A collection of wrappers that all use the LambdaAction class.
|
||||
|
||||
* ``LambdaAction`` - Transforms the actions based on a function
|
||||
* ``ClipAction`` - Clips the action within a bounds
|
||||
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
|
||||
* ``LambdaActionV0`` - Transforms the actions based on a function
|
||||
* ``ClipActionV0`` - Clips the action within a bounds
|
||||
* ``RescaleActionV0`` - Rescales the action within a minimum and maximum actions
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -34,8 +34,8 @@ class LambdaActionV0(
|
||||
"""Initialize LambdaAction.
|
||||
|
||||
Args:
|
||||
env: The gymnasium environment
|
||||
func: Function to apply to ``step`` ``action``
|
||||
env: The environment to wrap
|
||||
func: Function to apply to the :meth:`step`'s ``action``
|
||||
action_space: The updated action space of the wrapper given the function.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
@@ -75,7 +75,7 @@ class ClipActionV0(
|
||||
"""A wrapper for clipping continuous actions within the valid bound.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
env: The environment to wrap
|
||||
"""
|
||||
assert isinstance(env.action_space, Box)
|
||||
|
||||
@@ -125,10 +125,10 @@ class RescaleActionV0(
|
||||
min_action: float | int | np.ndarray,
|
||||
max_action: float | int | np.ndarray,
|
||||
):
|
||||
"""Initializes the :class:`RescaleAction` wrapper.
|
||||
"""Constructor for the Rescale Action wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
env (Env): The environment to wrap
|
||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
|
@@ -9,7 +9,6 @@
|
||||
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
|
||||
* ``DtypeObservationV0`` - Convert an observation to a dtype
|
||||
* ``PixelObservationV0`` - Allows the observation to the rendered frame
|
||||
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -27,7 +26,6 @@ import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||
|
||||
|
||||
class LambdaObservationV0(
|
||||
@@ -37,7 +35,7 @@ class LambdaObservationV0(
|
||||
"""Transforms an observation via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all observations.
|
||||
If the observations from :attr:`func` are outside the bounds of the `env` spaces, provide a :attr:`observation_space`.
|
||||
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
@@ -60,8 +58,8 @@ class LambdaObservationV0(
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
|
||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
|
||||
func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`.
|
||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, func=func, observation_space=observation_space
|
||||
@@ -82,7 +80,7 @@ class FilterObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Filter Dict observation space by the keys.
|
||||
"""Filters Dict or Tuple observation space by the keys or indexes.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
@@ -103,7 +101,12 @@ class FilterObservationV0(
|
||||
def __init__(
|
||||
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
|
||||
):
|
||||
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
||||
"""Constructor for the filter observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
|
||||
"""
|
||||
assert isinstance(filter_keys, Sequence)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
||||
|
||||
@@ -177,7 +180,7 @@ class FilterObservationV0(
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: {type(env.observation_space)}"
|
||||
f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}"
|
||||
)
|
||||
|
||||
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||
@@ -204,7 +207,11 @@ class FlattenObservationV0(
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
@@ -237,7 +244,12 @@ class GrayscaleObservationV0(
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
|
||||
"""Constructor for an RGB image based environments to make the image grayscale."""
|
||||
"""Constructor for an RGB image based environments to make the image grayscale.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert (
|
||||
len(env.observation_space.shape) == 3
|
||||
@@ -301,7 +313,12 @@ class ResizeObservationV0(
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
|
||||
"""Constructor that requires an image environment observation space with a shape."""
|
||||
"""Constructor that requires an image environment observation space with a shape.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
shape: The resized observation shape
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert len(env.observation_space.shape) in [2, 3]
|
||||
assert np.all(env.observation_space.low == 0) and np.all(
|
||||
@@ -323,7 +340,10 @@ class ResizeObservationV0(
|
||||
self.shape: Final[tuple[int, ...]] = tuple(shape)
|
||||
|
||||
new_observation_space = spaces.Box(
|
||||
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.shape + env.observation_space.shape[2:],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||
@@ -353,7 +373,12 @@ class ReshapeObservationV0(
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
|
||||
"""Constructor for env with Box observation space that has a shape product equal to the new shape product."""
|
||||
"""Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
shape: The reshaped observation space
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert np.product(shape) == np.product(env.observation_space.shape)
|
||||
|
||||
@@ -401,7 +426,13 @@ class RescaleObservationV0(
|
||||
min_obs: np.floating | np.integer | np.ndarray,
|
||||
max_obs: np.floating | np.integer | np.ndarray,
|
||||
):
|
||||
"""Constructor that requires the env observation spaces to be a :class:`Box`."""
|
||||
"""Constructor that requires the env observation spaces to be a :class:`Box`.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
min_obs: The new minimum observation bound
|
||||
max_obs: The new maximum observation bound
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert not np.any(env.observation_space.low == np.inf) and not np.any(
|
||||
env.observation_space.high == np.inf
|
||||
@@ -452,10 +483,19 @@ class DtypeObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper for transforming the dtype of an observation."""
|
||||
"""Observation wrapper for transforming the dtype of an observation.
|
||||
|
||||
Note:
|
||||
This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
||||
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
|
||||
"""Constructor for Dtype observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
dtype: The new dtype of the observation
|
||||
"""
|
||||
assert isinstance(
|
||||
env.observation_space,
|
||||
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
||||
@@ -505,7 +545,7 @@ class PixelObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Augment observations by pixel values.
|
||||
"""Includes the rendered observations to the environment's observations.
|
||||
|
||||
Observations of this wrapper will be dictionaries of images.
|
||||
You can also choose to add the observation of the base environment to this dictionary.
|
||||
@@ -522,13 +562,13 @@ class PixelObservationV0(
|
||||
pixels_key: str = "pixels",
|
||||
obs_key: str = "state",
|
||||
):
|
||||
"""Initializes a new pixel Wrapper.
|
||||
"""Constructor of the pixel observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap.
|
||||
pixels_only (bool): If `True` (default), the original observation returned
|
||||
pixels_only (bool): If ``True`` (default), the original observation returned
|
||||
by the wrapped environment will be discarded, and a dictionary
|
||||
observation will only include pixels. If `False`, the
|
||||
observation will only include pixels. If ``False``, the
|
||||
observation dictionary will contain both the original
|
||||
observations and the pixel observations.
|
||||
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
||||
@@ -571,51 +611,3 @@ class PixelObservationV0(
|
||||
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
|
||||
observation_space=obs_space,
|
||||
)
|
||||
|
||||
|
||||
class NormalizeObservationV0(
|
||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
||||
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
|
||||
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
||||
|
||||
Note:
|
||||
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
|
||||
newly instantiated or the policy was changed recently.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
epsilon: A stability parameter that is used when scaling the observations.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||
self.epsilon = epsilon
|
||||
self._update_running_mean = True
|
||||
|
||||
@property
|
||||
def update_running_mean(self) -> bool:
|
||||
"""Property to freeze/continue the running mean calculation of the observation statistics."""
|
||||
return self._update_running_mean
|
||||
|
||||
@update_running_mean.setter
|
||||
def update_running_mean(self, setting: bool):
|
||||
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
|
||||
self._update_running_mean = setting
|
||||
|
||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||
"""Normalises the observation using the running mean and variance of the observations."""
|
||||
if self._update_running_mean:
|
||||
self.obs_rms.update(observation)
|
||||
return (observation - self.obs_rms.mean) / np.sqrt(
|
||||
self.obs_rms.var + self.epsilon
|
||||
)
|
@@ -1,18 +1,17 @@
|
||||
"""A collection of wrappers for modifying the reward.
|
||||
|
||||
* ``LambdaReward`` - Transforms the reward by a function
|
||||
* ``ClipReward`` - Clips the reward between a minimum and maximum value
|
||||
* ``LambdaRewardV0`` - Transforms the reward by a function
|
||||
* ``ClipRewardV0`` - Clips the reward between a minimum and maximum value
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, SupportsFloat
|
||||
from typing import Callable, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import InvalidBound
|
||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||
|
||||
|
||||
class LambdaRewardV0(
|
||||
@@ -39,7 +38,7 @@ class LambdaRewardV0(
|
||||
"""Initialize LambdaRewardV0 wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
env (Env): The environment to wrap
|
||||
func: (Callable): The function to apply to reward
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, func=func)
|
||||
@@ -79,7 +78,7 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
|
||||
"""Initialize ClipRewardsV0 wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
env (Env): The environment to wrap
|
||||
min_reward (Union[float, np.ndarray]): lower bound to apply
|
||||
max_reward (Union[float, np.ndarray]): higher bound to apply
|
||||
"""
|
||||
@@ -98,72 +97,3 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
|
||||
LambdaRewardV0.__init__(
|
||||
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
|
||||
)
|
||||
|
||||
|
||||
class NormalizeRewardV1(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||
|
||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
|
||||
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
|
||||
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
||||
|
||||
Note:
|
||||
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
|
||||
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
|
||||
|
||||
Note:
|
||||
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
|
||||
instantiated or the policy was changed recently.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
gamma: float = 0.99,
|
||||
epsilon: float = 1e-8,
|
||||
):
|
||||
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
Args:
|
||||
env (env): The environment to apply the wrapper
|
||||
epsilon (float): A stability parameter
|
||||
gamma (float): The discount factor that is used in the exponential moving average.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.rewards_running_means = RunningMeanStd(shape=())
|
||||
self.discounted_reward: np.array = np.array([0.0])
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self._update_running_mean = True
|
||||
|
||||
@property
|
||||
def update_running_mean(self) -> bool:
|
||||
"""Property to freeze/continue the running mean calculation of the reward statistics."""
|
||||
return self._update_running_mean
|
||||
|
||||
@update_running_mean.setter
|
||||
def update_running_mean(self, setting: bool):
|
||||
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
|
||||
self._update_running_mean = setting
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment, normalizing the reward returned."""
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
self.discounted_reward = self.discounted_reward * self.gamma * (
|
||||
1 - terminated
|
||||
) + float(reward)
|
||||
return obs, self.normalize(float(reward)), terminated, truncated, info
|
||||
|
||||
def normalize(self, reward: SupportsFloat):
|
||||
"""Normalizes the rewards with the running mean rewards and their variance."""
|
||||
if self._update_running_mean:
|
||||
self.rewards_running_means.update(self.discounted_reward)
|
||||
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)
|
||||
|
@@ -111,13 +111,13 @@ class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
||||
|
||||
Args:
|
||||
action: The action to perform as a PyTorch Tensor
|
||||
action: A PyTorch-based action
|
||||
|
||||
Returns:
|
||||
The next observation, reward, termination, truncation, and extra info
|
||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_numpy(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
@@ -3,6 +3,7 @@
|
||||
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
|
||||
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
|
||||
* ``FrameStackObservationV0`` - Frame stack the observations
|
||||
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -21,7 +22,7 @@ from gymnasium.experimental.vector.utils import (
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.utils import create_zero_array
|
||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd, create_zero_array
|
||||
from gymnasium.spaces import Box, Dict, Tuple
|
||||
|
||||
|
||||
@@ -382,3 +383,51 @@ class FrameStackObservationV0(
|
||||
)
|
||||
)
|
||||
return updated_obs, info
|
||||
|
||||
|
||||
class NormalizeObservationV0(
|
||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
||||
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
|
||||
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
||||
|
||||
Note:
|
||||
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
|
||||
newly instantiated or the policy was changed recently.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
epsilon: A stability parameter that is used when scaling the observations.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||
self.epsilon = epsilon
|
||||
self._update_running_mean = True
|
||||
|
||||
@property
|
||||
def update_running_mean(self) -> bool:
|
||||
"""Property to freeze/continue the running mean calculation of the observation statistics."""
|
||||
return self._update_running_mean
|
||||
|
||||
@update_running_mean.setter
|
||||
def update_running_mean(self, setting: bool):
|
||||
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
|
||||
self._update_running_mean = setting
|
||||
|
||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||
"""Normalises the observation using the running mean and variance of the observations."""
|
||||
if self._update_running_mean:
|
||||
self.obs_rms.update(observation)
|
||||
return (observation - self.obs_rms.mean) / np.sqrt(
|
||||
self.obs_rms.var + self.epsilon
|
||||
)
|
||||
|
82
gymnasium/experimental/wrappers/stateful_reward.py
Normal file
82
gymnasium/experimental/wrappers/stateful_reward.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""A collection of wrappers for modifying the reward with an internal state.
|
||||
|
||||
* ``NormalizeRewardV1`` - Normalizes the rewards to a mean and standard deviation
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||
|
||||
|
||||
class NormalizeRewardV1(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||
|
||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
|
||||
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
|
||||
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
||||
|
||||
Note:
|
||||
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
|
||||
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
|
||||
|
||||
Note:
|
||||
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
|
||||
instantiated or the policy was changed recently.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
gamma: float = 0.99,
|
||||
epsilon: float = 1e-8,
|
||||
):
|
||||
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
Args:
|
||||
env (env): The environment to apply the wrapper
|
||||
epsilon (float): A stability parameter
|
||||
gamma (float): The discount factor that is used in the exponential moving average.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.rewards_running_means = RunningMeanStd(shape=())
|
||||
self.discounted_reward: np.array = np.array([0.0])
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self._update_running_mean = True
|
||||
|
||||
@property
|
||||
def update_running_mean(self) -> bool:
|
||||
"""Property to freeze/continue the running mean calculation of the reward statistics."""
|
||||
return self._update_running_mean
|
||||
|
||||
@update_running_mean.setter
|
||||
def update_running_mean(self, setting: bool):
|
||||
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
|
||||
self._update_running_mean = setting
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment, normalizing the reward returned."""
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
self.discounted_reward = self.discounted_reward * self.gamma * (
|
||||
1 - terminated
|
||||
) + float(reward)
|
||||
return obs, self.normalize(float(reward)), terminated, truncated, info
|
||||
|
||||
def normalize(self, reward: SupportsFloat):
|
||||
"""Normalizes the rewards with the running mean rewards and their variance."""
|
||||
if self._update_running_mean:
|
||||
self.rewards_running_means.update(self.discounted_reward)
|
||||
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)
|
@@ -1,12 +1,44 @@
|
||||
"""Wrappers for vector environments."""
|
||||
|
||||
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
|
||||
VectorRecordEpisodeStatistics,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.vector.vector_list_info import VectorListInfo
|
||||
# pyright: reportUnsupportedDunderAll=false
|
||||
import importlib
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VectorRecordEpisodeStatistics",
|
||||
"VectorListInfo",
|
||||
# --- Vector only wrappers
|
||||
"VectoriseLambdaObservationV0",
|
||||
"VectoriseLambdaActionV0",
|
||||
"VectoriseLambdaRewardV0",
|
||||
"DictInfoToListV0",
|
||||
# --- Observation wrappers ---
|
||||
"LambdaObservationV0",
|
||||
"FilterObservationV0",
|
||||
"FlattenObservationV0",
|
||||
"GrayscaleObservationV0",
|
||||
"ResizeObservationV0",
|
||||
"ReshapeObservationV0",
|
||||
"RescaleObservationV0",
|
||||
"DtypeObservationV0",
|
||||
"PixelObservationV0",
|
||||
"NormalizeObservationV0",
|
||||
# "TimeAwareObservationV0",
|
||||
# "FrameStackObservationV0",
|
||||
# "DelayObservationV0",
|
||||
# --- Action Wrappers ---
|
||||
"LambdaActionV0",
|
||||
"ClipActionV0",
|
||||
"RescaleActionV0",
|
||||
# --- Reward wrappers ---
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
"NormalizeRewardV1",
|
||||
# --- Common ---
|
||||
"RecordEpisodeStatisticsV0",
|
||||
# --- Rendering ---
|
||||
# "RenderCollectionV0",
|
||||
# "RecordVideoV0",
|
||||
# "HumanRenderingV0",
|
||||
# --- Conversion ---
|
||||
"JaxToNumpyV0",
|
||||
"JaxToTorchV0",
|
||||
"NumpyToTorchV0",
|
||||
]
|
||||
|
@@ -1,11 +1,13 @@
|
||||
"""Wrapper that converts the info format for vec envs into the list format."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
||||
|
||||
|
||||
class VectorListInfo(VectorWrapper):
|
||||
class DictInfoToListV0(VectorWrapper):
|
||||
"""Converts infos of vectorized environments from dict to List[dict].
|
||||
|
||||
This wrapper converts the info format of a
|
||||
@@ -15,17 +17,16 @@ class VectorListInfo(VectorWrapper):
|
||||
operation on info like `RecordEpisodeStatistics` this
|
||||
need to be the outermost wrapper.
|
||||
|
||||
i.e. `VectorListInfo(RecordEpisodeStatistics(envs))`
|
||||
i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))``
|
||||
|
||||
Example::
|
||||
|
||||
>>> # actual
|
||||
>>> { # doctest: +SKIP
|
||||
>>> import numpy as np
|
||||
>>> dict_info = {
|
||||
... "k": np.array([0., 0., 0.5, 0.3]),
|
||||
... "_k": np.array([False, False, True, True])
|
||||
... }
|
||||
>>> # classic
|
||||
>>> [{}, {}, {k: 0.5}, {k: 0.3}] # doctest: +SKIP
|
||||
>>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
|
||||
"""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
@@ -36,20 +37,28 @@ class VectorListInfo(VectorWrapper):
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
def step(self, action):
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
|
||||
"""Steps through the environment, convert dict info to list."""
|
||||
observation, reward, terminated, truncated, infos = self.env.step(action)
|
||||
observation, reward, terminated, truncated, infos = self.env.step(actions)
|
||||
list_info = self._convert_info_to_list(infos)
|
||||
|
||||
return observation, reward, terminated, truncated, list_info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, list[dict[str, Any]]]:
|
||||
"""Resets the environment using kwargs."""
|
||||
obs, infos = self.env.reset(**kwargs)
|
||||
obs, infos = self.env.reset(seed=seed, options=options)
|
||||
list_info = self._convert_info_to_list(infos)
|
||||
|
||||
return obs, list_info
|
||||
|
||||
def _convert_info_to_list(self, infos: dict) -> List[dict]:
|
||||
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
|
||||
"""Convert the dict info to list.
|
||||
|
||||
Convert the dict info of the vectorized environment
|
||||
@@ -72,36 +81,3 @@ class VectorListInfo(VectorWrapper):
|
||||
if has_info:
|
||||
list_info[i][k] = infos[k][i]
|
||||
return list_info
|
||||
|
||||
def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]:
|
||||
"""Process episode statistics.
|
||||
|
||||
`RecordEpisodeStatistics` wrapper add extra
|
||||
information to the info. This information are in
|
||||
the form of a dict of dict. This method process these
|
||||
information and add them to the info.
|
||||
`RecordEpisodeStatistics` info contains the keys
|
||||
"r", "l", "t" which represents "cumulative reward",
|
||||
"episode length", "elapsed time since instantiation of wrapper".
|
||||
|
||||
Args:
|
||||
infos (dict): infos coming from `RecordEpisodeStatistics`.
|
||||
list_info (list): info of the current vectorized environment.
|
||||
|
||||
Returns:
|
||||
list_info (list): updated info.
|
||||
|
||||
"""
|
||||
episode_statistics = infos.pop("episode", False)
|
||||
if not episode_statistics:
|
||||
return list_info
|
||||
|
||||
episode_statistics_mask = infos.pop("_episode")
|
||||
for i, has_info in enumerate(episode_statistics_mask):
|
||||
if has_info:
|
||||
list_info[i]["episode"] = {}
|
||||
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
|
||||
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
|
||||
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
|
||||
|
||||
return list_info
|
76
gymnasium/experimental/wrappers/vector/jax_to_numpy.py
Normal file
76
gymnasium/experimental/wrappers/vector/jax_to_numpy.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Vector wrapper for converting between NumPy and Jax."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
||||
|
||||
|
||||
class JaxToNumpyV0(VectorWrapper):
|
||||
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
||||
|
||||
Notes:
|
||||
A vectorised version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
|
||||
|
||||
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||
|
||||
Args:
|
||||
env: the vector jax environment to wrap
|
||||
"""
|
||||
if jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"jax is not installed, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
super().__init__(env)
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Transforms the action to a jax array .
|
||||
|
||||
Args:
|
||||
actions: the action to perform as a numpy array
|
||||
|
||||
Returns:
|
||||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
||||
"""
|
||||
jax_actions = numpy_to_jax(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
|
||||
|
||||
return (
|
||||
jax_to_numpy(obs),
|
||||
jax_to_numpy(reward),
|
||||
jax_to_numpy(terminated),
|
||||
jax_to_numpy(truncated),
|
||||
jax_to_numpy(info),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning numpy-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
Numpy-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = numpy_to_jax(options)
|
||||
|
||||
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
73
gymnasium/experimental/wrappers/vector/jax_to_torch.py
Normal file
73
gymnasium/experimental/wrappers/vector/jax_to_torch.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Vector wrapper class for converting between PyTorch and Jax."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import (
|
||||
Device,
|
||||
jax_to_torch,
|
||||
torch_to_jax,
|
||||
)
|
||||
|
||||
|
||||
class JaxToTorchV0(VectorWrapper):
|
||||
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, env: VectorEnv, device: Device | None = None):
|
||||
"""Vector wrapper to change inputs and outputs to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
env: The Jax-based vector environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
|
||||
Args:
|
||||
actions: The action to perform as a PyTorch Tensor
|
||||
|
||||
Returns:
|
||||
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_jax(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_torch(obs, self.device),
|
||||
jax_to_torch(reward, self.device),
|
||||
jax_to_torch(terminated, self.device),
|
||||
jax_to_torch(truncated, self.device),
|
||||
jax_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_jax(options)
|
||||
|
||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
70
gymnasium/experimental/wrappers/vector/numpy_to_torch.py
Normal file
70
gymnasium/experimental/wrappers/vector/numpy_to_torch.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Wrapper for converting NumPy environments to PyTorch."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental import VectorEnv, VectorWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import Device
|
||||
from gymnasium.experimental.wrappers.numpy_to_torch import (
|
||||
numpy_to_torch,
|
||||
torch_to_numpy,
|
||||
)
|
||||
|
||||
|
||||
class NumpyToTorchV0(VectorWrapper):
|
||||
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors."""
|
||||
|
||||
def __init__(self, env: VectorEnv, device: Device | None = None):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
env: The Jax-based vector environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
||||
|
||||
Args:
|
||||
action: A PyTorch-based action
|
||||
|
||||
Returns:
|
||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_numpy(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
numpy_to_torch(obs, self.device),
|
||||
numpy_to_torch(reward, self.device),
|
||||
numpy_to_torch(terminated, self.device),
|
||||
numpy_to_torch(truncated, self.device),
|
||||
numpy_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_numpy(options)
|
||||
|
||||
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
@@ -1,14 +1,16 @@
|
||||
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
||||
|
||||
|
||||
class VectorRecordEpisodeStatistics(VectorWrapper):
|
||||
class RecordEpisodeStatisticsV0(VectorWrapper):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||
@@ -32,9 +34,9 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
||||
>>> infos = { # doctest: +SKIP
|
||||
... ...
|
||||
... "episode": {
|
||||
... "r": "<array of cumulative reward>",
|
||||
... "l": "<array of episode length>",
|
||||
... "t": "<array of elapsed time since beginning of episode>"
|
||||
... "r": "<array of cumulative reward for each done sub-environment>",
|
||||
... "l": "<array of episode length for each done sub-environment>",
|
||||
... "t": "<array of elapsed time since beginning of episode for each done sub-environment>"
|
||||
... },
|
||||
... "_episode": "<boolean array of length num-envs>"
|
||||
... }
|
||||
@@ -55,30 +57,35 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
||||
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||
"""
|
||||
super().__init__(env)
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
|
||||
self.episode_count = 0
|
||||
self.episode_start_times: np.ndarray = None
|
||||
self.episode_returns: Optional[np.ndarray] = None
|
||||
self.episode_lengths: Optional[np.ndarray] = None
|
||||
|
||||
self.episode_start_times: np.ndarray = np.zeros(())
|
||||
self.episode_returns: np.ndarray = np.zeros(())
|
||||
self.episode_lengths: np.ndarray = np.zeros(())
|
||||
|
||||
self.return_queue = deque(maxlen=deque_size)
|
||||
self.length_queue = deque(maxlen=deque_size)
|
||||
self.is_vector_env = True
|
||||
|
||||
def reset(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
options: Optional[dict] = None,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict | None = None,
|
||||
):
|
||||
"""Resets the environment using kwargs and resets the episode returns and lengths."""
|
||||
obs, info = super().reset(seed=seed, options=options)
|
||||
|
||||
self.episode_start_times = np.full(
|
||||
self.num_envs, time.perf_counter(), dtype=np.float32
|
||||
)
|
||||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
||||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
||||
|
||||
return obs, info
|
||||
|
||||
def step(self, action):
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Steps through the environment, recording the episode statistics."""
|
||||
(
|
||||
observations,
|
||||
@@ -86,14 +93,18 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
||||
terminations,
|
||||
truncations,
|
||||
infos,
|
||||
) = self.env.step(action)
|
||||
) = self.env.step(actions)
|
||||
|
||||
assert isinstance(
|
||||
infos, dict
|
||||
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
||||
|
||||
self.episode_returns += rewards
|
||||
self.episode_lengths += 1
|
||||
|
||||
dones = np.logical_or(terminations, truncations)
|
||||
num_dones = np.sum(dones)
|
||||
|
||||
if num_dones:
|
||||
if "episode" in infos or "_episode" in infos:
|
||||
raise ValueError(
|
||||
@@ -109,14 +120,18 @@ class VectorRecordEpisodeStatistics(VectorWrapper):
|
||||
0.0,
|
||||
),
|
||||
}
|
||||
if self.is_vector_env:
|
||||
infos["_episode"] = np.where(dones, True, False)
|
||||
self.return_queue.extend(self.episode_returns[dones])
|
||||
self.length_queue.extend(self.episode_lengths[dones])
|
||||
infos["_episode"] = dones
|
||||
|
||||
self.episode_count += num_dones
|
||||
|
||||
for i in np.where(dones):
|
||||
self.return_queue.extend(self.episode_returns[i])
|
||||
self.length_queue.extend(self.episode_lengths[i])
|
||||
|
||||
self.episode_lengths[dones] = 0
|
||||
self.episode_returns[dones] = 0
|
||||
self.episode_start_times[dones] = time.perf_counter()
|
||||
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
|
Reference in New Issue
Block a user