Files
Gymnasium/tests/utils/test_passive_env_checker.py
Mark Towers 1bf58d8eb4 Remove warnings (#435)
Co-authored-by: Aaron Walsman <aaronwalsman@gmail.com>
2023-04-06 14:28:07 +01:00

445 lines
16 KiB
Python

import re
import warnings
from typing import Callable, Dict, Union
import numpy as np
import pytest
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils.passive_env_checker import (
check_action_space,
check_obs,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
from tests.testing_env import GenericTestEnv
def _modify_space(space: spaces.Space, attribute: str, value):
setattr(space, attribute, value)
return space
@pytest.mark.parametrize(
"test,space,message",
[
[
AssertionError,
"error",
"observation space does not inherit from `gymnasium.spaces.Space`, actual type: <class 'str'>",
],
# ===== Check box observation space ====
[
UserWarning,
spaces.Box(np.zeros(5), np.zeros(5)),
"A Box observation space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)]",
],
[
UserWarning,
spaces.Box(np.ones(5), np.zeros(5)),
"A Box observation space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)]",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
"The Box observation space shape and low shape have different shapes, low shape: (3,), box shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
"The Box observation space shape and high shape have have different shapes, high shape: (3,), box shape: (2,)",
],
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
[
AssertionError,
_modify_space(spaces.Discrete(5), "n", -1),
"Discrete observation space's number of elements must be positive, actual number of elements: -1",
],
[
AssertionError,
_modify_space(spaces.MultiDiscrete([2, 2]), "nvec", np.array([2, -1])),
"Multi-discrete observation space's all nvec elements must be greater than 0, actual nvec: [ 2 -1]",
],
[
AssertionError,
_modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, 1, 2)),
"Multi-discrete observation space's shape must be equal to the nvec shape, space shape: (2, 1, 2), nvec shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)),
"Multi-binary observation space's all shape elements must be greater than 0, actual shape: (2, -1)",
],
[
AssertionError,
spaces.Tuple([]),
"An empty Tuple observation space is not allowed.",
],
[
AssertionError,
spaces.Dict(),
"An empty Dict observation space is not allowed.",
],
],
)
def test_check_observation_space(test, space, message: str):
"""Tests the check observation space."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_observation_space(space)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_observation_space(space)
assert len(caught_warnings) == 0
@pytest.mark.parametrize(
"test,space,message",
[
[
AssertionError,
"error",
"action space does not inherit from `gymnasium.spaces.Space`, actual type: <class 'str'>",
],
# ===== Check box observation space ====
[
UserWarning,
spaces.Box(np.zeros(5), np.zeros(5)),
"A Box action space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)]",
],
[
UserWarning,
spaces.Box(np.ones(5), np.zeros(5)),
"A Box action space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)]",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
"The Box action space shape and low shape have have different shapes, low shape: (3,), box shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
"The Box action space shape and high shape have different shapes, high shape: (3,), box shape: (2,)",
],
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
[
AssertionError,
_modify_space(spaces.Discrete(5), "n", -1),
"Discrete action space's number of elements must be positive, actual number of elements: -1",
],
[
AssertionError,
_modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, -1)),
"Multi-discrete action space's shape must be equal to the nvec shape, space shape: (2, -1), nvec shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)),
"Multi-binary action space's all shape elements must be greater than 0, actual shape: (2, -1)",
],
[
AssertionError,
spaces.Tuple([]),
"An empty Tuple action space is not allowed.",
],
[AssertionError, spaces.Dict(), "An empty Dict action space is not allowed."],
],
)
def test_check_action_space(
test: Union[UserWarning, type], space: spaces.Space, message: str
):
"""Tests the check action space function."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_action_space(space)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_action_space(space)
assert len(caught_warnings) == 0
@pytest.mark.parametrize(
"test,obs,obs_space,message",
[
[
UserWarning,
3,
spaces.Discrete(2),
"The obs returned by the `testing()` method is not within the observation space.",
],
[
UserWarning,
np.uint8(0),
spaces.Discrete(1),
"The obs returned by the `testing()` method should be an int or np.int64, actual type: <class 'numpy.uint8'>",
],
[
UserWarning,
[0, 1],
spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]),
"The obs returned by the `testing()` method was expecting a tuple, actual type: <class 'list'>",
],
[
AssertionError,
(1, 2, 3),
spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]),
"The obs returned by the `testing()` method length is not same as the observation space length, obs length: 3, space length: 2",
],
[
AssertionError,
{1, 2, 3},
spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)),
"The obs returned by the `testing()` method must be a dict, actual type: <class 'set'>",
],
[
AssertionError,
{"a": 1, "c": 2},
spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)),
"The obs returned by the `testing()` method observation keys is not same as the observation space keys, obs keys: ['a', 'c'], space keys: ['a', 'b']",
],
],
)
def test_check_obs(test, obs, obs_space: spaces.Space, message: str):
"""Tests the check observations function."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_obs(obs, obs_space, "testing")
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_obs(obs, obs_space, "testing")
assert len(caught_warnings) == 0
def _reset_no_seed(self, options=None):
return self.observation_space.sample(), {}
def _reset_seed_default(self, seed="error", options=None):
return self.observation_space.sample(), {}
def _reset_no_option(self, seed=None):
return self.observation_space.sample(), {}
def _make_reset_results(results):
def _reset_result(self, seed=None, options=None):
return results
return _reset_result
@pytest.mark.parametrize(
"test,func,message,kwargs",
[
[
DeprecationWarning,
_reset_no_seed,
"Current gymnasium version requires that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.",
{},
],
[
UserWarning,
_reset_seed_default,
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
{},
],
[
DeprecationWarning,
_reset_no_option,
"Current gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.",
{},
],
[
UserWarning,
_make_reset_results([0, {}]),
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
{},
],
[
AssertionError,
_make_reset_results((np.array([0], dtype=np.float32), {1, 2})),
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'set'>",
{},
],
],
)
def test_passive_env_reset_checker(test, func: Callable, message: str, kwargs: Dict):
"""Tests the passive env reset check"""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
elif test is DeprecationWarning:
with pytest.warns(
DeprecationWarning,
match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$",
):
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs)
assert len(caught_warnings) == 0
def _modified_step(
self, obs=None, reward=0, terminated=False, truncated=False, info=None
):
if obs is None:
obs = self.observation_space.sample()
if info is None:
info = {}
if truncated is None:
return obs, reward, terminated, info
else:
return obs, reward, terminated, truncated, info
@pytest.mark.parametrize(
"test,func,message",
[
[
AssertionError,
lambda self, _: "error",
"Expects step result to be a tuple, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, terminated="error", truncated=None),
"Expects `done` signal to be a boolean, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, terminated="error", truncated=False),
"Expects `terminated` signal to be a boolean, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, truncated="error"),
"Expects `truncated` signal to be a boolean, actual type: <class 'str'>",
],
[
gym.error.Error,
lambda self, _: (1, 2, 3),
"Expected `Env.step` to return a four or five element tuple, actual number of elements returned: 3.",
],
[
UserWarning,
lambda self, _: _modified_step(self, reward="error"),
"The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, reward=np.nan),
"The reward is a NaN value.",
],
[
UserWarning,
lambda self, _: _modified_step(self, reward=np.inf),
"The reward is an inf value.",
],
[
AssertionError,
lambda self, _: _modified_step(self, info="error"),
"The `info` returned by `step()` must be a python dictionary, actual type: <class 'str'>",
],
],
)
def test_passive_env_step_checker(
test: Union[UserWarning, type], func: Callable, message: str
):
"""Tests the passive env step checker."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
assert len(caught_warnings) == 0, caught_warnings
@pytest.mark.parametrize(
"test,env,message",
[
[
UserWarning,
GenericTestEnv(metadata={"render_modes": None}),
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`.",
],
[
UserWarning,
GenericTestEnv(metadata={"render_modes": "Testing mode"}),
"Expects the render_modes to be a sequence (i.e. list, tuple), actual type: <class 'str'>",
],
[
UserWarning,
GenericTestEnv(
metadata={"render_modes": ["Testing mode", 1], "render_fps": 1},
),
"Expects all render modes to be strings, actual types: [<class 'str'>, <class 'int'>]",
],
[
UserWarning,
GenericTestEnv(
metadata={"render_modes": ["Testing mode"], "render_fps": None},
render_mode="Testing mode",
render_func=lambda self: 0,
),
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
],
[
UserWarning,
GenericTestEnv(
metadata={"render_modes": ["Testing mode"], "render_fps": "fps"}
),
"Expects the `env.metadata['render_fps']` to be an integer or a float, actual type: <class 'str'>",
],
[
AssertionError,
GenericTestEnv(
metadata={"render_modes": [], "render_fps": 30}, render_mode="Test"
),
"With no render_modes, expects the Env.render_mode to be None, actual value: Test",
],
[
AssertionError,
GenericTestEnv(
metadata={"render_modes": ["Testing mode"], "render_fps": 30},
render_mode="Non mode",
),
"The environment was initialized successfully however with an unsupported render mode. Render mode: Non mode, modes: ['Testing mode']",
],
],
)
def test_passive_render_checker(test, env: GenericTestEnv, message: str):
"""Tests the passive render checker."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_render_passive_checker(env)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_render_passive_checker(env)
assert len(caught_warnings) == 0