mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-14 14:08:17 +00:00
Add dtype testing for data_equivalence
and update testing (#515)
This commit is contained in:
@@ -18,32 +18,51 @@ from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected_value",
|
||||
[
|
||||
(1.0, np.array(1.0)),
|
||||
(2, np.array(2)),
|
||||
((3.0, 4), (np.array(3.0), np.array(4))),
|
||||
([3.0, 4], [np.array(3.0), np.array(4)]),
|
||||
(1.0, np.array(1.0, dtype=np.float32)),
|
||||
(2, np.array(2, dtype=np.int32)),
|
||||
((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32))),
|
||||
([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32)]),
|
||||
(
|
||||
{
|
||||
"a": 6.0,
|
||||
"b": 7,
|
||||
},
|
||||
{"a": np.array(6.0), "b": np.array(7)},
|
||||
{"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int32)},
|
||||
),
|
||||
(np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)),
|
||||
(np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)),
|
||||
(np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)),
|
||||
(
|
||||
np.array([[1.0], [2.0]], dtype=np.int32),
|
||||
np.array([[1.0], [2.0]], dtype=np.int32),
|
||||
),
|
||||
(np.array(1.0), np.array(1.0)),
|
||||
(np.array([1, 2]), np.array([1, 2])),
|
||||
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
|
||||
(
|
||||
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
|
||||
{
|
||||
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
|
||||
"b": {"c": np.array(5)},
|
||||
"a": (
|
||||
1,
|
||||
np.array(2.0, dtype=np.float32),
|
||||
np.array([3, 4], dtype=np.int32),
|
||||
),
|
||||
"b": {"c": 5},
|
||||
},
|
||||
{
|
||||
"a": (
|
||||
np.array(1, dtype=np.int32),
|
||||
np.array(2.0, dtype=np.float32),
|
||||
np.array([3, 4], dtype=np.int32),
|
||||
),
|
||||
"b": {"c": np.array(5, dtype=np.int32)},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_roundtripping(value, expected_value):
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.
|
||||
|
||||
Warning: Jax doesn't support float64 out of the box, therefore, we only test float32 in this test.
|
||||
"""
|
||||
roundtripped_value = jax_to_numpy(numpy_to_jax(value))
|
||||
assert data_equivalence(roundtripped_value, expected_value)
|
||||
|
||||
|
||||
def jax_reset_func(self, seed=None, options=None):
|
||||
|
Reference in New Issue
Block a user