StickyAction wrapper can repeat the old action for more than 1 step (#1240)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
Simone Parisi
2024-11-14 03:23:06 -07:00
committed by GitHub
parent 90d04f27a7
commit ebe70a1aeb
2 changed files with 103 additions and 23 deletions

View File

@@ -4,9 +4,11 @@ from __future__ import annotations
from typing import Any
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidProbability
from gymnasium.error import InvalidBound, InvalidProbability
__all__ = ["StickyAction"]
@@ -18,7 +20,8 @@ class StickyAction(
"""Adds a probability that the action is repeated for the same ``step`` function.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
in Section 5.2 on page 12.
in Section 5.2 on page 12, and adds the possibility to repeat the action for
more than one step.
No vector version of the wrapper exists.
@@ -39,20 +42,47 @@ class StickyAction(
Change logs:
* v1.0.0 - Initially added
* v1.1.0 - Add `repeat_action_duration` argument for dynamic number of sticky actions
"""
def __init__(
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float
self,
env: gym.Env[ObsType, ActType],
repeat_action_probability: float,
repeat_action_duration: int | tuple[int, int] = 1,
):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a probability of repeating the old action.
env (Env): the wrapped environment,
repeat_action_probability (int | float): a probability of repeating the old action,
repeat_action_duration (int | tuple[int, int]): the number of steps
the action is repeated. It can be either an int (for deterministic
repeats) or a tuple[int, int] for a range of stochastic number of repeats.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
f"`repeat_action_probability` should be in the interval [0,1). Received {repeat_action_probability}"
)
if isinstance(repeat_action_duration, int):
repeat_action_duration = (repeat_action_duration, repeat_action_duration)
if not isinstance(repeat_action_duration, tuple):
raise ValueError(
f"`repeat_action_duration` should be either an integer or a tuple. Received {repeat_action_duration}"
)
elif len(repeat_action_duration) != 2:
raise ValueError(
f"`repeat_action_duration` should be a tuple or a list of two integers. Received {repeat_action_duration}"
)
elif repeat_action_duration[0] > repeat_action_duration[1]:
raise InvalidBound(
f"`repeat_action_duration` is not a valid bound. Received {repeat_action_duration}"
)
elif np.any(np.array(repeat_action_duration) < 1):
raise ValueError(
f"`repeat_action_duration` should be larger or equal than 1. Received {repeat_action_duration}"
)
gym.utils.RecordConstructorArgs.__init__(
@@ -61,23 +91,46 @@ class StickyAction(
gym.ActionWrapper.__init__(self, env)
self.repeat_action_probability = repeat_action_probability
self.repeat_action_duration_range = repeat_action_duration
self.last_action: ActType | None = None
self.is_sticky_actions: bool = False # if sticky actions are taken
self.num_repeats: int = 0 # number of sticky action repeats
self.repeats_taken: int = 0 # number of sticky actions taken
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment."""
self.last_action = None
self.is_sticky_actions = False
self.num_repeats = 0
self.repeats_taken = 0
return super().reset(seed=seed, options=options)
def action(self, action: ActType) -> ActType:
"""Execute the action."""
if (
# either the agent was already "stuck" into repeats, or a new series of repeats is triggered
if self.is_sticky_actions or (
self.last_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
# if a new series starts, randomly sample its duration
if self.num_repeats == 0:
self.num_repeats = self.np_random.integers(
self.repeat_action_duration_range[0],
self.repeat_action_duration_range[1] + 1,
)
action = self.last_action
self.is_sticky_actions = True
self.repeats_taken += 1
# repeats are done, reset "stuck" status
if self.is_sticky_actions and self.num_repeats == self.repeats_taken:
self.is_sticky_actions = False
self.num_repeats = 0
self.repeats_taken = 0
self.last_action = action
return action

View File

@@ -1,36 +1,63 @@
"""Test suite for StickyAction wrapper."""
import numpy as np
import pytest
from gymnasium.error import InvalidProbability
from gymnasium.error import InvalidBound, InvalidProbability
from gymnasium.spaces import Discrete
from gymnasium.wrappers import StickyAction
from tests.testing_env import GenericTestEnv
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step
from tests.wrappers.utils import record_action_as_obs_step
def test_sticky_action():
@pytest.mark.parametrize(
"repeat_action_probability,repeat_action_duration,actions,expected_action",
[
(0.25, 1, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 2, 3, 3, 3, 6, 6]),
(0.25, 2, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 3, 4, 4, 4, 4]),
(0.25, (1, 3), [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 0, 4, 4, 4, 4]),
],
)
def test_sticky_action(
repeat_action_probability, repeat_action_duration, actions, expected_action
):
"""Tests the sticky action wrapper."""
env = StickyAction(
GenericTestEnv(step_func=record_action_as_obs_step),
repeat_action_probability=0.5,
GenericTestEnv(
step_func=record_action_as_obs_step, observation_space=Discrete(7)
),
repeat_action_probability=repeat_action_probability,
repeat_action_duration=repeat_action_duration,
)
env.reset(seed=11)
previous_action = None
for _ in range(NUM_STEPS):
input_action = env.action_space.sample()
executed_action, _, _, _, _ = env.step(input_action)
assert np.all(executed_action == input_action) or np.all(
executed_action == previous_action
)
previous_action = executed_action
assert len(actions) == len(expected_action)
for action, action_taken in zip(actions, expected_action):
executed_action, _, _, _, _ = env.step(action)
assert executed_action == action_taken
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability):
def test_sticky_action_raise_probability(repeat_action_probability):
"""Tests the stick action wrapper with probabilities that should raise an error."""
with pytest.raises(InvalidProbability):
StickyAction(
GenericTestEnv(), repeat_action_probability=repeat_action_probability
)
@pytest.mark.parametrize(
"repeat_action_duration",
[
-4,
0,
(0, 0),
(4, 2),
[1, 2],
],
)
def test_sticky_action_raise_duration(repeat_action_duration):
"""Tests the stick action wrapper with durations that should raise an error."""
with pytest.raises((ValueError, InvalidBound)):
StickyAction(
GenericTestEnv(), 0.5, repeat_action_duration=repeat_action_duration
)