"""Test suite for TimeAwareObservation wrapper.""" import numpy as np import pytest import gymnasium as gym from gymnasium import spaces from gymnasium.spaces import Box, Dict, Tuple from gymnasium.wrappers import TimeAwareObservation from tests.testing_env import GenericTestEnv @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_default(env_id): env = gym.make(env_id, disable_env_checker=True) wrapped_env = TimeAwareObservation(env) assert isinstance(env.observation_space, spaces.Box) assert isinstance(wrapped_env.observation_space, spaces.Box) assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 obs, info = env.reset() wrapped_obs, wrapped_obs_info = wrapped_env.reset() assert wrapped_env.timesteps == 0.0 assert wrapped_obs[-1] == 0.0, wrapped_obs assert wrapped_obs.shape[0] == obs.shape[0] + 1 wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.timesteps == 1.0 assert wrapped_obs[-1] == 1.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.timesteps == 2.0 assert wrapped_obs[-1] == 2.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 wrapped_obs, wrapped_obs_info = wrapped_env.reset() assert wrapped_env.timesteps == 0.0 assert wrapped_obs[-1] == 0.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 def test_no_flatten(): """Test the TimeAwareObservation wrapper without flattening the space.""" env = GenericTestEnv(observation_space=Box(0, 1)) wrapped_env = TimeAwareObservation(env) assert isinstance(wrapped_env.observation_space, Box) reset_obs, _ = wrapped_env.reset() step_obs, _, _, _, _ = wrapped_env.step(None) assert reset_obs.shape == (2,) and step_obs.shape == (2,) assert reset_obs in wrapped_env.observation_space assert step_obs in wrapped_env.observation_space def test_with_flatten(): """Test the flatten parameter for the TimeAwareObservation wrapper on three types of observation spaces.""" env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3))) wrapped_env = TimeAwareObservation(env, flatten=False) assert isinstance(wrapped_env.observation_space, Dict) reset_obs, _ = wrapped_env.reset() step_obs, _, _, _, _ = wrapped_env.step(None) assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}" assert reset_obs in wrapped_env.observation_space assert step_obs in wrapped_env.observation_space env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3)))) wrapped_env = TimeAwareObservation(env, flatten=False) assert isinstance(wrapped_env.observation_space, Tuple) reset_obs, _ = wrapped_env.reset() step_obs, _, _, _, _ = wrapped_env.step(None) assert len(reset_obs) == 3 and len(step_obs) == 3 assert reset_obs in wrapped_env.observation_space assert step_obs in wrapped_env.observation_space env = GenericTestEnv(observation_space=Box(0, 1)) wrapped_env = TimeAwareObservation(env, flatten=False) assert isinstance(wrapped_env.observation_space, Dict) reset_obs, _ = wrapped_env.reset() step_obs, _, _, _, _ = wrapped_env.step(None) assert isinstance(reset_obs, dict) and isinstance(step_obs, dict) assert "obs" in reset_obs and "obs" in step_obs assert "time" in reset_obs and "time" in step_obs assert reset_obs in wrapped_env.observation_space assert step_obs in wrapped_env.observation_space def test_normalize_time(): """Test the normalize time parameter for DelayObservation wrappers.""" env = GenericTestEnv(observation_space=Box(0, 1)) wrapped_env = TimeAwareObservation(env, flatten=False, normalize_time=False) reset_obs, _ = wrapped_env.reset() step_obs, _, _, _, _ = wrapped_env.step(None) assert reset_obs["time"] == np.array([0], dtype=np.int32) and step_obs[ "time" ] == np.array([1], dtype=np.int32) assert reset_obs in wrapped_env.observation_space assert step_obs in wrapped_env.observation_space env = GenericTestEnv(observation_space=Box(0, 1)) wrapped_env = TimeAwareObservation(env, flatten=False, normalize_time=True) reset_obs, _ = wrapped_env.reset() step_obs, _, _, _, _ = wrapped_env.step(None) assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01 assert reset_obs in wrapped_env.observation_space assert step_obs in wrapped_env.observation_space