2023-11-07 13:27:25 +00:00
|
|
|
"""Test suite for MaxAndSkipObservation wrapper."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2023-06-20 16:14:33 +01:00
|
|
|
import re
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
import gymnasium as gym
|
2023-11-07 13:27:25 +00:00
|
|
|
from gymnasium.wrappers import MaxAndSkipObservation
|
2023-06-20 16:14:33 +01:00
|
|
|
|
|
|
|
|
|
|
|
def test_max_and_skip_obs(skip: int = 4):
|
|
|
|
"""Test MaxAndSkipObservationV0."""
|
|
|
|
env = gym.make("CartPole-v1")
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
env = MaxAndSkipObservation(env, skip=skip)
|
2023-06-20 16:14:33 +01:00
|
|
|
|
|
|
|
obs, _ = env.reset()
|
|
|
|
assert obs in env.observation_space
|
|
|
|
|
|
|
|
for i in range(10):
|
2023-11-07 13:27:25 +00:00
|
|
|
obs, _, term, trunc, _ = env.step(env.action_space.sample())
|
2023-06-20 16:14:33 +01:00
|
|
|
assert obs in env.observation_space
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
if term or trunc:
|
|
|
|
obs, _ = env.reset()
|
|
|
|
assert obs in env.observation_space
|
|
|
|
|
2023-06-20 16:14:33 +01:00
|
|
|
|
|
|
|
def test_skip_size_failures():
|
|
|
|
"""Test the error raised by the MaxAndSkipObservation."""
|
|
|
|
env = gym.make("CartPole-v1")
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
TypeError,
|
|
|
|
match=re.escape(
|
|
|
|
"The skip is expected to be an integer, actual type: <class 'float'>"
|
|
|
|
),
|
|
|
|
):
|
2023-11-07 13:27:25 +00:00
|
|
|
MaxAndSkipObservation(env, skip=1.0)
|
2023-06-20 16:14:33 +01:00
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
ValueError,
|
|
|
|
match=re.escape(
|
|
|
|
"The skip value needs to be equal or greater than two, actual value: 0"
|
|
|
|
),
|
|
|
|
):
|
2023-11-07 13:27:25 +00:00
|
|
|
MaxAndSkipObservation(env, skip=0)
|