diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 58954ef88..c02a0632b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -35,7 +35,7 @@ repos: - --show-source - --statistics - repo: https://github.com/asottile/pyupgrade - rev: v3.17.0 + rev: v3.19.1 hooks: - id: pyupgrade args: ["--py38-plus"] @@ -44,7 +44,7 @@ repos: hooks: - id: isort - repo: https://github.com/python/black - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black - repo: https://github.com/pycqa/pydocstyle diff --git a/gymnasium/wrappers/stateful_observation.py b/gymnasium/wrappers/stateful_observation.py index 17dec84e8..230ac9b49 100644 --- a/gymnasium/wrappers/stateful_observation.py +++ b/gymnasium/wrappers/stateful_observation.py @@ -198,9 +198,18 @@ class TimeAwareObservation( if env.spec is not None and env.spec.max_episode_steps is not None: self.max_timesteps = env.spec.max_episode_steps else: - raise ValueError( - "The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`." - ) + # else we need to loop through the environment stack to check if a `TimeLimit` wrapper exists + wrapped_env = env + while isinstance(wrapped_env, gym.Wrapper): + if isinstance(wrapped_env, gym.wrappers.TimeLimit): + self.max_timesteps = wrapped_env._max_episode_steps + break + wrapped_env = wrapped_env.env + + if not isinstance(wrapped_env, gym.wrappers.TimeLimit): + raise ValueError( + "The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`." + ) self.timesteps: int = 0 diff --git a/tests/wrappers/test_time_aware_observation.py b/tests/wrappers/test_time_aware_observation.py index 0fd0ae23e..33a9ed379 100644 --- a/tests/wrappers/test_time_aware_observation.py +++ b/tests/wrappers/test_time_aware_observation.py @@ -1,12 +1,16 @@ """Test suite for TimeAwareObservation wrapper.""" +import re +import warnings + import numpy as np import pytest import gymnasium as gym from gymnasium import spaces +from gymnasium.envs.classic_control import CartPoleEnv from gymnasium.spaces import Box, Dict, Tuple -from gymnasium.wrappers import TimeAwareObservation +from gymnasium.wrappers import TimeAwareObservation, TimeLimit from tests.testing_env import GenericTestEnv @@ -41,6 +45,25 @@ def test_default(env_id): assert wrapped_obs.shape[0] == obs.shape[0] + 1 +def test_no_spec(): + env = CartPoleEnv() + + with pytest.raises( + ValueError, + match=re.escape( + "The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`." + ), + ): + TimeAwareObservation(env) + + env = TimeLimit(env, 100) + with warnings.catch_warnings(record=True) as caught_warnings: + env = TimeAwareObservation(env) + + assert env.max_timesteps == 100 + assert len(caught_warnings) == 0 + + def test_no_flatten(): """Test the TimeAwareObservation wrapper without flattening the space.""" env = GenericTestEnv(observation_space=Box(0, 1))