From ae5d8888aae23cdb0bafc9dc0648481d9e51ebaa Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 23 May 2023 17:03:25 +0100 Subject: [PATCH] Add dtype testing for `data_equivalence` and update testing (#515) --- .../experimental/wrappers/jax_to_numpy.py | 14 ++++-- gymnasium/spaces/box.py | 2 +- gymnasium/spaces/graph.py | 14 +++--- gymnasium/utils/env_checker.py | 10 +++-- .../wrappers/test_jax_to_numpy.py | 45 +++++++++++++------ 5 files changed, 57 insertions(+), 28 deletions(-) diff --git a/gymnasium/experimental/wrappers/jax_to_numpy.py b/gymnasium/experimental/wrappers/jax_to_numpy.py index baf48a246..5abc4b05c 100644 --- a/gymnasium/experimental/wrappers/jax_to_numpy.py +++ b/gymnasium/experimental/wrappers/jax_to_numpy.py @@ -32,15 +32,21 @@ def numpy_to_jax(value: Any) -> Any: @numpy_to_jax.register(numbers.Number) -@numpy_to_jax.register(np.ndarray) -def _number_ndarray_numpy_to_jax( - value: np.ndarray | numbers.Number, +def _number_to_jax( + value: numbers.Number, ) -> jnp.DeviceArray: - """Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray.""" + """Converts a number (int, float, etc.) to a Jax DeviceArray.""" assert jnp is not None return jnp.array(value) +@numpy_to_jax.register(np.ndarray) +def _numpy_array_to_jax(value: np.ndarray) -> jnp.DeviceArray: + """Converts a NumPy Array to a Jax DeviceArray with the same dtype (excluding float64 without being enabled).""" + assert jnp is not None + return jnp.array(value, dtype=value.dtype) + + @numpy_to_jax.register(abc.Mapping) def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: """Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays.""" diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index db0d32d42..3b75ab4fc 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -256,7 +256,7 @@ class Box(Space[NDArray[Any]]): def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]: """Convert a JSONable data type to a batch of samples from this space.""" - return [np.asarray(sample) for sample in sample_n] + return [np.asarray(sample, dtype=self.dtype) for sample in sample_n] def __repr__(self) -> str: """A string representation of this space. diff --git a/gymnasium/spaces/graph.py b/gymnasium/spaces/graph.py index bfeb1bed6..d2be47e8a 100644 --- a/gymnasium/spaces/graph.py +++ b/gymnasium/spaces/graph.py @@ -49,8 +49,7 @@ class Graph(Space[GraphInstance]): [7, 0], [3, 7], [8, 4], - [8, 8]])) - + [8, 8]], dtype=int32)) """ def __init__( @@ -178,7 +177,7 @@ class Graph(Space[GraphInstance]): sampled_edge_links = None if sampled_edges is not None and num_edges > 0: sampled_edge_links = self.np_random.integers( - low=0, high=num_nodes, size=(num_edges, 2) + low=0, high=num_nodes, size=(num_edges, 2), dtype=np.int32 ) return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links) @@ -248,14 +247,15 @@ class Graph(Space[GraphInstance]): ret: list[GraphInstance] = [] for sample in sample_n: if "edges" in sample: + assert self.edge_space is not None ret_n = GraphInstance( - np.asarray(sample["nodes"]), - np.asarray(sample["edges"]), - np.asarray(sample["edge_links"]), + np.asarray(sample["nodes"], dtype=self.node_space.dtype), + np.asarray(sample["edges"], dtype=self.edge_space.dtype), + np.asarray(sample["edge_links"], dtype=np.int32), ) else: ret_n = GraphInstance( - np.asarray(sample["nodes"]), + np.asarray(sample["nodes"], dtype=self.node_space.dtype), None, None, ) diff --git a/gymnasium/utils/env_checker.py b/gymnasium/utils/env_checker.py index 7dc2569c7..e32e22127 100644 --- a/gymnasium/utils/env_checker.py +++ b/gymnasium/utils/env_checker.py @@ -50,9 +50,13 @@ def data_equivalence(data_1, data_2) -> bool: data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2) ) elif isinstance(data_1, np.ndarray): - return data_1.shape == data_2.shape and np.allclose( - data_1, data_2, atol=0.00001 - ) + if data_1.shape == data_2.shape and data_1.dtype == data_2.dtype: + if data_1.dtype == object: + return all(data_equivalence(a, b) for a, b in zip(data_1, data_2)) + else: + return np.allclose(data_1, data_2, atol=0.00001) + else: + return False else: return data_1 == data_2 else: diff --git a/tests/experimental/wrappers/test_jax_to_numpy.py b/tests/experimental/wrappers/test_jax_to_numpy.py index e0936ee31..bc0951ecf 100644 --- a/tests/experimental/wrappers/test_jax_to_numpy.py +++ b/tests/experimental/wrappers/test_jax_to_numpy.py @@ -18,32 +18,51 @@ from tests.testing_env import GenericTestEnv # noqa: E402 @pytest.mark.parametrize( "value, expected_value", [ - (1.0, np.array(1.0)), - (2, np.array(2)), - ((3.0, 4), (np.array(3.0), np.array(4))), - ([3.0, 4], [np.array(3.0), np.array(4)]), + (1.0, np.array(1.0, dtype=np.float32)), + (2, np.array(2, dtype=np.int32)), + ((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32))), + ([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32)]), ( { "a": 6.0, "b": 7, }, - {"a": np.array(6.0), "b": np.array(7)}, + {"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int32)}, + ), + (np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)), + (np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)), + (np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)), + ( + np.array([[1.0], [2.0]], dtype=np.int32), + np.array([[1.0], [2.0]], dtype=np.int32), ), - (np.array(1.0), np.array(1.0)), - (np.array([1, 2]), np.array([1, 2])), - (np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])), ( - {"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}}, { - "a": (np.array(1), np.array(2.0), np.array([3, 4])), - "b": {"c": np.array(5)}, + "a": ( + 1, + np.array(2.0, dtype=np.float32), + np.array([3, 4], dtype=np.int32), + ), + "b": {"c": 5}, + }, + { + "a": ( + np.array(1, dtype=np.int32), + np.array(2.0, dtype=np.float32), + np.array([3, 4], dtype=np.int32), + ), + "b": {"c": np.array(5, dtype=np.int32)}, }, ), ], ) def test_roundtripping(value, expected_value): - """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" - assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value) + """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper. + + Warning: Jax doesn't support float64 out of the box, therefore, we only test float32 in this test. + """ + roundtripped_value = jax_to_numpy(numpy_to_jax(value)) + assert data_equivalence(roundtripped_value, expected_value) def jax_reset_func(self, seed=None, options=None):