Files
Gymnasium/tests/experimental/wrappers/test_time_aware_observation.py

100 lines
2.9 KiB
Python

"""Test suite for TimeAwareobservationV0."""
from collections import OrderedDict
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import TimeAwareObservationV0
from gymnasium.spaces import Box, Dict
NUM_STEPS = 20
SEED = 0
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation(env):
"""Test TimeAwareObservationV0 wrapper creation.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created.
"""
wrapped_env = TimeAwareObservationV0(env)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Dict)
assert isinstance(obs, OrderedDict)
assert np.all(obs["time"] == 0)
assert env.observation_space == wrapped_env.observation_space["obs"]
@pytest.mark.parametrize("normalize_time", [True, False])
@pytest.mark.parametrize("flatten", [False, True])
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
],
)
def test_time_aware_observation_step(env, flatten, normalize_time):
"""Test TimeAwareObservationV0 step.
This test checks if wrapped env with TimeAwareObservationV0
steps correctly.
"""
env.action_space.seed(SEED)
max_timesteps = env._max_episode_steps
wrapped_env = TimeAwareObservationV0(
env, flatten=flatten, normalize_time=normalize_time
)
wrapped_env.reset(seed=SEED)
for timestep in range(1, NUM_STEPS):
action = env.action_space.sample()
observation, _, terminated, _, _ = wrapped_env.step(action)
expected_time_obs = (
timestep / max_timesteps if normalize_time else max_timesteps - timestep
)
if flatten:
assert np.allclose(observation[-1], expected_time_obs)
else:
assert np.allclose(observation["time"], expected_time_obs)
if terminated:
break
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True),
gym.make("CarRacing-v2", disable_env_checker=True),
],
)
def test_time_aware_observation_creation_flatten(env):
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
This test checks if wrapped env with TimeAwareObservationV0
is correctly created when the `flatten` parameter is set to `True`.
When flattened, the observation space should be a 1 dimension `Box`
with time appended to the end.
"""
wrapped_env = TimeAwareObservationV0(env, flatten=True)
obs, _ = wrapped_env.reset()
assert isinstance(wrapped_env.observation_space, Box)
assert isinstance(obs, np.ndarray)
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]