mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-07 08:21:48 +00:00
79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
"""Test suite for DelayObservation wrapper."""
|
|
|
|
import re
|
|
|
|
import pytest
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.utils.env_checker import data_equivalence
|
|
from gymnasium.wrappers import DelayObservation
|
|
from gymnasium.wrappers.utils import create_zero_array
|
|
from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS
|
|
|
|
|
|
@pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS)
|
|
def test_env_obs(env, delay: int = 3, extra_steps: int = 4):
|
|
"""Tests the delay observation wrapper."""
|
|
env.action_space.seed(SEED)
|
|
obs, _ = env.reset(seed=SEED)
|
|
|
|
undelayed_obs = [obs]
|
|
for _ in range(delay + extra_steps):
|
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
undelayed_obs.append(obs)
|
|
|
|
env = DelayObservation(env, delay=delay)
|
|
example_zero_obs = create_zero_array(env.observation_space)
|
|
env.action_space.seed(SEED)
|
|
obs, _ = env.reset(seed=SEED)
|
|
assert data_equivalence(obs, example_zero_obs)
|
|
assert obs in env.observation_space
|
|
|
|
delayed_obs = [obs]
|
|
for i in range(delay + extra_steps):
|
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
delayed_obs.append(obs)
|
|
assert obs in env.observation_space
|
|
|
|
if i < delay - 1:
|
|
assert data_equivalence(obs, example_zero_obs)
|
|
|
|
assert data_equivalence(delayed_obs[delay:], undelayed_obs[:-delay])
|
|
|
|
|
|
@pytest.mark.parametrize("delay", [1, 2, 3, 4])
|
|
def test_delay_values(delay):
|
|
"""Test the possible delay values for the DelayObservation wrapper."""
|
|
env = gym.make("CartPole-v1")
|
|
first_obs, _ = env.reset(seed=123)
|
|
|
|
env = DelayObservation(gym.make("CartPole-v1"), delay=delay)
|
|
zero_obs = create_zero_array(env.observation_space)
|
|
obs, _ = env.reset(seed=123)
|
|
assert data_equivalence(obs, zero_obs)
|
|
for _ in range(delay - 1):
|
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
assert data_equivalence(obs, zero_obs)
|
|
|
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
|
assert data_equivalence(first_obs, obs)
|
|
|
|
|
|
def test_delay_failures():
|
|
"""Test errors raised by DelayObservation wrapper."""
|
|
env = gym.make("CartPole-v1")
|
|
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=re.escape(
|
|
"The delay is expected to be an integer, actual type: <class 'float'>"
|
|
),
|
|
):
|
|
DelayObservation(env, delay=1.0)
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match=re.escape("The delay needs to be greater than zero, actual value: -1"),
|
|
):
|
|
DelayObservation(env, delay=-1)
|