import numpy as np import pytest from gym.utils.env_checker import data_equivalence from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api @pytest.mark.parametrize( "is_vector_env, done_returns, expected_terminated, expected_truncated", ( # Test each of the permutations for single environments with and without the old info (False, (0, 0, False, {"Test-info": True}), False, False), (False, (0, 0, False, {"TimeLimit.truncated": False}), False, False), (False, (0, 0, True, {}), True, False), (False, (0, 0, True, {"TimeLimit.truncated": True}), False, True), (False, (0, 0, True, {"Test-info": True}), True, False), # Test vectorise versions with both list and dict infos testing each permutation for sub-environments ( True, ( 0, 0, np.array([False, True, True]), [{}, {}, {"TimeLimit.truncated": True}], ), np.array([False, True, False]), np.array([False, False, True]), ), ( True, ( 0, 0, np.array([False, True, True]), {"TimeLimit.truncated": np.array([False, False, True])}, ), np.array([False, True, False]), np.array([False, False, True]), ), # empty truncated info ( True, ( 0, 0, np.array([False, True]), {}, ), np.array([False, True]), np.array([False, False]), ), ), ) def test_to_done_step_api( is_vector_env, done_returns, expected_terminated, expected_truncated ): _, _, terminated, truncated, info = step_to_new_api( done_returns, is_vector_env=is_vector_env ) assert np.all(terminated == expected_terminated) assert np.all(truncated == expected_truncated) if is_vector_env is False: assert "TimeLimit.truncated" not in info elif isinstance(info, list): assert all("TimeLimit.truncated" not in sub_info for sub_info in info) else: # isinstance(info, dict) assert "TimeLimit.truncated" not in info roundtripped_returns = step_to_old_api( (0, 0, terminated, truncated, info), is_vector_env=is_vector_env ) assert data_equivalence(done_returns, roundtripped_returns) @pytest.mark.parametrize( "is_vector_env, terminated_truncated_returns, expected_done, expected_truncated", ( (False, (0, 0, False, False, {"Test-info": True}), False, False), (False, (0, 0, True, False, {}), True, False), (False, (0, 0, False, True, {}), True, True), # (False, (), True, True), # Not possible to encode in the old step api # Test vector dict info ( True, (0, 0, np.array([False, True, False]), np.array([False, False, True]), {}), np.array([False, True, True]), np.array([False, False, True]), ), # Test vector dict info with no truncation ( True, (0, 0, np.array([False, True]), np.array([False, False]), {}), np.array([False, True]), np.array([False, False]), ), # Test vector list info ( True, ( 0, 0, np.array([False, True, False]), np.array([False, False, True]), [{"Test-Info": True}, {}, {}], ), np.array([False, True, True]), np.array([False, False, True]), ), ), ) def test_to_terminated_truncated_step_api( is_vector_env, terminated_truncated_returns, expected_done, expected_truncated ): _, _, done, info = step_to_old_api( terminated_truncated_returns, is_vector_env=is_vector_env ) assert np.all(done == expected_done) if is_vector_env is False: if expected_done: assert info["TimeLimit.truncated"] == expected_truncated else: assert "TimeLimit.truncated" not in info elif isinstance(info, list): for sub_info, env_done, env_truncated in zip( info, expected_done, expected_truncated ): if env_done: assert sub_info["TimeLimit.truncated"] == env_truncated else: assert "TimeLimit.truncated" not in sub_info else: # isinstance(info, dict) if np.any(expected_done): assert np.all(info["TimeLimit.truncated"] == expected_truncated) else: assert "TimeLimit.truncated" not in info roundtripped_returns = step_to_new_api( (0, 0, done, info), is_vector_env=is_vector_env ) assert data_equivalence(terminated_truncated_returns, roundtripped_returns) def test_edge_case(): # When converting between the two-step APIs this is not possible in a single case # terminated=True and truncated=True -> done=True and info={} # We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail _, _, done, info = step_to_old_api((0, 0, True, True, {})) assert done is True assert info == {"TimeLimit.truncated": False} # Test with vector dict info _, _, done, info = step_to_old_api( (0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True ) assert np.all(done) assert info == {"TimeLimit.truncated": np.array([False])} # Test with vector list info _, _, done, info = step_to_old_api( (0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]), is_vector_env=True, ) assert np.all(done) assert info == [{"Test-Info": True, "TimeLimit.truncated": False}]