Fix outdated docs for TimeLimit max_episode_steps and add validation. (#1149)

Co-authored-by: troiganto <troiganto@proton.me>
This commit is contained in:
Übertreiber
2024-08-26 18:52:50 +02:00
committed by GitHub
parent 94384bbe93
commit ad8734d89b
2 changed files with 25 additions and 1 deletions

View File

@@ -96,8 +96,11 @@ class TimeLimit(
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
max_episode_steps: the environment step after which the episode is truncated (``elapsed >= max_episode_steps``)
"""
assert (
isinstance(max_episode_steps, int) and max_episode_steps > 0
), f"Expect the `max_episode_steps` to be positive, actually: {max_episode_steps}"
gym.utils.RecordConstructorArgs.__init__(
self, max_episode_steps=max_episode_steps
)

View File

@@ -57,3 +57,24 @@ def test_termination_on_last_step(double_wrap):
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
assert terminated is True
assert truncated is True
def test_max_episode_steps():
env = gym.make("CartPole-v1", disable_env_checker=True)
assert env.spec.max_episode_steps == 500
assert TimeLimit(env, max_episode_steps=10).spec.max_episode_steps == 10
with pytest.raises(
AssertionError,
match="Expect the `max_episode_steps` to be positive, actually: -1",
):
TimeLimit(env, max_episode_steps=-1)
with pytest.raises(
AssertionError,
match="Expect the `max_episode_steps` to be positive, actually: None",
):
TimeLimit(env, max_episode_steps=None)
env.close()