mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Stateful wrappers [DelayObservation ,TimeAwareObservation, StickyAction] (#165)
This commit is contained in:
@@ -19,6 +19,19 @@
|
|||||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Observation Wrappers
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Action Wrappers
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
||||||
|
```
|
||||||
|
|
||||||
## Common Wrappers
|
## Common Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
|
@@ -73,6 +73,10 @@ class MissingArgument(Error):
|
|||||||
"""Raised when a required argument in the initializer is missing."""
|
"""Raised when a required argument in the initializer is missing."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidProbability(Error):
|
||||||
|
"""Raised when given an invalid value for a probability."""
|
||||||
|
|
||||||
|
|
||||||
class InvalidBound(Error):
|
class InvalidBound(Error):
|
||||||
"""Raised when the clipping an array with invalid upper and/or lower bound."""
|
"""Raised when the clipping an array with invalid upper and/or lower bound."""
|
||||||
|
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
"""Root __init__ of the gym dev_wrappers."""
|
"""Root __init__ of the gym experimental wrappers."""
|
||||||
|
|
||||||
|
|
||||||
from gymnasium.experimental.functional import FuncEnv
|
from gymnasium.experimental.functional import FuncEnv
|
||||||
|
@@ -12,15 +12,23 @@ from gymnasium.experimental.wrappers.lambda_action import (
|
|||||||
)
|
)
|
||||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
||||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||||
|
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
|
||||||
|
from gymnasium.experimental.wrappers.time_aware_observation import (
|
||||||
|
TimeAwareObservationV0,
|
||||||
|
)
|
||||||
|
from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ArgType",
|
"ArgType",
|
||||||
# Lambda Action
|
# Lambda Action
|
||||||
"LambdaActionV0",
|
"LambdaActionV0",
|
||||||
|
"StickyActionV0",
|
||||||
"ClipActionV0",
|
"ClipActionV0",
|
||||||
"RescaleActionV0",
|
"RescaleActionV0",
|
||||||
# Lambda Observation
|
# Lambda Observation
|
||||||
"LambdaObservationV0",
|
"LambdaObservationV0",
|
||||||
|
"DelayObservationV0",
|
||||||
|
"TimeAwareObservationV0",
|
||||||
# Lambda Reward
|
# Lambda Reward
|
||||||
"LambdaRewardV0",
|
"LambdaRewardV0",
|
||||||
"ClipRewardV0",
|
"ClipRewardV0",
|
||||||
|
35
gymnasium/experimental/wrappers/delay_observation.py
Normal file
35
gymnasium/experimental/wrappers/delay_observation.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Wrapper for delaying the returned observation."""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import jumpy as jp
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import ObsType
|
||||||
|
|
||||||
|
|
||||||
|
class DelayObservationV0(gym.ObservationWrapper):
|
||||||
|
"""Wrapper which adds a delay to the returned observation."""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, delay: int):
|
||||||
|
"""Initialize the DelayObservation wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): the wrapped environment
|
||||||
|
delay (int): number of timesteps for delaying the observation.
|
||||||
|
Before reaching the `delay` number of timesteps,
|
||||||
|
returned observation is an array of zeros with the
|
||||||
|
same shape of the observation space.
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self.delay = delay
|
||||||
|
self.observation_queue = deque()
|
||||||
|
|
||||||
|
def observation(self, observation: ObsType) -> ObsType:
|
||||||
|
"""Return the delayed observation."""
|
||||||
|
self.observation_queue.append(observation)
|
||||||
|
|
||||||
|
if len(self.observation_queue) > self.delay:
|
||||||
|
return self.observation_queue.popleft()
|
||||||
|
|
||||||
|
return jp.zeros_like(observation)
|
40
gymnasium/experimental/wrappers/sticky_action.py
Normal file
40
gymnasium/experimental/wrappers/sticky_action.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Wrapper which adds a probability of repeating the previous executed action."""
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import ActType
|
||||||
|
from gymnasium.error import InvalidProbability
|
||||||
|
|
||||||
|
|
||||||
|
class StickyActionV0(gym.ActionWrapper):
|
||||||
|
"""Wrapper which adds a probability of repeating the previous action."""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]):
|
||||||
|
"""Initialize StickyAction wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): the wrapped environment
|
||||||
|
repeat_action_probability (int | float): a proability of repeating the old action.
|
||||||
|
"""
|
||||||
|
if not 0 <= repeat_action_probability < 1:
|
||||||
|
raise InvalidProbability(
|
||||||
|
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||||
|
)
|
||||||
|
super().__init__(env)
|
||||||
|
self.repeat_action_probability = repeat_action_probability
|
||||||
|
self.old_action = None
|
||||||
|
|
||||||
|
def action(self, action: ActType):
|
||||||
|
"""Execute the action."""
|
||||||
|
if (
|
||||||
|
self.old_action is not None
|
||||||
|
and self.np_random.uniform() < self.repeat_action_probability
|
||||||
|
):
|
||||||
|
action = self.old_action
|
||||||
|
self.old_action = action
|
||||||
|
return action
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
"""Reset the environment."""
|
||||||
|
self.old_action = None
|
||||||
|
return super().reset(**kwargs)
|
113
gymnasium/experimental/wrappers/time_aware_observation.py
Normal file
113
gymnasium/experimental/wrappers/time_aware_observation.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Wrapper for adding time aware observations to environment observation."""
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import gymnasium.spaces as spaces
|
||||||
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.spaces import Box, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||||
|
"""Augment the observation with time information of the episode.
|
||||||
|
|
||||||
|
Time can be represented as a normalized value between [0,1]
|
||||||
|
or by the number of timesteps remaining before truncation occurs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import gym
|
||||||
|
>>> from gym.wrappers import TimeAwareObservationV0
|
||||||
|
>>> env = gym.make('CartPole-v1')
|
||||||
|
>>> env = TimeAwareObservationV0(env)
|
||||||
|
>>> env.observation_space
|
||||||
|
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
|
||||||
|
>>> _ = env.reset()
|
||||||
|
>>> env.step(env.action_space.sample())[0]
|
||||||
|
OrderedDict([('obs',
|
||||||
|
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
|
||||||
|
... ('time', array([0.002]))])
|
||||||
|
|
||||||
|
Flatten observation space example:
|
||||||
|
>>> env = gym.make('CartPole-v1')
|
||||||
|
>>> env = TimeAwareObservationV0(env, flatten=True)
|
||||||
|
>>> env.observation_space
|
||||||
|
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
|
||||||
|
>>> _ = env.reset()
|
||||||
|
>>> env.step(env.action_space.sample())[0]
|
||||||
|
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, flatten=False, normalize_time=True):
|
||||||
|
"""Initialize :class:`TimeAwareObservationV0`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to apply the wrapper
|
||||||
|
flatten: Flatten the observation to a `Box` of a single dimension
|
||||||
|
normalize_time: if `True` return time in the range [0,1]
|
||||||
|
otherwise return time as remaining timesteps before truncation
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self.flatten = flatten
|
||||||
|
self.normalize_time = normalize_time
|
||||||
|
self.max_timesteps = getattr(env, "_max_episode_steps")
|
||||||
|
|
||||||
|
if self.normalize_time:
|
||||||
|
self._get_time_observation = lambda: self.timesteps / self.max_timesteps
|
||||||
|
time_space = Box(0, 1)
|
||||||
|
else:
|
||||||
|
self._get_time_observation = lambda: self.max_timesteps - self.timesteps
|
||||||
|
time_space = Box(0, self.max_timesteps)
|
||||||
|
|
||||||
|
self.time_aware_observation_space = Dict(
|
||||||
|
obs=env.observation_space, time=time_space
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.flatten:
|
||||||
|
self.observation_space = spaces.flatten_space(
|
||||||
|
self.time_aware_observation_space
|
||||||
|
)
|
||||||
|
self._observation_postprocess = lambda observation: spaces.flatten(
|
||||||
|
self.time_aware_observation_space, observation
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.observation_space = self.time_aware_observation_space
|
||||||
|
self._observation_postprocess = lambda observation: observation
|
||||||
|
|
||||||
|
def observation(self, observation: ObsType):
|
||||||
|
"""Adds to the observation with the current time information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: The observation to add the time step to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The observation with the time information appended to
|
||||||
|
"""
|
||||||
|
time_observation = self._get_time_observation()
|
||||||
|
observation = OrderedDict(obs=observation, time=time_observation)
|
||||||
|
|
||||||
|
return self._observation_postprocess(observation)
|
||||||
|
|
||||||
|
def step(self, action: ActType):
|
||||||
|
"""Steps through the environment, incrementing the time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to take
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The environment's step using the action.
|
||||||
|
"""
|
||||||
|
self.timesteps += 1
|
||||||
|
observation, reward, terminated, truncated, info = super().step(action)
|
||||||
|
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
"""Reset the environment setting the time to zero.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Kwargs to apply to env.reset()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reset environment
|
||||||
|
"""
|
||||||
|
self.timesteps = 0
|
||||||
|
return super().reset(**kwargs)
|
36
tests/experimental/wrappers/test_delay_observation.py
Normal file
36
tests/experimental/wrappers/test_delay_observation.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.experimental.wrappers import DelayObservationV0
|
||||||
|
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
DELAY = 3
|
||||||
|
NUM_STEPS = 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_delay_observation():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
|
undelayed_observations = []
|
||||||
|
for _ in range(NUM_STEPS):
|
||||||
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
undelayed_observations.append(obs)
|
||||||
|
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
env = DelayObservationV0(env, delay=DELAY)
|
||||||
|
|
||||||
|
delayed_observations = []
|
||||||
|
for i in range(NUM_STEPS):
|
||||||
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||||
|
if i < DELAY - 1:
|
||||||
|
assert np.all(obs == 0)
|
||||||
|
delayed_observations.append(obs)
|
||||||
|
|
||||||
|
assert np.alltrue(
|
||||||
|
np.array(delayed_observations[DELAY:])
|
||||||
|
== np.array(undelayed_observations[: DELAY - 1])
|
||||||
|
)
|
41
tests/experimental/wrappers/test_sticky_action.py
Normal file
41
tests/experimental/wrappers/test_sticky_action.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Test suite for StickyActionV0."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gymnasium.error import InvalidProbability
|
||||||
|
from gymnasium.experimental.wrappers import StickyActionV0
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
DELAY = 3
|
||||||
|
NUM_STEPS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def step_fn(self, action):
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
def test_sticky_action():
|
||||||
|
env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
|
||||||
|
previous_action = None
|
||||||
|
for _ in range(NUM_STEPS):
|
||||||
|
input_action = env.action_space.sample()
|
||||||
|
executed_action = env.step(input_action)
|
||||||
|
|
||||||
|
if executed_action != input_action:
|
||||||
|
assert executed_action == previous_action
|
||||||
|
else:
|
||||||
|
assert executed_action == input_action
|
||||||
|
|
||||||
|
previous_action = input_action
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5])
|
||||||
|
def test_sticky_action_raise(repeat_action_probability):
|
||||||
|
with pytest.raises(InvalidProbability):
|
||||||
|
StickyActionV0(
|
||||||
|
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
||||||
|
)
|
98
tests/experimental/wrappers/test_time_aware_observation.py
Normal file
98
tests/experimental/wrappers/test_time_aware_observation.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Test suite for TimeAwareobservationV0."""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
||||||
|
from gymnasium.spaces import Box, Dict
|
||||||
|
|
||||||
|
NUM_STEPS = 20
|
||||||
|
SEED = 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env",
|
||||||
|
[
|
||||||
|
gym.make("CartPole-v1", disable_env_checker=True),
|
||||||
|
gym.make("CarRacing-v2", disable_env_checker=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_time_aware_observation_creation(env):
|
||||||
|
"""Test TimeAwareObservationV0 wrapper creation.
|
||||||
|
|
||||||
|
This test checks if wrapped env with TimeAwareObservationV0
|
||||||
|
is correctly created.
|
||||||
|
"""
|
||||||
|
wrapped_env = TimeAwareObservationV0(env)
|
||||||
|
obs, _ = wrapped_env.reset()
|
||||||
|
|
||||||
|
assert isinstance(wrapped_env.observation_space, Dict)
|
||||||
|
assert isinstance(obs, OrderedDict)
|
||||||
|
assert np.all(obs["time"] == 0)
|
||||||
|
assert env.observation_space == wrapped_env.observation_space["obs"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("normalize_time", [True, False])
|
||||||
|
@pytest.mark.parametrize("flatten", [False, True])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env",
|
||||||
|
[
|
||||||
|
gym.make("CartPole-v1", disable_env_checker=True),
|
||||||
|
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_time_aware_observation_step(env, flatten, normalize_time):
|
||||||
|
"""Test TimeAwareObservationV0 step.
|
||||||
|
|
||||||
|
This test checks if wrapped env with TimeAwareObservationV0
|
||||||
|
steps correctly.
|
||||||
|
"""
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
max_timesteps = env._max_episode_steps
|
||||||
|
|
||||||
|
wrapped_env = TimeAwareObservationV0(
|
||||||
|
env, flatten=flatten, normalize_time=normalize_time
|
||||||
|
)
|
||||||
|
wrapped_env.reset(seed=SEED)
|
||||||
|
|
||||||
|
for timestep in range(1, NUM_STEPS):
|
||||||
|
action = env.action_space.sample()
|
||||||
|
observation, _, terminated, _, _ = wrapped_env.step(action)
|
||||||
|
|
||||||
|
expected_time_obs = (
|
||||||
|
timestep / max_timesteps if normalize_time else max_timesteps - timestep
|
||||||
|
)
|
||||||
|
|
||||||
|
if flatten:
|
||||||
|
assert np.allclose(observation[-1], expected_time_obs)
|
||||||
|
else:
|
||||||
|
assert np.allclose(observation["time"], expected_time_obs)
|
||||||
|
|
||||||
|
if terminated:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env",
|
||||||
|
[
|
||||||
|
gym.make("CartPole-v1", disable_env_checker=True),
|
||||||
|
gym.make("CarRacing-v2", disable_env_checker=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_time_aware_observation_creation_flatten(env):
|
||||||
|
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
|
||||||
|
|
||||||
|
This test checks if wrapped env with TimeAwareObservationV0
|
||||||
|
is correctly created when the `flatten` parameter is set to `True`.
|
||||||
|
When flattened, the observation space should be a 1 dimension `Box`
|
||||||
|
with time appended to the end.
|
||||||
|
"""
|
||||||
|
wrapped_env = TimeAwareObservationV0(env, flatten=True)
|
||||||
|
obs, _ = wrapped_env.reset()
|
||||||
|
|
||||||
|
assert isinstance(wrapped_env.observation_space, Box)
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]
|
Reference in New Issue
Block a user