Files
Gymnasium/tests/experimental/wrappers/test_delay_observation.py

82 lines
2.6 KiB
Python
Raw Normal View History

2022-12-10 22:04:14 +00:00
"""Test suite for DelayObservationV0."""
import re
import pytest
2022-12-10 22:04:14 +00:00
import gymnasium as gym
from gymnasium.experimental.wrappers import DelayObservationV0
from gymnasium.experimental.wrappers.utils import create_zero_array
from gymnasium.utils.env_checker import data_equivalence
from tests.experimental.wrappers.utils import (
SEED,
TESTING_OBS_ENVS,
TESTING_OBS_ENVS_IDS,
)
2022-12-10 22:04:14 +00:00
@pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS)
def test_env_obs(env, delay: int = 3, extra_steps: int = 4):
2022-12-10 22:04:14 +00:00
"""Tests the delay observation wrapper."""
env.action_space.seed(SEED)
obs, _ = env.reset(seed=SEED)
2022-12-10 22:04:14 +00:00
undelayed_obs = [obs]
for _ in range(delay + extra_steps):
2022-12-10 22:04:14 +00:00
obs, _, _, _, _ = env.step(env.action_space.sample())
undelayed_obs.append(obs)
2022-12-10 22:04:14 +00:00
env = DelayObservationV0(env, delay=delay)
example_zero_obs = create_zero_array(env.observation_space)
2022-12-10 22:04:14 +00:00
env.action_space.seed(SEED)
obs, _ = env.reset(seed=SEED)
assert data_equivalence(obs, example_zero_obs)
assert obs in env.observation_space
2022-12-10 22:04:14 +00:00
delayed_obs = [obs]
for i in range(delay + extra_steps):
2022-12-10 22:04:14 +00:00
obs, _, _, _, _ = env.step(env.action_space.sample())
delayed_obs.append(obs)
assert obs in env.observation_space
if i < delay - 1:
assert data_equivalence(obs, example_zero_obs)
assert data_equivalence(delayed_obs[delay:], undelayed_obs[:-delay])
@pytest.mark.parametrize("delay", [1, 2, 3, 4])
def test_delay_values(delay):
"""Test the possible delay values for the DelayObservation wrapper."""
env = gym.make("CartPole-v1")
first_obs, _ = env.reset(seed=123)
env = DelayObservationV0(gym.make("CartPole-v1"), delay=delay)
zero_obs = create_zero_array(env.observation_space)
obs, _ = env.reset(seed=123)
assert data_equivalence(obs, zero_obs)
for _ in range(delay - 1):
obs, _, _, _, _ = env.step(env.action_space.sample())
assert data_equivalence(obs, zero_obs)
obs, _, _, _, _ = env.step(env.action_space.sample())
assert data_equivalence(first_obs, obs)
def test_delay_failures():
"""Test errors raised by DelayObservation wrapper."""
env = gym.make("CartPole-v1")
2022-12-10 22:04:14 +00:00
with pytest.raises(
TypeError,
match=re.escape(
"The delay is expected to be an integer, actual type: <class 'float'>"
),
):
DelayObservationV0(env, delay=1.0)
2022-12-10 22:04:14 +00:00
with pytest.raises(
ValueError,
match=re.escape("The delay needs to be greater than zero, actual value: -1"),
):
DelayObservationV0(env, delay=-1)