mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-07 00:11:46 +00:00
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
"""Test suite for StickyAction wrapper."""
|
|
|
|
import pytest
|
|
|
|
from gymnasium.error import InvalidBound, InvalidProbability
|
|
from gymnasium.spaces import Discrete
|
|
from gymnasium.wrappers import StickyAction
|
|
from tests.testing_env import GenericTestEnv
|
|
from tests.wrappers.utils import record_action_as_obs_step
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"repeat_action_probability,repeat_action_duration,actions,expected_action",
|
|
[
|
|
(0.25, 1, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 2, 3, 3, 3, 6, 6]),
|
|
(0.25, 2, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 3, 4, 4, 4, 4]),
|
|
(0.25, (1, 3), [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 0, 4, 4, 4, 4]),
|
|
],
|
|
)
|
|
def test_sticky_action(
|
|
repeat_action_probability, repeat_action_duration, actions, expected_action
|
|
):
|
|
"""Tests the sticky action wrapper."""
|
|
env = StickyAction(
|
|
GenericTestEnv(
|
|
step_func=record_action_as_obs_step, observation_space=Discrete(7)
|
|
),
|
|
repeat_action_probability=repeat_action_probability,
|
|
repeat_action_duration=repeat_action_duration,
|
|
)
|
|
env.reset(seed=11)
|
|
|
|
assert len(actions) == len(expected_action)
|
|
for action, action_taken in zip(actions, expected_action):
|
|
executed_action, _, _, _, _ = env.step(action)
|
|
assert executed_action == action_taken
|
|
|
|
|
|
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
|
def test_sticky_action_raise_probability(repeat_action_probability):
|
|
"""Tests the stick action wrapper with probabilities that should raise an error."""
|
|
with pytest.raises(InvalidProbability):
|
|
StickyAction(
|
|
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"repeat_action_duration",
|
|
[
|
|
-4,
|
|
0,
|
|
(0, 0),
|
|
(4, 2),
|
|
[1, 2],
|
|
],
|
|
)
|
|
def test_sticky_action_raise_duration(repeat_action_duration):
|
|
"""Tests the stick action wrapper with durations that should raise an error."""
|
|
with pytest.raises((ValueError, InvalidBound)):
|
|
StickyAction(
|
|
GenericTestEnv(), 0.5, repeat_action_duration=repeat_action_duration
|
|
)
|