Stateful wrappers [DelayObservation ,TimeAwareObservation, StickyAction] (#165)

This commit is contained in:
Gianluca De Cola
2022-12-02 01:04:34 +01:00
committed by GitHub
parent 678d361e62
commit d0a929c77d
10 changed files with 389 additions and 1 deletions

View File

@@ -19,6 +19,19 @@
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
```
## Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
```
## Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
```
## Common Wrappers
```{eval-rst}

View File

@@ -73,6 +73,10 @@ class MissingArgument(Error):
"""Raised when a required argument in the initializer is missing."""
class InvalidProbability(Error):
"""Raised when given an invalid value for a probability."""
class InvalidBound(Error):
"""Raised when the clipping an array with invalid upper and/or lower bound."""

View File

@@ -1,4 +1,4 @@
"""Root __init__ of the gym dev_wrappers."""
"""Root __init__ of the gym experimental wrappers."""
from gymnasium.experimental.functional import FuncEnv

View File

@@ -12,15 +12,23 @@ from gymnasium.experimental.wrappers.lambda_action import (
)
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
from gymnasium.experimental.wrappers.time_aware_observation import (
TimeAwareObservationV0,
)
from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0
__all__ = [
"ArgType",
# Lambda Action
"LambdaActionV0",
"StickyActionV0",
"ClipActionV0",
"RescaleActionV0",
# Lambda Observation
"LambdaObservationV0",
"DelayObservationV0",
"TimeAwareObservationV0",
# Lambda Reward
"LambdaRewardV0",
"ClipRewardV0",

View File

@@ -0,0 +1,35 @@
"""Wrapper for delaying the returned observation."""
from collections import deque
import jumpy as jp
import gymnasium as gym
from gymnasium.core import ObsType
class DelayObservationV0(gym.ObservationWrapper):
"""Wrapper which adds a delay to the returned observation."""
def __init__(self, env: gym.Env, delay: int):
"""Initialize the DelayObservation wrapper.
Args:
env (Env): the wrapped environment
delay (int): number of timesteps for delaying the observation.
Before reaching the `delay` number of timesteps,
returned observation is an array of zeros with the
same shape of the observation space.
"""
super().__init__(env)
self.delay = delay
self.observation_queue = deque()
def observation(self, observation: ObsType) -> ObsType:
"""Return the delayed observation."""
self.observation_queue.append(observation)
if len(self.observation_queue) > self.delay:
return self.observation_queue.popleft()
return jp.zeros_like(observation)

View File

@@ -0,0 +1,40 @@
"""Wrapper which adds a probability of repeating the previous executed action."""
from typing import Union
import gymnasium as gym
from gymnasium.core import ActType
from gymnasium.error import InvalidProbability
class StickyActionV0(gym.ActionWrapper):
"""Wrapper which adds a probability of repeating the previous action."""
def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a proability of repeating the old action.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)
super().__init__(env)
self.repeat_action_probability = repeat_action_probability
self.old_action = None
def action(self, action: ActType):
"""Execute the action."""
if (
self.old_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
action = self.old_action
self.old_action = action
return action
def reset(self, **kwargs):
"""Reset the environment."""
self.old_action = None
return super().reset(**kwargs)

View File

@@ -0,0 +1,113 @@
"""Wrapper for adding time aware observations to environment observation."""
from collections import OrderedDict
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium.core import ActType, ObsType
from gymnasium.spaces import Box, Dict
class TimeAwareObservationV0(gym.ObservationWrapper):
"""Augment the observation with time information of the episode.
Time can be represented as a normalized value between [0,1]
or by the number of timesteps remaining before truncation occurs.
Example:
>>> import gym
>>> from gym.wrappers import TimeAwareObservationV0
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env)
>>> env.observation_space
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32))
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
OrderedDict([('obs',
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
... ('time', array([0.002]))])
Flatten observation space example:
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32)
>>> _ = env.reset()
>>> env.step(env.action_space.sample())[0]
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32)
"""
def __init__(self, env: gym.Env, flatten=False, normalize_time=True):
"""Initialize :class:`TimeAwareObservationV0`.
Args:
env: The environment to apply the wrapper
flatten: Flatten the observation to a `Box` of a single dimension
normalize_time: if `True` return time in the range [0,1]
otherwise return time as remaining timesteps before truncation
"""
super().__init__(env)
self.flatten = flatten
self.normalize_time = normalize_time
self.max_timesteps = getattr(env, "_max_episode_steps")
if self.normalize_time:
self._get_time_observation = lambda: self.timesteps / self.max_timesteps
time_space = Box(0, 1)
else:
self._get_time_observation = lambda: self.max_timesteps - self.timesteps
time_space = Box(0, self.max_timesteps)
self.time_aware_observation_space = Dict(
obs=env.observation_space, time=time_space
)
if self.flatten:
self.observation_space = spaces.flatten_space(
self.time_aware_observation_space
)
self._observation_postprocess = lambda observation: spaces.flatten(
self.time_aware_observation_space, observation
)
else:
self.observation_space = self.time_aware_observation_space
self._observation_postprocess = lambda observation: observation
def observation(self, observation: ObsType):
"""Adds to the observation with the current time information.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time information appended to
"""
time_observation = self._get_time_observation()
observation = OrderedDict(obs=observation, time=time_observation)
return self._observation_postprocess(observation)
def step(self, action: ActType):
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.timesteps += 1
observation, reward, terminated, truncated, info = super().step(action)
return observation, reward, terminated, truncated, info
def reset(self, **kwargs):
"""Reset the environment setting the time to zero.
Args:
**kwargs: Kwargs to apply to env.reset()
Returns:
The reset environment
"""
self.timesteps = 0
return super().reset(**kwargs)

View File

@@ -0,0 +1,36 @@
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0
SEED = 42
DELAY = 3
NUM_STEPS = 5
def test_delay_observation():
env = gym.make("CartPole-v1")
env.action_space.seed(SEED)
env.reset(seed=SEED)
undelayed_observations = []
for _ in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_observations.append(obs)
env.action_space.seed(SEED)
env.reset(seed=SEED)
env = DelayObservationV0(env, delay=DELAY)
delayed_observations = []
for i in range(NUM_STEPS):
obs, _, _, _, _ = env.step(env.action_space.sample())
if i < DELAY - 1:
assert np.all(obs == 0)
delayed_observations.append(obs)
assert np.alltrue(
np.array(delayed_observations[DELAY:])
== np.array(undelayed_observations[: DELAY - 1])
)

View File

@@ -0,0 +1,41 @@
"""Test suite for StickyActionV0."""
import pytest
from gymnasium.error import InvalidProbability
from gymnasium.experimental.wrappers import StickyActionV0
from tests.testing_env import GenericTestEnv
SEED = 42
DELAY = 3
NUM_STEPS = 10
def step_fn(self, action):
return action
def test_sticky_action():
env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5)
env.reset(seed=SEED)
env.action_space.seed(SEED)
previous_action = None
for _ in range(NUM_STEPS):
input_action = env.action_space.sample()
executed_action = env.step(input_action)
if executed_action != input_action:
assert executed_action == previous_action
else:
assert executed_action == input_action
previous_action = input_action
@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability):
with pytest.raises(InvalidProbability):
StickyActionV0(
GenericTestEnv(), repeat_action_probability=repeat_action_probability
)

