Files
Gymnasium/tests/utils/test_env_checker.py

297 lines
11 KiB
Python
Raw Normal View History

"""Tests that the `env_checker` runs as expects and all errors are possible."""
2024-06-10 17:07:47 +01:00
import re
import warnings
from collections.abc import Callable
Seeding update (#2422) * Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
2021-12-08 22:14:15 +01:00
import numpy as np
import pytest
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium import spaces
from gymnasium.core import ObsType
from gymnasium.utils.env_checker import (
check_env,
check_reset_options,
check_reset_return_info_deprecation,
check_reset_return_type,
2024-02-29 17:03:38 +02:00
check_reset_seed_determinism,
check_seed_deprecation,
2024-02-29 17:03:38 +02:00
check_step_determinism,
)
from tests.testing_env import GenericTestEnv
CHECK_ENV_IGNORE_WARNINGS = [
f"\x1b[33mWARN: {message}\x1b[0m"
for message in [
"A Box observation space minimum value is -infinity. This is probably too low.",
"A Box observation space maximum value is infinity. This is probably too high.",
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
]
]
def _no_error_warnings_envs():
yield gym.make("CartPole-v1", disable_env_checker=True).unwrapped
yield gym.make("MountainCar-v0", disable_env_checker=True).unwrapped
space_a = spaces.Discrete(10)
space_b = spaces.Box(np.zeros(2, np.float32), np.ones(2, np.float32))
yield GenericTestEnv(observation_space=spaces.Dict(a=space_a, b=space_b))
yield GenericTestEnv(observation_space=spaces.Tuple([space_a, space_b]))
yield GenericTestEnv(
observation_space=spaces.Dict(a=spaces.Tuple([space_a, space_b]), b=space_b)
)
@pytest.mark.parametrize("env", _no_error_warnings_envs())
def test_no_error_warnings(env):
2022-09-08 10:10:07 +01:00
"""A full version of this test with all gymnasium envs is run in tests/envs/test_envs.py."""
with warnings.catch_warnings(record=True) as caught_warnings:
check_env(env)
caught_warnings = [
warning
for warning in caught_warnings
if str(warning.message) not in CHECK_ENV_IGNORE_WARNINGS
]
assert len(caught_warnings) == 0, [warning.message for warning in caught_warnings]
Render API (#2671) * add pygame GUI for frozen_lake.py env * add new line at EOF * pre-commit reformat * improve graphics * new images and dynamic window size * darker tile borders and fix ICC profile * pre-commit hook * adjust elf and stool size * Update frozen_lake.py * reformat * fix #2600 * #2600 * add rgb_array support * reformat * test render api change on FrozenLake * add render support for reset on frozenlake * add clock on pygame render * new render api for blackjack * new render api for cliffwalking * new render api for Env class * update reset method, lunar and Env * fix wrapper * fix reset lunar * new render api for box2d envs * new render api for mujoco envs * fix bug * new render api for classic control envs * fix tests * add render_mode None for CartPole * new render api for test fake envs * pre-commit hook * fix FrozenLake * fix FrozenLake * more render_mode to super - frozenlake * remove kwargs from frozen_lake new * pre-commit hook * add deprecated render method * add backwards compatibility * fix test * add _render * move pygame.init() (avoid pygame dependency on init) * fix pygame dependencies * remove collect_render() maintain multi-behaviours .render() * add type hints * fix renderer * don't call .render() with None * improve docstring * add single_rgb_array to all envs * remove None from metadata["render_modes"] * add type hints to test_env_checkers * fix lint * add comments to renderer * add comments to single_depth_array and single_state_pixels * reformat * add deprecation warnings and env.render_mode declaration * fix lint * reformat * fix tests * add docs * fix car racing determinism * remove warning test envs, customizable modes on renderer * remove commments and add todo for env_checker * fix car racing * replace render mode check with assert * update new mujoco * reformat * reformat * change metaclass definition * fix tests * implement mark suggestions (test, docs, sets) * check_render Co-authored-by: J K Terry <jkterry0@gmail.com>
2022-06-08 00:20:56 +02:00
def _no_super_reset(self, seed=None, options=None):
self.np_random.random() # generates a new prng
# generate seed deterministic result
self.observation_space.seed(0)
return self.observation_space.sample(), {}
def _super_reset_fixed(self, seed=None, options=None):
# Call super that ignores the seed passed, use fixed seed
super(GenericTestEnv, self).reset(seed=1)
# deterministic output
self.observation_space._np_random = self.np_random
return self.observation_space.sample(), {}
def _reset_default_seed(self: GenericTestEnv, seed=23, options=None):
super(GenericTestEnv, self).reset(seed=seed)
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
self.np_random
)
return self.observation_space.sample(), {}
@pytest.mark.parametrize(
"test,func,message",
[
[
gym.error.Error,
lambda self: (self.observation_space.sample(), {}),
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
],
[
AssertionError,
lambda self, seed, *_: (self.observation_space.sample(), {}),
2024-04-29 04:43:11 -04:00
"Expects the random number generator to have been generated given a seed was passed to reset. Most likely the environment reset function does not call `super().reset(seed=seed)`.",
],
[
AssertionError,
_no_super_reset,
2024-04-29 04:43:11 -04:00
"Most likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`.",
],
[
AssertionError,
_super_reset_fixed,
2024-04-29 04:43:11 -04:00
"Most likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`.",
],
[
UserWarning,
_reset_default_seed,
"The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. Actual default: 23",
],
],
)
2024-02-29 17:03:38 +02:00
def test_check_reset_seed_determinism(test, func: Callable, message: str):
"""Tests the check reset seed function works as expected."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
2024-02-29 17:03:38 +02:00
check_reset_seed_determinism(GenericTestEnv(reset_func=func))
else:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
2024-02-29 17:03:38 +02:00
check_reset_seed_determinism(GenericTestEnv(reset_func=func))
def _deprecated_return_info(
self, return_info: bool = False
) -> tuple[ObsType, dict] | ObsType:
"""function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument"""
if return_info:
return self.observation_space.sample(), {}
else:
return self.observation_space.sample()
def _reset_var_keyword_kwargs(self, kwargs):
return self.observation_space.sample(), {}
def _reset_return_info_type(self, seed=None, options=None):
"""Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly
checks that the return type of `env.reset()` is a `tuple`"""
return [self.observation_space.sample(), {}]
def _reset_return_info_length(self, seed=None, options=None):
return 1, 2, 3
def _return_info_obs_outside(self, seed=None, options=None):
return self.observation_space.sample() + self.observation_space.high, {}
def _return_info_not_dict(self, seed=None, options=None):
return self.observation_space.sample(), ["key", "value"]
@pytest.mark.parametrize(
"test,func,message",
[
[
AssertionError,
_reset_return_info_type,
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
],
[
AssertionError,
_reset_return_info_length,
"Calling the reset method did not return a 2-tuple, actual length: 3",
],
[
AssertionError,
_return_info_obs_outside,
"The first element returned by `env.reset()` is not within the observation space.",
],
[
AssertionError,
_return_info_not_dict,
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'list'>",
],
],
)
def test_check_reset_return_type(test, func: Callable, message: str):
"""Tests the check `env.reset()` function has a correct return type."""
with pytest.raises(test, match=f"^{re.escape(message)}$"):
2022-12-05 19:14:56 +00:00
check_reset_return_type(GenericTestEnv(reset_func=func))
@pytest.mark.parametrize(
"test,func,message",
[
[
UserWarning,
_deprecated_return_info,
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
"containing additional information.",
],
],
)
def test_check_reset_return_info_deprecation(test, func: Callable, message: str):
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
2022-12-05 19:14:56 +00:00
check_reset_return_info_deprecation(GenericTestEnv(reset_func=func))
def test_check_seed_deprecation():
"""Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed."""
2022-09-08 10:10:07 +01:00
message = """Official support for the `seed` function is dropped. Standard practice is to reset gymnasium environments using `env.reset(seed=<desired seed>)`"""
env = GenericTestEnv()
def seed(seed):
return
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env.seed = seed
assert callable(env.seed)
check_seed_deprecation(env)
with warnings.catch_warnings(record=True) as caught_warnings:
env.seed = []
check_seed_deprecation(env)
env.seed = 123
check_seed_deprecation(env)
del env.seed
check_seed_deprecation(env)
assert len(caught_warnings) == 0
def test_check_reset_options():
"""Tests the check_reset_options function."""
with pytest.raises(
gym.error.Error,
match=re.escape(
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
),
):
2022-12-05 19:14:56 +00:00
check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {})))
2024-02-29 17:03:38 +02:00
@pytest.mark.parametrize(
"test,step_func,message",
[
[
AssertionError,
lambda self, action: (np.random.normal(), 0, False, False, {}),
"Deterministic step observations are not equivalent for the same seed and action",
2024-02-29 17:03:38 +02:00
],
[
AssertionError,
lambda self, action: (0, np.random.normal(), False, False, {}),
"Deterministic step rewards are not equivalent for the same seed and action",
2024-02-29 17:03:38 +02:00
],
[
AssertionError,
lambda self, action: (0, 0, False, False, {"value": np.random.normal()}),
"Deterministic step info are not equivalent for the same seed and action",
2024-02-29 17:03:38 +02:00
],
],
)
def test_check_step_determinism(test, step_func, message: str):
"""Tests the check_step_determinism function."""
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_step_determinism(GenericTestEnv(step_func=step_func))
@pytest.mark.parametrize(
"env, error_type, message",
[
[
"Error",
TypeError,
"The environment must inherit from the gymnasium.Env class, actual class: <class 'str'>. See https://gymnasium.farama.org/introduction/create_custom_env/ for more info.",
],
[
GenericTestEnv(action_space=None),
AttributeError,
"The environment must specify an action space. See https://gymnasium.farama.org/introduction/create_custom_env/ for more info.",
],
[
GenericTestEnv(observation_space=None),
AttributeError,
"The environment must specify an observation space. See https://gymnasium.farama.org/introduction/create_custom_env/ for more info.",
],
],
)
def test_check_env(env: gym.Env, error_type, message: str):
"""Tests the check_env function works as expected."""
with pytest.raises(error_type, match=f"^{re.escape(message)}$"):
check_env(env)