Support only new step API (while retaining compatibility functions) (#3019)

This commit is contained in:
Arjun KG
2022-08-30 19:41:59 +05:30
committed by GitHub
parent 884ba08f19
commit 54b406b799
58 changed files with 378 additions and 559 deletions

View File

@@ -2,7 +2,10 @@ import numpy as np
import pytest
from gym.utils.env_checker import data_equivalence
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
from gym.utils.step_api_compatibility import (
convert_to_done_step_api,
convert_to_terminated_truncated_step_api,
)
@pytest.mark.parametrize(
@@ -54,7 +57,7 @@ from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
def test_to_done_step_api(
is_vector_env, done_returns, expected_terminated, expected_truncated
):
_, _, terminated, truncated, info = step_to_new_api(
_, _, terminated, truncated, info = convert_to_terminated_truncated_step_api(
done_returns, is_vector_env=is_vector_env
)
assert np.all(terminated == expected_terminated)
@@ -67,7 +70,7 @@ def test_to_done_step_api(
else: # isinstance(info, dict)
assert "TimeLimit.truncated" not in info
roundtripped_returns = step_to_old_api(
roundtripped_returns = convert_to_done_step_api(
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
)
assert data_equivalence(done_returns, roundtripped_returns)
@@ -112,7 +115,7 @@ def test_to_done_step_api(
def test_to_terminated_truncated_step_api(
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
):
_, _, done, info = step_to_old_api(
_, _, done, info = convert_to_done_step_api(
terminated_truncated_returns, is_vector_env=is_vector_env
)
assert np.all(done == expected_done)
@@ -136,7 +139,7 @@ def test_to_terminated_truncated_step_api(
else:
assert "TimeLimit.truncated" not in info
roundtripped_returns = step_to_new_api(
roundtripped_returns = convert_to_terminated_truncated_step_api(
(0, 0, done, info), is_vector_env=is_vector_env
)
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)
@@ -146,19 +149,19 @@ def test_edge_case():
# When converting between the two-step APIs this is not possible in a single case
# terminated=True and truncated=True -> done=True and info={}
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
_, _, done, info = step_to_old_api((0, 0, True, True, {}))
_, _, done, info = convert_to_done_step_api((0, 0, True, True, {}))
assert done is True
assert info == {"TimeLimit.truncated": False}
# Test with vector dict info
_, _, done, info = step_to_old_api(
_, _, done, info = convert_to_done_step_api(
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
)
assert np.all(done)
assert info == {"TimeLimit.truncated": np.array([False])}
# Test with vector list info
_, _, done, info = step_to_old_api(
_, _, done, info = convert_to_done_step_api(
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
is_vector_env=True,
)