2022-07-11 02:45:24 +01:00
import re
2022-08-30 19:47:26 +01:00
import warnings
2022-07-11 02:45:24 +01:00
from typing import Dict , Union
import numpy as np
import pytest
2022-09-16 23:41:27 +01:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium import spaces
from gymnasium . utils . passive_env_checker import (
2022-07-11 02:45:24 +01:00
check_action_space ,
check_obs ,
check_observation_space ,
env_render_passive_checker ,
env_reset_passive_checker ,
env_step_passive_checker ,
)
from tests . testing_env import GenericTestEnv
def _modify_space ( space : spaces . Space , attribute : str , value ) :
setattr ( space , attribute , value )
return space
@pytest.mark.parametrize (
" test,space,message " ,
[
[
AssertionError ,
" error " ,
2022-09-08 10:10:07 +01:00
" observation space does not inherit from `gymnasium.spaces.Space`, actual type: <class ' str ' > " ,
2022-07-11 02:45:24 +01:00
] ,
# ===== Check box observation space ====
[
UserWarning ,
spaces . Box ( np . zeros ( ( 5 , 5 , 1 ) ) , 255 * np . ones ( ( 5 , 5 , 1 ) ) , dtype = np . int32 ) ,
" It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: int32. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector. " ,
] ,
[
UserWarning ,
spaces . Box ( np . ones ( ( 2 , 2 , 1 ) ) , 255 * np . ones ( ( 2 , 2 , 1 ) ) , dtype = np . uint8 ) ,
2022-11-10 12:52:30 +00:00
" It seems a Box observation space is an image but the lower and upper bounds are not [0, 255]. Actual lower bound: 1, upper bound: 255. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not. " ,
2022-07-11 02:45:24 +01:00
] ,
[
UserWarning ,
spaces . Box ( np . zeros ( ( 5 , 5 , 1 ) ) , np . ones ( ( 5 , 5 , 1 ) ) , dtype = np . uint8 ) ,
2022-11-10 12:52:30 +00:00
" It seems a Box observation space is an image but the lower and upper bounds are not [0, 255]. Actual lower bound: 0, upper bound: 1. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not. " ,
2022-07-11 02:45:24 +01:00
] ,
[
UserWarning ,
spaces . Box ( np . zeros ( ( 5 , 5 ) ) , np . ones ( ( 5 , 5 ) ) ) ,
" A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (5, 5) " ,
] ,
[
UserWarning ,
spaces . Box ( np . zeros ( 5 ) , np . zeros ( 5 ) ) ,
2022-11-10 12:52:30 +00:00
" A Box observation space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)] " ,
2022-07-11 02:45:24 +01:00
] ,
[
UserWarning ,
spaces . Box ( np . ones ( 5 ) , np . zeros ( 5 ) ) ,
2022-11-10 12:52:30 +00:00
" A Box observation space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)] " ,
2022-07-11 02:45:24 +01:00
] ,
[
AssertionError ,
_modify_space ( spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) ) , " low " , np . zeros ( 3 ) ) ,
" The Box observation space shape and low shape have different shapes, low shape: (3,), box shape: (2,) " ,
] ,
[
AssertionError ,
_modify_space ( spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) ) , " high " , np . ones ( 3 ) ) ,
" The Box observation space shape and high shape have have different shapes, high shape: (3,), box shape: (2,) " ,
] ,
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
[
AssertionError ,
_modify_space ( spaces . Discrete ( 5 ) , " n " , - 1 ) ,
" Discrete observation space ' s number of elements must be positive, actual number of elements: -1 " ,
] ,
[
AssertionError ,
_modify_space ( spaces . MultiDiscrete ( [ 2 , 2 ] ) , " nvec " , np . array ( [ 2 , - 1 ] ) ) ,
" Multi-discrete observation space ' s all nvec elements must be greater than 0, actual nvec: [ 2 -1] " ,
] ,
[
AssertionError ,
_modify_space ( spaces . MultiDiscrete ( [ 2 , 2 ] ) , " _shape " , ( 2 , 1 , 2 ) ) ,
" Multi-discrete observation space ' s shape must be equal to the nvec shape, space shape: (2, 1, 2), nvec shape: (2,) " ,
] ,
[
AssertionError ,
_modify_space ( spaces . MultiBinary ( ( 2 , 2 ) ) , " _shape " , ( 2 , - 1 ) ) ,
" Multi-binary observation space ' s all shape elements must be greater than 0, actual shape: (2, -1) " ,
] ,
[
AssertionError ,
spaces . Tuple ( [ ] ) ,
" An empty Tuple observation space is not allowed. " ,
] ,
[
AssertionError ,
spaces . Dict ( ) ,
" An empty Dict observation space is not allowed. " ,
] ,
] ,
)
def test_check_observation_space ( test , space , message : str ) :
""" Tests the check observation space. """
if test is UserWarning :
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
check_observation_space ( space )
else :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-07-11 02:45:24 +01:00
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
check_observation_space ( space )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0
2022-07-11 02:45:24 +01:00
@pytest.mark.parametrize (
" test,space,message " ,
[
[
AssertionError ,
" error " ,
2022-09-08 10:10:07 +01:00
" action space does not inherit from `gymnasium.spaces.Space`, actual type: <class ' str ' > " ,
2022-07-11 02:45:24 +01:00
] ,
# ===== Check box observation space ====
[
UserWarning ,
spaces . Box ( np . zeros ( 5 ) , np . zeros ( 5 ) ) ,
2022-11-10 12:52:30 +00:00
" A Box action space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)] " ,
2022-07-11 02:45:24 +01:00
] ,
[
UserWarning ,
spaces . Box ( np . ones ( 5 ) , np . zeros ( 5 ) ) ,
2022-11-10 12:52:30 +00:00
" A Box action space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)] " ,
2022-07-11 02:45:24 +01:00
] ,
[
AssertionError ,
_modify_space ( spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) ) , " low " , np . zeros ( 3 ) ) ,
" The Box action space shape and low shape have have different shapes, low shape: (3,), box shape: (2,) " ,
] ,
[
AssertionError ,
_modify_space ( spaces . Box ( np . zeros ( 2 ) , np . ones ( 2 ) ) , " high " , np . ones ( 3 ) ) ,
" The Box action space shape and high shape have different shapes, high shape: (3,), box shape: (2,) " ,
] ,
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
[
AssertionError ,
_modify_space ( spaces . Discrete ( 5 ) , " n " , - 1 ) ,
" Discrete action space ' s number of elements must be positive, actual number of elements: -1 " ,
] ,
[
AssertionError ,
_modify_space ( spaces . MultiDiscrete ( [ 2 , 2 ] ) , " _shape " , ( 2 , - 1 ) ) ,
" Multi-discrete action space ' s shape must be equal to the nvec shape, space shape: (2, -1), nvec shape: (2,) " ,
] ,
[
AssertionError ,
_modify_space ( spaces . MultiBinary ( ( 2 , 2 ) ) , " _shape " , ( 2 , - 1 ) ) ,
" Multi-binary action space ' s all shape elements must be greater than 0, actual shape: (2, -1) " ,
] ,
[
AssertionError ,
spaces . Tuple ( [ ] ) ,
" An empty Tuple action space is not allowed. " ,
] ,
[ AssertionError , spaces . Dict ( ) , " An empty Dict action space is not allowed. " ] ,
] ,
)
def test_check_action_space (
test : Union [ UserWarning , type ] , space : spaces . Space , message : str
) :
""" Tests the check action space function. """
if test is UserWarning :
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
check_action_space ( space )
else :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-07-11 02:45:24 +01:00
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
check_action_space ( space )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0
2022-07-11 02:45:24 +01:00
@pytest.mark.parametrize (
" test,obs,obs_space,message " ,
[
[
UserWarning ,
3 ,
spaces . Discrete ( 2 ) ,
" The obs returned by the `testing()` method is not within the observation space. " ,
] ,
[
UserWarning ,
np . uint8 ( 0 ) ,
spaces . Discrete ( 1 ) ,
" The obs returned by the `testing()` method should be an int or np.int64, actual type: <class ' numpy.uint8 ' > " ,
] ,
[
UserWarning ,
[ 0 , 1 ] ,
spaces . Tuple ( [ spaces . Discrete ( 1 ) , spaces . Discrete ( 2 ) ] ) ,
" The obs returned by the `testing()` method was expecting a tuple, actual type: <class ' list ' > " ,
] ,
[
AssertionError ,
( 1 , 2 , 3 ) ,
spaces . Tuple ( [ spaces . Discrete ( 1 ) , spaces . Discrete ( 2 ) ] ) ,
" The obs returned by the `testing()` method length is not same as the observation space length, obs length: 3, space length: 2 " ,
] ,
[
AssertionError ,
{ 1 , 2 , 3 } ,
spaces . Dict ( a = spaces . Discrete ( 1 ) , b = spaces . Discrete ( 2 ) ) ,
" The obs returned by the `testing()` method must be a dict, actual type: <class ' set ' > " ,
] ,
[
AssertionError ,
{ " a " : 1 , " c " : 2 } ,
spaces . Dict ( a = spaces . Discrete ( 1 ) , b = spaces . Discrete ( 2 ) ) ,
" The obs returned by the `testing()` method observation keys is not same as the observation space keys, obs keys: [ ' a ' , ' c ' ], space keys: [ ' a ' , ' b ' ] " ,
] ,
] ,
)
def test_check_obs ( test , obs , obs_space : spaces . Space , message : str ) :
""" Tests the check observations function. """
if test is UserWarning :
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
check_obs ( obs , obs_space , " testing " )
else :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-07-11 02:45:24 +01:00
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
check_obs ( obs , obs_space , " testing " )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0
2022-07-11 02:45:24 +01:00
2022-08-23 11:09:54 -04:00
def _reset_no_seed ( self , options = None ) :
return self . observation_space . sample ( ) , { }
2022-07-11 02:45:24 +01:00
2022-08-23 11:09:54 -04:00
def _reset_seed_default ( self , seed = " error " , options = None ) :
return self . observation_space . sample ( ) , { }
2022-07-11 02:45:24 +01:00
2022-08-23 11:09:54 -04:00
def _reset_no_option ( self , seed = None ) :
return self . observation_space . sample ( ) , { }
2022-07-11 02:45:24 +01:00
def _make_reset_results ( results ) :
2022-08-23 11:09:54 -04:00
def _reset_result ( self , seed = None , options = None ) :
2022-07-11 02:45:24 +01:00
return results
return _reset_result
@pytest.mark.parametrize (
" test,func,message,kwargs " ,
[
[
UserWarning ,
_reset_no_seed ,
2022-09-08 10:10:07 +01:00
" Future gymnasium versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator. " ,
2022-07-11 02:45:24 +01:00
{ } ,
] ,
[
UserWarning ,
_reset_seed_default ,
" The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed= ' error ' " ,
{ } ,
] ,
[
UserWarning ,
_reset_no_option ,
2022-09-08 10:10:07 +01:00
" Future gymnasium versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information. " ,
2022-07-11 02:45:24 +01:00
{ } ,
] ,
[
2022-08-23 11:09:54 -04:00
UserWarning ,
2022-07-11 02:45:24 +01:00
_make_reset_results ( [ 0 , { } ] ) ,
2022-08-23 11:09:54 -04:00
" The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class ' list ' >` " ,
{ } ,
2022-07-11 02:45:24 +01:00
] ,
[
AssertionError ,
2022-08-23 11:09:54 -04:00
_make_reset_results ( ( np . array ( [ 0 ] , dtype = np . float32 ) , { 1 , 2 } ) ) ,
" The second element returned by `env.reset()` was not a dictionary, actual type: <class ' set ' > " ,
{ } ,
2022-07-11 02:45:24 +01:00
] ,
] ,
)
def test_passive_env_reset_checker ( test , func : callable , message : str , kwargs : Dict ) :
""" Tests the passive env reset check """
if test is UserWarning :
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
2022-12-05 19:14:56 +00:00
env_reset_passive_checker ( GenericTestEnv ( reset_func = func ) , * * kwargs )
2022-07-11 02:45:24 +01:00
else :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-07-11 02:45:24 +01:00
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
2022-12-05 19:14:56 +00:00
env_reset_passive_checker ( GenericTestEnv ( reset_func = func ) , * * kwargs )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0
2022-07-11 02:45:24 +01:00
def _modified_step (
self , obs = None , reward = 0 , terminated = False , truncated = False , info = None
) :
if obs is None :
obs = self . observation_space . sample ( )
if info is None :
info = { }
if truncated is None :
return obs , reward , terminated , info
else :
return obs , reward , terminated , truncated , info
@pytest.mark.parametrize (
" test,func,message " ,
[
[
AssertionError ,
lambda self , _ : " error " ,
" Expects step result to be a tuple, actual type: <class ' str ' > " ,
] ,
[
UserWarning ,
lambda self , _ : _modified_step ( self , terminated = " error " , truncated = None ) ,
" Expects `done` signal to be a boolean, actual type: <class ' str ' > " ,
] ,
[
UserWarning ,
lambda self , _ : _modified_step ( self , terminated = " error " , truncated = False ) ,
" Expects `terminated` signal to be a boolean, actual type: <class ' str ' > " ,
] ,
[
UserWarning ,
lambda self , _ : _modified_step ( self , truncated = " error " ) ,
" Expects `truncated` signal to be a boolean, actual type: <class ' str ' > " ,
] ,
[
2022-09-16 23:41:27 +01:00
gym . error . Error ,
2022-07-11 02:45:24 +01:00
lambda self , _ : ( 1 , 2 , 3 ) ,
" Expected `Env.step` to return a four or five element tuple, actual number of elements returned: 3. " ,
] ,
[
UserWarning ,
lambda self , _ : _modified_step ( self , reward = " error " ) ,
" The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: <class ' str ' > " ,
] ,
[
UserWarning ,
lambda self , _ : _modified_step ( self , reward = np . nan ) ,
" The reward is a NaN value. " ,
] ,
[
UserWarning ,
lambda self , _ : _modified_step ( self , reward = np . inf ) ,
" The reward is an inf value. " ,
] ,
[
AssertionError ,
lambda self , _ : _modified_step ( self , info = " error " ) ,
" The `info` returned by `step()` must be a python dictionary, actual type: <class ' str ' > " ,
] ,
] ,
)
def test_passive_env_step_checker (
test : Union [ UserWarning , type ] , func : callable , message : str
) :
""" Tests the passive env step checker. """
if test is UserWarning :
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
2022-12-05 19:14:56 +00:00
env_step_passive_checker ( GenericTestEnv ( step_func = func ) , 0 )
2022-07-11 02:45:24 +01:00
else :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-07-11 02:45:24 +01:00
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
2022-12-05 19:14:56 +00:00
env_step_passive_checker ( GenericTestEnv ( step_func = func ) , 0 )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0 , caught_warnings
2022-07-11 02:45:24 +01:00
@pytest.mark.parametrize (
" test,env,message " ,
[
[
UserWarning ,
2022-09-01 14:06:42 +01:00
GenericTestEnv ( metadata = { " render_modes " : None } ) ,
2022-07-11 02:45:24 +01:00
" No render modes was declared in the environment (env.metadata[ ' render_modes ' ] is None or not defined), you may have trouble when calling `.render()`. " ,
] ,
[
UserWarning ,
2022-09-01 14:06:42 +01:00
GenericTestEnv ( metadata = { " render_modes " : " Testing mode " } ) ,
2022-07-11 02:45:24 +01:00
" Expects the render_modes to be a sequence (i.e. list, tuple), actual type: <class ' str ' > " ,
] ,
[
UserWarning ,
2022-09-01 14:06:42 +01:00
GenericTestEnv (
metadata = { " render_modes " : [ " Testing mode " , 1 ] , " render_fps " : 1 } ,
) ,
2022-07-11 02:45:24 +01:00
" Expects all render modes to be strings, actual types: [<class ' str ' >, <class ' int ' >] " ,
] ,
[
UserWarning ,
GenericTestEnv (
2022-09-01 14:06:42 +01:00
metadata = { " render_modes " : [ " Testing mode " ] , " render_fps " : None } ,
2022-07-11 02:45:24 +01:00
render_mode = " Testing mode " ,
2022-12-05 19:14:56 +00:00
render_func = lambda self : 0 ,
2022-07-11 02:45:24 +01:00
) ,
" No render fps was declared in the environment (env.metadata[ ' render_fps ' ] is None or not defined), rendering may occur at inconsistent fps. " ,
] ,
[
UserWarning ,
2022-09-01 14:06:42 +01:00
GenericTestEnv (
metadata = { " render_modes " : [ " Testing mode " ] , " render_fps " : " fps " }
) ,
2022-07-11 02:45:24 +01:00
" Expects the `env.metadata[ ' render_fps ' ]` to be an integer or a float, actual type: <class ' str ' > " ,
] ,
[
AssertionError ,
2022-09-01 14:06:42 +01:00
GenericTestEnv (
metadata = { " render_modes " : [ ] , " render_fps " : 30 } , render_mode = " Test "
) ,
2022-07-11 02:45:24 +01:00
" With no render_modes, expects the Env.render_mode to be None, actual value: Test " ,
] ,
[
AssertionError ,
GenericTestEnv (
2022-09-01 14:06:42 +01:00
metadata = { " render_modes " : [ " Testing mode " ] , " render_fps " : 30 } ,
render_mode = " Non mode " ,
2022-07-11 02:45:24 +01:00
) ,
" The environment was initialized successfully however with an unsupported render mode. Render mode: Non mode, modes: [ ' Testing mode ' ] " ,
] ,
] ,
)
def test_passive_render_checker ( test , env : GenericTestEnv , message : str ) :
""" Tests the passive render checker. """
if test is UserWarning :
with pytest . warns (
UserWarning , match = f " ^ \\ x1b \\ [33mWARN: { re . escape ( message ) } \\ x1b \\ [0m$ "
) :
env_render_passive_checker ( env )
else :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-07-11 02:45:24 +01:00
with pytest . raises ( test , match = f " ^ { re . escape ( message ) } $ " ) :
env_render_passive_checker ( env )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0