Add Graph to Spaces (#2869)

This commit is contained in:
Jet
2022-06-09 15:42:58 +01:00
committed by GitHub
parent f9e2b92e00
commit a6274a55f0
5 changed files with 376 additions and 20 deletions

View File

@@ -11,6 +11,7 @@ are vectors in the two-dimensional unit cube, the environment code may contain t
from gym.spaces.box import Box
from gym.spaces.dict import Dict
from gym.spaces.discrete import Discrete
from gym.spaces.graph import Graph, GraphInstance
from gym.spaces.multi_binary import MultiBinary
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.space import Space
@@ -21,6 +22,8 @@ __all__ = [
"Space",
"Box",
"Discrete",
"Graph",
"GraphInstance",
"MultiDiscrete",
"MultiBinary",
"Tuple",

211
gym/spaces/graph.py Normal file
View File

@@ -0,0 +1,211 @@
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from collections import namedtuple
from typing import NamedTuple, Optional, Sequence, Union
import numpy as np
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.space import Space
from gym.utils import seeding
class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"])):
r"""Returns a NamedTuple representing a graph object.
Args:
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes.
(...) must adhere to the shape of the node space.
edges (np.ndarray): an (m x ...) sized array representing the features for m nodes.
(...) must adhere to the shape of the edge space.
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
Returns:
A NamedTuple representing a graph with `.nodes`, `.edges`, and `.edge_links`.
"""
class Graph(Space):
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
Example usage::
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3))
"""
def __init__(
self,
node_space: Union[Box, Discrete],
edge_space: Union[None, Box, Discrete],
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
):
r"""Constructor of :class:`Graph`.
The argument ``node_space`` specifies the base space that each node feature will use.
This argument must be either a Box or Discrete instance.
The argument ``edge_space`` specifies the base space that each edge feature will use.
This argument must be either a None, Box or Discrete instance.
Args:
node_space (Union[Box, Discrete]): space of the node features.
edge_space (Union[None, Box, Discrete]): space of the node features.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
assert isinstance(
node_space, (Box, Discrete)
), f"Values of the node_space should be instances of Box or Discrete, got {type(node_space)}"
if edge_space is not None:
assert isinstance(
edge_space, (Box, Discrete)
), f"Values of the edge_space should be instances of None Box or Discrete, got {type(node_space)}"
self.node_space = node_space
self.edge_space = edge_space
super().__init__(None, None, seed)
def _generate_sample_space(
self, base_space: Union[None, Box, Discrete], num: int
) -> Optional[Union[Box, Discrete]]:
# the possibility of this space , got {type(base_space)}aving nothing
if num == 0:
return None
if isinstance(base_space, Box):
return Box(
low=np.array(max(1, num) * [base_space.low]),
high=np.array(max(1, num) * [base_space.high]),
shape=(num, *base_space.shape),
dtype=base_space.dtype,
seed=self._np_random,
)
elif isinstance(base_space, Discrete):
return MultiDiscrete(nvec=[base_space.n] * num, seed=self._np_random)
elif base_space is None:
return None
else:
raise AssertionError(
f"Only Box and Discrete can be accepted as a base_space, got {type(base_space)}, you should not have gotten this error."
)
def _sample_sample_space(self, sample_space) -> Optional[np.ndarray]:
if sample_space is not None:
return sample_space.sample()
else:
return None
def sample(self) -> NamedTuple:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
Returns:
A NamedTuple representing a graph with attributes .nodes, .edges, and .edge_links.
"""
num_nodes = self.np_random.integers(low=1, high=10)
# we only have edges when we have at least 2 nodes
num_edges = 0
if num_nodes > 1:
# maximal number of edges is (n*n) allowing self connections and two way is allowed
num_edges = self.np_random.integers(num_nodes * num_nodes)
node_sample_space = self._generate_sample_space(self.node_space, num_nodes)
edge_sample_space = self._generate_sample_space(self.edge_space, num_edges)
sampled_nodes = self._sample_sample_space(node_sample_space)
sampled_edges = self._sample_sample_space(edge_sample_space)
sampled_edge_links = None
if sampled_edges is not None and num_edges > 0:
sampled_edge_links = self.np_random.integers(
low=0, high=num_nodes, size=(num_edges, 2)
)
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
def contains(self, x: GraphInstance) -> bool:
"""Return boolean specifying if x is a valid member of this space.
Returns False when:
- any node in nodes is not contained in Graph.node_space
- edge_links is not of dtype int
- len(edge_links) != len(edges)
- has edges but Graph.edge_space is None
- edge_links has index less than 0
- edge_links has index more than number of nodes
- any edge in edges is not contained in Graph.edge_space
"""
if not isinstance(x, GraphInstance):
return False
if x.edges is not None:
if not np.issubdtype(x.edge_links.dtype, np.integer):
return False
if x.edge_links.shape[-1] != 2:
return False
if self.edge_space is None:
return False
if x.edge_links.min() < 0:
return False
if x.edge_links.max() >= len(x.nodes):
return False
if len(x.edges) != len(x.edge_links):
return False
if any(edge not in self.edge_space for edge in x.edges):
return False
if any(node not in self.node_space for node in x.nodes):
return False
return True
def __repr__(self) -> str:
"""A string representation of this space.
The representation will include node_space and edge_space
Returns:
A representation of the space
"""
return f"Graph({self.node_space}, {self.edge_space})"
def __eq__(self, other) -> bool:
"""Check whether `other` is equivalent to this instance."""
return (
isinstance(other, Graph)
and (self.node_space == other.node_space)
and (self.edge_space == other.edge_space)
)
def to_jsonable(self, sample_n: NamedTuple) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list of dicts
ret_n = []
for sample in sample_n:
ret = {}
ret["nodes"] = sample.nodes.tolist()
if sample.edges is not None:
ret["edges"] = sample.edges.tolist()
ret["edge_links"] = sample.edge_links.tolist()
ret_n.append(ret)
return ret_n
def from_jsonable(self, sample_n: Sequence[dict]) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
ret = []
for sample in sample_n:
if "edges" in sample:
ret_n = GraphInstance(
np.asarray(sample["nodes"]),
np.asarray(sample["edges"]),
np.asarray(sample["edge_links"]),
)
else:
ret_n = GraphInstance(
np.asarray(sample["nodes"]),
None,
None,
)
ret.append(ret_n)
return ret

View File

@@ -10,7 +10,17 @@ from typing import TypeVar, Union, cast
import numpy as np
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
from gym.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
@singledispatch
@@ -77,7 +87,15 @@ def flatten(space: Space[T], x: T) -> np.ndarray:
x: The value to flatten
Returns:
The flattened ``x``, always returns a 1D array.
- The flattened ``x``, always returns a 1D array for non-graph spaces.
- For graph spaces, returns `GraphInstance` where:
- `nodes` are n x k arrays
- `edges` are either:
- m x k arrays
- None
- `edge_links` are either:
- m x 2 arrays
- None
Raises:
NotImplementedError: If the space is not defined in ``gym.spaces``.
@@ -118,6 +136,26 @@ def _flatten_dict(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
@flatten.register(Graph)
def _flatten_graph(space, x) -> np.ndarray:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
def _graph_unflatten(space, x):
ret = None
if space is not None and x is not None:
if isinstance(space, Box):
ret = x.reshape(x.shape[0], -1)
elif isinstance(space, Discrete):
ret = np.zeros((x.shape[0], space.n - space.start), dtype=space.dtype)
ret[np.arange(x.shape[0]), x - space.start] = 1
return ret
nodes = _graph_unflatten(space.node_space, x.nodes)
edges = _graph_unflatten(space.edge_space, x.edges)
return GraphInstance(nodes, edges, x.edge_links)
@singledispatch
def unflatten(space: Space[T], x: np.ndarray) -> T:
"""Unflatten a data point from a space.
@@ -181,14 +219,40 @@ def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
)
@unflatten.register(Graph)
def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space.
The size of the outcome is actually not fixed, but determined based on the number of
nodes and edges in the graph.
"""
def _graph_unflatten(space, x):
ret = None
if space is not None and x is not None:
if isinstance(space, Box):
ret = x.reshape(-1, *space.shape)
elif isinstance(space, Discrete):
ret = np.asarray(np.nonzero(x))[-1, :]
return ret
nodes = _graph_unflatten(space.node_space, x.nodes)
edges = _graph_unflatten(space.edge_space, x.edges)
return GraphInstance(nodes, edges, x.edge_links)
@singledispatch
def flatten_space(space: Space) -> Box:
"""Flatten a space into a single ``Box``.
This is equivalent to :func:`flatten`, but operates on the space itself. The
result always is a `Box` with flat boundaries. The box has exactly
:func:`flatdim` dimensions. Flattening a sample of the original space
has the same effect as taking a sample of the flattenend space.
result for non-graph spaces is always a `Box` with flat boundaries. While
the result for graph spaces is always a `Graph` with `node_space` being a `Box`
with flat boundaries and `edge_space` being a `Box` with flat boundaries or
`None`. The box has exactly :func:`flatdim` dimensions. Flattening a sample
of the original space has the same effect as taking a sample of the flattenend
space.
Example::
@@ -216,6 +280,15 @@ def flatten_space(space: Space) -> Box:
>>> flatten(space, space.sample()) in flatten_space(space)
True
Example that flattens a graph::
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
>>> flatten_space(space)
Graph(Box(-100.0, 100.0, (12,), float32), Box(0, 1, (5,), int64))
>>> flatten(space, space.sample()) in flatten_space(space)
True
Args:
space: The space to flatten
@@ -258,3 +331,13 @@ def _flatten_space_dict(space: Dict) -> Box:
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)
@flatten_space.register(Graph)
def _flatten_space_graph(space: Graph) -> Graph:
return Graph(
node_space=flatten_space(space.node_space),
edge_space=flatten_space(space.edge_space)
if space.edge_space is not None
else None,
)

