Add dtype testing for data_equivalence and update testing (#515)

This commit is contained in:
Mark Towers
2023-05-23 17:03:25 +01:00
committed by GitHub
parent 5bf6c1e93f
commit ae5d8888aa
5 changed files with 57 additions and 28 deletions

View File

@@ -32,15 +32,21 @@ def numpy_to_jax(value: Any) -> Any:
@numpy_to_jax.register(numbers.Number)
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(
value: np.ndarray | numbers.Number,
def _number_to_jax(
value: numbers.Number,
) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
"""Converts a number (int, float, etc.) to a Jax DeviceArray."""
assert jnp is not None
return jnp.array(value)
@numpy_to_jax.register(np.ndarray)
def _numpy_array_to_jax(value: np.ndarray) -> jnp.DeviceArray:
"""Converts a NumPy Array to a Jax DeviceArray with the same dtype (excluding float64 without being enabled)."""
assert jnp is not None
return jnp.array(value, dtype=value.dtype)
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""

View File

@@ -256,7 +256,7 @@ class Box(Space[NDArray[Any]]):
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n]
return [np.asarray(sample, dtype=self.dtype) for sample in sample_n]
def __repr__(self) -> str:
"""A string representation of this space.

View File

@@ -49,8 +49,7 @@ class Graph(Space[GraphInstance]):
[7, 0],
[3, 7],
[8, 4],
[8, 8]]))
[8, 8]], dtype=int32))
"""
def __init__(
@@ -178,7 +177,7 @@ class Graph(Space[GraphInstance]):
sampled_edge_links = None
if sampled_edges is not None and num_edges > 0:
sampled_edge_links = self.np_random.integers(
low=0, high=num_nodes, size=(num_edges, 2)
low=0, high=num_nodes, size=(num_edges, 2), dtype=np.int32
)
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
@@ -248,14 +247,15 @@ class Graph(Space[GraphInstance]):
ret: list[GraphInstance] = []
for sample in sample_n:
if "edges" in sample:
assert self.edge_space is not None
ret_n = GraphInstance(
np.asarray(sample["nodes"]),
np.asarray(sample["edges"]),
np.asarray(sample["edge_links"]),
np.asarray(sample["nodes"], dtype=self.node_space.dtype),
np.asarray(sample["edges"], dtype=self.edge_space.dtype),
np.asarray(sample["edge_links"], dtype=np.int32),
)
else:
ret_n = GraphInstance(
np.asarray(sample["nodes"]),
np.asarray(sample["nodes"], dtype=self.node_space.dtype),
None,
None,
)

View File

@@ -50,9 +50,13 @@ def data_equivalence(data_1, data_2) -> bool:
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
elif isinstance(data_1, np.ndarray):
return data_1.shape == data_2.shape and np.allclose(
data_1, data_2, atol=0.00001
)
if data_1.shape == data_2.shape and data_1.dtype == data_2.dtype:
if data_1.dtype == object:
return all(data_equivalence(a, b) for a, b in zip(data_1, data_2))
else:
return np.allclose(data_1, data_2, atol=0.00001)
else:
return False
else:
return data_1 == data_2
else:

View File

@@ -18,32 +18,51 @@ from tests.testing_env import GenericTestEnv # noqa: E402
@pytest.mark.parametrize(
"value, expected_value",
[
(1.0, np.array(1.0)),
(2, np.array(2)),
((3.0, 4), (np.array(3.0), np.array(4))),
([3.0, 4], [np.array(3.0), np.array(4)]),
(1.0, np.array(1.0, dtype=np.float32)),
(2, np.array(2, dtype=np.int32)),
((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32))),
([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32)]),
(
{
"a": 6.0,
"b": 7,
},
{"a": np.array(6.0), "b": np.array(7)},
{"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int32)},
),
(np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)),
(np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)),
(np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)),
(
np.array([[1.0], [2.0]], dtype=np.int32),
np.array([[1.0], [2.0]], dtype=np.int32),
),
(np.array(1.0), np.array(1.0)),
(np.array([1, 2]), np.array([1, 2])),
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
(
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
{
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
"b": {"c": np.array(5)},
"a": (
1,
np.array(2.0, dtype=np.float32),
np.array([3, 4], dtype=np.int32),
),
"b": {"c": 5},
},
{
"a": (
np.array(1, dtype=np.int32),
np.array(2.0, dtype=np.float32),
np.array([3, 4], dtype=np.int32),
),
"b": {"c": np.array(5, dtype=np.int32)},
},
),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.
Warning: Jax doesn't support float64 out of the box, therefore, we only test float32 in this test.
"""
roundtripped_value = jax_to_numpy(numpy_to_jax(value))
assert data_equivalence(roundtripped_value, expected_value)
def jax_reset_func(self, seed=None, options=None):