import unittest import numpy as np from gym import envs from tests.envs.spec_list import SKIP_MUJOCO_V3_WARNING_MESSAGE, skip_mujoco_v3 def verify_environments_match( old_environment_id, new_environment_id, seed=1, num_actions=1000 ): old_environment = envs.make(old_environment_id) new_environment = envs.make(new_environment_id) old_reset_observation = old_environment.reset(seed=seed) new_reset_observation = new_environment.reset(seed=seed) np.testing.assert_allclose(old_reset_observation, new_reset_observation) for i in range(num_actions): action = old_environment.action_space.sample() old_observation, old_reward, old_done, old_info = old_environment.step(action) new_observation, new_reward, new_done, new_info = new_environment.step(action) eps = 1e-6 np.testing.assert_allclose(old_observation, new_observation, atol=eps) np.testing.assert_allclose(old_reward, new_reward, atol=eps) np.testing.assert_allclose(old_done, new_done, atol=eps) for key in old_info: np.testing.assert_allclose(old_info[key], new_info[key], atol=eps) @unittest.skipIf(skip_mujoco_v3, SKIP_MUJOCO_V3_WARNING_MESSAGE) class Mujocov2Tov3ConversionTest(unittest.TestCase): def test_environments_match(self): test_cases = ( {"old_id": "Swimmer-v2", "new_id": "Swimmer-v3"}, {"old_id": "Hopper-v2", "new_id": "Hopper-v3"}, {"old_id": "Walker2d-v2", "new_id": "Walker2d-v3"}, {"old_id": "HalfCheetah-v2", "new_id": "HalfCheetah-v3"}, {"old_id": "Ant-v2", "new_id": "Ant-v3"}, {"old_id": "Humanoid-v2", "new_id": "Humanoid-v3"}, ) for test_case in test_cases: verify_environments_match(test_case["old_id"], test_case["new_id"]) # Raises KeyError because the new envs have extra info with self.assertRaises(KeyError): verify_environments_match("Swimmer-v3", "Swimmer-v2") # Raises KeyError because the new envs have extra info with self.assertRaises(KeyError): verify_environments_match("Humanoid-v3", "Humanoid-v2") # Raises KeyError because the new envs have extra info with self.assertRaises(KeyError): verify_environments_match("Swimmer-v3", "Swimmer-v2") if __name__ == "__main__": unittest.main()