2022-06-09 15:42:58 +01:00
""" Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space. """
2024-06-10 17:07:47 +01:00
2022-11-15 14:09:22 +00:00
from __future__ import annotations
from typing import Any , NamedTuple , Sequence
2022-06-09 15:42:58 +01:00
import numpy as np
2022-11-15 14:09:22 +00:00
from numpy . typing import NDArray
2022-06-09 15:42:58 +01:00
2022-11-15 14:09:22 +00:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium . spaces . box import Box
from gymnasium . spaces . discrete import Discrete
from gymnasium . spaces . multi_discrete import MultiDiscrete
from gymnasium . spaces . space import Space
2022-06-09 15:42:58 +01:00
2022-09-03 23:39:23 +01:00
class GraphInstance ( NamedTuple ) :
""" A Graph space instance.
2022-06-09 15:42:58 +01:00
2022-09-03 23:39:23 +01:00
* nodes ( np . ndarray ) : an ( n x . . . ) sized array representing the features for n nodes , ( . . . ) must adhere to the shape of the node space .
2022-10-05 17:53:45 +01:00
* edges ( Optional [ np . ndarray ] ) : an ( m x . . . ) sized array representing the features for m edges , ( . . . ) must adhere to the shape of the edge space .
* edge_links ( Optional [ np . ndarray ] ) : an ( m x 2 ) sized array of ints representing the indices of the two nodes that each edge connects .
2022-06-09 15:42:58 +01:00
"""
2022-11-15 14:09:22 +00:00
nodes : NDArray [ Any ]
edges : NDArray [ Any ] | None
edge_links : NDArray [ Any ] | None
2022-09-03 23:39:23 +01:00
2022-06-09 15:42:58 +01:00
2022-11-15 14:09:22 +00:00
class Graph ( Space [ GraphInstance ] ) :
2023-11-07 13:27:25 +00:00
r """ A space representing graph information as a series of ``nodes`` connected with ``edges`` according to an adjacency matrix represented as a series of ``edge_links``.
2022-06-09 15:42:58 +01:00
2023-01-23 11:30:00 +01:00
Example :
>> > from gymnasium . spaces import Graph , Box , Discrete
2024-04-28 16:10:35 +01:00
>> > observation_space = Graph ( node_space = Box ( low = - 100 , high = 100 , shape = ( 3 , ) ) , edge_space = Discrete ( 3 ) , seed = 123 )
>> > observation_space . sample ( num_nodes = 4 , num_edges = 8 )
GraphInstance ( nodes = array ( [ [ 36.47037 , - 89.235794 , - 55.928024 ] ,
[ - 63.125637 , - 64.81882 , 62.4189 ] ,
[ 84.669 , - 44.68512 , 63.950912 ] ,
[ 77.97854 , 2.594091 , - 51.00708 ] ] , dtype = float32 ) , edges = array ( [ 2 , 0 , 2 , 1 , 2 , 0 , 2 , 1 ] ) , edge_links = array ( [ [ 3 , 0 ] ,
[ 0 , 0 ] ,
[ 0 , 1 ] ,
[ 0 , 2 ] ,
[ 1 , 0 ] ,
[ 1 , 0 ] ,
[ 0 , 1 ] ,
[ 0 , 2 ] ] , dtype = int32 ) )
2022-06-09 15:42:58 +01:00
"""
def __init__ (
self ,
2022-11-15 14:09:22 +00:00
node_space : Box | Discrete ,
edge_space : None | Box | Discrete ,
seed : int | np . random . Generator | None = None ,
2022-06-09 15:42:58 +01:00
) :
r """ Constructor of :class:`Graph`.
The argument ` ` node_space ` ` specifies the base space that each node feature will use .
This argument must be either a Box or Discrete instance .
The argument ` ` edge_space ` ` specifies the base space that each edge feature will use .
This argument must be either a None , Box or Discrete instance .
Args :
node_space ( Union [ Box , Discrete ] ) : space of the node features .
2023-03-29 19:09:46 +05:30
edge_space ( Union [ None , Box , Discrete ] ) : space of the edge features .
2022-06-09 15:42:58 +01:00
seed : Optionally , you can use this argument to seed the RNG that is used to sample from the space .
"""
assert isinstance (
node_space , ( Box , Discrete )
) , f " Values of the node_space should be instances of Box or Discrete, got { type ( node_space ) } "
if edge_space is not None :
assert isinstance (
edge_space , ( Box , Discrete )
2024-05-20 11:25:10 +02:00
) , f " Values of the edge_space should be instances of None Box or Discrete, got { type ( edge_space ) } "
2022-06-09 15:42:58 +01:00
self . node_space = node_space
self . edge_space = edge_space
super ( ) . __init__ ( None , None , seed )
2022-08-15 17:11:32 +02:00
@property
def is_np_flattenable ( self ) :
""" Checks whether this space can be flattened to a :class:`spaces.Box`. """
return False
2022-06-09 15:42:58 +01:00
def _generate_sample_space (
2022-11-15 14:09:22 +00:00
self , base_space : None | Box | Discrete , num : int
) - > Box | MultiDiscrete | None :
2022-06-26 23:23:15 +01:00
if num == 0 or base_space is None :
2022-06-09 15:42:58 +01:00
return None
if isinstance ( base_space , Box ) :
return Box (
low = np . array ( max ( 1 , num ) * [ base_space . low ] ) ,
high = np . array ( max ( 1 , num ) * [ base_space . high ] ) ,
2022-06-26 23:23:15 +01:00
shape = ( num , ) + base_space . shape ,
2022-06-09 15:42:58 +01:00
dtype = base_space . dtype ,
2022-06-26 23:23:15 +01:00
seed = self . np_random ,
2022-06-09 15:42:58 +01:00
)
elif isinstance ( base_space , Discrete ) :
2022-06-26 23:23:15 +01:00
return MultiDiscrete ( nvec = [ base_space . n ] * num , seed = self . np_random )
2022-06-09 15:42:58 +01:00
else :
2022-09-03 23:39:23 +01:00
raise TypeError (
2022-06-26 23:23:15 +01:00
f " Expects base space to be Box and Discrete, actual space: { type ( base_space ) } . "
2022-06-09 15:42:58 +01:00
)
2024-04-28 16:10:35 +01:00
def seed (
self , seed : int | tuple [ int , int ] | tuple [ int , int , int ] | None = None
) - > tuple [ int , int ] | tuple [ int , int , int ] :
""" Seeds the PRNG of this space and node / edge subspace.
Depending on the type of seed , the subspaces will be seeded differently
* ` ` None ` ` - The root , node and edge spaces PRNG are randomly initialized
* ` ` Int ` ` - The integer is used to seed the : class : ` Graph ` space that is used to generate seed values for the node and edge subspaces .
* ` ` Tuple [ int , int ] ` ` - Seeds the : class : ` Graph ` and node subspace with a particular value . Only if edge subspace isn ' t specified
* ` ` Tuple [ int , int , int ] ` ` - Seeds the : class : ` Graph ` , node and edge subspaces with a particular value .
Args :
seed : An optional int or tuple of ints for this space and the node / edge subspaces . See above for more details .
Returns :
A tuple of two or three ints depending on if the edge subspace is specified .
"""
if seed is None :
if self . edge_space is None :
return super ( ) . seed ( None ) , self . node_space . seed ( None )
else :
return (
super ( ) . seed ( None ) ,
self . node_space . seed ( None ) ,
self . edge_space . seed ( None ) ,
)
elif isinstance ( seed , int ) :
if self . edge_space is None :
super_seed = super ( ) . seed ( seed )
node_seed = int ( self . np_random . integers ( np . iinfo ( np . int32 ) . max ) )
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
super ( ) . seed ( seed )
return super_seed , self . node_space . seed ( node_seed )
else :
super_seed = super ( ) . seed ( seed )
node_seed , edge_seed = self . np_random . integers (
np . iinfo ( np . int32 ) . max , size = ( 2 , )
)
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
super ( ) . seed ( seed )
return (
super_seed ,
self . node_space . seed ( int ( node_seed ) ) ,
self . edge_space . seed ( int ( edge_seed ) ) ,
)
elif isinstance ( seed , ( list , tuple ) ) :
if self . edge_space is None :
if len ( seed ) != 2 :
raise ValueError (
f " Expects a tuple of two values for Graph and node space, actual length: { len ( seed ) } "
)
return super ( ) . seed ( seed [ 0 ] ) , self . node_space . seed ( seed [ 1 ] )
else :
if len ( seed ) != 3 :
raise ValueError (
f " Expects a tuple of three values for Graph, node and edge space, actual length: { len ( seed ) } "
)
return (
super ( ) . seed ( seed [ 0 ] ) ,
self . node_space . seed ( seed [ 1 ] ) ,
self . edge_space . seed ( seed [ 2 ] ) ,
)
else :
raise TypeError (
f " Expects `None`, int or tuple of ints, actual type: { type ( seed ) } "
)
2022-06-26 23:23:15 +01:00
def sample (
self ,
2024-06-10 17:07:47 +01:00
mask : None | (
2022-11-15 14:09:22 +00:00
tuple [
NDArray [ Any ] | tuple [ Any , . . . ] | None ,
NDArray [ Any ] | tuple [ Any , . . . ] | None ,
2022-06-26 23:23:15 +01:00
]
2022-11-15 14:09:22 +00:00
) = None ,
2022-06-26 23:23:15 +01:00
num_nodes : int = 10 ,
2022-11-15 14:09:22 +00:00
num_edges : int | None = None ,
2022-09-03 23:39:23 +01:00
) - > GraphInstance :
2023-11-07 13:27:25 +00:00
""" Generates a single sample graph with num_nodes between ``1`` and ``10`` sampled from the Graph.
2022-06-09 15:42:58 +01:00
2022-06-26 23:23:15 +01:00
Args :
mask : An optional tuple of optional node and edge mask that is only possible with Discrete spaces
( Box spaces don ' t support sample masks).
2023-11-07 13:27:25 +00:00
If no ` ` num_edges ` ` is provided then the ` ` edge_mask ` ` is multiplied by the number of edges
num_nodes : The number of nodes that will be sampled , the default is ` 10 ` nodes
num_edges : An optional number of edges , otherwise , a random number between ` 0 ` and : math : ` num_nodes ^ 2 `
2022-06-26 23:23:15 +01:00
2022-06-09 15:42:58 +01:00
Returns :
2022-10-12 15:58:01 +01:00
A : class : ` GraphInstance ` with attributes ` . nodes ` , ` . edges ` , and ` . edge_links ` .
2022-06-09 15:42:58 +01:00
"""
2022-06-26 23:23:15 +01:00
assert (
num_nodes > 0
) , f " The number of nodes is expected to be greater than 0, actual value: { num_nodes } "
2022-06-09 15:42:58 +01:00
2022-06-26 23:23:15 +01:00
if mask is not None :
node_space_mask , edge_space_mask = mask
else :
node_space_mask , edge_space_mask = None , None
2022-06-09 15:42:58 +01:00
2022-06-26 23:23:15 +01:00
# we only have edges when we have at least 2 nodes
if num_edges is None :
if num_nodes > 1 :
# maximal number of edges is `n*(n-1)` allowing self connections and two-way is allowed
num_edges = self . np_random . integers ( num_nodes * ( num_nodes - 1 ) )
else :
num_edges = 0
2022-09-03 23:39:23 +01:00
2022-06-26 23:23:15 +01:00
if edge_space_mask is not None :
edge_space_mask = tuple ( edge_space_mask for _ in range ( num_edges ) )
else :
2022-09-03 23:39:23 +01:00
if self . edge_space is None :
2022-11-15 14:09:22 +00:00
gym . logger . warn (
2022-09-03 23:39:23 +01:00
f " The number of edges is set ( { num_edges } ) but the edge space is None. "
)
2022-06-26 23:23:15 +01:00
assert (
num_edges > = 0
2022-09-03 23:39:23 +01:00
) , f " Expects the number of edges to be greater than 0, actual value: { num_edges } "
2022-07-04 18:19:25 +01:00
assert num_edges is not None
2022-06-26 23:23:15 +01:00
sampled_node_space = self . _generate_sample_space ( self . node_space , num_nodes )
sampled_edge_space = self . _generate_sample_space ( self . edge_space , num_edges )
2022-07-04 18:19:25 +01:00
assert sampled_node_space is not None
2022-06-26 23:23:15 +01:00
sampled_nodes = sampled_node_space . sample ( node_space_mask )
sampled_edges = (
sampled_edge_space . sample ( edge_space_mask )
if sampled_edge_space is not None
else None
)
2022-06-09 15:42:58 +01:00
sampled_edge_links = None
if sampled_edges is not None and num_edges > 0 :
sampled_edge_links = self . np_random . integers (
2023-05-23 17:03:25 +01:00
low = 0 , high = num_nodes , size = ( num_edges , 2 ) , dtype = np . int32
2022-06-09 15:42:58 +01:00
)
return GraphInstance ( sampled_nodes , sampled_edges , sampled_edge_links )
def contains ( self , x : GraphInstance ) - > bool :
2022-09-03 23:39:23 +01:00
""" Return boolean specifying if x is a valid member of this space. """
if isinstance ( x , GraphInstance ) :
# Checks the nodes
if isinstance ( x . nodes , np . ndarray ) :
if all ( node in self . node_space for node in x . nodes ) :
# Check the edges and edge links which are optional
if isinstance ( x . edges , np . ndarray ) and isinstance (
x . edge_links , np . ndarray
) :
assert x . edges is not None
assert x . edge_links is not None
if self . edge_space is not None :
if all ( edge in self . edge_space for edge in x . edges ) :
if np . issubdtype ( x . edge_links . dtype , np . integer ) :
if x . edge_links . shape == ( len ( x . edges ) , 2 ) :
if np . all (
np . logical_and (
x . edge_links > = 0 ,
x . edge_links < len ( x . nodes ) ,
)
) :
return True
else :
return x . edges is None and x . edge_links is None
return False
2022-06-09 15:42:58 +01:00
def __repr__ ( self ) - > str :
""" A string representation of this space.
2023-11-07 13:27:25 +00:00
The representation will include ` ` node_space ` ` and ` ` edge_space ` `
2022-06-09 15:42:58 +01:00
Returns :
A representation of the space
"""
return f " Graph( { self . node_space } , { self . edge_space } ) "
2022-11-15 14:09:22 +00:00
def __eq__ ( self , other : Any ) - > bool :
2022-06-09 15:42:58 +01:00
""" Check whether `other` is equivalent to this instance. """
return (
isinstance ( other , Graph )
and ( self . node_space == other . node_space )
and ( self . edge_space == other . edge_space )
)
2022-11-15 14:09:22 +00:00
def to_jsonable (
self , sample_n : Sequence [ GraphInstance ]
2023-02-13 18:18:40 +01:00
) - > list [ dict [ str , list [ int | float ] ] ] :
2022-06-09 15:42:58 +01:00
""" Convert a batch of samples from this space to a JSONable data type. """
2023-02-13 18:18:40 +01:00
ret_n = [ ]
2022-06-09 15:42:58 +01:00
for sample in sample_n :
2022-11-15 14:09:22 +00:00
ret = { " nodes " : sample . nodes . tolist ( ) }
if sample . edges is not None and sample . edge_links is not None :
2022-06-09 15:42:58 +01:00
ret [ " edges " ] = sample . edges . tolist ( )
ret [ " edge_links " ] = sample . edge_links . tolist ( )
ret_n . append ( ret )
return ret_n
2022-11-15 14:09:22 +00:00
def from_jsonable (
self , sample_n : Sequence [ dict [ str , list [ list [ int ] | list [ float ] ] ] ]
) - > list [ GraphInstance ] :
2022-06-09 15:42:58 +01:00
""" Convert a JSONable data type to a batch of samples from this space. """
2022-11-15 14:09:22 +00:00
ret : list [ GraphInstance ] = [ ]
2022-06-09 15:42:58 +01:00
for sample in sample_n :
if " edges " in sample :
2023-05-23 17:03:25 +01:00
assert self . edge_space is not None
2022-06-09 15:42:58 +01:00
ret_n = GraphInstance (
2023-05-23 17:03:25 +01:00
np . asarray ( sample [ " nodes " ] , dtype = self . node_space . dtype ) ,
np . asarray ( sample [ " edges " ] , dtype = self . edge_space . dtype ) ,
np . asarray ( sample [ " edge_links " ] , dtype = np . int32 ) ,
2022-06-09 15:42:58 +01:00
)
else :
ret_n = GraphInstance (
2023-05-23 17:03:25 +01:00
np . asarray ( sample [ " nodes " ] , dtype = self . node_space . dtype ) ,
2022-06-09 15:42:58 +01:00
None ,
None ,
)
ret . append ( ret_n )
return ret