mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Fix the wrapper type hints (#337)
This commit is contained in:
@@ -201,7 +201,7 @@ class Env(Generic[ObsType, ActType]):
|
||||
Instances of `np.random.Generator`
|
||||
"""
|
||||
if self._np_random is None:
|
||||
self._np_random, seed = seeding.np_random()
|
||||
self._np_random, _ = seeding.np_random()
|
||||
return self._np_random
|
||||
|
||||
@np_random.setter
|
||||
@@ -234,7 +234,10 @@ WrapperObsType = TypeVar("WrapperObsType")
|
||||
WrapperActType = TypeVar("WrapperActType")
|
||||
|
||||
|
||||
class Wrapper(Env[WrapperObsType, WrapperActType]):
|
||||
class Wrapper(
|
||||
Env[WrapperObsType, WrapperActType],
|
||||
Generic[WrapperObsType, WrapperActType, ObsType, ActType],
|
||||
):
|
||||
"""Wraps a :class:`gymnasium.Env` to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||
|
||||
This class is the base class of all wrappers to change the behavior of the underlying environment.
|
||||
@@ -391,7 +394,7 @@ class Wrapper(Env[WrapperObsType, WrapperActType]):
|
||||
return self.env.unwrapped
|
||||
|
||||
|
||||
class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
|
||||
class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
|
||||
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`.
|
||||
|
||||
If you would like to apply a function to only the observation before
|
||||
@@ -434,7 +437,7 @@ class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RewardWrapper(Wrapper[ObsType, ActType]):
|
||||
class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""Superclass of wrappers that can modify the returning reward from a step.
|
||||
|
||||
If you would like to apply a function to the reward that is returned by the base environment before
|
||||
@@ -467,7 +470,7 @@ class RewardWrapper(Wrapper[ObsType, ActType]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ActionWrapper(Wrapper[ObsType, WrapperActType]):
|
||||
class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
|
||||
"""Superclass of wrappers that can modify the action before :meth:`env.step`.
|
||||
|
||||
If you would like to apply a function to the action before passing it to the base environment,
|
||||
|
@@ -15,7 +15,7 @@ import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import Env
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.error import ResetNeeded
|
||||
from gymnasium.utils.passive_env_checker import (
|
||||
check_action_space,
|
||||
@@ -26,10 +26,10 @@ from gymnasium.utils.passive_env_checker import (
|
||||
)
|
||||
|
||||
|
||||
class AutoresetV0(gym.Wrapper):
|
||||
class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||
|
||||
Args:
|
||||
@@ -40,8 +40,8 @@ class AutoresetV0(gym.Wrapper):
|
||||
self._reset_options: dict[str, Any] | None = None
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
|
||||
|
||||
Args:
|
||||
@@ -51,7 +51,7 @@ class AutoresetV0(gym.Wrapper):
|
||||
The autoreset environment :meth:`step`
|
||||
"""
|
||||
if self._episode_ended:
|
||||
obs, info = super().reset(options=self._reset_options)
|
||||
obs, info = self.env.reset(options=self._reset_options)
|
||||
self._episode_ended = True
|
||||
return obs, 0, False, False, info
|
||||
else:
|
||||
@@ -61,14 +61,14 @@ class AutoresetV0(gym.Wrapper):
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment, saving the options used."""
|
||||
self._episode_ended = False
|
||||
self._reset_options = options
|
||||
return super().reset(seed=seed, options=self._reset_options)
|
||||
|
||||
|
||||
class PassiveEnvCheckerV0(gym.Wrapper):
|
||||
class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
@@ -89,8 +89,8 @@ class PassiveEnvCheckerV0(gym.Wrapper):
|
||||
self._checked_render: bool = False
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
||||
if self._checked_step is False:
|
||||
self._checked_step = True
|
||||
@@ -100,7 +100,7 @@ class PassiveEnvCheckerV0(gym.Wrapper):
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
||||
if self._checked_reset is False:
|
||||
self._checked_reset = True
|
||||
@@ -117,7 +117,7 @@ class PassiveEnvCheckerV0(gym.Wrapper):
|
||||
return self.env.render()
|
||||
|
||||
|
||||
class OrderEnforcingV0(gym.Wrapper):
|
||||
class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||
|
||||
Example:
|
||||
@@ -139,7 +139,11 @@ class OrderEnforcingV0(gym.Wrapper):
|
||||
>>> env.close()
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
disable_render_order_enforcing: bool = False,
|
||||
):
|
||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||
|
||||
Args:
|
||||
@@ -150,17 +154,15 @@ class OrderEnforcingV0(gym.Wrapper):
|
||||
self._has_reset: bool = False
|
||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Steps through the environment with `kwargs`."""
|
||||
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Steps through the environment."""
|
||||
if not self._has_reset:
|
||||
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
|
||||
return super().step(action)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment with `kwargs`."""
|
||||
self._has_reset = True
|
||||
return super().reset(seed=seed, options=options)
|
||||
@@ -180,7 +182,7 @@ class OrderEnforcingV0(gym.Wrapper):
|
||||
return self._has_reset
|
||||
|
||||
|
||||
class RecordEpisodeStatisticsV0(gym.Wrapper):
|
||||
class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||
@@ -244,13 +246,13 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
|
||||
self.episode_reward: float = -1
|
||||
self.episode_length: int = -1
|
||||
|
||||
self.episode_time_length_buffer = deque(maxlen=buffer_length)
|
||||
self.episode_reward_buffer = deque(maxlen=buffer_length)
|
||||
self.episode_length_buffer = deque(maxlen=buffer_length)
|
||||
self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length)
|
||||
self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length)
|
||||
self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment, recording the episode statistics."""
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
|
||||
@@ -279,7 +281,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
|
||||
obs, info = super().reset(seed=seed, options=options)
|
||||
|
||||
|
@@ -9,7 +9,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ if jnp is not None:
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
class JaxToNumpyV0(Wrapper):
|
||||
class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
|
||||
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
|
||||
|
||||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||
@@ -102,7 +102,7 @@ class JaxToNumpyV0(Wrapper):
|
||||
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env):
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||
|
||||
Args:
|
||||
|
@@ -16,18 +16,18 @@ except ImportError as e:
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, WrapperActType
|
||||
from gymnasium.core import ActType, ObsType, WrapperActType
|
||||
from gymnasium.spaces import Box, Space
|
||||
|
||||
|
||||
class LambdaActionV0(gym.ActionWrapper):
|
||||
class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
|
||||
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
func: Callable[[WrapperActType], ActType],
|
||||
action_space: Space | None,
|
||||
action_space: Space[WrapperActType] | None,
|
||||
):
|
||||
"""Initialize LambdaAction.
|
||||
|
||||
@@ -47,7 +47,7 @@ class LambdaActionV0(gym.ActionWrapper):
|
||||
return self.func(action)
|
||||
|
||||
|
||||
class ClipActionV0(LambdaActionV0):
|
||||
class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
||||
|
||||
Example:
|
||||
@@ -63,7 +63,7 @@ class ClipActionV0(LambdaActionV0):
|
||||
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""A wrapper for clipping continuous actions within the valid bound.
|
||||
|
||||
Args:
|
||||
@@ -83,7 +83,7 @@ class ClipActionV0(LambdaActionV0):
|
||||
)
|
||||
|
||||
|
||||
class RescaleActionV0(LambdaActionV0):
|
||||
class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||
|
||||
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
||||
@@ -107,7 +107,7 @@ class RescaleActionV0(LambdaActionV0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
min_action: float | int | np.ndarray,
|
||||
max_action: float | int | np.ndarray,
|
||||
):
|
||||
|
@@ -31,7 +31,7 @@ from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||
from gymnasium.spaces import Box, Dict, utils
|
||||
|
||||
|
||||
class LambdaObservationV0(gym.ObservationWrapper):
|
||||
class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
||||
"""Transforms an observation via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all observations.
|
||||
@@ -50,9 +50,9 @@ class LambdaObservationV0(gym.ObservationWrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
func: Callable[[ObsType], Any],
|
||||
observation_space: gym.Space | None,
|
||||
observation_space: gym.Space[WrapperObsType] | None,
|
||||
):
|
||||
"""Constructor for the lambda observation wrapper.
|
||||
|
||||
@@ -72,7 +72,7 @@ class LambdaObservationV0(gym.ObservationWrapper):
|
||||
return self.func(observation)
|
||||
|
||||
|
||||
class FilterObservationV0(LambdaObservationV0):
|
||||
class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Filter Dict observation space by the keys.
|
||||
|
||||
Example:
|
||||
@@ -91,7 +91,9 @@ class FilterObservationV0(LambdaObservationV0):
|
||||
({'time': 0}, 1.0, False, False, {})
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
|
||||
def __init__(
|
||||
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
|
||||
):
|
||||
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
||||
assert isinstance(filter_keys, Sequence)
|
||||
|
||||
@@ -169,7 +171,7 @@ class FilterObservationV0(LambdaObservationV0):
|
||||
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||
|
||||
|
||||
class FlattenObservationV0(LambdaObservationV0):
|
||||
class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Observation wrapper that flattens the observation.
|
||||
|
||||
Example:
|
||||
@@ -186,7 +188,7 @@ class FlattenObservationV0(LambdaObservationV0):
|
||||
(27648,)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
||||
super().__init__(
|
||||
env,
|
||||
@@ -195,7 +197,7 @@ class FlattenObservationV0(LambdaObservationV0):
|
||||
)
|
||||
|
||||
|
||||
class GrayscaleObservationV0(LambdaObservationV0):
|
||||
class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Observation wrapper that converts an RGB image to grayscale.
|
||||
|
||||
The :attr:`keep_dim` will keep the channel dimension
|
||||
@@ -214,7 +216,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
||||
(96, 96, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, keep_dim: bool = False):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
|
||||
"""Constructor for an RGB image based environments to make the image grayscale."""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert (
|
||||
@@ -258,7 +260,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
|
||||
)
|
||||
|
||||
|
||||
class ResizeObservationV0(LambdaObservationV0):
|
||||
class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Resizes image observations using OpenCV to shape.
|
||||
|
||||
Example:
|
||||
@@ -272,7 +274,7 @@ class ResizeObservationV0(LambdaObservationV0):
|
||||
(32, 32, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, shape: tuple[int, ...]):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
|
||||
"""Constructor that requires an image environment observation space with a shape."""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert len(env.observation_space.shape) in [2, 3]
|
||||
@@ -304,7 +306,7 @@ class ResizeObservationV0(LambdaObservationV0):
|
||||
)
|
||||
|
||||
|
||||
class ReshapeObservationV0(LambdaObservationV0):
|
||||
class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Reshapes array based observations to shapes.
|
||||
|
||||
Example:
|
||||
@@ -318,7 +320,7 @@ class ReshapeObservationV0(LambdaObservationV0):
|
||||
(24, 4, 96, 1, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, shape: int | tuple[int, ...]):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
|
||||
"""Constructor for env with Box observation space that has a shape product equal to the new shape product."""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert np.product(shape) == np.product(env.observation_space.shape)
|
||||
@@ -337,7 +339,7 @@ class ReshapeObservationV0(LambdaObservationV0):
|
||||
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
|
||||
|
||||
|
||||
class RescaleObservationV0(LambdaObservationV0):
|
||||
class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Linearly rescales observation to between a minimum and maximum value.
|
||||
|
||||
Example:
|
||||
@@ -353,7 +355,7 @@ class RescaleObservationV0(LambdaObservationV0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
min_obs: np.floating | np.integer | np.ndarray,
|
||||
max_obs: np.floating | np.integer | np.ndarray,
|
||||
):
|
||||
@@ -402,10 +404,10 @@ class RescaleObservationV0(LambdaObservationV0):
|
||||
)
|
||||
|
||||
|
||||
class DtypeObservationV0(LambdaObservationV0):
|
||||
class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Observation wrapper for transforming the dtype of an observation."""
|
||||
|
||||
def __init__(self, env: gym.Env, dtype: Any):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
||||
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
|
||||
assert isinstance(
|
||||
env.observation_space,
|
||||
@@ -446,7 +448,7 @@ class DtypeObservationV0(LambdaObservationV0):
|
||||
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
||||
|
||||
|
||||
class PixelObservationV0(LambdaObservationV0):
|
||||
class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"""Augment observations by pixel values.
|
||||
|
||||
Observations of this wrapper will be dictionaries of images.
|
||||
@@ -499,7 +501,7 @@ class PixelObservationV0(LambdaObservationV0):
|
||||
)
|
||||
|
||||
|
||||
class NormalizeObservationV0(ObservationWrapper):
|
||||
class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
||||
@@ -511,7 +513,7 @@ class NormalizeObservationV0(ObservationWrapper):
|
||||
newly instantiated or the policy was changed recently.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
Args:
|
||||
|
@@ -3,7 +3,6 @@
|
||||
* ``LambdaReward`` - Transforms the reward by a function
|
||||
* ``ClipReward`` - Clips the reward between a minimum and maximum value
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, SupportsFloat
|
||||
@@ -11,12 +10,12 @@ from typing import Any, Callable, SupportsFloat
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import WrapperActType, WrapperObsType
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import InvalidBound
|
||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||
|
||||
|
||||
class LambdaRewardV0(gym.RewardWrapper):
|
||||
class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
|
||||
"""A reward wrapper that allows a custom function to modify the step reward.
|
||||
|
||||
Example:
|
||||
@@ -32,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
func: Callable[[SupportsFloat], SupportsFloat],
|
||||
):
|
||||
"""Initialize LambdaRewardV0 wrapper.
|
||||
@@ -54,7 +53,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
||||
return self.func(reward)
|
||||
|
||||
|
||||
class ClipRewardV0(LambdaRewardV0):
|
||||
class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
|
||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
|
||||
|
||||
Example:
|
||||
@@ -70,7 +69,7 @@ class ClipRewardV0(LambdaRewardV0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
min_reward: float | np.ndarray | None = None,
|
||||
max_reward: float | np.ndarray | None = None,
|
||||
):
|
||||
@@ -93,7 +92,7 @@ class ClipRewardV0(LambdaRewardV0):
|
||||
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
|
||||
|
||||
|
||||
class NormalizeRewardV0(gym.Wrapper):
|
||||
class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||
@@ -109,7 +108,7 @@ class NormalizeRewardV0(gym.Wrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
gamma: float = 0.99,
|
||||
epsilon: float = 1e-8,
|
||||
):
|
||||
@@ -138,8 +137,8 @@ class NormalizeRewardV0(gym.Wrapper):
|
||||
self._update_running_mean = setting
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment, normalizing the reward returned."""
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
self.discounted_reward = self.discounted_reward * self.gamma * (
|
||||
@@ -147,7 +146,7 @@ class NormalizeRewardV0(gym.Wrapper):
|
||||
) + float(reward)
|
||||
return obs, self.normalize(float(reward)), terminated, truncated, info
|
||||
|
||||
def normalize(self, reward):
|
||||
def normalize(self, reward: SupportsFloat):
|
||||
"""Normalizes the rewards with the running mean rewards and their variance."""
|
||||
if self._update_running_mean:
|
||||
self.rewards_running_means.update(self.discounted_reward)
|
||||
|
@@ -14,11 +14,11 @@ import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import error, logger
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
|
||||
class RenderCollectionV0(gym.Wrapper):
|
||||
class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
|
||||
|
||||
def __init__(
|
||||
@@ -52,8 +52,8 @@ class RenderCollectionV0(gym.Wrapper):
|
||||
return f"{self.env.render_mode}_list"
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Perform a step in the base environment and collect a frame."""
|
||||
output = super().step(action)
|
||||
self.frame_list.append(super().render())
|
||||
@@ -61,7 +61,7 @@ class RenderCollectionV0(gym.Wrapper):
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
|
||||
output = super().reset(seed=seed, options=options)
|
||||
|
||||
@@ -71,7 +71,7 @@ class RenderCollectionV0(gym.Wrapper):
|
||||
|
||||
return output
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
def render(self) -> list[RenderFrame]:
|
||||
"""Returns the collection of frames and, if pop_frames = True, clears it."""
|
||||
frames = self.frame_list
|
||||
if self.pop_frames:
|
||||
@@ -80,7 +80,7 @@ class RenderCollectionV0(gym.Wrapper):
|
||||
return frames
|
||||
|
||||
|
||||
class RecordVideoV0(gym.Wrapper):
|
||||
class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""This wrapper records videos of rollouts.
|
||||
|
||||
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
||||
@@ -98,10 +98,10 @@ class RecordVideoV0(gym.Wrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
video_folder: str,
|
||||
episode_trigger: Callable[[int], bool] = None,
|
||||
step_trigger: Callable[[int], bool] = None,
|
||||
episode_trigger: Callable[[int], bool] | None = None,
|
||||
step_trigger: Callable[[int], bool] | None = None,
|
||||
video_length: int = 0,
|
||||
name_prefix: str = "rl-video",
|
||||
disable_logger: bool = False,
|
||||
@@ -155,13 +155,13 @@ class RecordVideoV0(gym.Wrapper):
|
||||
)
|
||||
os.makedirs(self.video_folder, exist_ok=True)
|
||||
|
||||
self.name_prefix = name_prefix
|
||||
self._video_name = None
|
||||
self.frames_per_sec = self.metadata.get("render_fps", 30)
|
||||
self.video_length = video_length if video_length != 0 else float("inf")
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
self.render_history = []
|
||||
self.name_prefix: str = name_prefix
|
||||
self._video_name: str | None = None
|
||||
self.frames_per_sec: int = self.metadata.get("render_fps", 30)
|
||||
self.video_length: int = video_length if video_length != 0 else float("inf")
|
||||
self.recording: bool = False
|
||||
self.recorded_frames: list[RenderFrame] = []
|
||||
self.render_history: list[RenderFrame] = []
|
||||
|
||||
self.step_id = -1
|
||||
self.episode_id = -1
|
||||
@@ -187,7 +187,7 @@ class RecordVideoV0(gym.Wrapper):
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset the environment and eventually starts a new recording."""
|
||||
obs, info = super().reset(seed=seed, options=options)
|
||||
self.episode_id += 1
|
||||
@@ -205,8 +205,8 @@ class RecordVideoV0(gym.Wrapper):
|
||||
return obs, info
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
|
||||
obs, rew, terminated, truncated, info = self.env.step(action)
|
||||
self.step_id += 1
|
||||
@@ -221,7 +221,7 @@ class RecordVideoV0(gym.Wrapper):
|
||||
|
||||
return obs, rew, terminated, truncated, info
|
||||
|
||||
def start_recording(self, video_name):
|
||||
def start_recording(self, video_name: str):
|
||||
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
|
||||
if self.recording:
|
||||
self.stop_recording()
|
||||
@@ -252,7 +252,7 @@ class RecordVideoV0(gym.Wrapper):
|
||||
self.recording = False
|
||||
self._video_name = None
|
||||
|
||||
def render(self):
|
||||
def render(self) -> RenderFrame | list[RenderFrame]:
|
||||
"""Compute the render frames as specified by render_mode attribute during initialization of the environment."""
|
||||
render_out = super().render()
|
||||
if self.recording and isinstance(render_out, List):
|
||||
@@ -277,7 +277,7 @@ class RecordVideoV0(gym.Wrapper):
|
||||
logger.warn("Unable to save last video! Did you call close()?")
|
||||
|
||||
|
||||
class HumanRenderingV0(gym.Wrapper):
|
||||
class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
||||
|
||||
This wrapper is particularly useful when you have implemented an environment that can produce
|
||||
@@ -311,7 +311,7 @@ class HumanRenderingV0(gym.Wrapper):
|
||||
[]
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Initialize a :class:`HumanRendering` instance.
|
||||
|
||||
Args:
|
||||
@@ -339,9 +339,7 @@ class HumanRenderingV0(gym.Wrapper):
|
||||
"""Always returns ``'human'``."""
|
||||
return "human"
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Perform a step in the base environment and render a frame to the screen."""
|
||||
result = super().step(action)
|
||||
self._render_frame()
|
||||
@@ -349,13 +347,13 @@ class HumanRenderingV0(gym.Wrapper):
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset the base environment and render a frame to the screen."""
|
||||
result = super().reset(seed=seed, options=options)
|
||||
self._render_frame()
|
||||
return result
|
||||
|
||||
def render(self):
|
||||
def render(self) -> None:
|
||||
"""This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
|
||||
return None
|
||||
|
||||
|
@@ -4,18 +4,20 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActionWrapper, ActType, WrapperActType, WrapperObsType
|
||||
from gymnasium.core import ActionWrapper, ActType, ObsType
|
||||
from gymnasium.error import InvalidProbability
|
||||
|
||||
|
||||
class StickyActionV0(ActionWrapper):
|
||||
class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
|
||||
"""Wrapper which adds a probability of repeating the previous action.
|
||||
|
||||
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||
in Section 5.2 on page 12.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, repeat_action_probability: float):
|
||||
def __init__(
|
||||
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float
|
||||
):
|
||||
"""Initialize StickyAction wrapper.
|
||||
|
||||
Args:
|
||||
@@ -29,17 +31,17 @@ class StickyActionV0(ActionWrapper):
|
||||
|
||||
super().__init__(env)
|
||||
self.repeat_action_probability = repeat_action_probability
|
||||
self.last_action: WrapperActType | None = None
|
||||
self.last_action: ActType | None = None
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.last_action = None
|
||||
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def action(self, action: WrapperActType) -> ActType:
|
||||
def action(self, action: ActType) -> ActType:
|
||||
"""Execute the action."""
|
||||
if (
|
||||
self.last_action is not None
|
||||
|
@@ -26,10 +26,10 @@ from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
|
||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||
|
||||
|
||||
class DelayObservationV0(gym.ObservationWrapper):
|
||||
class DelayObservationV0(gym.ObservationWrapper[ObsType, ActType, ObsType]):
|
||||
"""Wrapper which adds a delay to the returned observation."""
|
||||
|
||||
def __init__(self, env: gym.Env, delay: int):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
|
||||
"""Initialize the DelayObservation wrapper.
|
||||
|
||||
Args:
|
||||
@@ -45,13 +45,13 @@ class DelayObservationV0(gym.ObservationWrapper):
|
||||
assert 0 < delay
|
||||
|
||||
self.delay: Final[int] = delay
|
||||
self.observation_queue: Final[deque] = deque()
|
||||
self.observation_queue: Final[deque[ObsType]] = deque()
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment, clearing the observation queue."""
|
||||
self.observation_queue.clear()
|
||||
|
||||
@@ -67,7 +67,7 @@ class DelayObservationV0(gym.ObservationWrapper):
|
||||
return jp.zeros_like(observation)
|
||||
|
||||
|
||||
class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
class TimeAwareObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
||||
"""Augment the observation with time information of the episode.
|
||||
|
||||
Time can be represented as a normalized value between [0,1]
|
||||
@@ -104,7 +104,7 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
flatten: bool = False,
|
||||
normalize_time: bool = True,
|
||||
*,
|
||||
@@ -212,7 +212,7 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class FrameStackObservationV0(gym.Wrapper):
|
||||
class FrameStackObservationV0(gym.Wrapper[WrapperObsType, ActType, ObsType, ActType]):
|
||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||
|
||||
For example, if the number of stacks is 4, then the returned observation contains
|
||||
@@ -302,7 +302,7 @@ class FrameStackObservationV0(gym.Wrapper):
|
||||
info,
|
||||
)
|
||||
|
||||
def _init_stacked_obs(self) -> deque:
|
||||
def _init_stacked_obs(self) -> deque[ObsType]:
|
||||
return deque(
|
||||
iterate(
|
||||
self.observation_space,
|
||||
|
@@ -8,6 +8,13 @@ import numpy as np
|
||||
from gymnasium import Space, error, logger, spaces
|
||||
|
||||
|
||||
__all__ = [
|
||||
"env_render_passive_checker",
|
||||
"env_reset_passive_checker",
|
||||
"env_step_passive_checker",
|
||||
]
|
||||
|
||||
|
||||
def _check_box_observation_space(observation_space: spaces.Box):
|
||||
"""Checks that a :class:`Box` observation space is defined in a sensible way.
|
||||
|
||||
|
@@ -23,6 +23,7 @@ def test_record_video_using_default_trigger():
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
assert env.episode_trigger is not None
|
||||
assert len(mp4_files) == sum(
|
||||
env.episode_trigger(i) for i in range(episode_count + 1)
|
||||
)
|
||||
@@ -46,6 +47,7 @@ def test_record_video_while_rendering():
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
assert env.episode_trigger is not None
|
||||
assert len(mp4_files) == sum(
|
||||
env.episode_trigger(i) for i in range(episode_count + 1)
|
||||
)
|
||||
|
Reference in New Issue
Block a user