Files
Gymnasium/gym/spaces/graph.py

212 lines
8.0 KiB
Python
Raw Normal View History

2022-06-09 15:42:58 +01:00
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from collections import namedtuple
from typing import NamedTuple, Optional, Sequence, Union
import numpy as np
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.space import Space
from gym.utils import seeding
class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"])):
r"""Returns a NamedTuple representing a graph object.
Args:
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes.
(...) must adhere to the shape of the node space.
edges (np.ndarray): an (m x ...) sized array representing the features for m nodes.
(...) must adhere to the shape of the edge space.
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
Returns:
A NamedTuple representing a graph with `.nodes`, `.edges`, and `.edge_links`.
"""
class Graph(Space):
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
Example usage::
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3))
"""
def __init__(
self,
node_space: Union[Box, Discrete],
edge_space: Union[None, Box, Discrete],
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
):
r"""Constructor of :class:`Graph`.
The argument ``node_space`` specifies the base space that each node feature will use.
This argument must be either a Box or Discrete instance.
The argument ``edge_space`` specifies the base space that each edge feature will use.
This argument must be either a None, Box or Discrete instance.
Args:
node_space (Union[Box, Discrete]): space of the node features.
edge_space (Union[None, Box, Discrete]): space of the node features.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
assert isinstance(
node_space, (Box, Discrete)
), f"Values of the node_space should be instances of Box or Discrete, got {type(node_space)}"
if edge_space is not None:
assert isinstance(
edge_space, (Box, Discrete)
), f"Values of the edge_space should be instances of None Box or Discrete, got {type(node_space)}"
self.node_space = node_space
self.edge_space = edge_space
super().__init__(None, None, seed)
def _generate_sample_space(
self, base_space: Union[None, Box, Discrete], num: int
) -> Optional[Union[Box, Discrete]]:
# the possibility of this space , got {type(base_space)}aving nothing
if num == 0:
return None
if isinstance(base_space, Box):
return Box(
low=np.array(max(1, num) * [base_space.low]),
high=np.array(max(1, num) * [base_space.high]),
shape=(num, *base_space.shape),
dtype=base_space.dtype,
seed=self._np_random,
)
elif isinstance(base_space, Discrete):
return MultiDiscrete(nvec=[base_space.n] * num, seed=self._np_random)
elif base_space is None:
return None
else:
raise AssertionError(
f"Only Box and Discrete can be accepted as a base_space, got {type(base_space)}, you should not have gotten this error."
)
def _sample_sample_space(self, sample_space) -> Optional[np.ndarray]:
if sample_space is not None:
return sample_space.sample()
else:
return None
def sample(self) -> NamedTuple:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
Returns:
A NamedTuple representing a graph with attributes .nodes, .edges, and .edge_links.
"""
num_nodes = self.np_random.integers(low=1, high=10)
# we only have edges when we have at least 2 nodes
num_edges = 0
if num_nodes > 1:
# maximal number of edges is (n*n) allowing self connections and two way is allowed
num_edges = self.np_random.integers(num_nodes * num_nodes)
node_sample_space = self._generate_sample_space(self.node_space, num_nodes)
edge_sample_space = self._generate_sample_space(self.edge_space, num_edges)
sampled_nodes = self._sample_sample_space(node_sample_space)
sampled_edges = self._sample_sample_space(edge_sample_space)
sampled_edge_links = None
if sampled_edges is not None and num_edges > 0:
sampled_edge_links = self.np_random.integers(
low=0, high=num_nodes, size=(num_edges, 2)
)
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
def contains(self, x: GraphInstance) -> bool:
"""Return boolean specifying if x is a valid member of this space.
Returns False when:
- any node in nodes is not contained in Graph.node_space
- edge_links is not of dtype int
- len(edge_links) != len(edges)
- has edges but Graph.edge_space is None
- edge_links has index less than 0
- edge_links has index more than number of nodes
- any edge in edges is not contained in Graph.edge_space
"""
if not isinstance(x, GraphInstance):
return False
if x.edges is not None:
if not np.issubdtype(x.edge_links.dtype, np.integer):
return False
if x.edge_links.shape[-1] != 2:
return False
if self.edge_space is None:
return False
if x.edge_links.min() < 0:
return False
if x.edge_links.max() >= len(x.nodes):
return False
if len(x.edges) != len(x.edge_links):
return False
if any(edge not in self.edge_space for edge in x.edges):
return False
if any(node not in self.node_space for node in x.nodes):
return False
return True
def __repr__(self) -> str:
"""A string representation of this space.
The representation will include node_space and edge_space
Returns:
A representation of the space
"""
return f"Graph({self.node_space}, {self.edge_space})"
def __eq__(self, other) -> bool:
"""Check whether `other` is equivalent to this instance."""
return (
isinstance(other, Graph)
and (self.node_space == other.node_space)
and (self.edge_space == other.edge_space)
)
def to_jsonable(self, sample_n: NamedTuple) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list of dicts
ret_n = []
for sample in sample_n:
ret = {}
ret["nodes"] = sample.nodes.tolist()
if sample.edges is not None:
ret["edges"] = sample.edges.tolist()
ret["edge_links"] = sample.edge_links.tolist()
ret_n.append(ret)
return ret_n
def from_jsonable(self, sample_n: Sequence[dict]) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
ret = []
for sample in sample_n:
if "edges" in sample:
ret_n = GraphInstance(
np.asarray(sample["nodes"]),
np.asarray(sample["edges"]),
np.asarray(sample["edge_links"]),
)
else:
ret_n = GraphInstance(
np.asarray(sample["nodes"]),
None,
None,
)
ret.append(ret_n)
return ret