2022-06-16 14:29:13 +01:00
import numpy as np
import pytest
2022-09-16 23:41:27 +01:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium import envs
from gymnasium . envs . registration import EnvSpec
2022-06-16 14:29:13 +01:00
from tests . envs . utils import mujoco_testing_env_specs
2022-12-04 22:24:02 +08:00
2022-06-16 14:29:13 +01:00
EPS = 1e-6
def verify_environments_match (
old_env_id : str , new_env_id : str , seed : int = 1 , num_actions : int = 1000
) :
""" Verifies with two environment ids (old and new) are identical in obs, reward and done
( except info where all old info must be contained in new info ) . """
old_env = envs . make ( old_env_id , disable_env_checker = True )
new_env = envs . make ( new_env_id , disable_env_checker = True )
2022-08-23 11:09:54 -04:00
old_reset_obs , old_info = old_env . reset ( seed = seed )
new_reset_obs , new_info = new_env . reset ( seed = seed )
2022-06-16 14:29:13 +01:00
np . testing . assert_allclose ( old_reset_obs , new_reset_obs )
for i in range ( num_actions ) :
action = old_env . action_space . sample ( )
2022-08-30 19:41:59 +05:30
old_obs , old_reward , old_terminated , old_truncated , old_info = old_env . step (
action
)
new_obs , new_reward , new_terminated , new_truncated , new_info = new_env . step (
action
)
2022-06-16 14:29:13 +01:00
np . testing . assert_allclose ( old_obs , new_obs , atol = EPS )
np . testing . assert_allclose ( old_reward , new_reward , atol = EPS )
2022-08-30 19:41:59 +05:30
np . testing . assert_equal ( old_terminated , new_terminated )
np . testing . assert_equal ( old_truncated , new_truncated )
2022-06-16 14:29:13 +01:00
for key in old_info :
np . testing . assert_allclose ( old_info [ key ] , new_info [ key ] , atol = EPS )
2022-08-30 19:41:59 +05:30
if old_terminated or old_truncated :
2022-06-16 14:29:13 +01:00
break
2022-06-30 10:59:59 -04:00
EXCLUDE_POS_FROM_OBS = [
" Ant " ,
" HalfCheetah " ,
" Hopper " ,
" Humanoid " ,
" Swimmer " ,
" Walker2d " ,
]
@pytest.mark.parametrize (
" env_spec " ,
mujoco_testing_env_specs ,
ids = [ env_spec . id for env_spec in mujoco_testing_env_specs ] ,
)
def test_obs_space_mujoco_environments ( env_spec : EnvSpec ) :
""" Check that the returned observations are contained in the observation space of the environment """
env = env_spec . make ( disable_env_checker = True )
2022-08-23 11:09:54 -04:00
reset_obs , info = env . reset ( )
2022-06-30 10:59:59 -04:00
assert env . observation_space . contains (
reset_obs
2022-12-04 22:24:02 +08:00
) , f " Observation returned by reset() of { env_spec . id } is not contained in the default observation space { env . observation_space } . "
2022-06-30 10:59:59 -04:00
action = env . action_space . sample ( )
2022-08-30 19:41:59 +05:30
step_obs , _ , _ , _ , _ = env . step ( action )
2022-06-30 10:59:59 -04:00
assert env . observation_space . contains (
step_obs
2022-12-04 22:24:02 +08:00
) , f " Observation returned by step(action) of { env_spec . id } is not contained in the default observation space { env . observation_space } . "
2022-06-30 10:59:59 -04:00
if env_spec . name in EXCLUDE_POS_FROM_OBS and (
env_spec . version == 4 or env_spec . version == 3
) :
env = env_spec . make (
disable_env_checker = True , exclude_current_positions_from_observation = False
)
2022-08-23 11:09:54 -04:00
reset_obs , info = env . reset ( )
2022-06-30 10:59:59 -04:00
assert env . observation_space . contains (
reset_obs
2022-12-04 22:24:02 +08:00
) , f " Observation of { env_spec . id } is not contained in the default observation space { env . observation_space } when excluding current position from observation. "
2022-06-30 10:59:59 -04:00
2022-08-30 19:41:59 +05:30
step_obs , _ , _ , _ , _ = env . step ( action )
2022-06-30 10:59:59 -04:00
assert env . observation_space . contains (
step_obs
2022-12-04 22:24:02 +08:00
) , f " Observation returned by step(action) of { env_spec . id } is not contained in the default observation space { env . observation_space } when excluding current position from observation. "
2022-06-30 10:59:59 -04:00
# Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument
if env_spec . name == " Ant " and env_spec . version == 4 :
env = env_spec . make ( disable_env_checker = True , use_contact_forces = True )
2022-08-23 11:09:54 -04:00
reset_obs , info = env . reset ( )
2022-06-30 10:59:59 -04:00
assert env . observation_space . contains (
reset_obs
2022-12-04 22:24:02 +08:00
) , f " Observation of { env_spec . id } is not contained in the default observation space { env . observation_space } when using contact forces. "
2022-06-30 10:59:59 -04:00
2022-08-30 19:41:59 +05:30
step_obs , _ , _ , _ , _ = env . step ( action )
2022-06-30 10:59:59 -04:00
assert env . observation_space . contains (
step_obs
2022-12-04 22:24:02 +08:00
) , f " Observation returned by step(action) of { env_spec . id } is not contained in the default observation space { env . observation_space } when using contact forces. "
2022-06-30 10:59:59 -04:00
2022-06-16 14:29:13 +01:00
MUJOCO_V2_V3_ENVS = [
spec . name
for spec in mujoco_testing_env_specs
2022-09-16 23:41:27 +01:00
if spec . version == 2 and f " { spec . name } -v3 " in gym . envs . registry
2022-06-16 14:29:13 +01:00
]
@pytest.mark.parametrize ( " env_name " , MUJOCO_V2_V3_ENVS )
def test_mujoco_v2_to_v3_conversion ( env_name : str ) :
""" Checks that all v2 mujoco environments are the same as v3 environments. """
verify_environments_match ( f " { env_name } -v2 " , f " { env_name } -v3 " )
@pytest.mark.parametrize ( " env_name " , MUJOCO_V2_V3_ENVS )
def test_mujoco_incompatible_v3_to_v2 ( env_name : str ) :
""" Checks that the v3 environment are slightly different from v2, (v3 has additional info keys that v2 does not). """
with pytest . raises ( KeyError ) :
verify_environments_match ( f " { env_name } -v3 " , f " { env_name } -v2 " )