2023-11-07 13:27:25 +00:00
|
|
|
"""Test suite for StickyAction wrapper."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2022-12-10 22:04:14 +00:00
|
|
|
import numpy as np
|
2022-12-02 01:04:34 +01:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from gymnasium.error import InvalidProbability
|
2023-11-07 13:27:25 +00:00
|
|
|
from gymnasium.wrappers import StickyAction
|
2022-12-02 01:04:34 +01:00
|
|
|
from tests.testing_env import GenericTestEnv
|
2023-11-07 13:27:25 +00:00
|
|
|
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step
|
2022-12-02 01:04:34 +01:00
|
|
|
|
2022-12-04 22:24:02 +08:00
|
|
|
|
2022-12-02 01:04:34 +01:00
|
|
|
def test_sticky_action():
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Tests the sticky action wrapper."""
|
2023-11-07 13:27:25 +00:00
|
|
|
env = StickyAction(
|
2022-12-10 22:04:14 +00:00
|
|
|
GenericTestEnv(step_func=record_action_as_obs_step),
|
|
|
|
repeat_action_probability=0.5,
|
2022-12-05 19:14:56 +00:00
|
|
|
)
|
2022-12-02 01:04:34 +01:00
|
|
|
|
|
|
|
previous_action = None
|
|
|
|
for _ in range(NUM_STEPS):
|
|
|
|
input_action = env.action_space.sample()
|
2022-12-10 22:04:14 +00:00
|
|
|
executed_action, _, _, _, _ = env.step(input_action)
|
2022-12-02 01:04:34 +01:00
|
|
|
|
2022-12-10 22:04:14 +00:00
|
|
|
assert np.all(executed_action == input_action) or np.all(
|
|
|
|
executed_action == previous_action
|
|
|
|
)
|
|
|
|
previous_action = executed_action
|
2022-12-02 01:04:34 +01:00
|
|
|
|
|
|
|
|
2022-12-05 19:14:56 +00:00
|
|
|
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
|
2022-12-02 01:04:34 +01:00
|
|
|
def test_sticky_action_raise(repeat_action_probability):
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Tests the stick action wrapper with probabilities that should raise an error."""
|
2022-12-02 01:04:34 +01:00
|
|
|
with pytest.raises(InvalidProbability):
|
2023-11-07 13:27:25 +00:00
|
|
|
StickyAction(
|
2022-12-02 01:04:34 +01:00
|
|
|
GenericTestEnv(), repeat_action_probability=repeat_action_probability
|
|
|
|
)
|