mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
Add TimeAwareObservation
support for environments without a spec
(#1289)
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
# See https://pre-commit.com/hooks.html for more hooks
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.6.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-symlinks
|
- id: check-symlinks
|
||||||
- id: destroyed-symlinks
|
- id: destroyed-symlinks
|
||||||
@@ -35,7 +35,7 @@ repos:
|
|||||||
- --show-source
|
- --show-source
|
||||||
- --statistics
|
- --statistics
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.17.0
|
rev: v3.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
args: ["--py38-plus"]
|
args: ["--py38-plus"]
|
||||||
@@ -44,7 +44,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/python/black
|
- repo: https://github.com/python/black
|
||||||
rev: 24.8.0
|
rev: 24.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/pydocstyle
|
- repo: https://github.com/pycqa/pydocstyle
|
||||||
|
@@ -198,9 +198,18 @@ class TimeAwareObservation(
|
|||||||
if env.spec is not None and env.spec.max_episode_steps is not None:
|
if env.spec is not None and env.spec.max_episode_steps is not None:
|
||||||
self.max_timesteps = env.spec.max_episode_steps
|
self.max_timesteps = env.spec.max_episode_steps
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
# else we need to loop through the environment stack to check if a `TimeLimit` wrapper exists
|
||||||
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
wrapped_env = env
|
||||||
)
|
while isinstance(wrapped_env, gym.Wrapper):
|
||||||
|
if isinstance(wrapped_env, gym.wrappers.TimeLimit):
|
||||||
|
self.max_timesteps = wrapped_env._max_episode_steps
|
||||||
|
break
|
||||||
|
wrapped_env = wrapped_env.env
|
||||||
|
|
||||||
|
if not isinstance(wrapped_env, gym.wrappers.TimeLimit):
|
||||||
|
raise ValueError(
|
||||||
|
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
||||||
|
)
|
||||||
|
|
||||||
self.timesteps: int = 0
|
self.timesteps: int = 0
|
||||||
|
|
||||||
|
@@ -1,12 +1,16 @@
|
|||||||
"""Test suite for TimeAwareObservation wrapper."""
|
"""Test suite for TimeAwareObservation wrapper."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
from gymnasium.spaces import Box, Dict, Tuple
|
from gymnasium.spaces import Box, Dict, Tuple
|
||||||
from gymnasium.wrappers import TimeAwareObservation
|
from gymnasium.wrappers import TimeAwareObservation, TimeLimit
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
@@ -41,6 +45,25 @@ def test_default(env_id):
|
|||||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_spec():
|
||||||
|
env = CartPoleEnv()
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
TimeAwareObservation(env)
|
||||||
|
|
||||||
|
env = TimeLimit(env, 100)
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
env = TimeAwareObservation(env)
|
||||||
|
|
||||||
|
assert env.max_timesteps == 100
|
||||||
|
assert len(caught_warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_no_flatten():
|
def test_no_flatten():
|
||||||
"""Test the TimeAwareObservation wrapper without flattening the space."""
|
"""Test the TimeAwareObservation wrapper without flattening the space."""
|
||||||
env = GenericTestEnv(observation_space=Box(0, 1))
|
env = GenericTestEnv(observation_space=Box(0, 1))
|
||||||
|
Reference in New Issue
Block a user