Add TimeAwareObservation support for environments without a spec (#1289)

This commit is contained in:
Mark Towers
2025-01-08 15:30:34 +00:00
committed by GitHub
parent fc74bb8fc0
commit e6e3521cfe
3 changed files with 39 additions and 7 deletions

View File

@@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0 rev: v5.0.0
hooks: hooks:
- id: check-symlinks - id: check-symlinks
- id: destroyed-symlinks - id: destroyed-symlinks
@@ -35,7 +35,7 @@ repos:
- --show-source - --show-source
- --statistics - --statistics
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.17.0 rev: v3.19.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
args: ["--py38-plus"] args: ["--py38-plus"]
@@ -44,7 +44,7 @@ repos:
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/python/black - repo: https://github.com/python/black
rev: 24.8.0 rev: 24.10.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/pydocstyle - repo: https://github.com/pycqa/pydocstyle

View File

@@ -198,9 +198,18 @@ class TimeAwareObservation(
if env.spec is not None and env.spec.max_episode_steps is not None: if env.spec is not None and env.spec.max_episode_steps is not None:
self.max_timesteps = env.spec.max_episode_steps self.max_timesteps = env.spec.max_episode_steps
else: else:
raise ValueError( # else we need to loop through the environment stack to check if a `TimeLimit` wrapper exists
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`." wrapped_env = env
) while isinstance(wrapped_env, gym.Wrapper):
if isinstance(wrapped_env, gym.wrappers.TimeLimit):
self.max_timesteps = wrapped_env._max_episode_steps
break
wrapped_env = wrapped_env.env
if not isinstance(wrapped_env, gym.wrappers.TimeLimit):
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)
self.timesteps: int = 0 self.timesteps: int = 0

View File

@@ -1,12 +1,16 @@
"""Test suite for TimeAwareObservation wrapper.""" """Test suite for TimeAwareObservation wrapper."""
import re
import warnings
import numpy as np import numpy as np
import pytest import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium import spaces from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.spaces import Box, Dict, Tuple from gymnasium.spaces import Box, Dict, Tuple
from gymnasium.wrappers import TimeAwareObservation from gymnasium.wrappers import TimeAwareObservation, TimeLimit
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
@@ -41,6 +45,25 @@ def test_default(env_id):
assert wrapped_obs.shape[0] == obs.shape[0] + 1 assert wrapped_obs.shape[0] == obs.shape[0] + 1
def test_no_spec():
env = CartPoleEnv()
with pytest.raises(
ValueError,
match=re.escape(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
),
):
TimeAwareObservation(env)
env = TimeLimit(env, 100)
with warnings.catch_warnings(record=True) as caught_warnings:
env = TimeAwareObservation(env)
assert env.max_timesteps == 100
assert len(caught_warnings) == 0
def test_no_flatten(): def test_no_flatten():
"""Test the TimeAwareObservation wrapper without flattening the space.""" """Test the TimeAwareObservation wrapper without flattening the space."""
env = GenericTestEnv(observation_space=Box(0, 1)) env = GenericTestEnv(observation_space=Box(0, 1))