Files
Gymnasium/tests/wrappers/test_sticky_action.py

37 lines
1.2 KiB
Python
Raw Normal View History

"""Test suite for StickyAction wrapper."""
2024-06-10 17:07:47 +01:00
2022-12-10 22:04:14 +00:00
import numpy as np
import pytest
from gymnasium.error import InvalidProbability
from gymnasium.wrappers import StickyAction
from tests.testing_env import GenericTestEnv
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step
def test_sticky_action():
2022-12-10 22:04:14 +00:00
"""Tests the sticky action wrapper."""
env = StickyAction(
2022-12-10 22:04:14 +00:00
GenericTestEnv(step_func=record_action_as_obs_step),
repeat_action_probability=0.5,
2022-12-05 19:14:56 +00:00
)
previous_action = None
for _ in range(NUM_STEPS):
input_action = env.action_space.sample()
2022-12-10 22:04:14 +00:00
executed_action, _, _, _, _ = env.step(input_action)
2022-12-10 22:04:14 +00:00
assert np.all(executed_action == input_action) or np.all(
executed_action == previous_action
)
previous_action = executed_action
2022-12-05 19:14:56 +00:00
@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability):
2022-12-10 22:04:14 +00:00
"""Tests the stick action wrapper with probabilities that should raise an error."""
with pytest.raises(InvalidProbability):
StickyAction(
GenericTestEnv(), repeat_action_probability=repeat_action_probability
)