import re import warnings from typing import Dict, Union import numpy as np import pytest import gymnasium as gym from gymnasium import spaces from gymnasium.utils.passive_env_checker import ( check_action_space, check_obs, check_observation_space, env_render_passive_checker, env_reset_passive_checker, env_step_passive_checker, ) from tests.testing_env import GenericTestEnv def _modify_space(space: spaces.Space, attribute: str, value): setattr(space, attribute, value) return space @pytest.mark.parametrize( "test,space,message", [ [ AssertionError, "error", "observation space does not inherit from `gymnasium.spaces.Space`, actual type: ", ], # ===== Check box observation space ==== [ UserWarning, spaces.Box(np.zeros((5, 5, 1)), 255 * np.ones((5, 5, 1)), dtype=np.int32), "It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: int32. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector.", ], [ UserWarning, spaces.Box(np.ones((2, 2, 1)), 255 * np.ones((2, 2, 1)), dtype=np.uint8), "It seems a Box observation space is an image but the lower and upper bounds are not [0, 255]. Actual lower bound: 1, upper bound: 255. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.", ], [ UserWarning, spaces.Box(np.zeros((5, 5, 1)), np.ones((5, 5, 1)), dtype=np.uint8), "It seems a Box observation space is an image but the lower and upper bounds are not [0, 255]. Actual lower bound: 0, upper bound: 1. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.", ], [ UserWarning, spaces.Box(np.zeros((5, 5)), np.ones((5, 5))), "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (5, 5)", ], [ UserWarning, spaces.Box(np.zeros(5), np.zeros(5)), "A Box observation space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)]", ], [ UserWarning, spaces.Box(np.ones(5), np.zeros(5)), "A Box observation space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)]", ], [ AssertionError, _modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)), "The Box observation space shape and low shape have different shapes, low shape: (3,), box shape: (2,)", ], [ AssertionError, _modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)), "The Box observation space shape and high shape have have different shapes, high shape: (3,), box shape: (2,)", ], # ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict) [ AssertionError, _modify_space(spaces.Discrete(5), "n", -1), "Discrete observation space's number of elements must be positive, actual number of elements: -1", ], [ AssertionError, _modify_space(spaces.MultiDiscrete([2, 2]), "nvec", np.array([2, -1])), "Multi-discrete observation space's all nvec elements must be greater than 0, actual nvec: [ 2 -1]", ], [ AssertionError, _modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, 1, 2)), "Multi-discrete observation space's shape must be equal to the nvec shape, space shape: (2, 1, 2), nvec shape: (2,)", ], [ AssertionError, _modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)), "Multi-binary observation space's all shape elements must be greater than 0, actual shape: (2, -1)", ], [ AssertionError, spaces.Tuple([]), "An empty Tuple observation space is not allowed.", ], [ AssertionError, spaces.Dict(), "An empty Dict observation space is not allowed.", ], ], ) def test_check_observation_space(test, space, message: str): """Tests the check observation space.""" if test is UserWarning: with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): check_observation_space(space) else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): check_observation_space(space) assert len(caught_warnings) == 0 @pytest.mark.parametrize( "test,space,message", [ [ AssertionError, "error", "action space does not inherit from `gymnasium.spaces.Space`, actual type: ", ], # ===== Check box observation space ==== [ UserWarning, spaces.Box(np.zeros(5), np.zeros(5)), "A Box action space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)]", ], [ UserWarning, spaces.Box(np.ones(5), np.zeros(5)), "A Box action space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)]", ], [ AssertionError, _modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)), "The Box action space shape and low shape have have different shapes, low shape: (3,), box shape: (2,)", ], [ AssertionError, _modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)), "The Box action space shape and high shape have different shapes, high shape: (3,), box shape: (2,)", ], # ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict) [ AssertionError, _modify_space(spaces.Discrete(5), "n", -1), "Discrete action space's number of elements must be positive, actual number of elements: -1", ], [ AssertionError, _modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, -1)), "Multi-discrete action space's shape must be equal to the nvec shape, space shape: (2, -1), nvec shape: (2,)", ], [ AssertionError, _modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)), "Multi-binary action space's all shape elements must be greater than 0, actual shape: (2, -1)", ], [ AssertionError, spaces.Tuple([]), "An empty Tuple action space is not allowed.", ], [AssertionError, spaces.Dict(), "An empty Dict action space is not allowed."], ], ) def test_check_action_space( test: Union[UserWarning, type], space: spaces.Space, message: str ): """Tests the check action space function.""" if test is UserWarning: with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): check_action_space(space) else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): check_action_space(space) assert len(caught_warnings) == 0 @pytest.mark.parametrize( "test,obs,obs_space,message", [ [ UserWarning, 3, spaces.Discrete(2), "The obs returned by the `testing()` method is not within the observation space.", ], [ UserWarning, np.uint8(0), spaces.Discrete(1), "The obs returned by the `testing()` method should be an int or np.int64, actual type: ", ], [ UserWarning, [0, 1], spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]), "The obs returned by the `testing()` method was expecting a tuple, actual type: ", ], [ AssertionError, (1, 2, 3), spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]), "The obs returned by the `testing()` method length is not same as the observation space length, obs length: 3, space length: 2", ], [ AssertionError, {1, 2, 3}, spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)), "The obs returned by the `testing()` method must be a dict, actual type: ", ], [ AssertionError, {"a": 1, "c": 2}, spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)), "The obs returned by the `testing()` method observation keys is not same as the observation space keys, obs keys: ['a', 'c'], space keys: ['a', 'b']", ], ], ) def test_check_obs(test, obs, obs_space: spaces.Space, message: str): """Tests the check observations function.""" if test is UserWarning: with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): check_obs(obs, obs_space, "testing") else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): check_obs(obs, obs_space, "testing") assert len(caught_warnings) == 0 def _reset_no_seed(self, options=None): return self.observation_space.sample(), {} def _reset_seed_default(self, seed="error", options=None): return self.observation_space.sample(), {} def _reset_no_option(self, seed=None): return self.observation_space.sample(), {} def _make_reset_results(results): def _reset_result(self, seed=None, options=None): return results return _reset_result @pytest.mark.parametrize( "test,func,message,kwargs", [ [ UserWarning, _reset_no_seed, "Future gymnasium versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.", {}, ], [ UserWarning, _reset_seed_default, "The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'", {}, ], [ UserWarning, _reset_no_option, "Future gymnasium versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.", {}, ], [ UserWarning, _make_reset_results([0, {}]), "The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: ``", {}, ], [ AssertionError, _make_reset_results((np.array([0], dtype=np.float32), {1, 2})), "The second element returned by `env.reset()` was not a dictionary, actual type: ", {}, ], ], ) def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: Dict): """Tests the passive env reset check""" if test is UserWarning: with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs) else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): env_reset_passive_checker(GenericTestEnv(reset_func=func), **kwargs) assert len(caught_warnings) == 0 def _modified_step( self, obs=None, reward=0, terminated=False, truncated=False, info=None ): if obs is None: obs = self.observation_space.sample() if info is None: info = {} if truncated is None: return obs, reward, terminated, info else: return obs, reward, terminated, truncated, info @pytest.mark.parametrize( "test,func,message", [ [ AssertionError, lambda self, _: "error", "Expects step result to be a tuple, actual type: ", ], [ UserWarning, lambda self, _: _modified_step(self, terminated="error", truncated=None), "Expects `done` signal to be a boolean, actual type: ", ], [ UserWarning, lambda self, _: _modified_step(self, terminated="error", truncated=False), "Expects `terminated` signal to be a boolean, actual type: ", ], [ UserWarning, lambda self, _: _modified_step(self, truncated="error"), "Expects `truncated` signal to be a boolean, actual type: ", ], [ gym.error.Error, lambda self, _: (1, 2, 3), "Expected `Env.step` to return a four or five element tuple, actual number of elements returned: 3.", ], [ UserWarning, lambda self, _: _modified_step(self, reward="error"), "The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: ", ], [ UserWarning, lambda self, _: _modified_step(self, reward=np.nan), "The reward is a NaN value.", ], [ UserWarning, lambda self, _: _modified_step(self, reward=np.inf), "The reward is an inf value.", ], [ AssertionError, lambda self, _: _modified_step(self, info="error"), "The `info` returned by `step()` must be a python dictionary, actual type: ", ], ], ) def test_passive_env_step_checker( test: Union[UserWarning, type], func: callable, message: str ): """Tests the passive env step checker.""" if test is UserWarning: with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): env_step_passive_checker(GenericTestEnv(step_func=func), 0) else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): env_step_passive_checker(GenericTestEnv(step_func=func), 0) assert len(caught_warnings) == 0, caught_warnings @pytest.mark.parametrize( "test,env,message", [ [ UserWarning, GenericTestEnv(metadata={"render_modes": None}), "No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`.", ], [ UserWarning, GenericTestEnv(metadata={"render_modes": "Testing mode"}), "Expects the render_modes to be a sequence (i.e. list, tuple), actual type: ", ], [ UserWarning, GenericTestEnv( metadata={"render_modes": ["Testing mode", 1], "render_fps": 1}, ), "Expects all render modes to be strings, actual types: [, ]", ], [ UserWarning, GenericTestEnv( metadata={"render_modes": ["Testing mode"], "render_fps": None}, render_mode="Testing mode", render_func=lambda self: 0, ), "No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.", ], [ UserWarning, GenericTestEnv( metadata={"render_modes": ["Testing mode"], "render_fps": "fps"} ), "Expects the `env.metadata['render_fps']` to be an integer or a float, actual type: ", ], [ AssertionError, GenericTestEnv( metadata={"render_modes": [], "render_fps": 30}, render_mode="Test" ), "With no render_modes, expects the Env.render_mode to be None, actual value: Test", ], [ AssertionError, GenericTestEnv( metadata={"render_modes": ["Testing mode"], "render_fps": 30}, render_mode="Non mode", ), "The environment was initialized successfully however with an unsupported render mode. Render mode: Non mode, modes: ['Testing mode']", ], ], ) def test_passive_render_checker(test, env: GenericTestEnv, message: str): """Tests the passive render checker.""" if test is UserWarning: with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): env_render_passive_checker(env) else: with warnings.catch_warnings(record=True) as caught_warnings: with pytest.raises(test, match=f"^{re.escape(message)}$"): env_render_passive_checker(env) assert len(caught_warnings) == 0