"""Tests that the `env_checker` runs as expects and all errors are possible.""" import re import warnings from typing import Callable, Tuple, Union import numpy as np import pytest import gymnasium as gym from gymnasium import spaces from gymnasium.core import ObsType from gymnasium.utils.env_checker import ( check_env, check_reset_options, check_reset_return_info_deprecation, check_reset_return_type, check_reset_seed, check_seed_deprecation, ) from tests.testing_env import GenericTestEnv @pytest.mark.parametrize( "env", [ gym.make("CartPole-v1", disable_env_checker=True).unwrapped, gym.make("MountainCar-v0", disable_env_checker=True).unwrapped, GenericTestEnv( observation_space=spaces.Dict( a=spaces.Discrete(10), b=spaces.Box(np.zeros(2), np.ones(2)) ) ), GenericTestEnv( observation_space=spaces.Tuple( [spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))] ) ), GenericTestEnv( observation_space=spaces.Dict( a=spaces.Tuple( [spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))] ), b=spaces.Box(np.zeros(2), np.ones(2)), ) ), ], ) def test_no_error_warnings(env): """A full version of this test with all gymnasium envs is run in tests/envs/test_envs.py.""" with warnings.catch_warnings(record=True) as caught_warnings: check_env(env) assert len(caught_warnings) == 0, [warning.message for warning in caught_warnings] def _no_super_reset(self, seed=None, options=None): self.np_random.random() # generates a new prng # generate seed deterministic result self.observation_space.seed(0) return self.observation_space.sample(), {} def _super_reset_fixed(self, seed=None, options=None): # Call super that ignores the seed passed, use fixed seed super(GenericTestEnv, self).reset(seed=1) # deterministic output self.observation_space._np_random = self.np_random return self.observation_space.sample(), {} def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None): super(GenericTestEnv, self).reset(seed=seed) self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage] self.np_random ) return self.observation_space.sample(), {} @pytest.mark.parametrize( "test,func,message", [ [ gym.error.Error, lambda self: (self.observation_space.sample(), {}), "The `reset` method does not provide a `seed` or `**kwargs` keyword argument.", ], [ AssertionError, lambda self, seed, *_: (self.observation_space.sample(), {}), "Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.", ], [ AssertionError, _no_super_reset, "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`.", ], [ AssertionError, _super_reset_fixed, "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`.", ], [ UserWarning, _reset_default_seed, "The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. Actual default: Error", ], ], ) def test_check_reset_seed(test, func: Callable, message: str): """Tests the check reset seed function works as expected.""" if test is UserWarning: with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): check_reset_seed(GenericTestEnv(reset_func=func)) else: with pytest.raises(test, match=f"^{re.escape(message)}$"): check_reset_seed(GenericTestEnv(reset_func=func)) def _deprecated_return_info( self, return_info: bool = False ) -> Union[Tuple[ObsType, dict], ObsType]: """function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument""" if return_info: return self.observation_space.sample(), {} else: return self.observation_space.sample() def _reset_var_keyword_kwargs(self, kwargs): return self.observation_space.sample(), {} def _reset_return_info_type(self, seed=None, options=None): """Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly checks that the return type of `env.reset()` is a `tuple`""" return [self.observation_space.sample(), {}] def _reset_return_info_length(self, seed=None, options=None): return 1, 2, 3 def _return_info_obs_outside(self, seed=None, options=None): return self.observation_space.sample() + self.observation_space.high, {} def _return_info_not_dict(self, seed=None, options=None): return self.observation_space.sample(), ["key", "value"] @pytest.mark.parametrize( "test,func,message", [ [ AssertionError, _reset_return_info_type, "The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: ``", ], [ AssertionError, _reset_return_info_length, "Calling the reset method did not return a 2-tuple, actual length: 3", ], [ AssertionError, _return_info_obs_outside, "The first element returned by `env.reset()` is not within the observation space.", ], [ AssertionError, _return_info_not_dict, "The second element returned by `env.reset()` was not a dictionary, actual type: ", ], ], ) def test_check_reset_return_type(test, func: Callable, message: str): """Tests the check `env.reset()` function has a correct return type.""" with pytest.raises(test, match=f"^{re.escape(message)}$"): check_reset_return_type(GenericTestEnv(reset_func=func)) @pytest.mark.parametrize( "test,func,message", [ [ UserWarning, _deprecated_return_info, "`return_info` is deprecated as an optional argument to `reset`. `reset`" "should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary" "containing additional information.", ], ], ) def test_check_reset_return_info_deprecation(test, func: Callable, message: str): """Tests that return_info has been correct deprecated as an argument to `env.reset()`.""" with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"): check_reset_return_info_deprecation(GenericTestEnv(reset_func=func)) def test_check_seed_deprecation(): """Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed.""" message = """Official support for the `seed` function is dropped. Standard practice is to reset gymnasium environments using `env.reset(seed=)`""" env = GenericTestEnv() def seed(seed): return with pytest.warns( UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" ): env.seed = seed assert callable(env.seed) check_seed_deprecation(env) with warnings.catch_warnings(record=True) as caught_warnings: env.seed = [] check_seed_deprecation(env) env.seed = 123 check_seed_deprecation(env) del env.seed check_seed_deprecation(env) assert len(caught_warnings) == 0 def test_check_reset_options(): """Tests the check_reset_options function.""" with pytest.raises( gym.error.Error, match=re.escape( "The `reset` method does not provide an `options` or `**kwargs` keyword argument" ), ): check_reset_options(GenericTestEnv(reset_func=lambda self: (0, {}))) @pytest.mark.parametrize( "env,message", [ [ "Error", "The environment must inherit from the gymnasium.Env class. See https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ for more info.", ], [ GenericTestEnv(action_space=None), "The environment must specify an action space. See https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ for more info.", ], [ GenericTestEnv(observation_space=None), "The environment must specify an observation space. See https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ for more info.", ], ], ) def test_check_env(env: gym.Env, message: str): """Tests the check_env function works as expected.""" with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"): check_env(env)