View File

@@ -6,7 +6,7 @@ import tempfile
import numpy as np
import pytest
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
from gym.spaces import Box, Dict, Discrete, Graph, MultiBinary, MultiDiscrete, Tuple
@pytest.mark.parametrize(
@@ -40,6 +40,9 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
),
}
),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(5), edge_space=None),
],
)
def test_roundtripping(space):
@@ -94,6 +97,9 @@ def test_roundtripping(space):
),
}
),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(5), edge_space=None),
],
)
def test_equality(space):
@@ -130,6 +136,12 @@ def test_equality(space):
),
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
(
Graph(
node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)
),
Graph(node_space=Discrete(5), edge_space=None),
),
],
)
def test_inequality(spaces):
@@ -192,6 +204,12 @@ def test_sample(space):
Box(low=np.array([-np.inf, 0.0]), high=np.array([0.0, np.inf])),
Box(low=np.array([-np.inf, 1.0]), high=np.array([0.0, np.inf])),
),
(
Graph(
node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)
),
Graph(node_space=Discrete(5), edge_space=None),
),
],
)
def test_class_inequality(spaces):
@@ -306,6 +324,10 @@ def test_box_dtype_check():
),
}
),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
Graph(node_space=Discrete(5), edge_space=None),
],
)
def test_seed_returns_list(space):
@@ -365,6 +387,10 @@ def sample_equal(sample1, sample2):
),
}
),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
Graph(node_space=Discrete(5), edge_space=None),
],
)
def test_seed_reproducibility(space):
@@ -405,10 +431,23 @@ def test_seed_reproducibility(space):
),
}
),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
Graph(node_space=Discrete(5), edge_space=None),
],
)
def test_seed_subspace_incorrelated(space):
subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()
subspaces = []
if isinstance(space, Tuple):
subspaces = space.spaces
elif isinstance(space, Dict):
subspaces = space.spaces.values()
elif isinstance(space, Graph):
if space.edge_space is not None:
subspaces = [space.node_space, space.edge_space]
else:
subspaces = [space.node_space]
space.seed(0)
states = [
@@ -657,6 +696,10 @@ def test_box_legacy_state_pickling():
),
}
),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
Graph(node_space=Discrete(5), edge_space=None),
],
)
def test_pickle(space):

