2022-07-11 02:45:24 +01:00
""" Tests that the `env_checker` runs as expects and all errors are possible. """
import re
2022-08-23 11:09:54 -04:00
import warnings
from typing import Tuple , Union
2021-12-08 22:14:15 +01:00
2021-08-12 12:35:09 -05:00
import numpy as np
import pytest
2022-09-16 23:41:27 +01:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium import spaces
from gymnasium . core import ObsType
from gymnasium . utils . env_checker import (
2022-07-11 02:45:24 +01:00
check_env ,
check_reset_options ,
2022-08-23 11:09:54 -04:00
check_reset_return_info_deprecation ,
check_reset_return_type ,
2022-07-11 02:45:24 +01:00
check_reset_seed ,
2022-08-23 11:09:54 -04:00
check_seed_deprecation ,
2022-07-11 02:45:24 +01:00
)
from tests . testing_env import GenericTestEnv
2021-08-12 12:35:09 -05:00
2022-07-11 02:45:24 +01:00
@pytest.mark.parametrize (
" env " ,
[
2022-09-16 23:41:27 +01:00
gym . make ( " CartPole-v1 " , disable_env_checker = True ) . unwrapped ,
gym . make ( " MountainCar-v0 " , disable_env_checker = True ) . unwrapped ,
2022-07-11 02:45:24 +01:00
GenericTestEnv (
observation_space = spaces . Dict (
a = spaces . Discrete ( 10 ) , b = spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) )
)
) ,
GenericTestEnv (
observation_space = spaces . Tuple (
[ spaces . Discrete ( 10 ) , spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) ) ]
)
) ,
GenericTestEnv (
observation_space = spaces . Dict (
a = spaces . Tuple (
[ spaces . Discrete ( 10 ) , spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) ) ]
) ,
b = spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) ) ,
)
) ,
] ,
)
def test_no_error_warnings ( env ) :
2022-09-08 10:10:07 +01:00
""" A full version of this test with all gymnasium envs is run in tests/envs/test_envs.py. """
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-07-11 02:45:24 +01:00
check_env ( env )
2021-08-12 12:35:09 -05:00
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0 , [ warning . message for warning in caught_warnings ]
2022-06-08 00:20:56 +02:00
2021-08-12 12:35:09 -05:00
2022-08-23 11:09:54 -04:00
def _no_super_reset ( self , seed = None , options = None ) :
2022-07-11 02:45:24 +01:00
self . np_random . random ( ) # generates a new prng
# generate seed deterministic result
self . observation_space . seed ( 0 )
2022-08-23 11:09:54 -04:00
return self . observation_space . sample ( ) , { }
2021-08-12 12:35:09 -05:00
2022-08-23 11:09:54 -04:00
def _super_reset_fixed ( self , seed = None , options = None ) :
2022-07-11 02:45:24 +01:00
# Call super that ignores the seed passed, use fixed seed
super ( GenericTestEnv , self ) . reset ( seed = 1 )
# deterministic output
self . observation_space . _np_random = self . np_random
2022-08-23 11:09:54 -04:00
return self . observation_space . sample ( ) , { }
2021-08-12 12:35:09 -05:00
2022-08-23 11:09:54 -04:00
def _reset_default_seed ( self : GenericTestEnv , seed = " Error " , options = None ) :
2022-07-11 02:45:24 +01:00
super ( GenericTestEnv , self ) . reset ( seed = seed )
self . observation_space . _np_random = ( # pyright: ignore [reportPrivateUsage]
self . np_random
)
2022-08-23 11:09:54 -04:00
return self . observation_space . sample ( ) , { }
2022-07-11 02:45:24 +01:00
@pytest.mark.parametrize (
" test,func,message " ,
[
[
2022-09-16 23:41:27 +01:00
gym . error . Error ,
2022-08-23 11:09:54 -04:00
lambda self : ( self . observation_space . sample ( ) , { } ) ,
2022-07-11 02:45:24 +01:00
" The `reset` method does not provide a `seed` or `**kwargs` keyword argument. " ,
] ,
[
AssertionError ,
2022-08-23 11:09:54 -04:00
lambda self , seed , * _ : ( self . observation_space . sample ( ) , { } ) ,
2022-07-11 02:45:24 +01:00
" Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`. " ,
] ,
[
AssertionError ,
_no_super_reset ,
" Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`. " ,
] ,
[
AssertionError ,
_super_reset_fixed ,
" Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`. " ,
] ,
[
UserWarning ,
_reset_default_seed ,
" The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. Actual default: Error " ,
] ,
] ,
)
def test_check_reset_seed ( test , func : callable , message : str ) :
""" Tests the check reset seed function works as expected. """
if test is UserWarning :
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
2022-12-05 19:14:56 +00:00
check_reset_seed ( GenericTestEnv ( reset_func = func ) )
2022-07-11 02:45:24 +01:00
else :
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
2022-12-05 19:14:56 +00:00
check_reset_seed ( GenericTestEnv ( reset_func = func ) )
2022-07-11 02:45:24 +01:00
2022-08-23 11:09:54 -04:00
def _deprecated_return_info (
self , return_info : bool = False
) - > Union [ Tuple [ ObsType , dict ] , ObsType ] :
""" function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument """
2022-07-11 02:45:24 +01:00
if return_info :
2022-08-23 11:09:54 -04:00
return self . observation_space . sample ( ) , { }
2022-07-11 02:45:24 +01:00
else :
return self . observation_space . sample ( )
2022-08-23 11:09:54 -04:00
def _reset_var_keyword_kwargs ( self , kwargs ) :
return self . observation_space . sample ( ) , { }
2022-07-11 02:45:24 +01:00
2022-08-23 11:09:54 -04:00
def _reset_return_info_type ( self , seed = None , options = None ) :
""" Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly
checks that the return type of ` env . reset ( ) ` is a ` tuple ` """
return [ self . observation_space . sample ( ) , { } ]
2022-07-11 02:45:24 +01:00
2022-08-23 11:09:54 -04:00
def _reset_return_info_length ( self , seed = None , options = None ) :
return 1 , 2 , 3
def _return_info_obs_outside ( self , seed = None , options = None ) :
return self . observation_space . sample ( ) + self . observation_space . high , { }
def _return_info_not_dict ( self , seed = None , options = None ) :
return self . observation_space . sample ( ) , [ " key " , " value " ]
2022-07-11 02:45:24 +01:00
@pytest.mark.parametrize (
" test,func,message " ,
[
[
AssertionError ,
_reset_return_info_type ,
2022-08-23 11:09:54 -04:00
" The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class ' list ' >` " ,
2022-07-11 02:45:24 +01:00
] ,
[
AssertionError ,
_reset_return_info_length ,
2022-08-23 11:09:54 -04:00
" Calling the reset method did not return a 2-tuple, actual length: 3 " ,
2022-07-11 02:45:24 +01:00
] ,
[
AssertionError ,
_return_info_obs_outside ,
2022-08-23 11:09:54 -04:00
" The first element returned by `env.reset()` is not within the observation space. " ,
2022-07-11 02:45:24 +01:00
] ,
[
AssertionError ,
_return_info_not_dict ,
2022-08-23 11:09:54 -04:00
" The second element returned by `env.reset()` was not a dictionary, actual type: <class ' list ' > " ,
2022-07-11 02:45:24 +01:00
] ,
] ,
)
2022-08-23 11:09:54 -04:00
def test_check_reset_return_type ( test , func : callable , message : str ) :
""" Tests the check `env.reset()` function has a correct return type. """
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
2022-12-05 19:14:56 +00:00
check_reset_return_type ( GenericTestEnv ( reset_func = func ) )
2022-08-23 11:09:54 -04:00
@pytest.mark.parametrize (
" test,func,message " ,
[
[
UserWarning ,
_deprecated_return_info ,
" `return_info` is deprecated as an optional argument to `reset`. `reset` "
" should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary "
" containing additional information. " ,
] ,
] ,
)
def test_check_reset_return_info_deprecation ( test , func : callable , message : str ) :
""" Tests that return_info has been correct deprecated as an argument to `env.reset()`. """
with pytest . warns ( test , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ " ) :
2022-12-05 19:14:56 +00:00
check_reset_return_info_deprecation ( GenericTestEnv ( reset_func = func ) )
2022-08-23 11:09:54 -04:00
def test_check_seed_deprecation ( ) :
""" Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed. """
2022-09-08 10:10:07 +01:00
message = """ Official support for the `seed` function is dropped. Standard practice is to reset gymnasium environments using `env.reset(seed=<desired seed>)` """
2022-08-23 11:09:54 -04:00
env = GenericTestEnv ( )
def seed ( seed ) :
return
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
env . seed = seed
assert callable ( env . seed )
check_seed_deprecation ( env )
with warnings . catch_warnings ( record = True ) as caught_warnings :
env . seed = [ ]
check_seed_deprecation ( env )
env . seed = 123
check_seed_deprecation ( env )
del env . seed
check_seed_deprecation ( env )
assert len ( caught_warnings ) == 0
2022-07-11 02:45:24 +01:00
def test_check_reset_options ( ) :
""" Tests the check_reset_options function. """
with pytest . raises (
2022-09-16 23:41:27 +01:00
gym . error . Error ,
2022-07-11 02:45:24 +01:00
match = re . escape (
" The `reset` method does not provide an `options` or `**kwargs` keyword argument "
) ,
) :
2022-12-05 19:14:56 +00:00
check_reset_options ( GenericTestEnv ( reset_func = lambda self : ( 0 , { } ) ) )
2022-07-11 02:45:24 +01:00
@pytest.mark.parametrize (
" env,message " ,
[
[
" Error " ,
2022-09-16 14:00:12 +01:00
" The environment must inherit from the gymnasium.Env class. See https://gymnasium.farama.org/content/environment_creation/ for more info. " ,
2022-07-11 02:45:24 +01:00
] ,
[
GenericTestEnv ( action_space = None ) ,
2022-09-16 14:00:12 +01:00
" The environment must specify an action space. See https://gymnasium.farama.org/content/environment_creation/ for more info. " ,
2022-07-11 02:45:24 +01:00
] ,
[
GenericTestEnv ( observation_space = None ) ,
2022-09-16 14:00:12 +01:00
" The environment must specify an observation space. See https://gymnasium.farama.org/content/environment_creation/ for more info. " ,
2022-07-11 02:45:24 +01:00
] ,
] ,
)
2022-09-16 23:41:27 +01:00
def test_check_env ( env : gym . Env , message : str ) :
2022-07-11 02:45:24 +01:00
""" Tests the check_env function works as expected. """
with pytest . raises ( AssertionError , match = f " ^ { re . escape ( message ) } $ " ) :
check_env ( env )