2023-02-22 15:05:58 +00:00
""" Space-based utility functions for vector environments.
2024-08-29 16:52:43 +01:00
- ` ` batch_space ` ` : Create a ( batched ) space containing multiple copies of a single space .
- ` ` batch_differing_spaces ` ` : Create a ( batched ) space containing copies of different compatible spaces ( share a common dtype and shape )
2023-02-22 15:05:58 +00:00
- ` ` concatenate ` ` : Concatenate multiple samples from ( unbatched ) space into a single object .
- ` ` Iterate ` ` : Iterate over the elements of a ( batched ) space and items .
- ` ` create_empty_array ` ` : Create an empty ( possibly nested ) ( normally numpy - based ) array , used in conjunction with ` ` concatenate ( . . . , out = array ) ` `
"""
2024-06-10 17:07:47 +01:00
2023-02-18 21:08:08 +00:00
from __future__ import annotations
2024-09-20 13:44:07 +01:00
import typing
2025-06-07 17:57:58 +01:00
from collections . abc import Callable , Iterable , Iterator
2023-02-18 21:08:08 +00:00
from copy import deepcopy
from functools import singledispatch
2025-06-07 17:57:58 +01:00
from typing import Any
2023-02-18 21:08:08 +00:00
import numpy as np
from gymnasium . error import CustomSpaceError
from gymnasium . spaces import (
Box ,
Dict ,
Discrete ,
2023-02-22 15:05:58 +00:00
Graph ,
GraphInstance ,
2023-02-18 21:08:08 +00:00
MultiBinary ,
MultiDiscrete ,
2024-03-11 13:30:50 +01:00
OneOf ,
2023-02-22 15:05:58 +00:00
Sequence ,
2023-02-18 21:08:08 +00:00
Space ,
2023-02-22 15:05:58 +00:00
Text ,
2023-02-18 21:08:08 +00:00
Tuple ,
)
2023-02-22 15:05:58 +00:00
from gymnasium . spaces . space import T_cov
2023-02-18 21:08:08 +00:00
2024-08-29 16:52:43 +01:00
__all__ = [
" batch_space " ,
" batch_differing_spaces " ,
" iterate " ,
" concatenate " ,
" create_empty_array " ,
]
2023-02-18 21:08:08 +00:00
@singledispatch
2023-02-22 15:05:58 +00:00
def batch_space ( space : Space [ Any ] , n : int = 1 ) - > Space [ Any ] :
2024-09-20 13:44:07 +01:00
""" Batch spaces of size `n` optimized for neural networks.
2023-02-18 21:08:08 +00:00
Args :
2024-09-20 13:44:07 +01:00
space : Space ( e . g . the observation space for a single environment in the vectorized environment ) .
n : Number of spaces to batch by ( e . g . the number of environments in a vectorized environment ) .
2023-02-18 21:08:08 +00:00
Returns :
2024-09-20 13:44:07 +01:00
Batched space of size ` n ` .
2023-02-18 21:08:08 +00:00
Raises :
2024-09-20 13:44:07 +01:00
ValueError : Cannot batch spaces that does not have a registered function .
2023-02-22 15:05:58 +00:00
2023-02-23 16:07:58 +00:00
Example :
2023-02-18 21:08:08 +00:00
>> > from gymnasium . spaces import Box , Dict
2023-02-23 16:07:58 +00:00
>> > import numpy as np
2023-02-18 21:08:08 +00:00
>> > space = Dict ( {
. . . ' position ' : Box ( low = 0 , high = 1 , shape = ( 3 , ) , dtype = np . float32 ) ,
. . . ' velocity ' : Box ( low = 0 , high = 1 , shape = ( 2 , ) , dtype = np . float32 )
. . . } )
>> > batch_space ( space , n = 5 )
Dict ( ' position ' : Box ( 0.0 , 1.0 , ( 5 , 3 ) , float32 ) , ' velocity ' : Box ( 0.0 , 1.0 , ( 5 , 2 ) , float32 ) )
"""
2023-02-22 15:05:58 +00:00
raise TypeError (
f " The space provided to `batch_space` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
2023-02-18 21:08:08 +00:00
)
@batch_space.register ( Box )
2023-02-22 15:05:58 +00:00
def _batch_space_box ( space : Box , n : int = 1 ) :
2023-02-18 21:08:08 +00:00
repeats = tuple ( [ n ] + [ 1 ] * space . low . ndim )
low , high = np . tile ( space . low , repeats ) , np . tile ( space . high , repeats )
return Box ( low = low , high = high , dtype = space . dtype , seed = deepcopy ( space . np_random ) )
@batch_space.register ( Discrete )
2023-02-23 16:07:58 +00:00
def _batch_space_discrete ( space : Discrete , n : int = 1 ) :
2023-06-16 16:36:42 +02:00
return MultiDiscrete (
np . full ( ( n , ) , space . n , dtype = space . dtype ) ,
dtype = space . dtype ,
seed = deepcopy ( space . np_random ) ,
start = np . full ( ( n , ) , space . start , dtype = space . dtype ) ,
)
2023-02-18 21:08:08 +00:00
@batch_space.register ( MultiDiscrete )
2023-02-23 16:07:58 +00:00
def _batch_space_multidiscrete ( space : MultiDiscrete , n : int = 1 ) :
2023-02-18 21:08:08 +00:00
repeats = tuple ( [ n ] + [ 1 ] * space . nvec . ndim )
2023-06-16 16:36:42 +02:00
low = np . tile ( space . start , repeats )
high = low + np . tile ( space . nvec , repeats ) - 1
2023-02-18 21:08:08 +00:00
return Box (
2023-06-16 16:36:42 +02:00
low = low ,
2023-02-18 21:08:08 +00:00
high = high ,
dtype = space . dtype ,
seed = deepcopy ( space . np_random ) ,
)
@batch_space.register ( MultiBinary )
2023-02-23 16:07:58 +00:00
def _batch_space_multibinary ( space : MultiBinary , n : int = 1 ) :
2023-02-18 21:08:08 +00:00
return Box (
low = 0 ,
high = 1 ,
shape = ( n , ) + space . shape ,
dtype = space . dtype ,
seed = deepcopy ( space . np_random ) ,
)
@batch_space.register ( Tuple )
2023-02-23 16:07:58 +00:00
def _batch_space_tuple ( space : Tuple , n : int = 1 ) :
2023-02-18 21:08:08 +00:00
return Tuple (
tuple ( batch_space ( subspace , n = n ) for subspace in space . spaces ) ,
seed = deepcopy ( space . np_random ) ,
)
@batch_space.register ( Dict )
2023-02-22 15:05:58 +00:00
def _batch_space_dict ( space : Dict , n : int = 1 ) :
2023-02-18 21:08:08 +00:00
return Dict (
2023-02-22 15:05:58 +00:00
{ key : batch_space ( subspace , n = n ) for key , subspace in space . items ( ) } ,
2023-02-18 21:08:08 +00:00
seed = deepcopy ( space . np_random ) ,
)
2023-02-22 15:05:58 +00:00
@batch_space.register ( Graph )
@batch_space.register ( Text )
@batch_space.register ( Sequence )
2024-03-11 13:30:50 +01:00
@batch_space.register ( OneOf )
2023-02-18 21:08:08 +00:00
@batch_space.register ( Space )
2024-03-11 13:30:50 +01:00
def _batch_space_custom ( space : Graph | Text | Sequence | OneOf , n : int = 1 ) :
2023-02-18 21:08:08 +00:00
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple (
tuple ( deepcopy ( space ) for _ in range ( n ) ) , seed = deepcopy ( space . np_random )
)
2023-02-22 15:05:58 +00:00
space_rng = deepcopy ( space . np_random )
new_seeds = list ( map ( int , space_rng . integers ( 0 , 1e8 , n ) ) )
2023-02-18 21:08:08 +00:00
batched_space . seed ( new_seeds )
return batched_space
2024-08-29 16:52:43 +01:00
@singledispatch
2024-09-20 13:44:07 +01:00
def batch_differing_spaces ( spaces : typing . Sequence [ Space ] ) - > Space :
""" Batch a Sequence of spaces where subspaces to contain minor differences.
Args :
spaces : A sequence of Spaces with minor differences ( the same space type but different parameters ) .
Returns :
A batched space
Example :
>> > from gymnasium . spaces import Discrete
>> > spaces = [ Discrete ( 3 ) , Discrete ( 5 ) , Discrete ( 4 ) , Discrete ( 8 ) ]
>> > batch_differing_spaces ( spaces )
MultiDiscrete ( [ 3 5 4 8 ] )
"""
2024-08-29 16:52:43 +01:00
assert len ( spaces ) > 0 , " Expects a non-empty list of spaces "
assert all (
isinstance ( space , type ( spaces [ 0 ] ) ) for space in spaces
) , f " Expects all spaces to be the same shape, actual types: { [ type ( space ) for space in spaces ] } "
assert (
type ( spaces [ 0 ] ) in batch_differing_spaces . registry
) , f " Requires the Space type to have a registered `batch_differing_space`, current list: { batch_differing_spaces . registry } "
return batch_differing_spaces . dispatch ( type ( spaces [ 0 ] ) ) ( spaces )
@batch_differing_spaces.register ( Box )
def _batch_differing_spaces_box ( spaces : list [ Box ] ) :
assert all (
spaces [ 0 ] . dtype == space . dtype for space in spaces
) , f " Expected all dtypes to be equal, actually { [ space . dtype for space in spaces ] } "
assert all (
spaces [ 0 ] . low . shape == space . low . shape for space in spaces
) , f " Expected all Box.low shape to be equal, actually { [ space . low . shape for space in spaces ] } "
assert all (
spaces [ 0 ] . high . shape == space . high . shape for space in spaces
) , f " Expected all Box.high shape to be equal, actually { [ space . high . shape for space in spaces ] } "
return Box (
low = np . array ( [ space . low for space in spaces ] ) ,
high = np . array ( [ space . high for space in spaces ] ) ,
dtype = spaces [ 0 ] . dtype ,
seed = deepcopy ( spaces [ 0 ] . np_random ) ,
)
@batch_differing_spaces.register ( Discrete )
def _batch_differing_spaces_discrete ( spaces : list [ Discrete ] ) :
return MultiDiscrete (
nvec = np . array ( [ space . n for space in spaces ] ) ,
start = np . array ( [ space . start for space in spaces ] ) ,
seed = deepcopy ( spaces [ 0 ] . np_random ) ,
)
@batch_differing_spaces.register ( MultiDiscrete )
def _batch_differing_spaces_multi_discrete ( spaces : list [ MultiDiscrete ] ) :
assert all (
spaces [ 0 ] . dtype == space . dtype for space in spaces
) , f " Expected all dtypes to be equal, actually { [ space . dtype for space in spaces ] } "
assert all (
spaces [ 0 ] . nvec . shape == space . nvec . shape for space in spaces
) , f " Expects all MultiDiscrete.nvec shape, actually { [ space . nvec . shape for space in spaces ] } "
assert all (
spaces [ 0 ] . start . shape == space . start . shape for space in spaces
) , f " Expects all MultiDiscrete.start shape, actually { [ space . start . shape for space in spaces ] } "
return Box (
low = np . array ( [ space . start for space in spaces ] ) ,
high = np . array ( [ space . start + space . nvec for space in spaces ] ) - 1 ,
dtype = spaces [ 0 ] . dtype ,
seed = deepcopy ( spaces [ 0 ] . np_random ) ,
)
@batch_differing_spaces.register ( MultiBinary )
def _batch_differing_spaces_multi_binary ( spaces : list [ MultiBinary ] ) :
assert all ( spaces [ 0 ] . shape == space . shape for space in spaces )
return Box (
low = 0 ,
high = 1 ,
shape = ( len ( spaces ) , ) + spaces [ 0 ] . shape ,
dtype = spaces [ 0 ] . dtype ,
seed = deepcopy ( spaces [ 0 ] . np_random ) ,
)
@batch_differing_spaces.register ( Tuple )
def _batch_differing_spaces_tuple ( spaces : list [ Tuple ] ) :
return Tuple (
tuple (
batch_differing_spaces ( subspaces )
for subspaces in zip ( * [ space . spaces for space in spaces ] )
) ,
seed = deepcopy ( spaces [ 0 ] . np_random ) ,
)
@batch_differing_spaces.register ( Dict )
def _batch_differing_spaces_dict ( spaces : list [ Dict ] ) :
assert all ( spaces [ 0 ] . keys ( ) == space . keys ( ) for space in spaces )
return Dict (
{
key : batch_differing_spaces ( [ space [ key ] for space in spaces ] )
for key in spaces [ 0 ] . keys ( )
} ,
seed = deepcopy ( spaces [ 0 ] . np_random ) ,
)
@batch_differing_spaces.register ( Graph )
@batch_differing_spaces.register ( Text )
@batch_differing_spaces.register ( Sequence )
@batch_differing_spaces.register ( OneOf )
def _batch_spaces_undefined ( spaces : list [ Graph | Text | Sequence | OneOf ] ) :
return Tuple (
[ deepcopy ( space ) for space in spaces ] , seed = deepcopy ( spaces [ 0 ] . np_random )
)
2023-02-18 21:08:08 +00:00
@singledispatch
2024-09-20 13:44:07 +01:00
def iterate ( space : Space [ T_cov ] , items : T_cov ) - > Iterator :
2023-02-18 21:08:08 +00:00
""" Iterate over the elements of a (batched) space.
Args :
2024-09-20 13:44:07 +01:00
space : ( batched ) space ( e . g . ` action_space ` or ` observation_space ` from vectorized environment ) .
items : Batched samples to be iterated over ( e . g . sample from the space ) .
2023-02-18 21:08:08 +00:00
Example :
>> > from gymnasium . spaces import Box , Dict
>> > import numpy as np
>> > space = Dict ( {
. . . ' position ' : Box ( low = 0 , high = 1 , shape = ( 2 , 3 ) , seed = 42 , dtype = np . float32 ) ,
. . . ' velocity ' : Box ( low = 0 , high = 1 , shape = ( 2 , 2 ) , seed = 42 , dtype = np . float32 ) } )
>> > items = space . sample ( )
>> > it = iterate ( space , items )
>> > next ( it )
2024-03-22 11:19:41 +00:00
{ ' position ' : array ( [ 0.77395606 , 0.43887845 , 0.85859793 ] , dtype = float32 ) , ' velocity ' : array ( [ 0.77395606 , 0.43887845 ] , dtype = float32 ) }
2023-02-18 21:08:08 +00:00
>> > next ( it )
2024-03-22 11:19:41 +00:00
{ ' position ' : array ( [ 0.697368 , 0.09417735 , 0.97562236 ] , dtype = float32 ) , ' velocity ' : array ( [ 0.85859793 , 0.697368 ] , dtype = float32 ) }
2023-02-18 21:08:08 +00:00
>> > next ( it )
Traceback ( most recent call last ) :
. . .
StopIteration
"""
2023-02-22 15:05:58 +00:00
if isinstance ( space , Space ) :
raise CustomSpaceError (
f " Space of type ` { type ( space ) } ` doesn ' t have an registered `iterate` function. Register ` { type ( space ) } ` for `iterate` to support it. "
)
else :
raise TypeError (
f " The space provided to `iterate` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
)
2023-02-18 21:08:08 +00:00
@iterate.register ( Discrete )
2023-02-23 16:07:58 +00:00
def _iterate_discrete ( space : Discrete , items : Iterable ) :
2023-02-18 21:08:08 +00:00
raise TypeError ( " Unable to iterate over a space of type `Discrete`. " )
@iterate.register ( Box )
@iterate.register ( MultiDiscrete )
@iterate.register ( MultiBinary )
2023-02-22 15:05:58 +00:00
def _iterate_base ( space : Box | MultiDiscrete | MultiBinary , items : np . ndarray ) :
2023-02-18 21:08:08 +00:00
try :
return iter ( items )
except TypeError as e :
raise TypeError (
f " Unable to iterate over the following elements: { items } "
) from e
@iterate.register ( Tuple )
2023-02-22 15:05:58 +00:00
def _iterate_tuple ( space : Tuple , items : tuple [ Any , . . . ] ) :
2023-02-18 21:08:08 +00:00
# If this is a tuple of custom subspaces only, then simply iterate over items
2023-02-22 15:05:58 +00:00
if all ( type ( subspace ) in iterate . registry for subspace in space ) :
return zip ( * [ iterate ( subspace , items [ i ] ) for i , subspace in enumerate ( space ) ] )
2023-02-18 21:08:08 +00:00
2023-02-22 15:05:58 +00:00
try :
return iter ( items )
except Exception as e :
unregistered_spaces = [
type ( subspace )
for subspace in space
if type ( subspace ) not in iterate . registry
]
raise CustomSpaceError (
f " Could not iterate through { space } as no custom iterate function is registered for { unregistered_spaces } and `iter(items)` raised the following error: { e } . "
) from e
2023-02-18 21:08:08 +00:00
@iterate.register ( Dict )
2023-02-22 15:05:58 +00:00
def _iterate_dict ( space : Dict , items : dict [ str , Any ] ) :
2023-02-18 21:08:08 +00:00
keys , values = zip (
* [
( key , iterate ( subspace , items [ key ] ) )
for key , subspace in space . spaces . items ( )
]
)
for item in zip ( * values ) :
2024-03-22 11:19:41 +00:00
yield { key : value for key , value in zip ( keys , item ) }
2023-02-18 21:08:08 +00:00
@singledispatch
def concatenate (
2023-02-22 15:05:58 +00:00
space : Space , items : Iterable , out : tuple [ Any , . . . ] | dict [ str , Any ] | np . ndarray
) - > tuple [ Any , . . . ] | dict [ str , Any ] | np . ndarray :
2023-02-18 21:08:08 +00:00
""" Concatenate multiple samples from space into a single object.
Args :
2024-09-20 13:44:07 +01:00
space : Space of each item ( e . g . ` single_action_space ` from vectorized environment )
items : Samples to be concatenated ( e . g . all sample should be an element of the ` space ` ) .
out : The output object ( e . g . generated from ` create_empty_array ` )
2023-02-18 21:08:08 +00:00
Returns :
2024-09-20 13:44:07 +01:00
The output object , can be the same object ` out ` .
2023-02-18 21:08:08 +00:00
Raises :
2024-09-20 13:44:07 +01:00
ValueError : Space is not a valid : class : ` gymnasium . Space ` instance
2023-02-18 21:08:08 +00:00
Example :
>> > from gymnasium . spaces import Box
>> > import numpy as np
>> > space = Box ( low = 0 , high = 1 , shape = ( 3 , ) , seed = 42 , dtype = np . float32 )
>> > out = np . zeros ( ( 2 , 3 ) , dtype = np . float32 )
>> > items = [ space . sample ( ) for _ in range ( 2 ) ]
>> > concatenate ( space , items , out )
array ( [ [ 0.77395606 , 0.43887845 , 0.85859793 ] ,
[ 0.697368 , 0.09417735 , 0.97562236 ] ] , dtype = float32 )
"""
2023-02-22 15:05:58 +00:00
raise TypeError (
f " The space provided to `concatenate` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
2023-02-18 21:08:08 +00:00
)
@concatenate.register ( Box )
@concatenate.register ( Discrete )
@concatenate.register ( MultiDiscrete )
@concatenate.register ( MultiBinary )
2023-02-22 15:05:58 +00:00
def _concatenate_base (
space : Box | Discrete | MultiDiscrete | MultiBinary ,
items : Iterable ,
out : np . ndarray ,
) - > np . ndarray :
2023-02-18 21:08:08 +00:00
return np . stack ( items , axis = 0 , out = out )
@concatenate.register ( Tuple )
2023-02-22 15:05:58 +00:00
def _concatenate_tuple (
space : Tuple , items : Iterable , out : tuple [ Any , . . . ]
) - > tuple [ Any , . . . ] :
2023-02-18 21:08:08 +00:00
return tuple (
concatenate ( subspace , [ item [ i ] for item in items ] , out [ i ] )
for ( i , subspace ) in enumerate ( space . spaces )
)
@concatenate.register ( Dict )
2023-02-22 15:05:58 +00:00
def _concatenate_dict (
space : Dict , items : Iterable , out : dict [ str , Any ]
) - > dict [ str , Any ] :
2024-03-22 11:19:41 +00:00
return {
key : concatenate ( subspace , [ item [ key ] for item in items ] , out [ key ] )
for key , subspace in space . items ( )
}
2023-02-18 21:08:08 +00:00
2023-02-22 15:05:58 +00:00
@concatenate.register ( Graph )
@concatenate.register ( Text )
@concatenate.register ( Sequence )
2023-02-18 21:08:08 +00:00
@concatenate.register ( Space )
2024-03-11 13:30:50 +01:00
@concatenate.register ( OneOf )
2023-02-22 15:05:58 +00:00
def _concatenate_custom ( space : Space , items : Iterable , out : None ) - > tuple [ Any , . . . ] :
2023-02-18 21:08:08 +00:00
return tuple ( items )
@singledispatch
def create_empty_array (
2025-06-07 10:31:31 -04:00
space : Space , n : int = 1 , fn : Callable = np . zeros
2023-02-22 15:05:58 +00:00
) - > tuple [ Any , . . . ] | dict [ str , Any ] | np . ndarray :
2024-09-20 13:44:07 +01:00
""" Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
2023-02-22 15:05:58 +00:00
In most cases , the array will be contained within the batched space , however , this is not guaranteed .
2023-02-18 21:08:08 +00:00
Args :
space : Observation space of a single environment in the vectorized environment .
2023-11-07 13:27:25 +00:00
n : Number of environments in the vectorized environment . If ` ` None ` ` , creates an empty sample from ` ` space ` ` .
fn : Function to apply when creating the empty numpy array . Examples of such functions are ` ` np . empty ` ` or ` ` np . zeros ` ` .
2023-02-18 21:08:08 +00:00
Returns :
The output object . This object is a ( possibly nested ) numpy array .
Raises :
2023-11-07 13:27:25 +00:00
ValueError : Space is not a valid : class : ` gymnasium . Space ` instance
2023-02-18 21:08:08 +00:00
Example :
>> > from gymnasium . spaces import Box , Dict
>> > import numpy as np
>> > space = Dict ( {
. . . ' position ' : Box ( low = 0 , high = 1 , shape = ( 3 , ) , dtype = np . float32 ) ,
. . . ' velocity ' : Box ( low = 0 , high = 1 , shape = ( 2 , ) , dtype = np . float32 ) } )
>> > create_empty_array ( space , n = 2 , fn = np . zeros )
2024-03-22 11:19:41 +00:00
{ ' position ' : array ( [ [ 0. , 0. , 0. ] ,
[ 0. , 0. , 0. ] ] , dtype = float32 ) , ' velocity ' : array ( [ [ 0. , 0. ] ,
[ 0. , 0. ] ] , dtype = float32 ) }
2023-02-18 21:08:08 +00:00
"""
2023-02-22 15:05:58 +00:00
raise TypeError (
f " The space provided to `create_empty_array` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
2023-02-18 21:08:08 +00:00
)
2024-03-11 13:30:50 +01:00
# It is possible for some of the Box low to be greater than 0, then array is not in space
2023-02-18 21:08:08 +00:00
@create_empty_array.register ( Box )
2023-02-23 16:07:58 +00:00
# If the Discrete start > 0 or start + length < 0 then array is not in space
2023-02-18 21:08:08 +00:00
@create_empty_array.register ( Discrete )
@create_empty_array.register ( MultiDiscrete )
@create_empty_array.register ( MultiBinary )
2023-02-22 15:05:58 +00:00
def _create_empty_array_multi ( space : Box , n : int = 1 , fn = np . zeros ) - > np . ndarray :
return fn ( ( n , ) + space . shape , dtype = space . dtype )
2023-02-18 21:08:08 +00:00
@create_empty_array.register ( Tuple )
2023-02-22 15:05:58 +00:00
def _create_empty_array_tuple ( space : Tuple , n : int = 1 , fn = np . zeros ) - > tuple [ Any , . . . ] :
2023-02-18 21:08:08 +00:00
return tuple ( create_empty_array ( subspace , n = n , fn = fn ) for subspace in space . spaces )
@create_empty_array.register ( Dict )
2023-02-22 15:05:58 +00:00
def _create_empty_array_dict ( space : Dict , n : int = 1 , fn = np . zeros ) - > dict [ str , Any ] :
2024-03-22 11:19:41 +00:00
return {
key : create_empty_array ( subspace , n = n , fn = fn ) for key , subspace in space . items ( )
}
2023-02-18 21:08:08 +00:00
2023-02-22 15:05:58 +00:00
@create_empty_array.register ( Graph )
def _create_empty_array_graph (
space : Graph , n : int = 1 , fn = np . zeros
) - > tuple [ GraphInstance , . . . ] :
if space . edge_space is not None :
return tuple (
GraphInstance (
nodes = fn ( ( 1 , ) + space . node_space . shape , dtype = space . node_space . dtype ) ,
edges = fn ( ( 1 , ) + space . edge_space . shape , dtype = space . edge_space . dtype ) ,
edge_links = fn ( ( 1 , 2 ) , dtype = np . int64 ) ,
)
for _ in range ( n )
)
else :
return tuple (
GraphInstance (
nodes = fn ( ( 1 , ) + space . node_space . shape , dtype = space . node_space . dtype ) ,
edges = None ,
edge_links = None ,
)
for _ in range ( n )
)
@create_empty_array.register ( Text )
def _create_empty_array_text ( space : Text , n : int = 1 , fn = np . zeros ) - > tuple [ str , . . . ] :
return tuple ( space . characters [ 0 ] * space . min_length for _ in range ( n ) )
@create_empty_array.register ( Sequence )
def _create_empty_array_sequence (
space : Sequence , n : int = 1 , fn = np . zeros
) - > tuple [ Any , . . . ] :
if space . stack :
return tuple (
create_empty_array ( space . feature_space , n = 1 , fn = fn ) for _ in range ( n )
)
else :
return tuple ( tuple ( ) for _ in range ( n ) )
2024-03-11 13:30:50 +01:00
@create_empty_array.register ( OneOf )
def _create_empty_array_oneof ( space : OneOf , n : int = 1 , fn = np . zeros ) :
return tuple ( tuple ( ) for _ in range ( n ) )
2023-02-18 21:08:08 +00:00
@create_empty_array.register ( Space )
def _create_empty_array_custom ( space , n = 1 , fn = np . zeros ) :
return None