mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-07 16:31:46 +00:00
81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
"""Test suite for TimeLimit wrapper."""
|
|
|
|
import pytest
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.envs.classic_control.pendulum import PendulumEnv
|
|
from gymnasium.wrappers import TimeLimit
|
|
|
|
|
|
def test_time_limit_reset_info():
|
|
env = gym.make("CartPole-v1", disable_env_checker=True)
|
|
env = TimeLimit(env, 100)
|
|
ob_space = env.observation_space
|
|
obs, info = env.reset()
|
|
assert ob_space.contains(obs)
|
|
assert isinstance(info, dict)
|
|
|
|
|
|
@pytest.mark.parametrize("double_wrap", [False, True])
|
|
def test_time_limit_wrapper(double_wrap):
|
|
# The pendulum env does not terminate by default
|
|
# so we are sure termination is only due to timeout
|
|
env = PendulumEnv()
|
|
max_episode_length = 20
|
|
env = TimeLimit(env, max_episode_length)
|
|
if double_wrap:
|
|
env = TimeLimit(env, max_episode_length)
|
|
env.reset()
|
|
terminated, truncated = False, False
|
|
n_steps = 0
|
|
info = {}
|
|
while not (terminated or truncated):
|
|
n_steps += 1
|
|
_, _, terminated, truncated, info = env.step(env.action_space.sample())
|
|
|
|
assert n_steps == max_episode_length
|
|
assert truncated
|
|
|
|
|
|
@pytest.mark.parametrize("double_wrap", [False, True])
|
|
def test_termination_on_last_step(double_wrap):
|
|
# Special case: termination at the last timestep
|
|
# Truncation due to timeout also happens at the same step
|
|
|
|
env = PendulumEnv()
|
|
|
|
def patched_step(_action):
|
|
return env.observation_space.sample(), 0.0, True, False, {}
|
|
|
|
env.step = patched_step
|
|
|
|
max_episode_length = 1
|
|
env = TimeLimit(env, max_episode_length)
|
|
if double_wrap:
|
|
env = TimeLimit(env, max_episode_length)
|
|
env.reset()
|
|
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
|
assert terminated is True
|
|
assert truncated is True
|
|
|
|
|
|
def test_max_episode_steps():
|
|
env = gym.make("CartPole-v1", disable_env_checker=True)
|
|
|
|
assert env.spec.max_episode_steps == 500
|
|
assert TimeLimit(env, max_episode_steps=10).spec.max_episode_steps == 10
|
|
|
|
with pytest.raises(
|
|
AssertionError,
|
|
match="Expect the `max_episode_steps` to be positive, actually: -1",
|
|
):
|
|
TimeLimit(env, max_episode_steps=-1)
|
|
|
|
with pytest.raises(
|
|
AssertionError,
|
|
match="Expect the `max_episode_steps` to be positive, actually: None",
|
|
):
|
|
TimeLimit(env, max_episode_steps=None)
|
|
|
|
env.close()
|