View File

@@ -0,0 +1,98 @@
"""Test suite for TimeAwareobservationV0."""
from collections import OrderedDict
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import TimeAwareObservationV0
from gymnasium.spaces import Box, Dict
NUM_STEPS = 20
SEED = 0
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation(env):
"""Test TimeAwareObservationV0 wrapper creation.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created.
"""
wrapped_env = TimeAwareObservationV0(env)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Dict)
assert isinstance(obs, OrderedDict)
assert np.all(obs["time"] == 0)
assert env.observation_space == wrapped_env.observation_space["obs"]
@pytest.mark.parametrize("normalize_time", [True, False])
@pytest.mark.parametrize("flatten", [False, True])
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
],
)
def test_time_aware_observation_step(env, flatten, normalize_time):
"""Test TimeAwareObservationV0 step.
This test checks if wrapped env with TimeAwareObservationV0
steps correctly.
"""
env.action_space.seed(SEED)
max_timesteps = env._max_episode_steps
wrapped_env = TimeAwareObservationV0(
env, flatten=flatten, normalize_time=normalize_time
)
wrapped_env.reset(seed=SEED)
for timestep in range(1, NUM_STEPS):
action = env.action_space.sample()
observation, _, terminated, _, _ = wrapped_env.step(action)
expected_time_obs = (
timestep / max_timesteps if normalize_time else max_timesteps - timestep
)
if flatten:
assert np.allclose(observation[-1], expected_time_obs)
else:
assert np.allclose(observation["time"], expected_time_obs)
if terminated:
break
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation_flatten(env):
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created when the `flatten` parameter is set to `True`.
When flattened, the observation space should be a 1 dimension `Box`
with time appended to the end.
"""
wrapped_env = TimeAwareObservationV0(env, flatten=True)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Box)
assert isinstance(obs, np.ndarray)
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]