From d0a929c77d7c0c6fb4f3e7ed9c02a932c6905242 Mon Sep 17 00:00:00 2001 From: Gianluca De Cola <42657588+gianlucadecola@users.noreply.github.com> Date: Fri, 2 Dec 2022 01:04:34 +0100 Subject: [PATCH] Stateful wrappers [DelayObservation ,TimeAwareObservation, StickyAction] (#165) --- docs/api/experimental/wrappers.md | 13 ++ gymnasium/error.py | 4 + gymnasium/experimental/__init__.py | 2 +- gymnasium/experimental/wrappers/__init__.py | 8 ++ .../wrappers/delay_observation.py | 35 ++++++ .../experimental/wrappers/sticky_action.py | 40 +++++++ .../wrappers/time_aware_observation.py | 113 ++++++++++++++++++ .../wrappers/test_delay_observation.py | 36 ++++++ .../wrappers/test_sticky_action.py | 41 +++++++ .../wrappers/test_time_aware_observation.py | 98 +++++++++++++++ 10 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 gymnasium/experimental/wrappers/delay_observation.py create mode 100644 gymnasium/experimental/wrappers/sticky_action.py create mode 100644 gymnasium/experimental/wrappers/time_aware_observation.py create mode 100644 tests/experimental/wrappers/test_delay_observation.py create mode 100644 tests/experimental/wrappers/test_sticky_action.py create mode 100644 tests/experimental/wrappers/test_time_aware_observation.py diff --git a/docs/api/experimental/wrappers.md b/docs/api/experimental/wrappers.md index 6ae00794e..b41bfe4ad 100644 --- a/docs/api/experimental/wrappers.md +++ b/docs/api/experimental/wrappers.md @@ -19,6 +19,19 @@ .. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 ``` +## Observation Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0 +.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0 +``` + +## Action Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0 +``` + ## Common Wrappers ```{eval-rst} diff --git a/gymnasium/error.py b/gymnasium/error.py index 23a0e105a..424ebc2c3 100644 --- a/gymnasium/error.py +++ b/gymnasium/error.py @@ -73,6 +73,10 @@ class MissingArgument(Error): """Raised when a required argument in the initializer is missing.""" +class InvalidProbability(Error): + """Raised when given an invalid value for a probability.""" + + class InvalidBound(Error): """Raised when the clipping an array with invalid upper and/or lower bound.""" diff --git a/gymnasium/experimental/__init__.py b/gymnasium/experimental/__init__.py index 70ba0a3cd..8864b92b5 100644 --- a/gymnasium/experimental/__init__.py +++ b/gymnasium/experimental/__init__.py @@ -1,4 +1,4 @@ -"""Root __init__ of the gym dev_wrappers.""" +"""Root __init__ of the gym experimental wrappers.""" from gymnasium.experimental.functional import FuncEnv diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 27e8a8156..06ff71f66 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -12,15 +12,23 @@ from gymnasium.experimental.wrappers.lambda_action import ( ) from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0 from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 +from gymnasium.experimental.wrappers.sticky_action import StickyActionV0 +from gymnasium.experimental.wrappers.time_aware_observation import ( + TimeAwareObservationV0, +) +from gymnasium.experimental.wrappers.delay_observation import DelayObservationV0 __all__ = [ "ArgType", # Lambda Action "LambdaActionV0", + "StickyActionV0", "ClipActionV0", "RescaleActionV0", # Lambda Observation "LambdaObservationV0", + "DelayObservationV0", + "TimeAwareObservationV0", # Lambda Reward "LambdaRewardV0", "ClipRewardV0", diff --git a/gymnasium/experimental/wrappers/delay_observation.py b/gymnasium/experimental/wrappers/delay_observation.py new file mode 100644 index 000000000..2903d1154 --- /dev/null +++ b/gymnasium/experimental/wrappers/delay_observation.py @@ -0,0 +1,35 @@ +"""Wrapper for delaying the returned observation.""" + +from collections import deque + +import jumpy as jp + +import gymnasium as gym +from gymnasium.core import ObsType + + +class DelayObservationV0(gym.ObservationWrapper): + """Wrapper which adds a delay to the returned observation.""" + + def __init__(self, env: gym.Env, delay: int): + """Initialize the DelayObservation wrapper. + + Args: + env (Env): the wrapped environment + delay (int): number of timesteps for delaying the observation. + Before reaching the `delay` number of timesteps, + returned observation is an array of zeros with the + same shape of the observation space. + """ + super().__init__(env) + self.delay = delay + self.observation_queue = deque() + + def observation(self, observation: ObsType) -> ObsType: + """Return the delayed observation.""" + self.observation_queue.append(observation) + + if len(self.observation_queue) > self.delay: + return self.observation_queue.popleft() + + return jp.zeros_like(observation) diff --git a/gymnasium/experimental/wrappers/sticky_action.py b/gymnasium/experimental/wrappers/sticky_action.py new file mode 100644 index 000000000..586e2ff1b --- /dev/null +++ b/gymnasium/experimental/wrappers/sticky_action.py @@ -0,0 +1,40 @@ +"""Wrapper which adds a probability of repeating the previous executed action.""" +from typing import Union + +import gymnasium as gym +from gymnasium.core import ActType +from gymnasium.error import InvalidProbability + + +class StickyActionV0(gym.ActionWrapper): + """Wrapper which adds a probability of repeating the previous action.""" + + def __init__(self, env: gym.Env, repeat_action_probability: Union[int, float]): + """Initialize StickyAction wrapper. + + Args: + env (Env): the wrapped environment + repeat_action_probability (int | float): a proability of repeating the old action. + """ + if not 0 <= repeat_action_probability < 1: + raise InvalidProbability( + f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}" + ) + super().__init__(env) + self.repeat_action_probability = repeat_action_probability + self.old_action = None + + def action(self, action: ActType): + """Execute the action.""" + if ( + self.old_action is not None + and self.np_random.uniform() < self.repeat_action_probability + ): + action = self.old_action + self.old_action = action + return action + + def reset(self, **kwargs): + """Reset the environment.""" + self.old_action = None + return super().reset(**kwargs) diff --git a/gymnasium/experimental/wrappers/time_aware_observation.py b/gymnasium/experimental/wrappers/time_aware_observation.py new file mode 100644 index 000000000..021d55468 --- /dev/null +++ b/gymnasium/experimental/wrappers/time_aware_observation.py @@ -0,0 +1,113 @@ +"""Wrapper for adding time aware observations to environment observation.""" +from collections import OrderedDict + +import gymnasium as gym +import gymnasium.spaces as spaces +from gymnasium.core import ActType, ObsType +from gymnasium.spaces import Box, Dict + + +class TimeAwareObservationV0(gym.ObservationWrapper): + """Augment the observation with time information of the episode. + + Time can be represented as a normalized value between [0,1] + or by the number of timesteps remaining before truncation occurs. + + Example: + >>> import gym + >>> from gym.wrappers import TimeAwareObservationV0 + >>> env = gym.make('CartPole-v1') + >>> env = TimeAwareObservationV0(env) + >>> env.observation_space + Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32)) + >>> _ = env.reset() + >>> env.step(env.action_space.sample())[0] + OrderedDict([('obs', + ... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)), + ... ('time', array([0.002]))]) + + Flatten observation space example: + >>> env = gym.make('CartPole-v1') + >>> env = TimeAwareObservationV0(env, flatten=True) + >>> env.observation_space + Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32) + >>> _ = env.reset() + >>> env.step(env.action_space.sample())[0] + array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32) + """ + + def __init__(self, env: gym.Env, flatten=False, normalize_time=True): + """Initialize :class:`TimeAwareObservationV0`. + + Args: + env: The environment to apply the wrapper + flatten: Flatten the observation to a `Box` of a single dimension + normalize_time: if `True` return time in the range [0,1] + otherwise return time as remaining timesteps before truncation + """ + super().__init__(env) + self.flatten = flatten + self.normalize_time = normalize_time + self.max_timesteps = getattr(env, "_max_episode_steps") + + if self.normalize_time: + self._get_time_observation = lambda: self.timesteps / self.max_timesteps + time_space = Box(0, 1) + else: + self._get_time_observation = lambda: self.max_timesteps - self.timesteps + time_space = Box(0, self.max_timesteps) + + self.time_aware_observation_space = Dict( + obs=env.observation_space, time=time_space + ) + + if self.flatten: + self.observation_space = spaces.flatten_space( + self.time_aware_observation_space + ) + self._observation_postprocess = lambda observation: spaces.flatten( + self.time_aware_observation_space, observation + ) + else: + self.observation_space = self.time_aware_observation_space + self._observation_postprocess = lambda observation: observation + + def observation(self, observation: ObsType): + """Adds to the observation with the current time information. + + Args: + observation: The observation to add the time step to + + Returns: + The observation with the time information appended to + """ + time_observation = self._get_time_observation() + observation = OrderedDict(obs=observation, time=time_observation) + + return self._observation_postprocess(observation) + + def step(self, action: ActType): + """Steps through the environment, incrementing the time step. + + Args: + action: The action to take + + Returns: + The environment's step using the action. + """ + self.timesteps += 1 + observation, reward, terminated, truncated, info = super().step(action) + + return observation, reward, terminated, truncated, info + + def reset(self, **kwargs): + """Reset the environment setting the time to zero. + + Args: + **kwargs: Kwargs to apply to env.reset() + + Returns: + The reset environment + """ + self.timesteps = 0 + return super().reset(**kwargs) diff --git a/tests/experimental/wrappers/test_delay_observation.py b/tests/experimental/wrappers/test_delay_observation.py new file mode 100644 index 000000000..bd1a40ef0 --- /dev/null +++ b/tests/experimental/wrappers/test_delay_observation.py @@ -0,0 +1,36 @@ +import numpy as np + +import gymnasium as gym +from gymnasium.experimental.wrappers import DelayObservationV0 + +SEED = 42 + +DELAY = 3 +NUM_STEPS = 5 + + +def test_delay_observation(): + env = gym.make("CartPole-v1") + env.action_space.seed(SEED) + env.reset(seed=SEED) + + undelayed_observations = [] + for _ in range(NUM_STEPS): + obs, _, _, _, _ = env.step(env.action_space.sample()) + undelayed_observations.append(obs) + + env.action_space.seed(SEED) + env.reset(seed=SEED) + env = DelayObservationV0(env, delay=DELAY) + + delayed_observations = [] + for i in range(NUM_STEPS): + obs, _, _, _, _ = env.step(env.action_space.sample()) + if i < DELAY - 1: + assert np.all(obs == 0) + delayed_observations.append(obs) + + assert np.alltrue( + np.array(delayed_observations[DELAY:]) + == np.array(undelayed_observations[: DELAY - 1]) + ) diff --git a/tests/experimental/wrappers/test_sticky_action.py b/tests/experimental/wrappers/test_sticky_action.py new file mode 100644 index 000000000..c105a948f --- /dev/null +++ b/tests/experimental/wrappers/test_sticky_action.py @@ -0,0 +1,41 @@ +"""Test suite for StickyActionV0.""" +import pytest + +from gymnasium.error import InvalidProbability +from gymnasium.experimental.wrappers import StickyActionV0 +from tests.testing_env import GenericTestEnv + +SEED = 42 + +DELAY = 3 +NUM_STEPS = 10 + + +def step_fn(self, action): + return action + + +def test_sticky_action(): + env = StickyActionV0(GenericTestEnv(step_fn=step_fn), repeat_action_probability=0.5) + env.reset(seed=SEED) + env.action_space.seed(SEED) + + previous_action = None + for _ in range(NUM_STEPS): + input_action = env.action_space.sample() + executed_action = env.step(input_action) + + if executed_action != input_action: + assert executed_action == previous_action + else: + assert executed_action == input_action + + previous_action = input_action + + +@pytest.mark.parametrize(("repeat_action_probability"), [-1, 1, 1.5]) +def test_sticky_action_raise(repeat_action_probability): + with pytest.raises(InvalidProbability): + StickyActionV0( + GenericTestEnv(), repeat_action_probability=repeat_action_probability + ) diff --git a/tests/experimental/wrappers/test_time_aware_observation.py b/tests/experimental/wrappers/test_time_aware_observation.py new file mode 100644 index 000000000..55331c978 --- /dev/null +++ b/tests/experimental/wrappers/test_time_aware_observation.py @@ -0,0 +1,98 @@ +"""Test suite for TimeAwareobservationV0.""" + +from collections import OrderedDict + +import numpy as np +import pytest + +import gymnasium as gym +from gymnasium.experimental.wrappers import TimeAwareObservationV0 +from gymnasium.spaces import Box, Dict + +NUM_STEPS = 20 +SEED = 0 + + +@pytest.mark.parametrize( + "env", + [ + gym.make("CartPole-v1", disable_env_checker=True), + gym.make("CarRacing-v2", disable_env_checker=True), + ], +) +def test_time_aware_observation_creation(env): + """Test TimeAwareObservationV0 wrapper creation. + + This test checks if wrapped env with TimeAwareObservationV0 + is correctly created. + """ + wrapped_env = TimeAwareObservationV0(env) + obs, _ = wrapped_env.reset() + + assert isinstance(wrapped_env.observation_space, Dict) + assert isinstance(obs, OrderedDict) + assert np.all(obs["time"] == 0) + assert env.observation_space == wrapped_env.observation_space["obs"] + + +@pytest.mark.parametrize("normalize_time", [True, False]) +@pytest.mark.parametrize("flatten", [False, True]) +@pytest.mark.parametrize( + "env", + [ + gym.make("CartPole-v1", disable_env_checker=True), + gym.make("CarRacing-v2", disable_env_checker=True, continuous=False), + ], +) +def test_time_aware_observation_step(env, flatten, normalize_time): + """Test TimeAwareObservationV0 step. + + This test checks if wrapped env with TimeAwareObservationV0 + steps correctly. + """ + env.action_space.seed(SEED) + max_timesteps = env._max_episode_steps + + wrapped_env = TimeAwareObservationV0( + env, flatten=flatten, normalize_time=normalize_time + ) + wrapped_env.reset(seed=SEED) + + for timestep in range(1, NUM_STEPS): + action = env.action_space.sample() + observation, _, terminated, _, _ = wrapped_env.step(action) + + expected_time_obs = ( + timestep / max_timesteps if normalize_time else max_timesteps - timestep + ) + + if flatten: + assert np.allclose(observation[-1], expected_time_obs) + else: + assert np.allclose(observation["time"], expected_time_obs) + + if terminated: + break + + +@pytest.mark.parametrize( + "env", + [ + gym.make("CartPole-v1", disable_env_checker=True), + gym.make("CarRacing-v2", disable_env_checker=True), + ], +) +def test_time_aware_observation_creation_flatten(env): + """Test TimeAwareObservationV0 wrapper creation with `flatten=True`. + + This test checks if wrapped env with TimeAwareObservationV0 + is correctly created when the `flatten` parameter is set to `True`. + When flattened, the observation space should be a 1 dimension `Box` + with time appended to the end. + """ + wrapped_env = TimeAwareObservationV0(env, flatten=True) + obs, _ = wrapped_env.reset() + + assert isinstance(wrapped_env.observation_space, Box) + assert isinstance(obs, np.ndarray) + assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]