Add Graph to Spaces (#2869)

This commit is contained in:
Jet
2022-06-09 15:42:58 +01:00
committed by GitHub
parent f9e2b92e00
commit a6274a55f0
5 changed files with 376 additions and 20 deletions

View File

@@ -10,7 +10,17 @@ from typing import TypeVar, Union, cast
import numpy as np
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
from gym.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
@singledispatch
@@ -77,7 +87,15 @@ def flatten(space: Space[T], x: T) -> np.ndarray:
x: The value to flatten
Returns:
The flattened ``x``, always returns a 1D array.
- The flattened ``x``, always returns a 1D array for non-graph spaces.
- For graph spaces, returns `GraphInstance` where:
- `nodes` are n x k arrays
- `edges` are either:
- m x k arrays
- None
- `edge_links` are either:
- m x 2 arrays
- None
Raises:
NotImplementedError: If the space is not defined in ``gym.spaces``.
@@ -118,6 +136,26 @@ def _flatten_dict(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
@flatten.register(Graph)
def _flatten_graph(space, x) -> np.ndarray:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
def _graph_unflatten(space, x):
ret = None
if space is not None and x is not None:
if isinstance(space, Box):
ret = x.reshape(x.shape[0], -1)
elif isinstance(space, Discrete):
ret = np.zeros((x.shape[0], space.n - space.start), dtype=space.dtype)
ret[np.arange(x.shape[0]), x - space.start] = 1
return ret
nodes = _graph_unflatten(space.node_space, x.nodes)
edges = _graph_unflatten(space.edge_space, x.edges)
return GraphInstance(nodes, edges, x.edge_links)
@singledispatch
def unflatten(space: Space[T], x: np.ndarray) -> T:
"""Unflatten a data point from a space.
@@ -181,14 +219,40 @@ def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
)
@unflatten.register(Graph)
def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space.
The size of the outcome is actually not fixed, but determined based on the number of
nodes and edges in the graph.
"""
def _graph_unflatten(space, x):
ret = None
if space is not None and x is not None:
if isinstance(space, Box):
ret = x.reshape(-1, *space.shape)
elif isinstance(space, Discrete):
ret = np.asarray(np.nonzero(x))[-1, :]
return ret
nodes = _graph_unflatten(space.node_space, x.nodes)
edges = _graph_unflatten(space.edge_space, x.edges)
return GraphInstance(nodes, edges, x.edge_links)
@singledispatch
def flatten_space(space: Space) -> Box:
"""Flatten a space into a single ``Box``.
This is equivalent to :func:`flatten`, but operates on the space itself. The
result always is a `Box` with flat boundaries. The box has exactly
:func:`flatdim` dimensions. Flattening a sample of the original space
has the same effect as taking a sample of the flattenend space.
result for non-graph spaces is always a `Box` with flat boundaries. While
the result for graph spaces is always a `Graph` with `node_space` being a `Box`
with flat boundaries and `edge_space` being a `Box` with flat boundaries or
`None`. The box has exactly :func:`flatdim` dimensions. Flattening a sample
of the original space has the same effect as taking a sample of the flattenend
space.
Example::
@@ -216,6 +280,15 @@ def flatten_space(space: Space) -> Box:
>>> flatten(space, space.sample()) in flatten_space(space)
True
Example that flattens a graph::
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
>>> flatten_space(space)
Graph(Box(-100.0, 100.0, (12,), float32), Box(0, 1, (5,), int64))
>>> flatten(space, space.sample()) in flatten_space(space)
True
Args:
space: The space to flatten
@@ -258,3 +331,13 @@ def _flatten_space_dict(space: Dict) -> Box:
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)
@flatten_space.register(Graph)
def _flatten_space_graph(space: Graph) -> Graph:
return Graph(
node_space=flatten_space(space.node_space),
edge_space=flatten_space(space.edge_space)
if space.edge_space is not None
else None,
)