2022-08-16 18:05:36 +02:00
import pickle
2023-03-13 12:10:28 +01:00
import re
2022-08-30 19:47:26 +01:00
import warnings
2022-08-16 18:05:36 +02:00
2022-03-31 12:50:38 -07:00
import pytest
2018-11-29 02:27:27 +01:00
2022-09-16 23:41:27 +01:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium . envs . registration import EnvSpec
from gymnasium . utils . env_checker import check_env , data_equivalence
2024-04-05 18:21:10 +02:00
from tests . envs . utils import all_testing_env_specs , all_testing_initialised_envs
2021-07-29 02:26:34 +02:00
2022-12-04 22:24:02 +08:00
2016-05-31 00:57:31 -07:00
# This runs a smoketest on each official registered env. We may want
2022-07-11 02:45:24 +01:00
# to try also running environments which are not officially registered envs.
PASSIVE_CHECK_IGNORE_WARNING = [
2023-03-13 12:10:28 +01:00
r " \ x1b \ [33mWARN: The environment (.*?) is out of date \ . You should consider upgrading to version `v( \ d)` \ . \ x1b \ [0m " ,
2022-07-11 02:45:24 +01:00
]
2023-03-13 12:10:28 +01:00
2022-07-11 02:45:24 +01:00
CHECK_ENV_IGNORE_WARNINGS = [
f " \x1b [33mWARN: { message } \x1b [0m "
for message in [
" A Box observation space minimum value is -infinity. This is probably too low. " ,
2023-09-08 06:51:50 -04:00
" A Box observation space maximum value is infinity. This is probably too high. " ,
2022-07-11 02:45:24 +01:00
" For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information. " ,
]
]
2022-06-16 14:29:13 +01:00
2022-05-24 08:47:51 -04:00
@pytest.mark.parametrize (
2022-11-01 12:55:38 +01:00
" spec " ,
all_testing_env_specs ,
ids = [ spec . id for spec in all_testing_env_specs ] ,
2022-05-24 08:47:51 -04:00
)
2023-02-05 00:05:59 +00:00
def test_all_env_api ( spec ) :
2022-07-11 02:45:24 +01:00
""" Check that all environments pass the environment checker with no warnings other than the expected. """
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2023-02-24 11:34:20 +00:00
env = spec . make ( ) . unwrapped
2024-04-05 18:21:10 +02:00
2022-12-01 12:18:01 +00:00
check_env ( env , skip_render_check = True )
2016-04-27 08:00:58 -07:00
2022-07-11 02:45:24 +01:00
env . close ( )
2022-08-30 19:47:26 +01:00
for warning in caught_warnings :
2022-07-11 02:45:24 +01:00
if warning . message . args [ 0 ] not in CHECK_ENV_IGNORE_WARNINGS :
2022-09-16 23:41:27 +01:00
raise gym . error . Error ( f " Unexpected warning: { warning . message } " )
2016-05-27 12:16:35 -07:00
2021-07-29 02:26:34 +02:00
2023-02-05 00:05:59 +00:00
@pytest.mark.parametrize (
" spec " , all_testing_env_specs , ids = [ spec . id for spec in all_testing_env_specs ]
)
def test_all_env_passive_env_checker ( spec ) :
with warnings . catch_warnings ( record = True ) as caught_warnings :
env = gym . make ( spec . id )
env . reset ( )
env . step ( env . action_space . sample ( ) )
env . close ( )
2023-03-13 12:10:28 +01:00
passive_check_pattern = re . compile ( " | " . join ( PASSIVE_CHECK_IGNORE_WARNING ) )
2023-02-05 00:05:59 +00:00
for warning in caught_warnings :
2023-03-13 12:10:28 +01:00
if not passive_check_pattern . search ( str ( warning . message ) ) :
2024-04-06 13:20:10 +01:00
raise ValueError ( f " Unexpected warning: { warning . message } " )
2023-02-05 00:05:59 +00:00
2022-06-16 14:29:13 +01:00
# Note that this precludes running this test in multiple threads.
# However, we probably already can't do multithreading due to some environments.
SEED = 0
NUM_STEPS = 50
2022-02-06 17:28:27 -06:00
2022-06-16 14:29:13 +01:00
@pytest.mark.parametrize (
2022-11-01 12:55:38 +01:00
" env_spec " ,
all_testing_env_specs ,
ids = [ env . id for env in all_testing_env_specs ] ,
2022-06-16 14:29:13 +01:00
)
def test_env_determinism_rollout ( env_spec : EnvSpec ) :
""" Run a rollout with two environments and assert equality.
This test run a rollout of NUM_STEPS steps with two environments
initialized with the same seed and assert that :
- observation after first reset are the same
- same actions are sampled by the two envs
- observations are contained in the observation space
- obs , rew , done and info are equals between the two envs
"""
# Don't check rollout equality if it's a nondeterministic environment.
if env_spec . nondeterministic is True :
2024-04-06 13:20:10 +01:00
pytest . skip ( f " Skipping { env_spec . id } as it is non-deterministic " )
2022-06-16 14:29:13 +01:00
env_1 = env_spec . make ( disable_env_checker = True )
env_2 = env_spec . make ( disable_env_checker = True )
2024-04-05 18:21:10 +02:00
if env_1 . metadata . get ( " jax " , False ) :
env_1 = gym . wrappers . JaxToNumpy ( env_1 )
env_2 = gym . wrappers . JaxToNumpy ( env_2 )
2022-08-23 11:09:54 -04:00
initial_obs_1 , initial_info_1 = env_1 . reset ( seed = SEED )
initial_obs_2 , initial_info_2 = env_2 . reset ( seed = SEED )
2024-04-05 18:21:10 +02:00
assert data_equivalence ( initial_obs_1 , initial_obs_2 , exact = True )
2022-06-16 14:29:13 +01:00
env_1 . action_space . seed ( SEED )
for time_step in range ( NUM_STEPS ) :
# We don't evaluate the determinism of actions
action = env_1 . action_space . sample ( )
2022-08-30 19:41:59 +05:30
obs_1 , rew_1 , terminated_1 , truncated_1 , info_1 = env_1 . step ( action )
obs_2 , rew_2 , terminated_2 , truncated_2 , info_2 = env_2 . step ( action )
2022-06-16 14:29:13 +01:00
2024-04-05 18:21:10 +02:00
assert data_equivalence (
obs_1 , obs_2 , exact = True
) , f " [ { time_step } ] obs_1= { obs_1 } , obs_2= { obs_2 } "
2022-06-16 14:29:13 +01:00
assert env_1 . observation_space . contains (
obs_1
) # obs_2 verified by previous assertion
assert rew_1 == rew_2 , f " [ { time_step } ] reward 1= { rew_1 } , reward 2= { rew_2 } "
2022-08-30 19:41:59 +05:30
assert (
terminated_1 == terminated_2
) , f " [ { time_step } ] done 1= { terminated_1 } , done 2= { terminated_2 } "
assert (
truncated_1 == truncated_2
) , f " [ { time_step } ] done 1= { truncated_1 } , done 2= { truncated_2 } "
2024-04-05 18:21:10 +02:00
assert data_equivalence (
info_1 , info_2 , exact = True
) , f " [ { time_step } ] info_1= { info_1 } , info_2= { info_2 } "
2022-06-16 14:29:13 +01:00
2022-08-30 19:41:59 +05:30
if (
terminated_1 or truncated_1
) : # terminated_2, truncated_2 verified by previous assertion
2022-06-16 14:29:13 +01:00
env_1 . reset ( seed = SEED )
env_2 . reset ( seed = SEED )
env_1 . close ( )
env_2 . close ( )
2022-02-06 17:28:27 -06:00
2022-08-16 18:05:36 +02:00
@pytest.mark.parametrize (
" env " ,
all_testing_initialised_envs ,
2022-11-12 10:21:24 +00:00
ids = [ env . spec . id for env in all_testing_initialised_envs if env . spec is not None ] ,
2022-08-16 18:05:36 +02:00
)
2022-09-16 23:41:27 +01:00
def test_pickle_env ( env : gym . Env ) :
2024-04-05 18:21:10 +02:00
if env . metadata . get ( " jax " , False ) :
env = gym . wrappers . JaxToNumpy ( env )
2024-04-18 23:55:36 +01:00
action = env . action_space . sample ( )
env_reset = env . reset ( seed = 123 )
env_step = env . step ( action )
2022-08-16 18:05:36 +02:00
pickled_env = pickle . loads ( pickle . dumps ( env ) )
2024-04-18 23:55:36 +01:00
pickle_reset = pickled_env . reset ( seed = 123 )
pickle_step = pickled_env . step ( action )
2022-08-16 18:05:36 +02:00
2024-04-18 23:55:36 +01:00
assert data_equivalence ( env_reset , pickle_reset )
assert data_equivalence ( env_step , pickle_step )
2022-08-16 18:05:36 +02:00
env . close ( )
pickled_env . close ( )