2022-05-20 14:49:30 +01:00
""" A set of functions for checking an environment details.
2021-08-12 12:35:09 -05:00
This file is originally from the Stable Baselines3 repository hosted on GitHub
( https : / / github . com / DLR - RM / stable - baselines3 / )
Original Author : Antonin Raffin
It also uses some warnings / assertions from the PettingZoo repository hosted on GitHub
( https : / / github . com / PettingZoo - Team / PettingZoo )
2021-12-21 14:05:40 -05:00
Original Author : J K Terry
2021-08-12 12:35:09 -05:00
2022-06-06 16:21:45 +01:00
This was rewritten and split into " env_checker.py " and " passive_env_checker.py " for invasive and passive environment checking
Original Author : Mark Towers
2021-08-12 12:35:09 -05:00
These projects are covered by the MIT License .
"""
2022-01-19 23:28:59 +01:00
import inspect
2022-06-06 16:21:45 +01:00
from copy import deepcopy
2021-08-12 12:35:09 -05:00
import numpy as np
2022-03-31 12:50:38 -07:00
import gym
2022-07-11 02:45:24 +01:00
from gym import logger , spaces
2022-06-06 16:21:45 +01:00
from gym . utils . passive_env_checker import (
check_action_space ,
check_observation_space ,
2022-07-11 02:45:24 +01:00
env_render_passive_checker ,
env_reset_passive_checker ,
env_step_passive_checker ,
2022-06-06 16:21:45 +01:00
)
2021-08-12 12:35:09 -05:00
2022-06-06 16:21:45 +01:00
def data_equivalence ( data_1 , data_2 ) - > bool :
""" Assert equality between data 1 and 2, i.e observations, actions, info.
2021-08-12 12:35:09 -05:00
2022-05-20 14:49:30 +01:00
Args :
2022-06-06 16:21:45 +01:00
data_1 : data structure 1
data_2 : data structure 2
2021-08-12 12:35:09 -05:00
2022-05-20 14:49:30 +01:00
Returns :
2022-06-06 16:21:45 +01:00
If observation 1 and 2 are equivalent
2021-08-12 12:35:09 -05:00
"""
2022-06-06 16:21:45 +01:00
if type ( data_1 ) == type ( data_2 ) :
if isinstance ( data_1 , dict ) :
return data_1 . keys ( ) == data_2 . keys ( ) and all (
data_equivalence ( data_1 [ k ] , data_2 [ k ] ) for k in data_1 . keys ( )
)
elif isinstance ( data_1 , tuple ) :
return len ( data_1 ) == len ( data_2 ) and all (
data_equivalence ( o_1 , o_2 ) for o_1 , o_2 in zip ( data_1 , data_2 )
)
elif isinstance ( data_1 , np . ndarray ) :
return np . all ( data_1 == data_2 )
else :
return data_1 == data_2
else :
return False
2021-08-12 12:35:09 -05:00
2022-06-06 16:21:45 +01:00
def check_reset_seed ( env : gym . Env ) :
""" Check that the environment can be reset with a seed.
2022-05-20 14:49:30 +01:00
Args :
env : The environment to check
2022-06-06 16:21:45 +01:00
Raises :
AssertionError : The environment cannot be reset with a random seed ,
even though ` seed ` or ` kwargs ` appear in the signature .
2021-08-12 12:35:09 -05:00
"""
2022-06-06 16:21:45 +01:00
signature = inspect . signature ( env . reset )
2022-07-11 02:45:24 +01:00
if " seed " in signature . parameters or (
" kwargs " in signature . parameters
and signature . parameters [ " kwargs " ] . kind is inspect . Parameter . VAR_KEYWORD
) :
2022-06-06 16:21:45 +01:00
try :
obs_1 = env . reset ( seed = 123 )
2022-07-11 02:45:24 +01:00
assert (
obs_1 in env . observation_space
) , " The observation returned by `env.reset(seed=123)` is not within the observation space. "
assert (
env . unwrapped . _np_random # pyright: ignore [reportPrivateUsage]
is not None
) , " Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`. "
seed_123_rng = deepcopy (
env . unwrapped . _np_random # pyright: ignore [reportPrivateUsage]
)
2022-06-06 16:21:45 +01:00
2022-07-11 02:45:24 +01:00
obs_2 = env . reset ( seed = 123 )
assert (
obs_2 in env . observation_space
) , " The observation returned by `env.reset(seed=123)` is not within the observation space. "
if env . spec is not None and env . spec . nondeterministic is False :
assert data_equivalence (
obs_1 , obs_2
) , " Using `env.reset(seed=123)` is non-deterministic as the observations are not equivalent. "
assert (
env . unwrapped . _np_random . bit_generator . state # pyright: ignore [reportPrivateUsage]
== seed_123_rng . bit_generator . state
) , " Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`. "
2022-06-06 16:21:45 +01:00
2022-07-11 02:45:24 +01:00
obs_3 = env . reset ( seed = 456 )
assert (
obs_3 in env . observation_space
) , " The observation returned by `env.reset(seed=456)` is not within the observation space. "
2022-06-06 16:21:45 +01:00
assert (
2022-07-11 02:45:24 +01:00
env . unwrapped . _np_random . bit_generator . state # pyright: ignore [reportPrivateUsage]
2022-06-06 16:21:45 +01:00
!= seed_123_rng . bit_generator . state
2022-07-11 02:45:24 +01:00
) , " Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`. "
2022-06-06 16:21:45 +01:00
except TypeError as e :
raise AssertionError (
" The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
2022-07-11 02:45:24 +01:00
f " This should never happen, please report this issue. The error was: { e } "
2022-06-06 16:21:45 +01:00
)
2022-05-20 14:49:30 +01:00
2022-06-06 16:21:45 +01:00
seed_param = signature . parameters . get ( " seed " )
# Check the default value is None
if seed_param is not None and seed_param . default is not None :
logger . warn (
2022-07-11 02:45:24 +01:00
" The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. "
f " Actual default: { seed_param . default } "
2022-06-06 16:21:45 +01:00
)
else :
2022-07-11 02:45:24 +01:00
raise gym . error . Error (
" The `reset` method does not provide a `seed` or `**kwargs` keyword argument. "
2021-08-12 12:35:09 -05:00
)
2022-06-06 16:21:45 +01:00
def check_reset_info ( env : gym . Env ) :
""" Checks that :meth:`reset` supports the ``return_info`` keyword.
2022-05-20 14:49:30 +01:00
Args :
2022-06-06 16:21:45 +01:00
env : The environment to check
Raises :
AssertionError : The environment cannot be reset with ` return_info = True ` ,
even though ` return_info ` or ` kwargs ` appear in the signature .
2022-05-20 14:49:30 +01:00
"""
2022-06-06 16:21:45 +01:00
signature = inspect . signature ( env . reset )
2022-07-11 02:45:24 +01:00
if " return_info " in signature . parameters or (
" kwargs " in signature . parameters
and signature . parameters [ " kwargs " ] . kind is inspect . Parameter . VAR_KEYWORD
) :
2022-06-06 16:21:45 +01:00
try :
2022-07-11 02:45:24 +01:00
obs = env . reset ( return_info = False )
assert (
obs in env . observation_space
) , " The value returned by `env.reset(return_info=True)` is not within the observation space. "
2022-06-06 16:21:45 +01:00
result = env . reset ( return_info = True )
2022-07-11 02:45:24 +01:00
assert isinstance (
result , tuple
) , f " Calling the reset method with `return_info=True` did not return a tuple, actual type: { type ( result ) } "
2022-06-06 16:21:45 +01:00
assert (
len ( result ) == 2
2022-07-11 02:45:24 +01:00
) , f " Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: { len ( result ) } "
2022-06-06 16:21:45 +01:00
obs , info = result
2022-07-11 02:45:24 +01:00
assert (
obs in env . observation_space
) , " The first element returned by `env.reset(return_info=True)` is not within the observation space. "
2022-06-06 16:21:45 +01:00
assert isinstance (
info , dict
2022-07-11 02:45:24 +01:00
) , f " The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: { type ( info ) } "
2022-06-06 16:21:45 +01:00
except TypeError as e :
raise AssertionError (
2022-07-11 02:45:24 +01:00
" The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` appear in the signature. "
f " This should never happen, please report this issue. The error was: { e } "
2022-06-06 16:21:45 +01:00
)
else :
2022-07-11 02:45:24 +01:00
raise gym . error . Error (
" The `reset` method does not provide a `return_info` or `**kwargs` keyword argument. "
2021-08-12 12:35:09 -05:00
)
2022-06-06 16:21:45 +01:00
def check_reset_options ( env : gym . Env ) :
""" Check that the environment can be reset with options.
2022-05-20 14:49:30 +01:00
Args :
2022-06-06 16:21:45 +01:00
env : The environment to check
2022-05-25 14:46:41 +01:00
Raises :
2022-06-06 16:21:45 +01:00
AssertionError : The environment cannot be reset with options ,
even though ` options ` or ` kwargs ` appear in the signature .
2021-08-12 12:35:09 -05:00
"""
2022-06-06 16:21:45 +01:00
signature = inspect . signature ( env . reset )
2022-07-11 02:45:24 +01:00
if " options " in signature . parameters or (
" kwargs " in signature . parameters
and signature . parameters [ " kwargs " ] . kind is inspect . Parameter . VAR_KEYWORD
) :
2022-06-06 16:21:45 +01:00
try :
env . reset ( options = { } )
except TypeError as e :
raise AssertionError (
2022-07-11 02:45:24 +01:00
" The environment cannot be reset with options, even though `options` or `**kwargs` appear in the signature. "
f " This should never happen, please report this issue. The error was: { e } "
2022-06-06 16:21:45 +01:00
)
2021-08-12 12:35:09 -05:00
else :
2022-07-11 02:45:24 +01:00
raise gym . error . Error (
" The `reset` method does not provide an `options` or `**kwargs` keyword argument. "
2022-06-06 16:21:45 +01:00
)
2021-08-12 12:35:09 -05:00
2022-07-11 02:45:24 +01:00
def check_space_limit ( space , space_type : str ) :
""" Check the space limit for only the Box space as a test that only runs as part of `check_env`. """
if isinstance ( space , spaces . Box ) :
if np . any ( np . equal ( space . low , - np . inf ) ) :
2022-06-08 00:20:56 +02:00
logger . warn (
2022-07-11 02:45:24 +01:00
f " A Box { space_type } space minimum value is -infinity. This is probably too low. "
2022-06-08 00:20:56 +02:00
)
2022-07-11 02:45:24 +01:00
if np . any ( np . equal ( space . high , np . inf ) ) :
2022-06-08 00:20:56 +02:00
logger . warn (
2022-07-11 02:45:24 +01:00
f " A Box { space_type } space maximum value is -infinity. This is probably too high. "
2022-06-08 00:20:56 +02:00
)
2022-06-06 16:21:45 +01:00
2022-07-11 02:45:24 +01:00
# Check that the Box space is normalized
if space_type == " action " :
if len ( space . shape ) == 1 : # for vector boxes
if (
np . any (
np . logical_and (
space . low != np . zeros_like ( space . low ) ,
np . abs ( space . low ) != np . abs ( space . high ) ,
)
)
or np . any ( space . low < - 1 )
or np . any ( space . high > 1 )
) :
# todo - Add to gymlibrary.ml?
logger . warn (
" For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). "
" See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information. "
)
elif isinstance ( space , spaces . Tuple ) :
for subspace in space . spaces :
check_space_limit ( subspace , space_type )
elif isinstance ( space , spaces . Dict ) :
for subspace in space . values ( ) :
check_space_limit ( subspace , space_type )
def check_env ( env : gym . Env , warn : bool = None , skip_render_check : bool = False ) :
2022-05-20 14:49:30 +01:00
""" Check that an environment follows Gym API.
2022-06-06 16:21:45 +01:00
This is an invasive function that calls the environment ' s reset and step.
2021-08-12 12:35:09 -05:00
This is particularly useful when using a custom environment .
2022-06-06 16:21:45 +01:00
Please take a look at https : / / www . gymlibrary . ml / content / environment_creation /
2021-08-12 12:35:09 -05:00
for more information about the API .
2022-05-20 14:49:30 +01:00
Args :
env : The Gym environment that will be checked
2022-06-06 16:21:45 +01:00
warn : Ignored
2022-05-20 14:49:30 +01:00
skip_render_check : Whether to skip the checks for the render method . True by default ( useful for the CI )
2021-08-12 12:35:09 -05:00
"""
2022-06-06 16:21:45 +01:00
if warn is not None :
2022-07-11 02:45:24 +01:00
logger . warn ( " `check_env(warn=...)` parameter is now ignored. " )
2022-06-06 16:21:45 +01:00
2021-08-12 12:35:09 -05:00
assert isinstance (
env , gym . Env
2022-07-11 02:45:24 +01:00
) , " The environment must inherit from the gym.Env class. See https://www.gymlibrary.ml/content/environment_creation/ for more info. "
if env . unwrapped is not env :
logger . warn (
f " The environment ( { env } ) is different from the unwrapped version ( { env . unwrapped } ). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`. "
)
2021-08-12 12:35:09 -05:00
# ============= Check the spaces (observation and action) ================
2022-06-06 16:21:45 +01:00
assert hasattr (
env , " action_space "
2022-07-11 02:45:24 +01:00
) , " The environment must specify an action space. See https://www.gymlibrary.ml/content/environment_creation/ for more info. "
2022-06-19 22:52:26 +02:00
check_action_space ( env . action_space )
2022-07-11 02:45:24 +01:00
check_space_limit ( env . action_space , " action " )
2022-06-06 16:21:45 +01:00
assert hasattr (
env , " observation_space "
2022-07-11 02:45:24 +01:00
) , " The environment must specify an observation space. See https://www.gymlibrary.ml/content/environment_creation/ for more info. "
2022-06-19 22:52:26 +02:00
check_observation_space ( env . observation_space )
2022-07-11 02:45:24 +01:00
check_space_limit ( env . observation_space , " observation " )
2021-08-12 12:35:09 -05:00
2022-06-06 16:21:45 +01:00
# ==== Check the reset method ====
check_reset_seed ( env )
check_reset_options ( env )
check_reset_info ( env )
2021-08-12 12:35:09 -05:00
# ============ Check the returned values ===============
2022-07-11 02:45:24 +01:00
env_reset_passive_checker ( env )
env_step_passive_checker ( env , env . action_space . sample ( ) )
2021-08-12 12:35:09 -05:00
# ==== Check the render method and the declared render modes ====
if not skip_render_check :
2022-07-11 02:45:24 +01:00
if env . render_mode is not None :
env_render_passive_checker ( env )
# todo: recreate the environment with a different render_mode for check that each work