mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-27 00:37:19 +00:00
Support only new step API (while retaining compatibility functions) (#3019)
This commit is contained in:
@@ -2,7 +2,10 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.utils.env_checker import data_equivalence
|
||||
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
|
||||
from gym.utils.step_api_compatibility import (
|
||||
convert_to_done_step_api,
|
||||
convert_to_terminated_truncated_step_api,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -54,7 +57,7 @@ from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
|
||||
def test_to_done_step_api(
|
||||
is_vector_env, done_returns, expected_terminated, expected_truncated
|
||||
):
|
||||
_, _, terminated, truncated, info = step_to_new_api(
|
||||
_, _, terminated, truncated, info = convert_to_terminated_truncated_step_api(
|
||||
done_returns, is_vector_env=is_vector_env
|
||||
)
|
||||
assert np.all(terminated == expected_terminated)
|
||||
@@ -67,7 +70,7 @@ def test_to_done_step_api(
|
||||
else: # isinstance(info, dict)
|
||||
assert "TimeLimit.truncated" not in info
|
||||
|
||||
roundtripped_returns = step_to_old_api(
|
||||
roundtripped_returns = convert_to_done_step_api(
|
||||
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
|
||||
)
|
||||
assert data_equivalence(done_returns, roundtripped_returns)
|
||||
@@ -112,7 +115,7 @@ def test_to_done_step_api(
|
||||
def test_to_terminated_truncated_step_api(
|
||||
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
|
||||
):
|
||||
_, _, done, info = step_to_old_api(
|
||||
_, _, done, info = convert_to_done_step_api(
|
||||
terminated_truncated_returns, is_vector_env=is_vector_env
|
||||
)
|
||||
assert np.all(done == expected_done)
|
||||
@@ -136,7 +139,7 @@ def test_to_terminated_truncated_step_api(
|
||||
else:
|
||||
assert "TimeLimit.truncated" not in info
|
||||
|
||||
roundtripped_returns = step_to_new_api(
|
||||
roundtripped_returns = convert_to_terminated_truncated_step_api(
|
||||
(0, 0, done, info), is_vector_env=is_vector_env
|
||||
)
|
||||
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)
|
||||
@@ -146,19 +149,19 @@ def test_edge_case():
|
||||
# When converting between the two-step APIs this is not possible in a single case
|
||||
# terminated=True and truncated=True -> done=True and info={}
|
||||
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
|
||||
_, _, done, info = step_to_old_api((0, 0, True, True, {}))
|
||||
_, _, done, info = convert_to_done_step_api((0, 0, True, True, {}))
|
||||
assert done is True
|
||||
assert info == {"TimeLimit.truncated": False}
|
||||
|
||||
# Test with vector dict info
|
||||
_, _, done, info = step_to_old_api(
|
||||
_, _, done, info = convert_to_done_step_api(
|
||||
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
|
||||
)
|
||||
assert np.all(done)
|
||||
assert info == {"TimeLimit.truncated": np.array([False])}
|
||||
|
||||
# Test with vector list info
|
||||
_, _, done, info = step_to_old_api(
|
||||
_, _, done, info = convert_to_done_step_api(
|
||||
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
|
||||
is_vector_env=True,
|
||||
)
|
||||
|
Reference in New Issue
Block a user