mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
305 lines
11 KiB
Python
305 lines
11 KiB
Python
"""A collection of stateful observation wrappers.
|
|
|
|
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
|
|
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
|
|
* ``FrameStackObservationV0`` - Frame stack the observations
|
|
* ``AtariPreprocessingV0`` - Preprocessing wrapper for atari environments
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from collections import deque
|
|
from typing import Any, SupportsFloat
|
|
from typing_extensions import Final
|
|
|
|
import jumpy as jp
|
|
import numpy as np
|
|
|
|
import gymnasium as gym
|
|
import gymnasium.spaces as spaces
|
|
from gymnasium import Env
|
|
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
|
|
from gymnasium.spaces import Box, Dict, MultiBinary, MultiDiscrete, Tuple
|
|
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
|
|
|
|
|
class DelayObservationV0(gym.ObservationWrapper):
|
|
"""Wrapper which adds a delay to the returned observation."""
|
|
|
|
def __init__(self, env: gym.Env, delay: int):
|
|
"""Initialize the DelayObservation wrapper.
|
|
|
|
Args:
|
|
env (Env): the wrapped environment
|
|
delay (int): number of timesteps for delaying the observation.
|
|
Before reaching the `delay` number of timesteps,
|
|
returned observation is an array of zeros with the
|
|
same shape of the observation space.
|
|
"""
|
|
assert isinstance(
|
|
env.observation_space, (Box, MultiBinary, MultiDiscrete)
|
|
), type(env.observation_space)
|
|
assert 0 < delay
|
|
|
|
self.delay: Final[int] = delay
|
|
self.observation_queue: Final[deque] = deque()
|
|
|
|
super().__init__(env)
|
|
|
|
def reset(
|
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
"""Resets the environment, clearing the observation queue."""
|
|
self.observation_queue.clear()
|
|
|
|
return super().reset(seed=seed, options=options)
|
|
|
|
def observation(self, observation: ObsType) -> ObsType:
|
|
"""Return the delayed observation."""
|
|
self.observation_queue.append(observation)
|
|
|
|
if len(self.observation_queue) > self.delay:
|
|
return self.observation_queue.popleft()
|
|
|
|
return jp.zeros_like(observation)
|
|
|
|
|
|
class TimeAwareObservationV0(gym.ObservationWrapper):
|
|
"""Augment the observation with time information of the episode.
|
|
|
|
Time can be represented as a normalized value between [0,1]
|
|
or by the number of timesteps remaining before truncation occurs.
|
|
|
|
For environments with ``Dict`` or ``Tuple`` observation spaces, by default,
|
|
the time information is automatically added in the key `"time"` and
|
|
as the final element in the tuple.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
|
>>> env = gym.make('CartPole-v1')
|
|
>>> env = TimeAwareObservationV0(env)
|
|
>>> env.observation_space
|
|
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
|
|
>>> _ = env.reset()
|
|
>>> env.step(env.action_space.sample())[0]
|
|
OrderedDict([('obs',
|
|
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
|
|
... ('time', array([0.002]))])
|
|
|
|
Flatten observation space example:
|
|
>>> env = gym.make('CartPole-v1')
|
|
>>> env = TimeAwareObservationV0(env, flatten=True)
|
|
>>> env.observation_space
|
|
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
|
|
>>> _ = env.reset()
|
|
>>> env.step(env.action_space.sample())[0]
|
|
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: gym.Env,
|
|
flatten: bool = False,
|
|
normalize_time: bool = True,
|
|
*,
|
|
dict_time_key: str = "time",
|
|
):
|
|
"""Initialize :class:`TimeAwareObservationV0`.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
flatten: Flatten the observation to a `Box` of a single dimension
|
|
normalize_time: if `True` return time in the range [0,1]
|
|
otherwise return time as remaining timesteps before truncation
|
|
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
|
|
"""
|
|
super().__init__(env)
|
|
self.flatten: Final[bool] = flatten
|
|
self.normalize_time: Final[bool] = normalize_time
|
|
|
|
if hasattr(env, "_max_episode_steps"):
|
|
self.max_timesteps = getattr(env, "_max_episode_steps")
|
|
elif env.spec is not None and env.spec.max_episode_steps is not None:
|
|
self.max_timesteps = env.spec.max_episode_steps
|
|
else:
|
|
raise ValueError(
|
|
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
|
)
|
|
|
|
self.timesteps: int = 0
|
|
|
|
# Find the normalized time space
|
|
if self.normalize_time:
|
|
self._time_preprocess_func = lambda time: time / self.max_timesteps
|
|
time_space = Box(0.0, 1.0)
|
|
else:
|
|
self._time_preprocess_func = lambda time: self.max_timesteps - time
|
|
time_space = Box(0, self.max_timesteps, dtype=np.int32)
|
|
|
|
# Find the observation space
|
|
if isinstance(env.observation_space, Dict):
|
|
assert dict_time_key not in env.observation_space.keys()
|
|
observation_space = Dict(
|
|
{dict_time_key: time_space, **env.observation_space.spaces}
|
|
)
|
|
self._append_data_func = lambda obs, time: {dict_time_key: time, **obs}
|
|
elif isinstance(env.observation_space, Tuple):
|
|
observation_space = Tuple(env.observation_space.spaces + (time_space,))
|
|
self._append_data_func = lambda obs, time: obs + (time,)
|
|
else:
|
|
observation_space = Dict(obs=env.observation_space, time=time_space)
|
|
self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
|
|
|
|
# If to flatten the observation space
|
|
if self.flatten:
|
|
self.observation_space = spaces.flatten_space(observation_space)
|
|
self._obs_postprocess_func = lambda obs: spaces.flatten(
|
|
observation_space, obs
|
|
)
|
|
else:
|
|
self.observation_space = observation_space
|
|
self._obs_postprocess_func = lambda obs: obs
|
|
|
|
def observation(self, observation: ObsType) -> WrapperObsType:
|
|
"""Adds to the observation with the current time information.
|
|
|
|
Args:
|
|
observation: The observation to add the time step to
|
|
|
|
Returns:
|
|
The observation with the time information appended to
|
|
"""
|
|
return self._obs_postprocess_func(
|
|
self._append_data_func(
|
|
observation, self._time_preprocess_func(self.timesteps)
|
|
)
|
|
)
|
|
|
|
def step(
|
|
self, action: ActType
|
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
"""Steps through the environment, incrementing the time step.
|
|
|
|
Args:
|
|
action: The action to take
|
|
|
|
Returns:
|
|
The environment's step using the action.
|
|
"""
|
|
self.timesteps += 1
|
|
return super().step(action)
|
|
|
|
def reset(
|
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
"""Reset the environment setting the time to zero.
|
|
|
|
Args:
|
|
seed: The seed to reset the environment
|
|
options: The options used to reset the environment
|
|
|
|
Returns:
|
|
The reset environment
|
|
"""
|
|
self.timesteps = 0
|
|
|
|
return super().reset(seed=seed, options=options)
|
|
|
|
|
|
class FrameStackObservationV0(gym.Wrapper):
|
|
"""Observation wrapper that stacks the observations in a rolling manner.
|
|
|
|
For example, if the number of stacks is 4, then the returned observation contains
|
|
the most recent 4 observations. For environment 'Pendulum-v1', the original observation
|
|
is an array with shape [3], so if we stack 4 observations, the processed observation
|
|
has shape [4, 3].
|
|
|
|
Note:
|
|
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
|
|
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> env = gym.make('CarRacing-v1')
|
|
>>> env = FrameStack(env, 4)
|
|
>>> env.observation_space
|
|
Box(4, 96, 96, 3)
|
|
>>> obs = env.reset()
|
|
>>> obs.shape
|
|
(4, 96, 96, 3)
|
|
"""
|
|
|
|
def __init__(self, env: Env[ObsType, ActType], stack_size: int):
|
|
"""Observation wrapper that stacks the observations in a rolling manner.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
stack_size: The number of frames to stack
|
|
"""
|
|
assert np.issubdtype(type(stack_size), np.integer)
|
|
assert stack_size > 0
|
|
|
|
super().__init__(env)
|
|
|
|
self.observation_space = batch_space(env.observation_space, n=stack_size)
|
|
self.stack_size = stack_size
|
|
|
|
self.stacked_obs_array = create_empty_array(env.observation_space, n=stack_size)
|
|
self.stacked_obs = self._init_stacked_obs()
|
|
|
|
def step(
|
|
self, action: WrapperActType
|
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
"""Steps through the environment, appending the observation to the frame buffer.
|
|
|
|
Args:
|
|
action: The action to step through the environment with
|
|
|
|
Returns:
|
|
Stacked observations, reward, terminated, truncated, and info from the environment
|
|
"""
|
|
obs, reward, terminated, truncated, info = super().step(action)
|
|
self.stacked_obs.rotate(1)
|
|
self.stacked_obs[0] = obs
|
|
|
|
return (
|
|
concatenate(
|
|
self.observation_space, self.stacked_obs, self.stacked_obs_array
|
|
),
|
|
reward,
|
|
terminated,
|
|
truncated,
|
|
info,
|
|
)
|
|
|
|
def reset(
|
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
"""Reset the environment, returning the stacked observation and info.
|
|
|
|
Args:
|
|
seed: The environment seed
|
|
options: The reset options
|
|
|
|
Returns:
|
|
The stacked observations and info
|
|
"""
|
|
obs, info = super().reset(seed=seed, options=options)
|
|
self.stacked_obs = self._init_stacked_obs()
|
|
self.stacked_obs[0] = obs
|
|
|
|
return (
|
|
concatenate(
|
|
self.observation_space, self.stacked_obs, self.stacked_obs_array
|
|
),
|
|
info,
|
|
)
|
|
|
|
def _init_stacked_obs(self) -> deque:
|
|
return deque(
|
|
iterate(
|
|
self.observation_space,
|
|
create_empty_array(self.env.observation_space, n=self.stack_size),
|
|
)
|
|
)
|