mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Add TimeAwareObservation
support for environments without a spec
(#1289)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-symlinks
|
||||
- id: destroyed-symlinks
|
||||
@@ -35,7 +35,7 @@ repos:
|
||||
- --show-source
|
||||
- --statistics
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.17.0
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: ["--py38-plus"]
|
||||
@@ -44,7 +44,7 @@ repos:
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/python/black
|
||||
rev: 24.8.0
|
||||
rev: 24.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/pydocstyle
|
||||
|
@@ -198,9 +198,18 @@ class TimeAwareObservation(
|
||||
if env.spec is not None and env.spec.max_episode_steps is not None:
|
||||
self.max_timesteps = env.spec.max_episode_steps
|
||||
else:
|
||||
raise ValueError(
|
||||
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
||||
)
|
||||
# else we need to loop through the environment stack to check if a `TimeLimit` wrapper exists
|
||||
wrapped_env = env
|
||||
while isinstance(wrapped_env, gym.Wrapper):
|
||||
if isinstance(wrapped_env, gym.wrappers.TimeLimit):
|
||||
self.max_timesteps = wrapped_env._max_episode_steps
|
||||
break
|
||||
wrapped_env = wrapped_env.env
|
||||
|
||||
if not isinstance(wrapped_env, gym.wrappers.TimeLimit):
|
||||
raise ValueError(
|
||||
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
||||
)
|
||||
|
||||
self.timesteps: int = 0
|
||||
|
||||
|
@@ -1,12 +1,16 @@
|
||||
"""Test suite for TimeAwareObservation wrapper."""
|
||||
|
||||
import re
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.envs.classic_control import CartPoleEnv
|
||||
from gymnasium.spaces import Box, Dict, Tuple
|
||||
from gymnasium.wrappers import TimeAwareObservation
|
||||
from gymnasium.wrappers import TimeAwareObservation, TimeLimit
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
@@ -41,6 +45,25 @@ def test_default(env_id):
|
||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||
|
||||
|
||||
def test_no_spec():
|
||||
env = CartPoleEnv()
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
||||
),
|
||||
):
|
||||
TimeAwareObservation(env)
|
||||
|
||||
env = TimeLimit(env, 100)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env = TimeAwareObservation(env)
|
||||
|
||||
assert env.max_timesteps == 100
|
||||
assert len(caught_warnings) == 0
|
||||
|
||||
|
||||
def test_no_flatten():
|
||||
"""Test the TimeAwareObservation wrapper without flattening the space."""
|
||||
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||
|
Reference in New Issue
Block a user