Files
Gymnasium/tests/wrappers/test_time_limit.py
2024-08-26 17:52:50 +01:00

81 lines
2.3 KiB
Python

"""Test suite for TimeLimit wrapper."""
import pytest
import gymnasium as gym
from gymnasium.envs.classic_control.pendulum import PendulumEnv
from gymnasium.wrappers import TimeLimit
def test_time_limit_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True)
env = TimeLimit(env, 100)
ob_space = env.observation_space
obs, info = env.reset()
assert ob_space.contains(obs)
assert isinstance(info, dict)
@pytest.mark.parametrize("double_wrap", [False, True])
def test_time_limit_wrapper(double_wrap):
# The pendulum env does not terminate by default
# so we are sure termination is only due to timeout
env = PendulumEnv()
max_episode_length = 20
env = TimeLimit(env, max_episode_length)
if double_wrap:
env = TimeLimit(env, max_episode_length)
env.reset()
terminated, truncated = False, False
n_steps = 0
info = {}
while not (terminated or truncated):
n_steps += 1
_, _, terminated, truncated, info = env.step(env.action_space.sample())
assert n_steps == max_episode_length
assert truncated
@pytest.mark.parametrize("double_wrap", [False, True])
def test_termination_on_last_step(double_wrap):
# Special case: termination at the last timestep
# Truncation due to timeout also happens at the same step
env = PendulumEnv()
def patched_step(_action):
return env.observation_space.sample(), 0.0, True, False, {}
env.step = patched_step
max_episode_length = 1
env = TimeLimit(env, max_episode_length)
if double_wrap:
env = TimeLimit(env, max_episode_length)
env.reset()
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
assert terminated is True
assert truncated is True
def test_max_episode_steps():
env = gym.make("CartPole-v1", disable_env_checker=True)
assert env.spec.max_episode_steps == 500
assert TimeLimit(env, max_episode_steps=10).spec.max_episode_steps == 10
with pytest.raises(
AssertionError,
match="Expect the `max_episode_steps` to be positive, actually: -1",
):
TimeLimit(env, max_episode_steps=-1)
with pytest.raises(
AssertionError,
match="Expect the `max_episode_steps` to be positive, actually: None",
):
TimeLimit(env, max_episode_steps=None)
env.close()