mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 22:11:25 +00:00
StickyAction wrapper can repeat the old action for more than 1 step (#1240)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
@@ -4,9 +4,11 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import InvalidProbability
|
||||
from gymnasium.error import InvalidBound, InvalidProbability
|
||||
|
||||
|
||||
__all__ = ["StickyAction"]
|
||||
@@ -18,7 +20,8 @@ class StickyAction(
|
||||
"""Adds a probability that the action is repeated for the same ``step`` function.
|
||||
|
||||
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||
in Section 5.2 on page 12.
|
||||
in Section 5.2 on page 12, and adds the possibility to repeat the action for
|
||||
more than one step.
|
||||
|
||||
No vector version of the wrapper exists.
|
||||
|
||||
@@ -39,20 +42,47 @@ class StickyAction(
|
||||
|
||||
Change logs:
|
||||
* v1.0.0 - Initially added
|
||||
* v1.1.0 - Add `repeat_action_duration` argument for dynamic number of sticky actions
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
repeat_action_probability: float,
|
||||
repeat_action_duration: int | tuple[int, int] = 1,
|
||||
):
|
||||
"""Initialize StickyAction wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): the wrapped environment
|
||||
repeat_action_probability (int | float): a probability of repeating the old action.
|
||||
env (Env): the wrapped environment,
|
||||
repeat_action_probability (int | float): a probability of repeating the old action,
|
||||
repeat_action_duration (int | tuple[int, int]): the number of steps
|
||||
the action is repeated. It can be either an int (for deterministic
|
||||
repeats) or a tuple[int, int] for a range of stochastic number of repeats.
|
||||
"""
|
||||
if not 0 <= repeat_action_probability < 1:
|
||||
raise InvalidProbability(
|
||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||
f"`repeat_action_probability` should be in the interval [0,1). Received {repeat_action_probability}"
|
||||
)
|
||||
|
||||
if isinstance(repeat_action_duration, int):
|
||||
repeat_action_duration = (repeat_action_duration, repeat_action_duration)
|
||||
|
||||
if not isinstance(repeat_action_duration, tuple):
|
||||
raise ValueError(
|
||||
f"`repeat_action_duration` should be either an integer or a tuple. Received {repeat_action_duration}"
|
||||
)
|
||||
elif len(repeat_action_duration) != 2:
|
||||
raise ValueError(
|
||||
f"`repeat_action_duration` should be a tuple or a list of two integers. Received {repeat_action_duration}"
|
||||
)
|
||||
elif repeat_action_duration[0] > repeat_action_duration[1]:
|
||||
raise InvalidBound(
|
||||
f"`repeat_action_duration` is not a valid bound. Received {repeat_action_duration}"
|
||||
)
|
||||
elif np.any(np.array(repeat_action_duration) < 1):
|
||||
raise ValueError(
|
||||
f"`repeat_action_duration` should be larger or equal than 1. Received {repeat_action_duration}"
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
@@ -61,23 +91,46 @@ class StickyAction(
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
|
||||
self.repeat_action_probability = repeat_action_probability
|
||||
self.repeat_action_duration_range = repeat_action_duration
|
||||
|
||||
self.last_action: ActType | None = None
|
||||
self.is_sticky_actions: bool = False # if sticky actions are taken
|
||||
self.num_repeats: int = 0 # number of sticky action repeats
|
||||
self.repeats_taken: int = 0 # number of sticky actions taken
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.last_action = None
|
||||
self.is_sticky_actions = False
|
||||
self.num_repeats = 0
|
||||
self.repeats_taken = 0
|
||||
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def action(self, action: ActType) -> ActType:
|
||||
"""Execute the action."""
|
||||
if (
|
||||
# either the agent was already "stuck" into repeats, or a new series of repeats is triggered
|
||||
if self.is_sticky_actions or (
|
||||
self.last_action is not None
|
||||
and self.np_random.uniform() < self.repeat_action_probability
|
||||
):
|
||||
# if a new series starts, randomly sample its duration
|
||||
if self.num_repeats == 0:
|
||||
self.num_repeats = self.np_random.integers(
|
||||
self.repeat_action_duration_range[0],
|
||||
self.repeat_action_duration_range[1] + 1,
|
||||
)
|
||||
action = self.last_action
|
||||
self.is_sticky_actions = True
|
||||
self.repeats_taken += 1
|
||||
|
||||
# repeats are done, reset "stuck" status
|
||||
if self.is_sticky_actions and self.num_repeats == self.repeats_taken:
|
||||
self.is_sticky_actions = False
|
||||
self.num_repeats = 0
|
||||
self.repeats_taken = 0
|
||||
|
||||
self.last_action = action
|
||||
return action
|
||||
|
@@ -1,36 +1,63 @@
|
||||
"""Test suite for StickyAction wrapper."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.error import InvalidProbability
|
||||
from gymnasium.error import InvalidBound, InvalidProbability
|
||||
from gymnasium.spaces import Discrete
|
||||
from gymnasium.wrappers import StickyAction
|
||||
from tests.testing_env import GenericTestEnv
|
||||
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step
|
||||
from tests.wrappers.utils import record_action_as_obs_step
|
||||
|
||||
|
||||
def test_sticky_action():
|
||||
@pytest.mark.parametrize(
|
||||
"repeat_action_probability,repeat_action_duration,actions,expected_action",
|
||||
[
|
||||
(0.25, 1, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 2, 3, 3, 3, 6, 6]),
|
||||
(0.25, 2, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 3, 4, 4, 4, 4]),
|
||||
(0.25, (1, 3), [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 0, 4, 4, 4, 4]),
|
||||
],
|
||||
)
|
||||
def test_sticky_action(
|
||||
repeat_action_probability, repeat_action_duration, actions, expected_action
|
||||
):
|
||||
"""Tests the sticky action wrapper."""
|
||||
env = StickyAction(
|
||||
GenericTestEnv(step_func=record_action_as_obs_step),
|
||||
repeat_action_probability=0.5,
|
||||
GenericTestEnv(
|
||||
step_func=record_action_as_obs_step, observation_space=Discrete(7)
|
||||
),
|
||||
repeat_action_probability=repeat_action_probability,
|
||||
repeat_action_duration=repeat_action_duration,
|
||||
)
|
||||
env.reset(seed=11)
|
||||
|
||||
previous_action = None
|
||||
for _ in range(NUM_STEPS):
|
||||
input_action = env.action_space.sample()
|
||||
executed_action, _, _, _, _ = env.step(input_action)
|
||||
|
||||
assert np.all(executed_action == input_action) or np.all(
|
||||
executed_action == previous_action
|
||||
)
|
||||
previous_action = executed_action
|
||||
assert len(actions) == len(expected_action)
|
||||
for action, action_taken in zip(actions, expected_action):
|
||||
executed_action, _, _, _, _ = env.step(action)
|
||||
assert executed_action == action_taken
|
||||
|
||||
|
||||
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
||||
def test_sticky_action_raise(repeat_action_probability):
|
||||
def test_sticky_action_raise_probability(repeat_action_probability):
|
||||
"""Tests the stick action wrapper with probabilities that should raise an error."""
|
||||
with pytest.raises(InvalidProbability):
|
||||
StickyAction(
|
||||
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repeat_action_duration",
|
||||
[
|
||||
-4,
|
||||
0,
|
||||
(0, 0),
|
||||
(4, 2),
|
||||
[1, 2],
|
||||
],
|
||||
)
|
||||
def test_sticky_action_raise_duration(repeat_action_duration):
|
||||
"""Tests the stick action wrapper with durations that should raise an error."""
|
||||
with pytest.raises((ValueError, InvalidBound)):
|
||||
StickyAction(
|
||||
GenericTestEnv(), 0.5, repeat_action_duration=repeat_action_duration
|
||||
)
|
||||
|
Reference in New Issue
Block a user