mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Add dtype testing for data_equivalence
and update testing (#515)
This commit is contained in:
@@ -32,15 +32,21 @@ def numpy_to_jax(value: Any) -> Any:
|
||||
|
||||
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _number_ndarray_numpy_to_jax(
|
||||
value: np.ndarray | numbers.Number,
|
||||
def _number_to_jax(
|
||||
value: numbers.Number,
|
||||
) -> jnp.DeviceArray:
|
||||
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||
"""Converts a number (int, float, etc.) to a Jax DeviceArray."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _numpy_array_to_jax(value: np.ndarray) -> jnp.DeviceArray:
|
||||
"""Converts a NumPy Array to a Jax DeviceArray with the same dtype (excluding float64 without being enabled)."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value, dtype=value.dtype)
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||
|
@@ -256,7 +256,7 @@ class Box(Space[NDArray[Any]]):
|
||||
|
||||
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
return [np.asarray(sample, dtype=self.dtype) for sample in sample_n]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""A string representation of this space.
|
||||
|
@@ -49,8 +49,7 @@ class Graph(Space[GraphInstance]):
|
||||
[7, 0],
|
||||
[3, 7],
|
||||
[8, 4],
|
||||
[8, 8]]))
|
||||
|
||||
[8, 8]], dtype=int32))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -178,7 +177,7 @@ class Graph(Space[GraphInstance]):
|
||||
sampled_edge_links = None
|
||||
if sampled_edges is not None and num_edges > 0:
|
||||
sampled_edge_links = self.np_random.integers(
|
||||
low=0, high=num_nodes, size=(num_edges, 2)
|
||||
low=0, high=num_nodes, size=(num_edges, 2), dtype=np.int32
|
||||
)
|
||||
|
||||
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
|
||||
@@ -248,14 +247,15 @@ class Graph(Space[GraphInstance]):
|
||||
ret: list[GraphInstance] = []
|
||||
for sample in sample_n:
|
||||
if "edges" in sample:
|
||||
assert self.edge_space is not None
|
||||
ret_n = GraphInstance(
|
||||
np.asarray(sample["nodes"]),
|
||||
np.asarray(sample["edges"]),
|
||||
np.asarray(sample["edge_links"]),
|
||||
np.asarray(sample["nodes"], dtype=self.node_space.dtype),
|
||||
np.asarray(sample["edges"], dtype=self.edge_space.dtype),
|
||||
np.asarray(sample["edge_links"], dtype=np.int32),
|
||||
)
|
||||
else:
|
||||
ret_n = GraphInstance(
|
||||
np.asarray(sample["nodes"]),
|
||||
np.asarray(sample["nodes"], dtype=self.node_space.dtype),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
@@ -50,9 +50,13 @@ def data_equivalence(data_1, data_2) -> bool:
|
||||
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||
)
|
||||
elif isinstance(data_1, np.ndarray):
|
||||
return data_1.shape == data_2.shape and np.allclose(
|
||||
data_1, data_2, atol=0.00001
|
||||
)
|
||||
if data_1.shape == data_2.shape and data_1.dtype == data_2.dtype:
|
||||
if data_1.dtype == object:
|
||||
return all(data_equivalence(a, b) for a, b in zip(data_1, data_2))
|
||||
else:
|
||||
return np.allclose(data_1, data_2, atol=0.00001)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return data_1 == data_2
|
||||
else:
|
||||
|
@@ -18,32 +18,51 @@ from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected_value",
|
||||
[
|
||||
(1.0, np.array(1.0)),
|
||||
(2, np.array(2)),
|
||||
((3.0, 4), (np.array(3.0), np.array(4))),
|
||||
([3.0, 4], [np.array(3.0), np.array(4)]),
|
||||
(1.0, np.array(1.0, dtype=np.float32)),
|
||||
(2, np.array(2, dtype=np.int32)),
|
||||
((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32))),
|
||||
([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32)]),
|
||||
(
|
||||
{
|
||||
"a": 6.0,
|
||||
"b": 7,
|
||||
},
|
||||
{"a": np.array(6.0), "b": np.array(7)},
|
||||
{"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int32)},
|
||||
),
|
||||
(np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)),
|
||||
(np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)),
|
||||
(np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)),
|
||||
(
|
||||
np.array([[1.0], [2.0]], dtype=np.int32),
|
||||
np.array([[1.0], [2.0]], dtype=np.int32),
|
||||
),
|
||||
(np.array(1.0), np.array(1.0)),
|
||||
(np.array([1, 2]), np.array([1, 2])),
|
||||
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
|
||||
(
|
||||
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
|
||||
{
|
||||
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
|
||||
"b": {"c": np.array(5)},
|
||||
"a": (
|
||||
1,
|
||||
np.array(2.0, dtype=np.float32),
|
||||
np.array([3, 4], dtype=np.int32),
|
||||
),
|
||||
"b": {"c": 5},
|
||||
},
|
||||
{
|
||||
"a": (
|
||||
np.array(1, dtype=np.int32),
|
||||
np.array(2.0, dtype=np.float32),
|
||||
np.array([3, 4], dtype=np.int32),
|
||||
),
|
||||
"b": {"c": np.array(5, dtype=np.int32)},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_roundtripping(value, expected_value):
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.
|
||||
|
||||
Warning: Jax doesn't support float64 out of the box, therefore, we only test float32 in this test.
|
||||
"""
|
||||
roundtripped_value = jax_to_numpy(numpy_to_jax(value))
|
||||
assert data_equivalence(roundtripped_value, expected_value)
|
||||
|
||||
|
||||
def jax_reset_func(self, seed=None, options=None):
|
||||
|
Reference in New Issue
Block a user