mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Stateful wrappers [DelayObservation ,TimeAwareObservation, StickyAction] (#165)
This commit is contained in:
@@ -19,6 +19,19 @@
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||
```
|
||||
|
||||
## Observation Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
||||
```
|
||||
|
||||
## Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
||||
```
|
||||
|
||||
## Common Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
|
@@ -73,6 +73,10 @@ class MissingArgument(Error):
|
||||
"""Raised when a required argument in the initializer is missing."""
|
||||
|
||||
|
||||
class InvalidProbability(Error):
|
||||
"""Raised when given an invalid value for a probability."""
|
||||
|
||||
|
||||
class InvalidBound(Error):
|
||||
"""Raised when the clipping an array with invalid upper and/or lower bound."""
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
"""Root __init__ of the gym dev_wrappers."""
|
||||
"""Root __init__ of the gym experimental wrappers."""
|
||||
|
||||
|
||||
from gymnasium.experimental.functional import FuncEnv
|
||||
|
@@ -12,15 +12,23 @@ from gymnasium.experimental.wrappers.lambda_action import (
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
|
||||
from gymnasium.experimental.wrappers.time_aware_observation import (
|
||||
TimeAwareObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0
|
||||
|
||||
__all__ = [
|
||||
"ArgType",
|
||||
# Lambda Action
|
||||
"LambdaActionV0",
|
||||
"StickyActionV0",
|
||||
"ClipActionV0",
|
||||
"RescaleActionV0",
|
||||
# Lambda Observation
|
||||
"LambdaObservationV0",
|
||||
"DelayObservationV0",
|
||||
"TimeAwareObservationV0",
|
||||
# Lambda Reward
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
|
35
gymnasium/experimental/wrappers/delay_observation.py
Normal file
35
gymnasium/experimental/wrappers/delay_observation.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Wrapper for delaying the returned observation."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import jumpy as jp
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
|
||||
class DelayObservationV0(gym.ObservationWrapper):
|
||||
"""Wrapper which adds a delay to the returned observation."""
|
||||
|
||||
def __init__(self, env: gym.Env, delay: int):
|
||||
"""Initialize the DelayObservation wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): the wrapped environment
|
||||
delay (int): number of timesteps for delaying the observation.
|
||||
Before reaching the `delay` number of timesteps,
|
||||
returned observation is an array of zeros with the
|
||||
same shape of the observation space.
|
||||
"""
|
||||
super().__init__(env)
|
||||
self.delay = delay
|
||||
self.observation_queue = deque()
|
||||
|
||||
def observation(self, observation: ObsType) -> ObsType:
|
||||
"""Return the delayed observation."""
|
||||
self.observation_queue.append(observation)
|
||||
|
||||
if len(self.observation_queue) > self.delay:
|
||||
return self.observation_queue.popleft()
|
||||
|
||||
return jp.zeros_like(observation)
|
40
gymnasium/experimental/wrappers/sticky_action.py
Normal file
40
gymnasium/experimental/wrappers/sticky_action.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Wrapper which adds a probability of repeating the previous executed action."""
|
||||
from typing import Union
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType
|
||||
from gymnasium.error import InvalidProbability
|
||||
|
||||
|
||||
class StickyActionV0(gym.ActionWrapper):
|
||||
"""Wrapper which adds a probability of repeating the previous action."""
|
||||
|
||||
def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]):
|
||||
"""Initialize StickyAction wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): the wrapped environment
|
||||
repeat_action_probability (int | float): a proability of repeating the old action.
|
||||
"""
|
||||
if not 0 <= repeat_action_probability < 1:
|
||||
raise InvalidProbability(
|
||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||
)
|
||||
super().__init__(env)
|
||||
self.repeat_action_probability = repeat_action_probability
|
||||
self.old_action = None
|
||||
|
||||
def action(self, action: ActType):
|
||||
"""Execute the action."""
|
||||
if (
|
||||
self.old_action is not None
|
||||
and self.np_random.uniform() < self.repeat_action_probability
|
||||
):
|
||||
action = self.old_action
|
||||
self.old_action = action
|
||||
return action
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment."""
|
||||
self.old_action = None
|
||||
return super().reset(**kwargs)
|
113
gymnasium/experimental/wrappers/time_aware_observation.py
Normal file
113
gymnasium/experimental/wrappers/time_aware_observation.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Wrapper for adding time aware observations to environment observation."""
|
||||
from collections import OrderedDict
|
||||
|
||||
import gymnasium as gym
|
||||
import gymnasium.spaces as spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces import Box, Dict
|
||||
|
||||
|
||||
class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
"""Augment the observation with time information of the episode.
|
||||
|
||||
Time can be represented as a normalized value between [0,1]
|
||||
or by the number of timesteps remaining before truncation occurs.
|
||||
|
||||
Example:
|
||||
>>> import gym
|
||||
>>> from gym.wrappers import TimeAwareObservationV0
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservationV0(env)
|
||||
>>> env.observation_space
|
||||
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
|
||||
>>> _ = env.reset()
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
OrderedDict([('obs',
|
||||
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
|
||||
... ('time', array([0.002]))])
|
||||
|
||||
Flatten observation space example:
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservationV0(env, flatten=True)
|
||||
>>> env.observation_space
|
||||
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
|
||||
>>> _ = env.reset()
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, flatten=False, normalize_time=True):
|
||||
"""Initialize :class:`TimeAwareObservationV0`.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
flatten: Flatten the observation to a `Box` of a single dimension
|
||||
normalize_time: if `True` return time in the range [0,1]
|
||||
otherwise return time as remaining timesteps before truncation
|
||||
"""
|
||||
super().__init__(env)
|
||||
self.flatten = flatten
|
||||
self.normalize_time = normalize_time
|
||||
self.max_timesteps = getattr(env, "_max_episode_steps")
|
||||
|
||||
if self.normalize_time:
|
||||
self._get_time_observation = lambda: self.timesteps / self.max_timesteps
|
||||
time_space = Box(0, 1)
|
||||
else:
|
||||
self._get_time_observation = lambda: self.max_timesteps - self.timesteps
|
||||
time_space = Box(0, self.max_timesteps)
|
||||
|
||||
self.time_aware_observation_space = Dict(
|
||||
obs=env.observation_space, time=time_space
|
||||
)
|
||||
|
||||
if self.flatten:
|
||||
self.observation_space = spaces.flatten_space(
|
||||
self.time_aware_observation_space
|
||||
)
|
||||
self._observation_postprocess = lambda observation: spaces.flatten(
|
||||
self.time_aware_observation_space, observation
|
||||
)
|
||||
else:
|
||||
self.observation_space = self.time_aware_observation_space
|
||||
self._observation_postprocess = lambda observation: observation
|
||||
|
||||
def observation(self, observation: ObsType):
|
||||
"""Adds to the observation with the current time information.
|
||||
|
||||
Args:
|
||||
observation: The observation to add the time step to
|
||||
|
||||
Returns:
|
||||
The observation with the time information appended to
|
||||
"""
|
||||
time_observation = self._get_time_observation()
|
||||
observation = OrderedDict(obs=observation, time=time_observation)
|
||||
|
||||
return self._observation_postprocess(observation)
|
||||
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment, incrementing the time step.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The environment's step using the action.
|
||||
"""
|
||||
self.timesteps += 1
|
||||
observation, reward, terminated, truncated, info = super().step(action)
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment setting the time to zero.
|
||||
|
||||
Args:
|
||||
**kwargs: Kwargs to apply to env.reset()
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self.timesteps = 0
|
||||
return super().reset(**kwargs)
|
36
tests/experimental/wrappers/test_delay_observation.py
Normal file
36
tests/experimental/wrappers/test_delay_observation.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import DelayObservationV0
|
||||
|
||||
SEED = 42
|
||||
|
||||
DELAY = 3
|
||||
NUM_STEPS = 5
|
||||
|
||||
|
||||
def test_delay_observation():
|
||||
env = gym.make("CartPole-v1")
|
||||
env.action_space.seed(SEED)
|
||||
env.reset(seed=SEED)
|
||||
|
||||
undelayed_observations = []
|
||||
for _ in range(NUM_STEPS):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
undelayed_observations.append(obs)
|
||||
|
||||
env.action_space.seed(SEED)
|
||||
env.reset(seed=SEED)
|
||||
env = DelayObservationV0(env, delay=DELAY)
|
||||
|
||||
delayed_observations = []
|
||||
for i in range(NUM_STEPS):
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
if i < DELAY - 1:
|
||||
assert np.all(obs == 0)
|
||||
delayed_observations.append(obs)
|
||||
|
||||
assert np.alltrue(
|
||||
np.array(delayed_observations[DELAY:])
|
||||
== np.array(undelayed_observations[: DELAY - 1])
|
||||
)
|
41
tests/experimental/wrappers/test_sticky_action.py
Normal file
41
tests/experimental/wrappers/test_sticky_action.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Test suite for StickyActionV0."""
|
||||
import pytest
|
||||
|
||||
from gymnasium.error import InvalidProbability
|
||||
from gymnasium.experimental.wrappers import StickyActionV0
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
SEED = 42
|
||||
|
||||
DELAY = 3
|
||||
NUM_STEPS = 10
|
||||
|
||||
|
||||
def step_fn(self, action):
|
||||
return action
|
||||
|
||||
|
||||
def test_sticky_action():
|
||||
env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5)
|
||||
env.reset(seed=SEED)
|
||||
env.action_space.seed(SEED)
|
||||
|
||||
previous_action = None
|
||||
for _ in range(NUM_STEPS):
|
||||
input_action = env.action_space.sample()
|
||||
executed_action = env.step(input_action)
|
||||
|
||||
if executed_action != input_action:
|
||||
assert executed_action == previous_action
|
||||
else:
|
||||
assert executed_action == input_action
|
||||
|
||||
previous_action = input_action
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5])
|
||||
def test_sticky_action_raise(repeat_action_probability):
|
||||
with pytest.raises(InvalidProbability):
|
||||
StickyActionV0(
|
||||
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
||||
)
|
98
tests/experimental/wrappers/test_time_aware_observation.py
Normal file
98
tests/experimental/wrappers/test_time_aware_observation.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Test suite for TimeAwareobservationV0."""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
||||
from gymnasium.spaces import Box, Dict
|
||||
|
||||
NUM_STEPS = 20
|
||||
SEED = 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
[
|
||||
gym.make("CartPole-v1", disable_env_checker=True),
|
||||
gym.make("CarRacing-v2", disable_env_checker=True),
|
||||
],
|
||||
)
|
||||
def test_time_aware_observation_creation(env):
|
||||
"""Test TimeAwareObservationV0 wrapper creation.
|
||||
|
||||
This test checks if wrapped env with TimeAwareObservationV0
|
||||
is correctly created.
|
||||
"""
|
||||
wrapped_env = TimeAwareObservationV0(env)
|
||||
obs, _ = wrapped_env.reset()
|
||||
|
||||
assert isinstance(wrapped_env.observation_space, Dict)
|
||||
assert isinstance(obs, OrderedDict)
|
||||
assert np.all(obs["time"] == 0)
|
||||
assert env.observation_space == wrapped_env.observation_space["obs"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("normalize_time", [True, False])
|
||||
@pytest.mark.parametrize("flatten", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
[
|
||||
gym.make("CartPole-v1", disable_env_checker=True),
|
||||
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
|
||||
],
|
||||
)
|
||||
def test_time_aware_observation_step(env, flatten, normalize_time):
|
||||
"""Test TimeAwareObservationV0 step.
|
||||
|
||||
This test checks if wrapped env with TimeAwareObservationV0
|
||||
steps correctly.
|
||||
"""
|
||||
env.action_space.seed(SEED)
|
||||
max_timesteps = env._max_episode_steps
|
||||
|
||||
wrapped_env = TimeAwareObservationV0(
|
||||
env, flatten=flatten, normalize_time=normalize_time
|
||||
)
|
||||
wrapped_env.reset(seed=SEED)
|
||||
|
||||
for timestep in range(1, NUM_STEPS):
|
||||
action = env.action_space.sample()
|
||||
observation, _, terminated, _, _ = wrapped_env.step(action)
|
||||
|
||||
expected_time_obs = (
|
||||
timestep / max_timesteps if normalize_time else max_timesteps - timestep
|
||||
)
|
||||
|
||||
if flatten:
|
||||
assert np.allclose(observation[-1], expected_time_obs)
|
||||
else:
|
||||
assert np.allclose(observation["time"], expected_time_obs)
|
||||
|
||||
if terminated:
|
||||
break
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
[
|
||||
gym.make("CartPole-v1", disable_env_checker=True),
|
||||
gym.make("CarRacing-v2", disable_env_checker=True),
|
||||
],
|
||||
)
|
||||
def test_time_aware_observation_creation_flatten(env):
|
||||
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
|
||||
|
||||
This test checks if wrapped env with TimeAwareObservationV0
|
||||
is correctly created when the `flatten` parameter is set to `True`.
|
||||
When flattened, the observation space should be a 1 dimension `Box`
|
||||
with time appended to the end.
|
||||
"""
|
||||
wrapped_env = TimeAwareObservationV0(env, flatten=True)
|
||||
obs, _ = wrapped_env.reset()
|
||||
|
||||
assert isinstance(wrapped_env.observation_space, Box)
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]
|
Reference in New Issue
Block a user