Add dtype testing for data_equivalence and update testing (#515)

This commit is contained in:
Mark Towers
2023-05-23 17:03:25 +01:00
committed by GitHub
parent 5bf6c1e93f
commit ae5d8888aa
5 changed files with 57 additions and 28 deletions

View File

@@ -18,32 +18,51 @@ from tests.testing_env import GenericTestEnv # noqa: E402
@pytest.mark.parametrize(
"value, expected_value",
[
(1.0, np.array(1.0)),
(2, np.array(2)),
((3.0, 4), (np.array(3.0), np.array(4))),
([3.0, 4], [np.array(3.0), np.array(4)]),
(1.0, np.array(1.0, dtype=np.float32)),
(2, np.array(2, dtype=np.int32)),
((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32))),
([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32)]),
(
{
"a": 6.0,
"b": 7,
},
{"a": np.array(6.0), "b": np.array(7)},
{"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int32)},
),
(np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)),
(np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)),
(np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)),
(
np.array([[1.0], [2.0]], dtype=np.int32),
np.array([[1.0], [2.0]], dtype=np.int32),
),
(np.array(1.0), np.array(1.0)),
(np.array([1, 2]), np.array([1, 2])),
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
(
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
{
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
"b": {"c": np.array(5)},
"a": (
1,
np.array(2.0, dtype=np.float32),
np.array([3, 4], dtype=np.int32),
),
"b": {"c": 5},
},
{
"a": (
np.array(1, dtype=np.int32),
np.array(2.0, dtype=np.float32),
np.array([3, 4], dtype=np.int32),
),
"b": {"c": np.array(5, dtype=np.int32)},
},
),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.
Warning: Jax doesn't support float64 out of the box, therefore, we only test float32 in this test.
"""
roundtripped_value = jax_to_numpy(numpy_to_jax(value))
assert data_equivalence(roundtripped_value, expected_value)
def jax_reset_func(self, seed=None, options=None):