mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
"""Test suite for TimeAwareobservationV0."""
|
|
|
|
from collections import OrderedDict
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.experimental.wrappers import TimeAwareObservationV0
|
|
from gymnasium.spaces import Box, Dict
|
|
|
|
|
|
NUM_STEPS = 20
|
|
SEED = 0
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env",
|
|
[
|
|
gym.make("CartPole-v1", disable_env_checker=True),
|
|
gym.make("CarRacing-v2", disable_env_checker=True),
|
|
],
|
|
)
|
|
def test_time_aware_observation_creation(env):
|
|
"""Test TimeAwareObservationV0 wrapper creation.
|
|
|
|
This test checks if wrapped env with TimeAwareObservationV0
|
|
is correctly created.
|
|
"""
|
|
wrapped_env = TimeAwareObservationV0(env)
|
|
obs, _ = wrapped_env.reset()
|
|
|
|
assert isinstance(wrapped_env.observation_space, Dict)
|
|
assert isinstance(obs, OrderedDict)
|
|
assert np.all(obs["time"] == 0)
|
|
assert env.observation_space == wrapped_env.observation_space["obs"]
|
|
|
|
|
|
@pytest.mark.parametrize("normalize_time", [True, False])
|
|
@pytest.mark.parametrize("flatten", [False, True])
|
|
@pytest.mark.parametrize(
|
|
"env",
|
|
[
|
|
gym.make("CartPole-v1", disable_env_checker=True),
|
|
gym.make("CarRacing-v2", disable_env_checker=True, continuous=False),
|
|
],
|
|
)
|
|
def test_time_aware_observation_step(env, flatten, normalize_time):
|
|
"""Test TimeAwareObservationV0 step.
|
|
|
|
This test checks if wrapped env with TimeAwareObservationV0
|
|
steps correctly.
|
|
"""
|
|
env.action_space.seed(SEED)
|
|
max_timesteps = env._max_episode_steps
|
|
|
|
wrapped_env = TimeAwareObservationV0(
|
|
env, flatten=flatten, normalize_time=normalize_time
|
|
)
|
|
wrapped_env.reset(seed=SEED)
|
|
|
|
for timestep in range(1, NUM_STEPS):
|
|
action = env.action_space.sample()
|
|
observation, _, terminated, _, _ = wrapped_env.step(action)
|
|
|
|
expected_time_obs = (
|
|
timestep / max_timesteps if normalize_time else max_timesteps - timestep
|
|
)
|
|
|
|
if flatten:
|
|
assert np.allclose(observation[-1], expected_time_obs)
|
|
else:
|
|
assert np.allclose(observation["time"], expected_time_obs)
|
|
|
|
if terminated:
|
|
break
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env",
|
|
[
|
|
gym.make("CartPole-v1", disable_env_checker=True),
|
|
gym.make("CarRacing-v2", disable_env_checker=True),
|
|
],
|
|
)
|
|
def test_time_aware_observation_creation_flatten(env):
|
|
"""Test TimeAwareObservationV0 wrapper creation with `flatten=True`.
|
|
|
|
This test checks if wrapped env with TimeAwareObservationV0
|
|
is correctly created when the `flatten` parameter is set to `True`.
|
|
When flattened, the observation space should be a 1 dimension `Box`
|
|
with time appended to the end.
|
|
"""
|
|
wrapped_env = TimeAwareObservationV0(env, flatten=True)
|
|
obs, _ = wrapped_env.reset()
|
|
|
|
assert isinstance(wrapped_env.observation_space, Box)
|
|
assert isinstance(obs, np.ndarray)
|
|
assert env.observation_space == wrapped_env.time_aware_observation_space["obs"]
|