mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
Add Graph to Spaces (#2869)
This commit is contained in:
@@ -11,6 +11,7 @@ are vectors in the two-dimensional unit cube, the environment code may contain t
|
||||
from gym.spaces.box import Box
|
||||
from gym.spaces.dict import Dict
|
||||
from gym.spaces.discrete import Discrete
|
||||
from gym.spaces.graph import Graph, GraphInstance
|
||||
from gym.spaces.multi_binary import MultiBinary
|
||||
from gym.spaces.multi_discrete import MultiDiscrete
|
||||
from gym.spaces.space import Space
|
||||
@@ -21,6 +22,8 @@ __all__ = [
|
||||
"Space",
|
||||
"Box",
|
||||
"Discrete",
|
||||
"Graph",
|
||||
"GraphInstance",
|
||||
"MultiDiscrete",
|
||||
"MultiBinary",
|
||||
"Tuple",
|
||||
|
211
gym/spaces/graph.py
Normal file
211
gym/spaces/graph.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces.box import Box
|
||||
from gym.spaces.discrete import Discrete
|
||||
from gym.spaces.multi_discrete import MultiDiscrete
|
||||
from gym.spaces.space import Space
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"])):
|
||||
r"""Returns a NamedTuple representing a graph object.
|
||||
|
||||
Args:
|
||||
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes.
|
||||
(...) must adhere to the shape of the node space.
|
||||
|
||||
edges (np.ndarray): an (m x ...) sized array representing the features for m nodes.
|
||||
(...) must adhere to the shape of the edge space.
|
||||
|
||||
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
|
||||
|
||||
Returns:
|
||||
A NamedTuple representing a graph with `.nodes`, `.edges`, and `.edge_links`.
|
||||
"""
|
||||
|
||||
|
||||
class Graph(Space):
|
||||
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
||||
|
||||
Example usage::
|
||||
|
||||
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_space: Union[Box, Discrete],
|
||||
edge_space: Union[None, Box, Discrete],
|
||||
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
||||
):
|
||||
r"""Constructor of :class:`Graph`.
|
||||
|
||||
The argument ``node_space`` specifies the base space that each node feature will use.
|
||||
This argument must be either a Box or Discrete instance.
|
||||
|
||||
The argument ``edge_space`` specifies the base space that each edge feature will use.
|
||||
This argument must be either a None, Box or Discrete instance.
|
||||
|
||||
Args:
|
||||
node_space (Union[Box, Discrete]): space of the node features.
|
||||
edge_space (Union[None, Box, Discrete]): space of the node features.
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||
"""
|
||||
assert isinstance(
|
||||
node_space, (Box, Discrete)
|
||||
), f"Values of the node_space should be instances of Box or Discrete, got {type(node_space)}"
|
||||
if edge_space is not None:
|
||||
assert isinstance(
|
||||
edge_space, (Box, Discrete)
|
||||
), f"Values of the edge_space should be instances of None Box or Discrete, got {type(node_space)}"
|
||||
|
||||
self.node_space = node_space
|
||||
self.edge_space = edge_space
|
||||
|
||||
super().__init__(None, None, seed)
|
||||
|
||||
def _generate_sample_space(
|
||||
self, base_space: Union[None, Box, Discrete], num: int
|
||||
) -> Optional[Union[Box, Discrete]]:
|
||||
# the possibility of this space , got {type(base_space)}aving nothing
|
||||
if num == 0:
|
||||
return None
|
||||
|
||||
if isinstance(base_space, Box):
|
||||
return Box(
|
||||
low=np.array(max(1, num) * [base_space.low]),
|
||||
high=np.array(max(1, num) * [base_space.high]),
|
||||
shape=(num, *base_space.shape),
|
||||
dtype=base_space.dtype,
|
||||
seed=self._np_random,
|
||||
)
|
||||
elif isinstance(base_space, Discrete):
|
||||
return MultiDiscrete(nvec=[base_space.n] * num, seed=self._np_random)
|
||||
elif base_space is None:
|
||||
return None
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Only Box and Discrete can be accepted as a base_space, got {type(base_space)}, you should not have gotten this error."
|
||||
)
|
||||
|
||||
def _sample_sample_space(self, sample_space) -> Optional[np.ndarray]:
|
||||
if sample_space is not None:
|
||||
return sample_space.sample()
|
||||
else:
|
||||
return None
|
||||
|
||||
def sample(self) -> NamedTuple:
|
||||
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
|
||||
|
||||
Returns:
|
||||
A NamedTuple representing a graph with attributes .nodes, .edges, and .edge_links.
|
||||
"""
|
||||
num_nodes = self.np_random.integers(low=1, high=10)
|
||||
|
||||
# we only have edges when we have at least 2 nodes
|
||||
num_edges = 0
|
||||
if num_nodes > 1:
|
||||
# maximal number of edges is (n*n) allowing self connections and two way is allowed
|
||||
num_edges = self.np_random.integers(num_nodes * num_nodes)
|
||||
|
||||
node_sample_space = self._generate_sample_space(self.node_space, num_nodes)
|
||||
edge_sample_space = self._generate_sample_space(self.edge_space, num_edges)
|
||||
|
||||
sampled_nodes = self._sample_sample_space(node_sample_space)
|
||||
sampled_edges = self._sample_sample_space(edge_sample_space)
|
||||
|
||||
sampled_edge_links = None
|
||||
if sampled_edges is not None and num_edges > 0:
|
||||
sampled_edge_links = self.np_random.integers(
|
||||
low=0, high=num_nodes, size=(num_edges, 2)
|
||||
)
|
||||
|
||||
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
|
||||
|
||||
def contains(self, x: GraphInstance) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space.
|
||||
|
||||
Returns False when:
|
||||
- any node in nodes is not contained in Graph.node_space
|
||||
- edge_links is not of dtype int
|
||||
- len(edge_links) != len(edges)
|
||||
- has edges but Graph.edge_space is None
|
||||
- edge_links has index less than 0
|
||||
- edge_links has index more than number of nodes
|
||||
- any edge in edges is not contained in Graph.edge_space
|
||||
"""
|
||||
if not isinstance(x, GraphInstance):
|
||||
return False
|
||||
if x.edges is not None:
|
||||
if not np.issubdtype(x.edge_links.dtype, np.integer):
|
||||
return False
|
||||
if x.edge_links.shape[-1] != 2:
|
||||
return False
|
||||
if self.edge_space is None:
|
||||
return False
|
||||
if x.edge_links.min() < 0:
|
||||
return False
|
||||
if x.edge_links.max() >= len(x.nodes):
|
||||
return False
|
||||
if len(x.edges) != len(x.edge_links):
|
||||
return False
|
||||
if any(edge not in self.edge_space for edge in x.edges):
|
||||
return False
|
||||
if any(node not in self.node_space for node in x.nodes):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""A string representation of this space.
|
||||
|
||||
The representation will include node_space and edge_space
|
||||
|
||||
Returns:
|
||||
A representation of the space
|
||||
"""
|
||||
return f"Graph({self.node_space}, {self.edge_space})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""Check whether `other` is equivalent to this instance."""
|
||||
return (
|
||||
isinstance(other, Graph)
|
||||
and (self.node_space == other.node_space)
|
||||
and (self.edge_space == other.edge_space)
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n: NamedTuple) -> list:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# serialize as list of dicts
|
||||
ret_n = []
|
||||
for sample in sample_n:
|
||||
ret = {}
|
||||
ret["nodes"] = sample.nodes.tolist()
|
||||
if sample.edges is not None:
|
||||
ret["edges"] = sample.edges.tolist()
|
||||
ret["edge_links"] = sample.edge_links.tolist()
|
||||
ret_n.append(ret)
|
||||
return ret_n
|
||||
|
||||
def from_jsonable(self, sample_n: Sequence[dict]) -> list:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
ret = []
|
||||
for sample in sample_n:
|
||||
if "edges" in sample:
|
||||
ret_n = GraphInstance(
|
||||
np.asarray(sample["nodes"]),
|
||||
np.asarray(sample["edges"]),
|
||||
np.asarray(sample["edge_links"]),
|
||||
)
|
||||
else:
|
||||
ret_n = GraphInstance(
|
||||
np.asarray(sample["nodes"]),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
ret.append(ret_n)
|
||||
return ret
|
@@ -10,7 +10,17 @@ from typing import TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
||||
from gym.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Space,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -77,7 +87,15 @@ def flatten(space: Space[T], x: T) -> np.ndarray:
|
||||
x: The value to flatten
|
||||
|
||||
Returns:
|
||||
The flattened ``x``, always returns a 1D array.
|
||||
- The flattened ``x``, always returns a 1D array for non-graph spaces.
|
||||
- For graph spaces, returns `GraphInstance` where:
|
||||
- `nodes` are n x k arrays
|
||||
- `edges` are either:
|
||||
- m x k arrays
|
||||
- None
|
||||
- `edge_links` are either:
|
||||
- m x 2 arrays
|
||||
- None
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the space is not defined in ``gym.spaces``.
|
||||
@@ -118,6 +136,26 @@ def _flatten_dict(space, x) -> np.ndarray:
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
|
||||
|
||||
@flatten.register(Graph)
|
||||
def _flatten_graph(space, x) -> np.ndarray:
|
||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
||||
|
||||
def _graph_unflatten(space, x):
|
||||
ret = None
|
||||
if space is not None and x is not None:
|
||||
if isinstance(space, Box):
|
||||
ret = x.reshape(x.shape[0], -1)
|
||||
elif isinstance(space, Discrete):
|
||||
ret = np.zeros((x.shape[0], space.n - space.start), dtype=space.dtype)
|
||||
ret[np.arange(x.shape[0]), x - space.start] = 1
|
||||
return ret
|
||||
|
||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def unflatten(space: Space[T], x: np.ndarray) -> T:
|
||||
"""Unflatten a data point from a space.
|
||||
@@ -181,14 +219,40 @@ def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
|
||||
)
|
||||
|
||||
|
||||
@unflatten.register(Graph)
|
||||
def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space.
|
||||
|
||||
The size of the outcome is actually not fixed, but determined based on the number of
|
||||
nodes and edges in the graph.
|
||||
"""
|
||||
|
||||
def _graph_unflatten(space, x):
|
||||
ret = None
|
||||
if space is not None and x is not None:
|
||||
if isinstance(space, Box):
|
||||
ret = x.reshape(-1, *space.shape)
|
||||
elif isinstance(space, Discrete):
|
||||
ret = np.asarray(np.nonzero(x))[-1, :]
|
||||
return ret
|
||||
|
||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space: Space) -> Box:
|
||||
"""Flatten a space into a single ``Box``.
|
||||
|
||||
This is equivalent to :func:`flatten`, but operates on the space itself. The
|
||||
result always is a `Box` with flat boundaries. The box has exactly
|
||||
:func:`flatdim` dimensions. Flattening a sample of the original space
|
||||
has the same effect as taking a sample of the flattenend space.
|
||||
result for non-graph spaces is always a `Box` with flat boundaries. While
|
||||
the result for graph spaces is always a `Graph` with `node_space` being a `Box`
|
||||
with flat boundaries and `edge_space` being a `Box` with flat boundaries or
|
||||
`None`. The box has exactly :func:`flatdim` dimensions. Flattening a sample
|
||||
of the original space has the same effect as taking a sample of the flattenend
|
||||
space.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -216,6 +280,15 @@ def flatten_space(space: Space) -> Box:
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
True
|
||||
|
||||
|
||||
Example that flattens a graph::
|
||||
|
||||
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
|
||||
>>> flatten_space(space)
|
||||
Graph(Box(-100.0, 100.0, (12,), float32), Box(0, 1, (5,), int64))
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
True
|
||||
|
||||
Args:
|
||||
space: The space to flatten
|
||||
|
||||
@@ -258,3 +331,13 @@ def _flatten_space_dict(space: Dict) -> Box:
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Graph)
|
||||
def _flatten_space_graph(space: Graph) -> Graph:
|
||||
return Graph(
|
||||
node_space=flatten_space(space.node_space),
|
||||
edge_space=flatten_space(space.edge_space)
|
||||
if space.edge_space is not None
|
||||
else None,
|
||||
)
|
||||
|
@@ -6,7 +6,7 @@ import tempfile
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
|
||||
from gym.spaces import Box, Dict, Discrete, Graph, MultiBinary, MultiDiscrete, Tuple
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -40,6 +40,9 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
|
||||
),
|
||||
}
|
||||
),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
],
|
||||
)
|
||||
def test_roundtripping(space):
|
||||
@@ -94,6 +97,9 @@ def test_roundtripping(space):
|
||||
),
|
||||
}
|
||||
),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
],
|
||||
)
|
||||
def test_equality(space):
|
||||
@@ -130,6 +136,12 @@ def test_equality(space):
|
||||
),
|
||||
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
||||
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
||||
(
|
||||
Graph(
|
||||
node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)
|
||||
),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_inequality(spaces):
|
||||
@@ -192,6 +204,12 @@ def test_sample(space):
|
||||
Box(low=np.array([-np.inf, 0.0]), high=np.array([0.0, np.inf])),
|
||||
Box(low=np.array([-np.inf, 1.0]), high=np.array([0.0, np.inf])),
|
||||
),
|
||||
(
|
||||
Graph(
|
||||
node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)
|
||||
),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_class_inequality(spaces):
|
||||
@@ -306,6 +324,10 @@ def test_box_dtype_check():
|
||||
),
|
||||
}
|
||||
),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
],
|
||||
)
|
||||
def test_seed_returns_list(space):
|
||||
@@ -365,6 +387,10 @@ def sample_equal(sample1, sample2):
|
||||
),
|
||||
}
|
||||
),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
],
|
||||
)
|
||||
def test_seed_reproducibility(space):
|
||||
@@ -405,10 +431,23 @@ def test_seed_reproducibility(space):
|
||||
),
|
||||
}
|
||||
),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
],
|
||||
)
|
||||
def test_seed_subspace_incorrelated(space):
|
||||
subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()
|
||||
subspaces = []
|
||||
if isinstance(space, Tuple):
|
||||
subspaces = space.spaces
|
||||
elif isinstance(space, Dict):
|
||||
subspaces = space.spaces.values()
|
||||
elif isinstance(space, Graph):
|
||||
if space.edge_space is not None:
|
||||
subspaces = [space.node_space, space.edge_space]
|
||||
else:
|
||||
subspaces = [space.node_space]
|
||||
|
||||
space.seed(0)
|
||||
states = [
|
||||
@@ -657,6 +696,10 @@ def test_box_legacy_state_pickling():
|
||||
),
|
||||
}
|
||||
),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
],
|
||||
)
|
||||
def test_pickle(space):
|
||||
|
@@ -3,9 +3,18 @@ from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, utils
|
||||
from gym.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Tuple,
|
||||
utils,
|
||||
)
|
||||
|
||||
spaces = [
|
||||
homogeneous_spaces = [
|
||||
Discrete(3),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2), dtype=np.float16),
|
||||
@@ -33,14 +42,20 @@ spaces = [
|
||||
|
||||
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
|
||||
|
||||
graph_spaces = [
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
|
||||
def test_flatdim(space, flatdim):
|
||||
dim = utils.flatdim(space)
|
||||
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", spaces)
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces)
|
||||
def test_flatten_space_boxes(space):
|
||||
flat_space = utils.flatten_space(space)
|
||||
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
|
||||
@@ -49,18 +64,18 @@ def test_flatten_space_boxes(space):
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", spaces)
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
|
||||
def test_flat_space_contains_flat_points(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
flat_space = utils.flatten_space(space)
|
||||
for i, flat_sample in enumerate(flattened_samples):
|
||||
assert (
|
||||
flat_sample in flat_space
|
||||
assert flat_space.contains(
|
||||
flat_sample
|
||||
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", spaces)
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces)
|
||||
def test_flatten_dim(space):
|
||||
sample = utils.flatten(space, space.sample())
|
||||
(single_dim,) = sample.shape
|
||||
@@ -68,7 +83,7 @@ def test_flatten_dim(space):
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", spaces)
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
|
||||
def test_flatten_roundtripping(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
@@ -131,7 +146,7 @@ expected_flattened_dtypes = [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["original_space", "expected_flattened_dtype"],
|
||||
zip(spaces, expected_flattened_dtypes),
|
||||
zip(homogeneous_spaces, expected_flattened_dtypes),
|
||||
)
|
||||
def test_dtypes(original_space, expected_flattened_dtype):
|
||||
flattened_space = utils.flatten_space(original_space)
|
||||
@@ -212,7 +227,7 @@ expected_flattened_samples = [
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "sample", "expected_flattened_sample"],
|
||||
zip(spaces, samples, expected_flattened_samples),
|
||||
zip(homogeneous_spaces, samples, expected_flattened_samples),
|
||||
)
|
||||
def test_flatten(space, sample, expected_flattened_sample):
|
||||
assert sample in space
|
||||
@@ -225,7 +240,7 @@ def test_flatten(space, sample, expected_flattened_sample):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "flattened_sample", "expected_sample"],
|
||||
zip(spaces, expected_flattened_samples, samples),
|
||||
zip(homogeneous_spaces, expected_flattened_samples, samples),
|
||||
)
|
||||
def test_unflatten(space, flattened_sample, expected_sample):
|
||||
sample = utils.unflatten(space, flattened_sample)
|
||||
@@ -256,7 +271,8 @@ expected_flattened_spaces = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "expected_flattened_space"], zip(spaces, expected_flattened_spaces)
|
||||
["space", "expected_flattened_space"],
|
||||
zip(homogeneous_spaces, expected_flattened_spaces),
|
||||
)
|
||||
def test_flatten_space(space, expected_flattened_space):
|
||||
flattened_space = utils.flatten_space(space)
|
||||
|
Reference in New Issue
Block a user