Add TimeAwareObservation support for environments without a spec (#1289)

This commit is contained in:
Mark Towers
2025-01-08 15:30:34 +00:00
committed by GitHub
parent fc74bb8fc0
commit e6e3521cfe
3 changed files with 39 additions and 7 deletions

View File

@@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
@@ -35,7 +35,7 @@ repos:
- --show-source
- --statistics
- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: ["--py38-plus"]
@@ -44,7 +44,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/python/black
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/pydocstyle

View File

@@ -198,9 +198,18 @@ class TimeAwareObservation(
if env.spec is not None and env.spec.max_episode_steps is not None:
self.max_timesteps = env.spec.max_episode_steps
else:
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)
# else we need to loop through the environment stack to check if a `TimeLimit` wrapper exists
wrapped_env = env
while isinstance(wrapped_env, gym.Wrapper):
if isinstance(wrapped_env, gym.wrappers.TimeLimit):
self.max_timesteps = wrapped_env._max_episode_steps
break
wrapped_env = wrapped_env.env
if not isinstance(wrapped_env, gym.wrappers.TimeLimit):
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)
self.timesteps: int = 0

View File

@@ -1,12 +1,16 @@
"""Test suite for TimeAwareObservation wrapper."""
import re
import warnings
import numpy as np
import pytest
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.spaces import Box, Dict, Tuple
from gymnasium.wrappers import TimeAwareObservation
from gymnasium.wrappers import TimeAwareObservation, TimeLimit
from tests.testing_env import GenericTestEnv
@@ -41,6 +45,25 @@ def test_default(env_id):
assert wrapped_obs.shape[0] == obs.shape[0] + 1
def test_no_spec():
env = CartPoleEnv()
with pytest.raises(
ValueError,
match=re.escape(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
),
):
TimeAwareObservation(env)
env = TimeLimit(env, 100)
with warnings.catch_warnings(record=True) as caught_warnings:
env = TimeAwareObservation(env)
assert env.max_timesteps == 100
assert len(caught_warnings) == 0
def test_no_flatten():
"""Test the TimeAwareObservation wrapper without flattening the space."""
env = GenericTestEnv(observation_space=Box(0, 1))