View File

@@ -3,9 +3,18 @@ from collections import OrderedDict
import numpy as np
import pytest
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, utils
from gym.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
Tuple,
utils,
)
spaces = [
homogeneous_spaces = [
Discrete(3),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Box(low=0.0, high=np.inf, shape=(2, 2), dtype=np.float16),
@@ -33,14 +42,20 @@ spaces = [
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
graph_spaces = [
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(5), edge_space=None),
]
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
def test_flatdim(space, flatdim):
dim = utils.flatdim(space)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
@pytest.mark.parametrize("space", spaces)
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_space_boxes(space):
flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
@@ -49,18 +64,18 @@ def test_flatten_space_boxes(space):
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", spaces)
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
def test_flat_space_contains_flat_points(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples):
assert (
flat_sample in flat_space
assert flat_space.contains(
flat_sample
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
@pytest.mark.parametrize("space", spaces)
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_dim(space):
sample = utils.flatten(space, space.sample())
(single_dim,) = sample.shape
@@ -68,7 +83,7 @@ def test_flatten_dim(space):
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", spaces)
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
@@ -131,7 +146,7 @@ expected_flattened_dtypes = [
@pytest.mark.parametrize(
["original_space", "expected_flattened_dtype"],
zip(spaces, expected_flattened_dtypes),
zip(homogeneous_spaces, expected_flattened_dtypes),
)
def test_dtypes(original_space, expected_flattened_dtype):
flattened_space = utils.flatten_space(original_space)
@@ -212,7 +227,7 @@ expected_flattened_samples = [
@pytest.mark.parametrize(
["space", "sample", "expected_flattened_sample"],
zip(spaces, samples, expected_flattened_samples),
zip(homogeneous_spaces, samples, expected_flattened_samples),
)
def test_flatten(space, sample, expected_flattened_sample):
assert sample in space
@@ -225,7 +240,7 @@ def test_flatten(space, sample, expected_flattened_sample):
@pytest.mark.parametrize(
["space", "flattened_sample", "expected_sample"],
zip(spaces, expected_flattened_samples, samples),
zip(homogeneous_spaces, expected_flattened_samples, samples),
)
def test_unflatten(space, flattened_sample, expected_sample):
sample = utils.unflatten(space, flattened_sample)
@@ -256,7 +271,8 @@ expected_flattened_spaces = [
@pytest.mark.parametrize(
["space", "expected_flattened_space"], zip(spaces, expected_flattened_spaces)
["space", "expected_flattened_space"],
zip(homogeneous_spaces, expected_flattened_spaces),
)
def test_flatten_space(space, expected_flattened_space):
flattened_space = utils.flatten_space(space)