Fix the wrapper type hints (#337)

This commit is contained in:
Mark Towers
2023-02-22 13:58:29 +00:00
committed by GitHub
parent 2f42af629b
commit 761bb2e033
11 changed files with 130 additions and 115 deletions

View File

@@ -201,7 +201,7 @@ class Env(Generic[ObsType, ActType]):
Instances of `np.random.Generator`
"""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
self._np_random, _ = seeding.np_random()
return self._np_random
@np_random.setter
@@ -234,7 +234,10 @@ WrapperObsType = TypeVar("WrapperObsType")
WrapperActType = TypeVar("WrapperActType")
class Wrapper(Env[WrapperObsType, WrapperActType]):
class Wrapper(
Env[WrapperObsType, WrapperActType],
Generic[WrapperObsType, WrapperActType, ObsType, ActType],
):
"""Wraps a :class:`gymnasium.Env` to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
This class is the base class of all wrappers to change the behavior of the underlying environment.
@@ -391,7 +394,7 @@ class Wrapper(Env[WrapperObsType, WrapperActType]):
return self.env.unwrapped
class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`.
If you would like to apply a function to only the observation before
@@ -434,7 +437,7 @@ class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
raise NotImplementedError
class RewardWrapper(Wrapper[ObsType, ActType]):
class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
"""Superclass of wrappers that can modify the returning reward from a step.
If you would like to apply a function to the reward that is returned by the base environment before
@@ -467,7 +470,7 @@ class RewardWrapper(Wrapper[ObsType, ActType]):
raise NotImplementedError
class ActionWrapper(Wrapper[ObsType, WrapperActType]):
class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
"""Superclass of wrappers that can modify the action before :meth:`env.step`.
If you would like to apply a function to the action before passing it to the base environment,

View File

@@ -15,7 +15,7 @@ import numpy as np
import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
@@ -26,10 +26,10 @@ from gymnasium.utils.passive_env_checker import (
)
class AutoresetV0(gym.Wrapper):
class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
@@ -40,8 +40,8 @@ class AutoresetV0(gym.Wrapper):
self._reset_options: dict[str, Any] | None = None
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
Args:
@@ -51,7 +51,7 @@ class AutoresetV0(gym.Wrapper):
The autoreset environment :meth:`step`
"""
if self._episode_ended:
obs, info = super().reset(options=self._reset_options)
obs, info = self.env.reset(options=self._reset_options)
self._episode_ended = True
return obs, 0, False, False, info
else:
@@ -61,14 +61,14 @@ class AutoresetV0(gym.Wrapper):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment, saving the options used."""
self._episode_ended = False
self._reset_options = options
return super().reset(seed=seed, options=self._reset_options)
class PassiveEnvCheckerV0(gym.Wrapper):
class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env: Env[ObsType, ActType]):
@@ -89,8 +89,8 @@ class PassiveEnvCheckerV0(gym.Wrapper):
self._checked_render: bool = False
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self._checked_step is False:
self._checked_step = True
@@ -100,7 +100,7 @@ class PassiveEnvCheckerV0(gym.Wrapper):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self._checked_reset is False:
self._checked_reset = True
@@ -117,7 +117,7 @@ class PassiveEnvCheckerV0(gym.Wrapper):
return self.env.render()
class OrderEnforcingV0(gym.Wrapper):
class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
@@ -139,7 +139,11 @@ class OrderEnforcingV0(gym.Wrapper):
>>> env.close()
"""
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
def __init__(
self,
env: gym.Env[ObsType, ActType],
disable_render_order_enforcing: bool = False,
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args:
@@ -150,17 +154,15 @@ class OrderEnforcingV0(gym.Wrapper):
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment with `kwargs`."""
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment."""
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment with `kwargs`."""
self._has_reset = True
return super().reset(seed=seed, options=options)
@@ -180,7 +182,7 @@ class OrderEnforcingV0(gym.Wrapper):
return self._has_reset
class RecordEpisodeStatisticsV0(gym.Wrapper):
class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
@@ -244,13 +246,13 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
self.episode_reward: float = -1
self.episode_length: int = -1
self.episode_time_length_buffer = deque(maxlen=buffer_length)
self.episode_reward_buffer = deque(maxlen=buffer_length)
self.episode_length_buffer = deque(maxlen=buffer_length)
self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length)
self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length)
self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, recording the episode statistics."""
obs, reward, terminated, truncated, info = super().step(action)
@@ -279,7 +281,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
obs, info = super().reset(seed=seed, options=options)

View File

@@ -9,7 +9,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat
import numpy as np
from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
@@ -92,7 +92,7 @@ if jnp is not None:
return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(Wrapper):
class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
@@ -102,7 +102,7 @@ class JaxToNumpyV0(Wrapper):
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
"""
def __init__(self, env: Env):
def __init__(self, env: Env[ObsType, ActType]):
"""Wraps an environment such that the input and outputs are numpy arrays.
Args:

View File

@@ -16,18 +16,18 @@ except ImportError as e:
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, WrapperActType
from gymnasium.core import ActType, ObsType, WrapperActType
from gymnasium.spaces import Box, Space
class LambdaActionV0(gym.ActionWrapper):
class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
func: Callable[[WrapperActType], ActType],
action_space: Space | None,
action_space: Space[WrapperActType] | None,
):
"""Initialize LambdaAction.
@@ -47,7 +47,7 @@ class LambdaActionV0(gym.ActionWrapper):
return self.func(action)
class ClipActionV0(LambdaActionV0):
class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
@@ -63,7 +63,7 @@ class ClipActionV0(LambdaActionV0):
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
"""
def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""A wrapper for clipping continuous actions within the valid bound.
Args:
@@ -83,7 +83,7 @@ class ClipActionV0(LambdaActionV0):
)
class RescaleActionV0(LambdaActionV0):
class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
@@ -107,7 +107,7 @@ class RescaleActionV0(LambdaActionV0):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):

View File

@@ -31,7 +31,7 @@ from gymnasium.experimental.wrappers.utils import RunningMeanStd
from gymnasium.spaces import Box, Dict, utils
class LambdaObservationV0(gym.ObservationWrapper):
class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all observations.
@@ -50,9 +50,9 @@ class LambdaObservationV0(gym.ObservationWrapper):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
func: Callable[[ObsType], Any],
observation_space: gym.Space | None,
observation_space: gym.Space[WrapperObsType] | None,
):
"""Constructor for the lambda observation wrapper.
@@ -72,7 +72,7 @@ class LambdaObservationV0(gym.ObservationWrapper):
return self.func(observation)
class FilterObservationV0(LambdaObservationV0):
class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Filter Dict observation space by the keys.
Example:
@@ -91,7 +91,9 @@ class FilterObservationV0(LambdaObservationV0):
({'time': 0}, 1.0, False, False, {})
"""
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
def __init__(
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
):
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
assert isinstance(filter_keys, Sequence)
@@ -169,7 +171,7 @@ class FilterObservationV0(LambdaObservationV0):
self.filter_keys: Final[Sequence[str | int]] = filter_keys
class FlattenObservationV0(LambdaObservationV0):
class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Observation wrapper that flattens the observation.
Example:
@@ -186,7 +188,7 @@ class FlattenObservationV0(LambdaObservationV0):
(27648,)
"""
def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
super().__init__(
env,
@@ -195,7 +197,7 @@ class FlattenObservationV0(LambdaObservationV0):
)
class GrayscaleObservationV0(LambdaObservationV0):
class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Observation wrapper that converts an RGB image to grayscale.
The :attr:`keep_dim` will keep the channel dimension
@@ -214,7 +216,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
(96, 96, 1)
"""
def __init__(self, env: gym.Env, keep_dim: bool = False):
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale."""
assert isinstance(env.observation_space, spaces.Box)
assert (
@@ -258,7 +260,7 @@ class GrayscaleObservationV0(LambdaObservationV0):
)
class ResizeObservationV0(LambdaObservationV0):
class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Resizes image observations using OpenCV to shape.
Example:
@@ -272,7 +274,7 @@ class ResizeObservationV0(LambdaObservationV0):
(32, 32, 3)
"""
def __init__(self, env: gym.Env, shape: tuple[int, ...]):
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape."""
assert isinstance(env.observation_space, spaces.Box)
assert len(env.observation_space.shape) in [2, 3]
@@ -304,7 +306,7 @@ class ResizeObservationV0(LambdaObservationV0):
)
class ReshapeObservationV0(LambdaObservationV0):
class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Reshapes array based observations to shapes.
Example:
@@ -318,7 +320,7 @@ class ReshapeObservationV0(LambdaObservationV0):
(24, 4, 96, 1, 3)
"""
def __init__(self, env: gym.Env, shape: int | tuple[int, ...]):
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
"""Constructor for env with Box observation space that has a shape product equal to the new shape product."""
assert isinstance(env.observation_space, spaces.Box)
assert np.product(shape) == np.product(env.observation_space.shape)
@@ -337,7 +339,7 @@ class ReshapeObservationV0(LambdaObservationV0):
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
class RescaleObservationV0(LambdaObservationV0):
class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Linearly rescales observation to between a minimum and maximum value.
Example:
@@ -353,7 +355,7 @@ class RescaleObservationV0(LambdaObservationV0):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
@@ -402,10 +404,10 @@ class RescaleObservationV0(LambdaObservationV0):
)
class DtypeObservationV0(LambdaObservationV0):
class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: gym.Env, dtype: Any):
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
"""Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces."""
assert isinstance(
env.observation_space,
@@ -446,7 +448,7 @@ class DtypeObservationV0(LambdaObservationV0):
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
class PixelObservationV0(LambdaObservationV0):
class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"""Augment observations by pixel values.
Observations of this wrapper will be dictionaries of images.
@@ -499,7 +501,7 @@ class PixelObservationV0(LambdaObservationV0):
)
class NormalizeObservationV0(ObservationWrapper):
class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
@@ -511,7 +513,7 @@ class NormalizeObservationV0(ObservationWrapper):
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:

View File

@@ -3,7 +3,6 @@
* ``LambdaReward`` - Transforms the reward by a function
* ``ClipReward`` - Clips the reward between a minimum and maximum value
"""
from __future__ import annotations
from typing import Any, Callable, SupportsFloat
@@ -11,12 +10,12 @@ from typing import Any, Callable, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers.utils import RunningMeanStd
class LambdaRewardV0(gym.RewardWrapper):
class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
"""A reward wrapper that allows a custom function to modify the step reward.
Example:
@@ -32,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
func: Callable[[SupportsFloat], SupportsFloat],
):
"""Initialize LambdaRewardV0 wrapper.
@@ -54,7 +53,7 @@ class LambdaRewardV0(gym.RewardWrapper):
return self.func(reward)
class ClipRewardV0(LambdaRewardV0):
class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
Example:
@@ -70,7 +69,7 @@ class ClipRewardV0(LambdaRewardV0):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
@@ -93,7 +92,7 @@ class ClipRewardV0(LambdaRewardV0):
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
class NormalizeRewardV0(gym.Wrapper):
class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
@@ -109,7 +108,7 @@ class NormalizeRewardV0(gym.Wrapper):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
gamma: float = 0.99,
epsilon: float = 1e-8,
):
@@ -138,8 +137,8 @@ class NormalizeRewardV0(gym.Wrapper):
self._update_running_mean = setting
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, normalizing the reward returned."""
obs, reward, terminated, truncated, info = super().step(action)
self.discounted_reward = self.discounted_reward * self.gamma * (
@@ -147,7 +146,7 @@ class NormalizeRewardV0(gym.Wrapper):
) + float(reward)
return obs, self.normalize(float(reward)), terminated, truncated, info
def normalize(self, reward):
def normalize(self, reward: SupportsFloat):
"""Normalizes the rewards with the running mean rewards and their variance."""
if self._update_running_mean:
self.rewards_running_means.update(self.discounted_reward)

View File

@@ -14,11 +14,11 @@ import numpy as np
import gymnasium as gym
from gymnasium import error, logger
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import DependencyNotInstalled
class RenderCollectionV0(gym.Wrapper):
class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
def __init__(
@@ -52,8 +52,8 @@ class RenderCollectionV0(gym.Wrapper):
return f"{self.env.render_mode}_list"
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Perform a step in the base environment and collect a frame."""
output = super().step(action)
self.frame_list.append(super().render())
@@ -61,7 +61,7 @@ class RenderCollectionV0(gym.Wrapper):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
output = super().reset(seed=seed, options=options)
@@ -71,7 +71,7 @@ class RenderCollectionV0(gym.Wrapper):
return output
def render(self) -> RenderFrame | list[RenderFrame] | None:
def render(self) -> list[RenderFrame]:
"""Returns the collection of frames and, if pop_frames = True, clears it."""
frames = self.frame_list
if self.pop_frames:
@@ -80,7 +80,7 @@ class RenderCollectionV0(gym.Wrapper):
return frames
class RecordVideoV0(gym.Wrapper):
class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode.
@@ -98,10 +98,10 @@ class RecordVideoV0(gym.Wrapper):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
video_folder: str,
episode_trigger: Callable[[int], bool] = None,
step_trigger: Callable[[int], bool] = None,
episode_trigger: Callable[[int], bool] | None = None,
step_trigger: Callable[[int], bool] | None = None,
video_length: int = 0,
name_prefix: str = "rl-video",
disable_logger: bool = False,
@@ -155,13 +155,13 @@ class RecordVideoV0(gym.Wrapper):
)
os.makedirs(self.video_folder, exist_ok=True)
self.name_prefix = name_prefix
self._video_name = None
self.frames_per_sec = self.metadata.get("render_fps", 30)
self.video_length = video_length if video_length != 0 else float("inf")
self.recording = False
self.recorded_frames = []
self.render_history = []
self.name_prefix: str = name_prefix
self._video_name: str | None = None
self.frames_per_sec: int = self.metadata.get("render_fps", 30)
self.video_length: int = video_length if video_length != 0 else float("inf")
self.recording: bool = False
self.recorded_frames: list[RenderFrame] = []
self.render_history: list[RenderFrame] = []
self.step_id = -1
self.episode_id = -1
@@ -187,7 +187,7 @@ class RecordVideoV0(gym.Wrapper):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment and eventually starts a new recording."""
obs, info = super().reset(seed=seed, options=options)
self.episode_id += 1
@@ -205,8 +205,8 @@ class RecordVideoV0(gym.Wrapper):
return obs, info
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
obs, rew, terminated, truncated, info = self.env.step(action)
self.step_id += 1
@@ -221,7 +221,7 @@ class RecordVideoV0(gym.Wrapper):
return obs, rew, terminated, truncated, info
def start_recording(self, video_name):
def start_recording(self, video_name: str):
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
if self.recording:
self.stop_recording()
@@ -252,7 +252,7 @@ class RecordVideoV0(gym.Wrapper):
self.recording = False
self._video_name = None
def render(self):
def render(self) -> RenderFrame | list[RenderFrame]:
"""Compute the render frames as specified by render_mode attribute during initialization of the environment."""
render_out = super().render()
if self.recording and isinstance(render_out, List):
@@ -277,7 +277,7 @@ class RecordVideoV0(gym.Wrapper):
logger.warn("Unable to save last video! Did you call close()?")
class HumanRenderingV0(gym.Wrapper):
class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce
@@ -311,7 +311,7 @@ class HumanRenderingV0(gym.Wrapper):
[]
"""
def __init__(self, env):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialize a :class:`HumanRendering` instance.
Args:
@@ -339,9 +339,7 @@ class HumanRenderingV0(gym.Wrapper):
"""Always returns ``'human'``."""
return "human"
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Perform a step in the base environment and render a frame to the screen."""
result = super().step(action)
self._render_frame()
@@ -349,13 +347,13 @@ class HumanRenderingV0(gym.Wrapper):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the base environment and render a frame to the screen."""
result = super().reset(seed=seed, options=options)
self._render_frame()
return result
def render(self):
def render(self) -> None:
"""This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
return None

View File

@@ -4,18 +4,20 @@ from __future__ import annotations
from typing import Any
import gymnasium as gym
from gymnasium.core import ActionWrapper, ActType, WrapperActType, WrapperObsType
from gymnasium.core import ActionWrapper, ActType, ObsType
from gymnasium.error import InvalidProbability
class StickyActionV0(ActionWrapper):
class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
"""Wrapper which adds a probability of repeating the previous action.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
in Section 5.2 on page 12.
"""
def __init__(self, env: gym.Env, repeat_action_probability: float):
def __init__(
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float
):
"""Initialize StickyAction wrapper.
Args:
@@ -29,17 +31,17 @@ class StickyActionV0(ActionWrapper):
super().__init__(env)
self.repeat_action_probability = repeat_action_probability
self.last_action: WrapperActType | None = None
self.last_action: ActType | None = None
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment."""
self.last_action = None
return super().reset(seed=seed, options=options)
def action(self, action: WrapperActType) -> ActType:
def action(self, action: ActType) -> ActType:
"""Execute the action."""
if (
self.last_action is not None

View File

@@ -26,10 +26,10 @@ from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
class DelayObservationV0(gym.ObservationWrapper):
class DelayObservationV0(gym.ObservationWrapper[ObsType, ActType, ObsType]):
"""Wrapper which adds a delay to the returned observation."""
def __init__(self, env: gym.Env, delay: int):
def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
"""Initialize the DelayObservation wrapper.
Args:
@@ -45,13 +45,13 @@ class DelayObservationV0(gym.ObservationWrapper):
assert 0 < delay
self.delay: Final[int] = delay
self.observation_queue: Final[deque] = deque()
self.observation_queue: Final[deque[ObsType]] = deque()
super().__init__(env)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment, clearing the observation queue."""
self.observation_queue.clear()
@@ -67,7 +67,7 @@ class DelayObservationV0(gym.ObservationWrapper):
return jp.zeros_like(observation)
class TimeAwareObservationV0(gym.ObservationWrapper):
class TimeAwareObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]):
"""Augment the observation with time information of the episode.
Time can be represented as a normalized value between [0,1]
@@ -104,7 +104,7 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
flatten: bool = False,
normalize_time: bool = True,
*,
@@ -212,7 +212,7 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
return super().reset(seed=seed, options=options)
class FrameStackObservationV0(gym.Wrapper):
class FrameStackObservationV0(gym.Wrapper[WrapperObsType, ActType, ObsType, ActType]):
"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
@@ -302,7 +302,7 @@ class FrameStackObservationV0(gym.Wrapper):
info,
)
def _init_stacked_obs(self) -> deque:
def _init_stacked_obs(self) -> deque[ObsType]:
return deque(
iterate(
self.observation_space,

View File

@@ -8,6 +8,13 @@ import numpy as np
from gymnasium import Space, error, logger, spaces
__all__ = [
"env_render_passive_checker",
"env_reset_passive_checker",
"env_step_passive_checker",
]
def _check_box_observation_space(observation_space: spaces.Box):
"""Checks that a :class:`Box` observation space is defined in a sensible way.

View File

@@ -23,6 +23,7 @@ def test_record_video_using_default_trigger():
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert env.episode_trigger is not None
assert len(mp4_files) == sum(
env.episode_trigger(i) for i in range(episode_count + 1)
)
@@ -46,6 +47,7 @@ def test_record_video_while_rendering():
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert env.episode_trigger is not None
assert len(mp4_files) == sum(
env.episode_trigger(i) for i in range(episode_count + 1)
)