2022-06-09 15:42:58 +01:00
""" Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space. """
2022-11-15 14:09:22 +00:00
from __future__ import annotations
from typing import Any , NamedTuple , Sequence
2022-06-09 15:42:58 +01:00
import numpy as np
2022-11-15 14:09:22 +00:00
from numpy . typing import NDArray
2022-06-09 15:42:58 +01:00
2022-11-15 14:09:22 +00:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium . spaces . box import Box
from gymnasium . spaces . discrete import Discrete
from gymnasium . spaces . multi_discrete import MultiDiscrete
from gymnasium . spaces . space import Space
2022-06-09 15:42:58 +01:00
2022-09-03 23:39:23 +01:00
class GraphInstance ( NamedTuple ) :
""" A Graph space instance.
2022-06-09 15:42:58 +01:00
2022-09-03 23:39:23 +01:00
* nodes ( np . ndarray ) : an ( n x . . . ) sized array representing the features for n nodes , ( . . . ) must adhere to the shape of the node space .
2022-10-05 17:53:45 +01:00
* edges ( Optional [ np . ndarray ] ) : an ( m x . . . ) sized array representing the features for m edges , ( . . . ) must adhere to the shape of the edge space .
* edge_links ( Optional [ np . ndarray ] ) : an ( m x 2 ) sized array of ints representing the indices of the two nodes that each edge connects .
2022-06-09 15:42:58 +01:00
"""
2022-11-15 14:09:22 +00:00
nodes : NDArray [ Any ]
edges : NDArray [ Any ] | None
edge_links : NDArray [ Any ] | None
2022-09-03 23:39:23 +01:00
2022-06-09 15:42:58 +01:00
2022-11-15 14:09:22 +00:00
class Graph ( Space [ GraphInstance ] ) :
2022-06-09 15:42:58 +01:00
r """ A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
2023-01-23 11:30:00 +01:00
Example :
>> > from gymnasium . spaces import Graph , Box , Discrete
>> > observation_space = Graph ( node_space = Box ( low = - 100 , high = 100 , shape = ( 3 , ) ) , edge_space = Discrete ( 3 ) , seed = 42 )
>> > observation_space . sample ( )
GraphInstance ( nodes = array ( [ [ - 12.224312 , 71.71958 , 39.473606 ] ,
[ - 81.16453 , 95.12447 , 52.22794 ] ,
[ 57.21286 , - 74.37727 , - 9.922812 ] ,
[ - 25.840395 , 85.353 , 28.773024 ] ,
[ 64.55232 , - 11.317161 , - 54.552258 ] ,
[ 10.916958 , - 87.23655 , 65.52624 ] ,
[ 26.33288 , 51.61755 , - 29.094807 ] ,
[ 94.1396 , 78.62422 , 55.6767 ] ,
[ - 61.072258 , - 6.6557994 , - 91.23925 ] ,
[ - 69.142105 , 36.60979 , 48.95243 ] ] , dtype = float32 ) , edges = array ( [ 2 , 0 , 1 , 1 , 0 , 0 , 1 , 0 ] ) , edge_links = array ( [ [ 7 , 5 ] ,
[ 6 , 9 ] ,
[ 4 , 1 ] ,
[ 8 , 6 ] ,
[ 7 , 0 ] ,
[ 3 , 7 ] ,
[ 8 , 4 ] ,
[ 8 , 8 ] ] ) )
2022-06-09 15:42:58 +01:00
"""
def __init__ (
self ,
2022-11-15 14:09:22 +00:00
node_space : Box | Discrete ,
edge_space : None | Box | Discrete ,
seed : int | np . random . Generator | None = None ,
2022-06-09 15:42:58 +01:00
) :
r """ Constructor of :class:`Graph`.
The argument ` ` node_space ` ` specifies the base space that each node feature will use .
This argument must be either a Box or Discrete instance .
The argument ` ` edge_space ` ` specifies the base space that each edge feature will use .
This argument must be either a None , Box or Discrete instance .
Args :
node_space ( Union [ Box , Discrete ] ) : space of the node features .
edge_space ( Union [ None , Box , Discrete ] ) : space of the node features .
seed : Optionally , you can use this argument to seed the RNG that is used to sample from the space .
"""
assert isinstance (
node_space , ( Box , Discrete )
) , f " Values of the node_space should be instances of Box or Discrete, got { type ( node_space ) } "
if edge_space is not None :
assert isinstance (
edge_space , ( Box , Discrete )
) , f " Values of the edge_space should be instances of None Box or Discrete, got { type ( node_space ) } "
self . node_space = node_space
self . edge_space = edge_space
super ( ) . __init__ ( None , None , seed )
2022-08-15 17:11:32 +02:00
@property
def is_np_flattenable ( self ) :
""" Checks whether this space can be flattened to a :class:`spaces.Box`. """
return False
2022-06-09 15:42:58 +01:00
def _generate_sample_space (
2022-11-15 14:09:22 +00:00
self , base_space : None | Box | Discrete , num : int
) - > Box | MultiDiscrete | None :
2022-06-26 23:23:15 +01:00
if num == 0 or base_space is None :
2022-06-09 15:42:58 +01:00
return None
if isinstance ( base_space , Box ) :
return Box (
low = np . array ( max ( 1 , num ) * [ base_space . low ] ) ,
high = np . array ( max ( 1 , num ) * [ base_space . high ] ) ,
2022-06-26 23:23:15 +01:00
shape = ( num , ) + base_space . shape ,
2022-06-09 15:42:58 +01:00
dtype = base_space . dtype ,
2022-06-26 23:23:15 +01:00
seed = self . np_random ,
2022-06-09 15:42:58 +01:00
)
elif isinstance ( base_space , Discrete ) :
2022-06-26 23:23:15 +01:00
return MultiDiscrete ( nvec = [ base_space . n ] * num , seed = self . np_random )
2022-06-09 15:42:58 +01:00
else :
2022-09-03 23:39:23 +01:00
raise TypeError (
2022-06-26 23:23:15 +01:00
f " Expects base space to be Box and Discrete, actual space: { type ( base_space ) } . "
2022-06-09 15:42:58 +01:00
)
2022-06-26 23:23:15 +01:00
def sample (
self ,
2022-11-15 14:09:22 +00:00
mask : None
| (
tuple [
NDArray [ Any ] | tuple [ Any , . . . ] | None ,
NDArray [ Any ] | tuple [ Any , . . . ] | None ,
2022-06-26 23:23:15 +01:00
]
2022-11-15 14:09:22 +00:00
) = None ,
2022-06-26 23:23:15 +01:00
num_nodes : int = 10 ,
2022-11-15 14:09:22 +00:00
num_edges : int | None = None ,
2022-09-03 23:39:23 +01:00
) - > GraphInstance :
2022-06-09 15:42:58 +01:00
""" Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
2022-06-26 23:23:15 +01:00
Args :
mask : An optional tuple of optional node and edge mask that is only possible with Discrete spaces
( Box spaces don ' t support sample masks).
If no ` num_edges ` is provided then the ` edge_mask ` is multiplied by the number of edges
num_nodes : The number of nodes that will be sampled , the default is 10 nodes
2022-10-12 15:58:01 +01:00
num_edges : An optional number of edges , otherwise , a random number between 0 and ` num_nodes ` ^ 2
2022-06-26 23:23:15 +01:00
2022-06-09 15:42:58 +01:00
Returns :
2022-10-12 15:58:01 +01:00
A : class : ` GraphInstance ` with attributes ` . nodes ` , ` . edges ` , and ` . edge_links ` .
2022-06-09 15:42:58 +01:00
"""
2022-06-26 23:23:15 +01:00
assert (
num_nodes > 0
) , f " The number of nodes is expected to be greater than 0, actual value: { num_nodes } "
2022-06-09 15:42:58 +01:00
2022-06-26 23:23:15 +01:00
if mask is not None :
node_space_mask , edge_space_mask = mask
else :
node_space_mask , edge_space_mask = None , None
2022-06-09 15:42:58 +01:00
2022-06-26 23:23:15 +01:00
# we only have edges when we have at least 2 nodes
if num_edges is None :
if num_nodes > 1 :
# maximal number of edges is `n*(n-1)` allowing self connections and two-way is allowed
num_edges = self . np_random . integers ( num_nodes * ( num_nodes - 1 ) )
else :
num_edges = 0
2022-09-03 23:39:23 +01:00
2022-06-26 23:23:15 +01:00
if edge_space_mask is not None :
edge_space_mask = tuple ( edge_space_mask for _ in range ( num_edges ) )
else :
2022-09-03 23:39:23 +01:00
if self . edge_space is None :
2022-11-15 14:09:22 +00:00
gym . logger . warn (
2022-09-03 23:39:23 +01:00
f " The number of edges is set ( { num_edges } ) but the edge space is None. "
)
2022-06-26 23:23:15 +01:00
assert (
num_edges > = 0
2022-09-03 23:39:23 +01:00
) , f " Expects the number of edges to be greater than 0, actual value: { num_edges } "
2022-07-04 18:19:25 +01:00
assert num_edges is not None
2022-06-26 23:23:15 +01:00
sampled_node_space = self . _generate_sample_space ( self . node_space , num_nodes )
sampled_edge_space = self . _generate_sample_space ( self . edge_space , num_edges )
2022-07-04 18:19:25 +01:00
assert sampled_node_space is not None
2022-06-26 23:23:15 +01:00
sampled_nodes = sampled_node_space . sample ( node_space_mask )
sampled_edges = (
sampled_edge_space . sample ( edge_space_mask )
if sampled_edge_space is not None
else None
)
2022-06-09 15:42:58 +01:00
sampled_edge_links = None
if sampled_edges is not None and num_edges > 0 :
sampled_edge_links = self . np_random . integers (
low = 0 , high = num_nodes , size = ( num_edges , 2 )
)
return GraphInstance ( sampled_nodes , sampled_edges , sampled_edge_links )
def contains ( self , x : GraphInstance ) - > bool :
2022-09-03 23:39:23 +01:00
""" Return boolean specifying if x is a valid member of this space. """
if isinstance ( x , GraphInstance ) :
# Checks the nodes
if isinstance ( x . nodes , np . ndarray ) :
if all ( node in self . node_space for node in x . nodes ) :
# Check the edges and edge links which are optional
if isinstance ( x . edges , np . ndarray ) and isinstance (
x . edge_links , np . ndarray
) :
assert x . edges is not None
assert x . edge_links is not None
if self . edge_space is not None :
if all ( edge in self . edge_space for edge in x . edges ) :
if np . issubdtype ( x . edge_links . dtype , np . integer ) :
if x . edge_links . shape == ( len ( x . edges ) , 2 ) :
if np . all (
np . logical_and (
x . edge_links > = 0 ,
x . edge_links < len ( x . nodes ) ,
)
) :
return True
else :
return x . edges is None and x . edge_links is None
return False
2022-06-09 15:42:58 +01:00
def __repr__ ( self ) - > str :
""" A string representation of this space.
The representation will include node_space and edge_space
Returns :
A representation of the space
"""
return f " Graph( { self . node_space } , { self . edge_space } ) "
2022-11-15 14:09:22 +00:00
def __eq__ ( self , other : Any ) - > bool :
2022-06-09 15:42:58 +01:00
""" Check whether `other` is equivalent to this instance. """
return (
isinstance ( other , Graph )
and ( self . node_space == other . node_space )
and ( self . edge_space == other . edge_space )
)
2022-11-15 14:09:22 +00:00
def to_jsonable (
self , sample_n : Sequence [ GraphInstance ]
2023-02-13 18:18:40 +01:00
) - > list [ dict [ str , list [ int | float ] ] ] :
2022-06-09 15:42:58 +01:00
""" Convert a batch of samples from this space to a JSONable data type. """
2023-02-13 18:18:40 +01:00
ret_n = [ ]
2022-06-09 15:42:58 +01:00
for sample in sample_n :
2022-11-15 14:09:22 +00:00
ret = { " nodes " : sample . nodes . tolist ( ) }
if sample . edges is not None and sample . edge_links is not None :
2022-06-09 15:42:58 +01:00
ret [ " edges " ] = sample . edges . tolist ( )
ret [ " edge_links " ] = sample . edge_links . tolist ( )
ret_n . append ( ret )
return ret_n
2022-11-15 14:09:22 +00:00
def from_jsonable (
self , sample_n : Sequence [ dict [ str , list [ list [ int ] | list [ float ] ] ] ]
) - > list [ GraphInstance ] :
2022-06-09 15:42:58 +01:00
""" Convert a JSONable data type to a batch of samples from this space. """
2022-11-15 14:09:22 +00:00
ret : list [ GraphInstance ] = [ ]
2022-06-09 15:42:58 +01:00
for sample in sample_n :
if " edges " in sample :
ret_n = GraphInstance (
np . asarray ( sample [ " nodes " ] ) ,
np . asarray ( sample [ " edges " ] ) ,
np . asarray ( sample [ " edge_links " ] ) ,
)
else :
ret_n = GraphInstance (
np . asarray ( sample [ " nodes " ] ) ,
None ,
None ,
)
ret . append ( ret_n )
return ret