StickyAction wrapper can repeat the old action for more than 1 step (#1240)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
Simone Parisi
2024-11-14 03:23:06 -07:00
committed by GitHub
parent 90d04f27a7
commit ebe70a1aeb
2 changed files with 103 additions and 23 deletions

View File

@@ -4,9 +4,11 @@ from __future__ import annotations
from typing import Any from typing import Any
import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.core import ActType, ObsType from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidProbability from gymnasium.error import InvalidBound, InvalidProbability
__all__ = ["StickyAction"] __all__ = ["StickyAction"]
@@ -18,7 +20,8 @@ class StickyAction(
"""Adds a probability that the action is repeated for the same ``step`` function. """Adds a probability that the action is repeated for the same ``step`` function.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_ This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
in Section 5.2 on page 12. in Section 5.2 on page 12, and adds the possibility to repeat the action for
more than one step.
No vector version of the wrapper exists. No vector version of the wrapper exists.
@@ -39,20 +42,47 @@ class StickyAction(
Change logs: Change logs:
* v1.0.0 - Initially added * v1.0.0 - Initially added
* v1.1.0 - Add `repeat_action_duration` argument for dynamic number of sticky actions
""" """
def __init__( def __init__(
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float self,
env: gym.Env[ObsType, ActType],
repeat_action_probability: float,
repeat_action_duration: int | tuple[int, int] = 1,
): ):
"""Initialize StickyAction wrapper. """Initialize StickyAction wrapper.
Args: Args:
env (Env): the wrapped environment env (Env): the wrapped environment,
repeat_action_probability (int | float): a probability of repeating the old action. repeat_action_probability (int | float): a probability of repeating the old action,
repeat_action_duration (int | tuple[int, int]): the number of steps
the action is repeated. It can be either an int (for deterministic
repeats) or a tuple[int, int] for a range of stochastic number of repeats.
""" """
if not 0 <= repeat_action_probability < 1: if not 0 <= repeat_action_probability < 1:
raise InvalidProbability( raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}" f"`repeat_action_probability` should be in the interval [0,1). Received {repeat_action_probability}"
)
if isinstance(repeat_action_duration, int):
repeat_action_duration = (repeat_action_duration, repeat_action_duration)
if not isinstance(repeat_action_duration, tuple):
raise ValueError(
f"`repeat_action_duration` should be either an integer or a tuple. Received {repeat_action_duration}"
)
elif len(repeat_action_duration) != 2:
raise ValueError(
f"`repeat_action_duration` should be a tuple or a list of two integers. Received {repeat_action_duration}"
)
elif repeat_action_duration[0] > repeat_action_duration[1]:
raise InvalidBound(
f"`repeat_action_duration` is not a valid bound. Received {repeat_action_duration}"
)
elif np.any(np.array(repeat_action_duration) < 1):
raise ValueError(
f"`repeat_action_duration` should be larger or equal than 1. Received {repeat_action_duration}"
) )
gym.utils.RecordConstructorArgs.__init__( gym.utils.RecordConstructorArgs.__init__(
@@ -61,23 +91,46 @@ class StickyAction(
gym.ActionWrapper.__init__(self, env) gym.ActionWrapper.__init__(self, env)
self.repeat_action_probability = repeat_action_probability self.repeat_action_probability = repeat_action_probability
self.repeat_action_duration_range = repeat_action_duration
self.last_action: ActType | None = None self.last_action: ActType | None = None
self.is_sticky_actions: bool = False # if sticky actions are taken
self.num_repeats: int = 0 # number of sticky action repeats
self.repeats_taken: int = 0 # number of sticky actions taken
def reset( def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]: ) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment.""" """Reset the environment."""
self.last_action = None self.last_action = None
self.is_sticky_actions = False
self.num_repeats = 0
self.repeats_taken = 0
return super().reset(seed=seed, options=options) return super().reset(seed=seed, options=options)
def action(self, action: ActType) -> ActType: def action(self, action: ActType) -> ActType:
"""Execute the action.""" """Execute the action."""
if ( # either the agent was already "stuck" into repeats, or a new series of repeats is triggered
if self.is_sticky_actions or (
self.last_action is not None self.last_action is not None
and self.np_random.uniform() < self.repeat_action_probability and self.np_random.uniform() < self.repeat_action_probability
): ):
# if a new series starts, randomly sample its duration
if self.num_repeats == 0:
self.num_repeats = self.np_random.integers(
self.repeat_action_duration_range[0],
self.repeat_action_duration_range[1] + 1,
)
action = self.last_action action = self.last_action
self.is_sticky_actions = True
self.repeats_taken += 1
# repeats are done, reset "stuck" status
if self.is_sticky_actions and self.num_repeats == self.repeats_taken:
self.is_sticky_actions = False
self.num_repeats = 0
self.repeats_taken = 0
self.last_action = action self.last_action = action
return action return action

View File

@@ -1,36 +1,63 @@
"""Test suite for StickyAction wrapper.""" """Test suite for StickyAction wrapper."""
import numpy as np
import pytest import pytest
from gymnasium.error import InvalidProbability from gymnasium.error import InvalidBound, InvalidProbability
from gymnasium.spaces import Discrete
from gymnasium.wrappers import StickyAction from gymnasium.wrappers import StickyAction
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step from tests.wrappers.utils import record_action_as_obs_step
def test_sticky_action(): @pytest.mark.parametrize(
"repeat_action_probability,repeat_action_duration,actions,expected_action",
[
(0.25, 1, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 2, 3, 3, 3, 6, 6]),
(0.25, 2, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 3, 4, 4, 4, 4]),
(0.25, (1, 3), [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 0, 4, 4, 4, 4]),
],
)
def test_sticky_action(
repeat_action_probability, repeat_action_duration, actions, expected_action
):
"""Tests the sticky action wrapper.""" """Tests the sticky action wrapper."""
env = StickyAction( env = StickyAction(
GenericTestEnv(step_func=record_action_as_obs_step), GenericTestEnv(
repeat_action_probability=0.5, step_func=record_action_as_obs_step, observation_space=Discrete(7)
),
repeat_action_probability=repeat_action_probability,
repeat_action_duration=repeat_action_duration,
) )
env.reset(seed=11)
previous_action = None assert len(actions) == len(expected_action)
for _ in range(NUM_STEPS): for action, action_taken in zip(actions, expected_action):
input_action = env.action_space.sample() executed_action, _, _, _, _ = env.step(action)
executed_action, _, _, _, _ = env.step(input_action) assert executed_action == action_taken
assert np.all(executed_action == input_action) or np.all(
executed_action == previous_action
)
previous_action = executed_action
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5]) @pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability): def test_sticky_action_raise_probability(repeat_action_probability):
"""Tests the stick action wrapper with probabilities that should raise an error.""" """Tests the stick action wrapper with probabilities that should raise an error."""
with pytest.raises(InvalidProbability): with pytest.raises(InvalidProbability):
StickyAction( StickyAction(
GenericTestEnv(), repeat_action_probability=repeat_action_probability GenericTestEnv(), repeat_action_probability=repeat_action_probability
) )
@pytest.mark.parametrize(
"repeat_action_duration",
[
-4,
0,
(0, 0),
(4, 2),
[1, 2],
],
)
def test_sticky_action_raise_duration(repeat_action_duration):
"""Tests the stick action wrapper with durations that should raise an error."""
with pytest.raises((ValueError, InvalidBound)):
StickyAction(
GenericTestEnv(), 0.5, repeat_action_duration=repeat_action_duration
)