mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
StickyAction wrapper can repeat the old action for more than 1 step (#1240)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
@@ -4,9 +4,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from gymnasium.error import InvalidProbability
|
from gymnasium.error import InvalidBound, InvalidProbability
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["StickyAction"]
|
__all__ = ["StickyAction"]
|
||||||
@@ -18,7 +20,8 @@ class StickyAction(
|
|||||||
"""Adds a probability that the action is repeated for the same ``step`` function.
|
"""Adds a probability that the action is repeated for the same ``step`` function.
|
||||||
|
|
||||||
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||||
in Section 5.2 on page 12.
|
in Section 5.2 on page 12, and adds the possibility to repeat the action for
|
||||||
|
more than one step.
|
||||||
|
|
||||||
No vector version of the wrapper exists.
|
No vector version of the wrapper exists.
|
||||||
|
|
||||||
@@ -39,20 +42,47 @@ class StickyAction(
|
|||||||
|
|
||||||
Change logs:
|
Change logs:
|
||||||
* v1.0.0 - Initially added
|
* v1.0.0 - Initially added
|
||||||
|
* v1.1.0 - Add `repeat_action_duration` argument for dynamic number of sticky actions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float
|
self,
|
||||||
|
env: gym.Env[ObsType, ActType],
|
||||||
|
repeat_action_probability: float,
|
||||||
|
repeat_action_duration: int | tuple[int, int] = 1,
|
||||||
):
|
):
|
||||||
"""Initialize StickyAction wrapper.
|
"""Initialize StickyAction wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): the wrapped environment
|
env (Env): the wrapped environment,
|
||||||
repeat_action_probability (int | float): a probability of repeating the old action.
|
repeat_action_probability (int | float): a probability of repeating the old action,
|
||||||
|
repeat_action_duration (int | tuple[int, int]): the number of steps
|
||||||
|
the action is repeated. It can be either an int (for deterministic
|
||||||
|
repeats) or a tuple[int, int] for a range of stochastic number of repeats.
|
||||||
"""
|
"""
|
||||||
if not 0 <= repeat_action_probability < 1:
|
if not 0 <= repeat_action_probability < 1:
|
||||||
raise InvalidProbability(
|
raise InvalidProbability(
|
||||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
f"`repeat_action_probability` should be in the interval [0,1). Received {repeat_action_probability}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(repeat_action_duration, int):
|
||||||
|
repeat_action_duration = (repeat_action_duration, repeat_action_duration)
|
||||||
|
|
||||||
|
if not isinstance(repeat_action_duration, tuple):
|
||||||
|
raise ValueError(
|
||||||
|
f"`repeat_action_duration` should be either an integer or a tuple. Received {repeat_action_duration}"
|
||||||
|
)
|
||||||
|
elif len(repeat_action_duration) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`repeat_action_duration` should be a tuple or a list of two integers. Received {repeat_action_duration}"
|
||||||
|
)
|
||||||
|
elif repeat_action_duration[0] > repeat_action_duration[1]:
|
||||||
|
raise InvalidBound(
|
||||||
|
f"`repeat_action_duration` is not a valid bound. Received {repeat_action_duration}"
|
||||||
|
)
|
||||||
|
elif np.any(np.array(repeat_action_duration) < 1):
|
||||||
|
raise ValueError(
|
||||||
|
f"`repeat_action_duration` should be larger or equal than 1. Received {repeat_action_duration}"
|
||||||
)
|
)
|
||||||
|
|
||||||
gym.utils.RecordConstructorArgs.__init__(
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
@@ -61,23 +91,46 @@ class StickyAction(
|
|||||||
gym.ActionWrapper.__init__(self, env)
|
gym.ActionWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.repeat_action_probability = repeat_action_probability
|
self.repeat_action_probability = repeat_action_probability
|
||||||
|
self.repeat_action_duration_range = repeat_action_duration
|
||||||
|
|
||||||
self.last_action: ActType | None = None
|
self.last_action: ActType | None = None
|
||||||
|
self.is_sticky_actions: bool = False # if sticky actions are taken
|
||||||
|
self.num_repeats: int = 0 # number of sticky action repeats
|
||||||
|
self.repeats_taken: int = 0 # number of sticky actions taken
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
) -> tuple[ObsType, dict[str, Any]]:
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
"""Reset the environment."""
|
"""Reset the environment."""
|
||||||
self.last_action = None
|
self.last_action = None
|
||||||
|
self.is_sticky_actions = False
|
||||||
|
self.num_repeats = 0
|
||||||
|
self.repeats_taken = 0
|
||||||
|
|
||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
def action(self, action: ActType) -> ActType:
|
def action(self, action: ActType) -> ActType:
|
||||||
"""Execute the action."""
|
"""Execute the action."""
|
||||||
if (
|
# either the agent was already "stuck" into repeats, or a new series of repeats is triggered
|
||||||
|
if self.is_sticky_actions or (
|
||||||
self.last_action is not None
|
self.last_action is not None
|
||||||
and self.np_random.uniform() < self.repeat_action_probability
|
and self.np_random.uniform() < self.repeat_action_probability
|
||||||
):
|
):
|
||||||
|
# if a new series starts, randomly sample its duration
|
||||||
|
if self.num_repeats == 0:
|
||||||
|
self.num_repeats = self.np_random.integers(
|
||||||
|
self.repeat_action_duration_range[0],
|
||||||
|
self.repeat_action_duration_range[1] + 1,
|
||||||
|
)
|
||||||
action = self.last_action
|
action = self.last_action
|
||||||
|
self.is_sticky_actions = True
|
||||||
|
self.repeats_taken += 1
|
||||||
|
|
||||||
|
# repeats are done, reset "stuck" status
|
||||||
|
if self.is_sticky_actions and self.num_repeats == self.repeats_taken:
|
||||||
|
self.is_sticky_actions = False
|
||||||
|
self.num_repeats = 0
|
||||||
|
self.repeats_taken = 0
|
||||||
|
|
||||||
self.last_action = action
|
self.last_action = action
|
||||||
return action
|
return action
|
||||||
|
@@ -1,36 +1,63 @@
|
|||||||
"""Test suite for StickyAction wrapper."""
|
"""Test suite for StickyAction wrapper."""
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.error import InvalidProbability
|
from gymnasium.error import InvalidBound, InvalidProbability
|
||||||
|
from gymnasium.spaces import Discrete
|
||||||
from gymnasium.wrappers import StickyAction
|
from gymnasium.wrappers import StickyAction
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step
|
from tests.wrappers.utils import record_action_as_obs_step
|
||||||
|
|
||||||
|
|
||||||
def test_sticky_action():
|
@pytest.mark.parametrize(
|
||||||
|
"repeat_action_probability,repeat_action_duration,actions,expected_action",
|
||||||
|
[
|
||||||
|
(0.25, 1, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 2, 3, 3, 3, 6, 6]),
|
||||||
|
(0.25, 2, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 3, 4, 4, 4, 4]),
|
||||||
|
(0.25, (1, 3), [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 0, 4, 4, 4, 4]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sticky_action(
|
||||||
|
repeat_action_probability, repeat_action_duration, actions, expected_action
|
||||||
|
):
|
||||||
"""Tests the sticky action wrapper."""
|
"""Tests the sticky action wrapper."""
|
||||||
env = StickyAction(
|
env = StickyAction(
|
||||||
GenericTestEnv(step_func=record_action_as_obs_step),
|
GenericTestEnv(
|
||||||
repeat_action_probability=0.5,
|
step_func=record_action_as_obs_step, observation_space=Discrete(7)
|
||||||
|
),
|
||||||
|
repeat_action_probability=repeat_action_probability,
|
||||||
|
repeat_action_duration=repeat_action_duration,
|
||||||
)
|
)
|
||||||
|
env.reset(seed=11)
|
||||||
|
|
||||||
previous_action = None
|
assert len(actions) == len(expected_action)
|
||||||
for _ in range(NUM_STEPS):
|
for action, action_taken in zip(actions, expected_action):
|
||||||
input_action = env.action_space.sample()
|
executed_action, _, _, _, _ = env.step(action)
|
||||||
executed_action, _, _, _, _ = env.step(input_action)
|
assert executed_action == action_taken
|
||||||
|
|
||||||
assert np.all(executed_action == input_action) or np.all(
|
|
||||||
executed_action == previous_action
|
|
||||||
)
|
|
||||||
previous_action = executed_action
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
||||||
def test_sticky_action_raise(repeat_action_probability):
|
def test_sticky_action_raise_probability(repeat_action_probability):
|
||||||
"""Tests the stick action wrapper with probabilities that should raise an error."""
|
"""Tests the stick action wrapper with probabilities that should raise an error."""
|
||||||
with pytest.raises(InvalidProbability):
|
with pytest.raises(InvalidProbability):
|
||||||
StickyAction(
|
StickyAction(
|
||||||
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"repeat_action_duration",
|
||||||
|
[
|
||||||
|
-4,
|
||||||
|
0,
|
||||||
|
(0, 0),
|
||||||
|
(4, 2),
|
||||||
|
[1, 2],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sticky_action_raise_duration(repeat_action_duration):
|
||||||
|
"""Tests the stick action wrapper with durations that should raise an error."""
|
||||||
|
with pytest.raises((ValueError, InvalidBound)):
|
||||||
|
StickyAction(
|
||||||
|
GenericTestEnv(), 0.5, repeat_action_duration=repeat_action_duration
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user