2024-03-11 13:30:50 +01:00
""" Implementation of a space that represents the cartesian product of other spaces. """
2024-06-10 17:07:47 +01:00
2024-03-11 13:30:50 +01:00
from __future__ import annotations
import typing
2025-06-07 17:57:58 +01:00
from collections . abc import Iterable
from typing import Any
2024-03-11 13:30:50 +01:00
import numpy as np
from gymnasium . spaces . space import Space
class OneOf ( Space [ Any ] ) :
""" An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.
Elements of this space are elements of one of the constituent spaces .
Example :
>> > from gymnasium . spaces import OneOf , Box , Discrete
2024-04-28 16:10:35 +01:00
>> > observation_space = OneOf ( ( Discrete ( 2 ) , Box ( - 1 , 1 , shape = ( 2 , ) ) ) , seed = 123 )
2025-02-21 13:39:23 +00:00
>> > observation_space . sample ( ) # the first element is the space index (Discrete in this case) and the second element is the sample from Discrete
2024-09-03 12:30:58 +01:00
( np . int64 ( 0 ) , np . int64 ( 0 ) )
2025-02-21 13:39:23 +00:00
>> > observation_space . sample ( ) # this time the Box space was sampled as index=1
2024-09-03 12:30:58 +01:00
( np . int64 ( 1 ) , array ( [ - 0.00711833 , - 0.7257502 ] , dtype = float32 ) )
2024-03-11 13:30:50 +01:00
>> > observation_space [ 0 ]
Discrete ( 2 )
>> > observation_space [ 1 ]
Box ( - 1.0 , 1.0 , ( 2 , ) , float32 )
>> > len ( observation_space )
2
"""
def __init__ (
self ,
spaces : Iterable [ Space [ Any ] ] ,
seed : int | typing . Sequence [ int ] | np . random . Generator | None = None ,
) :
r """ Constructor of :class:`OneOf` space.
The generated instance will represent the cartesian product : math : ` \text { spaces } [ 0 ] \times . . . \times \text { spaces } [ - 1 ] ` .
Args :
spaces ( Iterable [ Space ] ) : The spaces that are involved in the cartesian product .
seed : Optionally , you can use this argument to seed the RNGs of the ` ` spaces ` ` to ensure reproducible sampling .
"""
2024-04-06 13:20:10 +01:00
assert isinstance ( spaces , Iterable ) , f " { spaces } is not an iterable "
2024-03-11 13:30:50 +01:00
self . spaces = tuple ( spaces )
assert len ( self . spaces ) > 0 , " Empty `OneOf` spaces are not supported. "
for space in self . spaces :
assert isinstance (
space , Space
) , f " { space } does not inherit from `gymnasium.Space`. Actual Type: { type ( space ) } "
super ( ) . __init__ ( None , None , seed )
@property
def is_np_flattenable ( self ) :
""" Checks whether this space can be flattened to a :class:`spaces.Box`. """
return all ( space . is_np_flattenable for space in self . spaces )
2024-04-28 16:10:35 +01:00
def seed ( self , seed : int | tuple [ int , . . . ] | None = None ) - > tuple [ int , . . . ] :
2024-03-11 13:30:50 +01:00
""" Seed the PRNG of this space and all subspaces.
Depending on the type of seed , the subspaces will be seeded differently
* ` ` None ` ` - All the subspaces will use a random initial seed
* ` ` Int ` ` - The integer is used to seed the : class : ` Tuple ` space that is used to generate seed values for each of the subspaces . Warning , this does not guarantee unique seeds for all the subspaces .
2024-04-28 16:10:35 +01:00
* ` ` Tuple [ int , . . . ] ` ` - Values used to seed the subspaces , first value seeds the OneOf and subsequent seed the subspaces . This allows the seeding of multiple composite subspaces ` ` [ 42 , 54 , . . . ] ` ` .
2024-03-11 13:30:50 +01:00
Args :
2024-04-28 16:10:35 +01:00
seed : An optional int or tuple of ints to seed the OneOf space and subspaces . See above for more details .
Returns :
A tuple of ints used to seed the OneOf space and subspaces
2024-03-11 13:30:50 +01:00
"""
2024-04-28 16:10:35 +01:00
if seed is None :
super_seed = super ( ) . seed ( None )
return ( super_seed , ) + tuple ( space . seed ( None ) for space in self . spaces )
2024-03-11 13:30:50 +01:00
elif isinstance ( seed , int ) :
2024-04-28 16:10:35 +01:00
super_seed = super ( ) . seed ( seed )
2024-03-11 13:30:50 +01:00
subseeds = self . np_random . integers (
np . iinfo ( np . int32 ) . max , size = len ( self . spaces )
)
2024-04-28 16:10:35 +01:00
# this is necessary such that after int or list/tuple seeding, the OneOf PRNG are equivalent
super ( ) . seed ( seed )
return ( super_seed , ) + tuple (
space . seed ( int ( subseed ) )
for space , subseed in zip ( self . spaces , subseeds )
)
elif isinstance ( seed , ( tuple , list ) ) :
if len ( seed ) != len ( self . spaces ) + 1 :
raise ValueError (
f " Expects that the subspaces of seeds equals the number of subspaces + 1. Actual length of seeds: { len ( seed ) } , length of subspaces: { len ( self . spaces ) } "
)
return ( super ( ) . seed ( seed [ 0 ] ) , ) + tuple (
space . seed ( subseed ) for space , subseed in zip ( self . spaces , seed [ 1 : ] )
)
2024-03-11 13:30:50 +01:00
else :
raise TypeError (
2024-04-28 16:10:35 +01:00
f " Expected None, int, or tuple of ints, actual type: { type ( seed ) } "
2024-03-11 13:30:50 +01:00
)
2025-02-21 13:39:23 +00:00
def sample (
self ,
mask : tuple [ Any | None , . . . ] | None = None ,
probability : tuple [ Any | None , . . . ] | None = None ,
) - > tuple [ int , Any ] :
2024-03-11 13:30:50 +01:00
""" Generates a single random sample inside this space.
This method draws independent samples from the subspaces .
Args :
mask : An optional tuple of optional masks for each of the subspace ' s samples,
expects the same number of masks as spaces
2025-02-21 13:39:23 +00:00
probability : An optional tuple of optional probability masks for each of the subspace ' s samples,
expects the same number of probability masks as spaces
2024-03-11 13:30:50 +01:00
Returns :
Tuple of the subspace ' s samples
"""
2024-04-06 13:20:10 +01:00
subspace_idx = self . np_random . integers ( 0 , len ( self . spaces ) , dtype = np . int64 )
2024-03-11 13:30:50 +01:00
subspace = self . spaces [ subspace_idx ]
2025-02-21 13:39:23 +00:00
if mask is not None and probability is not None :
raise ValueError (
f " Only one of `mask` or `probability` can be provided, actual values: mask= { mask } , probability= { probability } "
)
elif mask is not None :
2024-03-11 13:30:50 +01:00
assert isinstance (
mask , tuple
2025-02-21 13:39:23 +00:00
) , f " Expected type of `mask` is tuple, actual type: { type ( mask ) } "
2024-03-11 13:30:50 +01:00
assert len ( mask ) == len (
self . spaces
2025-02-21 13:39:23 +00:00
) , f " Expected length of `mask` is { len ( self . spaces ) } , actual length: { len ( mask ) } "
subspace_sample = subspace . sample ( mask = mask [ subspace_idx ] )
elif probability is not None :
assert isinstance (
probability , tuple
) , f " Expected type of `probability` is tuple, actual type: { type ( probability ) } "
assert len ( probability ) == len (
self . spaces
) , f " Expected length of `probability` is { len ( self . spaces ) } , actual length: { len ( probability ) } "
2024-03-11 13:30:50 +01:00
2025-02-21 13:39:23 +00:00
subspace_sample = subspace . sample ( probability = probability [ subspace_idx ] )
else :
subspace_sample = subspace . sample ( )
2024-03-11 13:30:50 +01:00
2025-02-21 13:39:23 +00:00
return subspace_idx , subspace_sample
2024-03-11 13:30:50 +01:00
def contains ( self , x : tuple [ int , Any ] ) - > bool :
""" Return boolean specifying if x is a valid member of this space. """
2024-04-06 13:20:10 +01:00
# subspace_idx, subspace_value = x
return (
isinstance ( x , tuple )
and len ( x ) == 2
and isinstance ( x [ 0 ] , ( np . int64 , int ) )
and 0 < = x [ 0 ] < len ( self . spaces )
and self . spaces [ x [ 0 ] ] . contains ( x [ 1 ] )
)
2024-03-11 13:30:50 +01:00
def __repr__ ( self ) - > str :
""" Gives a string representation of this space. """
return " OneOf( " + " , " . join ( [ str ( s ) for s in self . spaces ] ) + " ) "
def to_jsonable (
self , sample_n : typing . Sequence [ tuple [ int , Any ] ]
) - > list [ list [ Any ] ] :
""" Convert a batch of samples from this space to a JSONable data type. """
return [
[ int ( i ) , self . spaces [ i ] . to_jsonable ( [ subsample ] ) [ 0 ] ]
for ( i , subsample ) in sample_n
]
def from_jsonable ( self , sample_n : list [ list [ Any ] ] ) - > list [ tuple [ Any , . . . ] ] :
""" Convert a JSONable data type to a batch of samples from this space. """
return [
2024-04-06 13:20:10 +01:00
(
np . int64 ( space_idx ) ,
self . spaces [ space_idx ] . from_jsonable ( [ jsonable_sample ] ) [ 0 ] ,
)
2024-03-11 13:30:50 +01:00
for space_idx , jsonable_sample in sample_n
]
def __getitem__ ( self , index : int ) - > Space [ Any ] :
""" Get the subspace at specific `index`. """
return self . spaces [ index ]
def __len__ ( self ) - > int :
""" Get the number of subspaces that are involved in the cartesian product. """
return len ( self . spaces )
def __eq__ ( self , other : Any ) - > bool :
""" Check whether ``other`` is equivalent to this instance. """
return isinstance ( other , OneOf ) and self . spaces == other . spaces