mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
Add Graph to Spaces (#2869)
This commit is contained in:
@@ -10,7 +10,17 @@ from typing import TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
||||
from gym.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Space,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -77,7 +87,15 @@ def flatten(space: Space[T], x: T) -> np.ndarray:
|
||||
x: The value to flatten
|
||||
|
||||
Returns:
|
||||
The flattened ``x``, always returns a 1D array.
|
||||
- The flattened ``x``, always returns a 1D array for non-graph spaces.
|
||||
- For graph spaces, returns `GraphInstance` where:
|
||||
- `nodes` are n x k arrays
|
||||
- `edges` are either:
|
||||
- m x k arrays
|
||||
- None
|
||||
- `edge_links` are either:
|
||||
- m x 2 arrays
|
||||
- None
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the space is not defined in ``gym.spaces``.
|
||||
@@ -118,6 +136,26 @@ def _flatten_dict(space, x) -> np.ndarray:
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
|
||||
|
||||
@flatten.register(Graph)
|
||||
def _flatten_graph(space, x) -> np.ndarray:
|
||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
||||
|
||||
def _graph_unflatten(space, x):
|
||||
ret = None
|
||||
if space is not None and x is not None:
|
||||
if isinstance(space, Box):
|
||||
ret = x.reshape(x.shape[0], -1)
|
||||
elif isinstance(space, Discrete):
|
||||
ret = np.zeros((x.shape[0], space.n - space.start), dtype=space.dtype)
|
||||
ret[np.arange(x.shape[0]), x - space.start] = 1
|
||||
return ret
|
||||
|
||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def unflatten(space: Space[T], x: np.ndarray) -> T:
|
||||
"""Unflatten a data point from a space.
|
||||
@@ -181,14 +219,40 @@ def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
|
||||
)
|
||||
|
||||
|
||||
@unflatten.register(Graph)
|
||||
def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space.
|
||||
|
||||
The size of the outcome is actually not fixed, but determined based on the number of
|
||||
nodes and edges in the graph.
|
||||
"""
|
||||
|
||||
def _graph_unflatten(space, x):
|
||||
ret = None
|
||||
if space is not None and x is not None:
|
||||
if isinstance(space, Box):
|
||||
ret = x.reshape(-1, *space.shape)
|
||||
elif isinstance(space, Discrete):
|
||||
ret = np.asarray(np.nonzero(x))[-1, :]
|
||||
return ret
|
||||
|
||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space: Space) -> Box:
|
||||
"""Flatten a space into a single ``Box``.
|
||||
|
||||
This is equivalent to :func:`flatten`, but operates on the space itself. The
|
||||
result always is a `Box` with flat boundaries. The box has exactly
|
||||
:func:`flatdim` dimensions. Flattening a sample of the original space
|
||||
has the same effect as taking a sample of the flattenend space.
|
||||
result for non-graph spaces is always a `Box` with flat boundaries. While
|
||||
the result for graph spaces is always a `Graph` with `node_space` being a `Box`
|
||||
with flat boundaries and `edge_space` being a `Box` with flat boundaries or
|
||||
`None`. The box has exactly :func:`flatdim` dimensions. Flattening a sample
|
||||
of the original space has the same effect as taking a sample of the flattenend
|
||||
space.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -216,6 +280,15 @@ def flatten_space(space: Space) -> Box:
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
True
|
||||
|
||||
|
||||
Example that flattens a graph::
|
||||
|
||||
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
|
||||
>>> flatten_space(space)
|
||||
Graph(Box(-100.0, 100.0, (12,), float32), Box(0, 1, (5,), int64))
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
True
|
||||
|
||||
Args:
|
||||
space: The space to flatten
|
||||
|
||||
@@ -258,3 +331,13 @@ def _flatten_space_dict(space: Dict) -> Box:
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Graph)
|
||||
def _flatten_space_graph(space: Graph) -> Graph:
|
||||
return Graph(
|
||||
node_space=flatten_space(space.node_space),
|
||||
edge_space=flatten_space(space.edge_space)
|
||||
if space.edge_space is not None
|
||||
else None,
|
||||
)
|
||||
|
Reference in New Issue
Block a user