Files
Gymnasium/tests/utils/test_env_checker.py
2022-12-05 19:14:56 +00:00

263 lines
9.2 KiB
Python

"""Tests that the `env_checker` runs as expects and all errors are possible."""
import re
import warnings
from typing import Tuple, Union
import numpy as np
import pytest
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ObsType
from gymnasium.utils.env_checker import (
check_env,
check_reset_options,
check_reset_return_info_deprecation,
check_reset_return_type,
check_reset_seed,
check_seed_deprecation,
)
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True).unwrapped,
gym.make("MountainCar-v0", disable_env_checker=True).unwrapped,
GenericTestEnv(
observation_space=spaces.Dict(
a=spaces.Discrete(10), b=spaces.Box(np.zeros(2), np.ones(2))
)
),
GenericTestEnv(
observation_space=spaces.Tuple(
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
)
),
GenericTestEnv(
observation_space=spaces.Dict(
a=spaces.Tuple(
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
),
b=spaces.Box(np.zeros(2), np.ones(2)),
)
),
],
)
def test_no_error_warnings(env):
"""A full version of this test with all gymnasium envs is run in tests/envs/test_envs.py."""
with warnings.catch_warnings(record=True) as caught_warnings:
check_env(env)
assert len(caught_warnings) == 0, [warning.message for warning in caught_warnings]
def _no_super_reset(self, seed=None, options=None):
self.np_random.random() # generates a new prng
# generate seed deterministic result
self.observation_space.seed(0)
return self.observation_space.sample(), {}
def _super_reset_fixed(self, seed=None, options=None):
# Call super that ignores the seed passed, use fixed seed
super(GenericTestEnv, self).reset(seed=1)
# deterministic output
self.observation_space._np_random = self.np_random
return self.observation_space.sample(), {}
def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None):
super(GenericTestEnv, self).reset(seed=seed)
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
self.np_random
)
return self.observation_space.sample(), {}
@pytest.mark.parametrize(
"test,func,message",
[
[
gym.error.Error,
lambda self: (self.observation_space.sample(), {}),
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
],
[
AssertionError,
lambda self, seed, *_: (self.observation_space.sample(), {}),
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
],
[
AssertionError,
_no_super_reset,
"Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`.",
],
[
AssertionError,
_super_reset_fixed,
"Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`.",
],
[
UserWarning,
_reset_default_seed,
"The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. Actual default: Error",
],
],
)
def test_check_reset_seed(test, func: callable, message: str):
"""Tests the check reset seed function works as expected."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_reset_seed(GenericTestEnv(reset_func=func))
else:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_seed(GenericTestEnv(reset_func=func))
def _deprecated_return_info(
self, return_info: bool = False
) -> Union[Tuple[ObsType, dict], ObsType]:
"""function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument"""
if return_info:
return self.observation_space.sample(), {}
else:
return self.observation_space.sample()
def _reset_var_keyword_kwargs(self, kwargs):
return self.observation_space.sample(), {}
def _reset_return_info_type(self, seed=None, options=None):
"""Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly
checks that the return type of `env.reset()` is a `tuple`"""
return [self.observation_space.sample(), {}]
def _reset_return_info_length(self, seed=None, options=None):
return 1, 2, 3
def _return_info_obs_outside(self, seed=None, options=None):
return self.observation_space.sample() + self.observation_space.high, {}
def _return_info_not_dict(self, seed=None, options=None):
return self.observation_space.sample(), ["key", "value"]
@pytest.mark.parametrize(
"test,func,message",
[
[
AssertionError,
_reset_return_info_type,
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
],
[
AssertionError,
_reset_return_info_length,
"Calling the reset method did not return a 2-tuple, actual length: 3",
],
[
AssertionError,
_return_info_obs_outside,
"The first element returned by `env.reset()` is not within the observation space.",
],
[
AssertionError,
_return_info_not_dict,
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'list'>",
],
],
)
def test_check_reset_return_type(test, func: callable, message: str):
"""Tests the check `env.reset()` function has a correct return type."""
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_return_type(GenericTestEnv(reset_func=func))
@pytest.mark.parametrize(
"test,func,message",
[
[
UserWarning,
_deprecated_return_info,
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
"containing additional information.",
],
],
)
def test_check_reset_return_info_deprecation(test, func: callable, message: str):
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
check_reset_return_info_deprecation(GenericTestEnv(reset_func=func))
def test_check_seed_deprecation():
"""Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed."""
message = """Official support for the `seed` function is dropped. Standard practice is to reset gymnasium environments using `env.reset(seed=<desired seed>)`"""
env = GenericTestEnv()
def seed(seed):
return
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env.seed = seed
assert callable(env.seed)
check_seed_deprecation(env)
with warnings.catch_warnings(record=True) as caught_warnings:
env.seed = []
check_seed_deprecation(env)
env.seed = 123
check_seed_deprecation(env)
del env.seed
check_seed_deprecation(env)
assert len(caught_warnings) == 0
def test_check_reset_options():
"""Tests the check_reset_options function."""
with pytest.raises(
gym.error.Error,
match=re.escape(
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
),
):
check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {})))
@pytest.mark.parametrize(
"env,message",
[
[
"Error",
"The environment must inherit from the gymnasium.Env class. See https://gymnasium.farama.org/content/environment_creation/ for more info.",
],
[
GenericTestEnv(action_space=None),
"The environment must specify an action space. See https://gymnasium.farama.org/content/environment_creation/ for more info.",
],
[
GenericTestEnv(observation_space=None),
"The environment must specify an observation space. See https://gymnasium.farama.org/content/environment_creation/ for more info.",
],
],
)
def test_check_env(env: gym.Env, message: str):
"""Tests the check_env function works as expected."""
with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"):
check_env(env)