2022-05-10 17:18:06 +02:00
""" Implementation of a space that represents closed boxes in euclidean space. """
2022-05-31 23:53:13 -04:00
from typing import Dict , List , Optional , Sequence , SupportsFloat , Tuple , Type , Union
2022-01-24 23:22:11 +01:00
2016-04-27 08:00:58 -07:00
import numpy as np
2018-11-29 02:27:27 +01:00
2022-09-08 10:10:07 +01:00
import gymnasium . error
from gymnasium import logger
from gymnasium . spaces . space import Space
2022-03-31 12:50:38 -07:00
2019-01-30 22:39:55 +01:00
2022-01-24 23:22:11 +01:00
def _short_repr ( arr : np . ndarray ) - > str :
2022-01-13 19:41:53 +01:00
""" Create a shortened string representation of a numpy array.
If arr is a multiple of the all - ones vector , return a string representation of the multiplier .
Otherwise , return a string representation of the entire array .
2022-05-25 14:46:41 +01:00
Args :
arr : The array to represent
Returns :
A short representation of the array
2022-01-13 19:41:53 +01:00
"""
if arr . size != 0 and np . min ( arr ) == np . max ( arr ) :
return str ( np . min ( arr ) )
return str ( arr )
2022-06-29 16:17:25 +01:00
def is_float_integer ( var ) - > bool :
""" Checks if a variable is an integer or float. """
return np . issubdtype ( type ( var ) , np . integer ) or np . issubdtype ( type ( var ) , np . floating )
2022-01-24 23:22:11 +01:00
class Box ( Space [ np . ndarray ] ) :
2022-05-10 17:18:06 +02:00
r """ A (possibly unbounded) box in :math:` \ mathbb {R} ^n`.
Specifically , a Box represents the Cartesian product of n closed intervals .
Each interval has the form of one of : math : ` [ a , b ] ` , : math : ` ( - \infty , b ] ` ,
: math : ` [ a , \infty ) ` , or : math : ` ( - \infty , \infty ) ` .
2020-04-25 00:24:35 +02:00
2019-03-25 00:39:32 +01:00
There are two common use cases :
2020-04-25 00:24:35 +02:00
2019-03-25 00:39:32 +01:00
* Identical bound for each dimension : :
2022-04-08 03:19:52 +02:00
2019-03-25 00:39:32 +01:00
>> > Box ( low = - 1.0 , high = 2.0 , shape = ( 3 , 4 ) , dtype = np . float32 )
Box ( 3 , 4 )
2020-04-25 00:24:35 +02:00
2019-03-25 00:39:32 +01:00
* Independent bound for each dimension : :
2022-04-08 03:19:52 +02:00
2019-03-25 00:39:32 +01:00
>> > Box ( low = np . array ( [ - 1.0 , - 2.0 ] ) , high = np . array ( [ 2.0 , 4.0 ] ) , dtype = np . float32 )
Box ( 2 , )
2016-04-27 08:00:58 -07:00
"""
2021-07-29 02:26:34 +02:00
2022-01-24 23:22:11 +01:00
def __init__ (
self ,
low : Union [ SupportsFloat , np . ndarray ] ,
high : Union [ SupportsFloat , np . ndarray ] ,
shape : Optional [ Sequence [ int ] ] = None ,
2022-05-25 15:28:19 +01:00
dtype : Type = np . float32 ,
2022-08-22 09:20:28 -04:00
seed : Optional [ Union [ int , np . random . Generator ] ] = None ,
2022-01-24 23:22:11 +01:00
) :
2022-05-10 17:18:06 +02:00
r """ Constructor of :class:`Box`.
The argument ` ` low ` ` specifies the lower bound of each dimension and ` ` high ` ` specifies the upper bounds .
I . e . , the space that is constructed will be the product of the intervals : math : ` [ \text { low } [ i ] , \text { high } [ i ] ] ` .
If ` ` low ` ` ( or ` ` high ` ` ) is a scalar , the lower bound ( or upper bound , respectively ) will be assumed to be
this value across all dimensions .
Args :
low ( Union [ SupportsFloat , np . ndarray ] ) : Lower bounds of the intervals .
high ( Union [ SupportsFloat , np . ndarray ] ) : Upper bounds of the intervals .
2022-06-29 16:17:25 +01:00
shape ( Optional [ Sequence [ int ] ] ) : The shape is inferred from the shape of ` low ` or ` high ` ` np . ndarray ` s with
` low ` and ` high ` scalars defaulting to a shape of ( 1 , )
2022-05-10 17:18:06 +02:00
dtype : The dtype of the elements of the space . If this is an integer type , the : class : ` Box ` is essentially a discrete space .
seed : Optionally , you can use this argument to seed the RNG that is used to sample from the space .
2022-05-25 14:46:41 +01:00
Raises :
ValueError : If no shape information is provided ( shape is None , low is None and high is None ) then a
value error is raised .
2022-05-10 17:18:06 +02:00
"""
2022-06-29 16:17:25 +01:00
assert (
dtype is not None
) , " Box dtype must be explicitly provided, cannot be None. "
2019-03-25 00:39:32 +01:00
self . dtype = np . dtype ( dtype )
2020-05-08 17:56:14 -04:00
# determine shape if it isn't provided directly
if shape is not None :
2022-06-29 16:17:25 +01:00
assert all (
np . issubdtype ( type ( dim ) , np . integer ) for dim in shape
) , f " Expect all shape elements to be an integer, actual type: { tuple ( type ( dim ) for dim in shape ) } "
shape = tuple ( int ( dim ) for dim in shape ) # This changes any np types to int
elif isinstance ( low , np . ndarray ) :
shape = low . shape
elif isinstance ( high , np . ndarray ) :
shape = high . shape
elif is_float_integer ( low ) and is_float_integer ( high ) :
shape = ( 1 , )
2016-04-27 08:00:58 -07:00
else :
2021-07-29 15:39:42 -04:00
raise ValueError (
2022-06-29 16:17:25 +01:00
f " Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: { type ( low ) } , high: { type ( high ) } "
2021-07-29 15:39:42 -04:00
)
2020-05-08 17:56:14 -04:00
2022-03-02 07:51:06 -08:00
# Capture the boundedness information before replacing np.inf with get_inf
2022-06-29 16:17:25 +01:00
_low = np . full ( shape , low , dtype = float ) if is_float_integer ( low ) else low
self . bounded_below = - np . inf < _low
_high = np . full ( shape , high , dtype = float ) if is_float_integer ( high ) else high
self . bounded_above = np . inf > _high
2022-03-02 07:51:06 -08:00
2022-01-24 23:22:11 +01:00
low = _broadcast ( low , dtype , shape , inf_sign = " - " ) # type: ignore
2022-03-02 07:51:06 -08:00
high = _broadcast ( high , dtype , shape , inf_sign = " + " ) # type: ignore
2022-01-24 23:22:11 +01:00
assert isinstance ( low , np . ndarray )
2022-06-29 16:17:25 +01:00
assert (
low . shape == shape
) , f " low.shape doesn ' t match provided shape, low.shape: { low . shape } , shape: { shape } "
2022-01-24 23:22:11 +01:00
assert isinstance ( high , np . ndarray )
2022-06-29 16:17:25 +01:00
assert (
high . shape == shape
) , f " high.shape doesn ' t match provided shape, high.shape: { high . shape } , shape: { shape } "
2022-01-24 23:22:11 +01:00
2022-05-25 15:28:19 +01:00
self . _shape : Tuple [ int , . . . ] = shape
2022-01-24 23:22:11 +01:00
low_precision = get_precision ( low . dtype )
high_precision = get_precision ( high . dtype )
2022-01-11 04:45:41 +00:00
dtype_precision = get_precision ( self . dtype )
2022-01-24 23:22:11 +01:00
if min ( low_precision , high_precision ) > dtype_precision : # type: ignore
2021-11-14 14:50:23 +01:00
logger . warn ( f " Box bound precision lowered by casting to { self . dtype } " )
2022-01-24 23:22:11 +01:00
self . low = low . astype ( self . dtype )
self . high = high . astype ( self . dtype )
2019-06-28 18:54:31 -04:00
2022-01-13 19:41:53 +01:00
self . low_repr = _short_repr ( self . low )
self . high_repr = _short_repr ( self . high )
2021-11-14 14:50:23 +01:00
super ( ) . __init__ ( self . shape , self . dtype , seed )
Cleanup, removal of unmaintained code (#836)
* add dtype to Box
* remove board_game, debugging, safety, parameter_tuning environments
* massive set of breaking changes
- remove python logging module
- _step, _reset, _seed, _close => non underscored method
- remove benchmark and scoring folder
* Improve render("human"), now resizable, closable window.
* get rid of default step and reset in wrappers, so it doesn’t silently fail for people with underscore methods
* CubeCrash unit test environment
* followup fixes
* MemorizeDigits unit test envrionment
* refactored spaces a bit
fixed indentation
disabled test_env_semantics
* fix unit tests
* fixes
* CubeCrash, MemorizeDigits tested
* gym backwards compatibility patch
* gym backwards compatibility, followup fixes
* changelist, add spaces to main namespaces
* undo_logger_setup for backwards compat
* remove configuration.py
2018-01-25 18:20:14 -08:00
2022-01-24 23:22:11 +01:00
@property
2022-05-25 15:28:19 +01:00
def shape ( self ) - > Tuple [ int , . . . ] :
2022-09-08 10:10:07 +01:00
""" Has stricter type than gymnasium.Space - never None. """
2022-01-24 23:22:11 +01:00
return self . _shape
2022-08-15 17:11:32 +02:00
@property
def is_np_flattenable ( self ) :
""" Checks whether this space can be flattened to a :class:`spaces.Box`. """
return True
2022-01-24 23:22:11 +01:00
def is_bounded ( self , manner : str = " both " ) - > bool :
2022-05-10 17:18:06 +02:00
""" Checks whether the box is bounded in some sense.
Args :
manner ( str ) : One of ` ` " both " ` ` , ` ` " below " ` ` , ` ` " above " ` ` .
2022-05-25 14:46:41 +01:00
Returns :
If the space is bounded
2022-05-10 17:18:06 +02:00
Raises :
2022-05-24 23:09:05 +01:00
ValueError : If ` manner ` is neither ` ` " both " ` ` nor ` ` " below " ` ` or ` ` " above " ` `
2022-05-10 17:18:06 +02:00
"""
2022-01-24 23:22:11 +01:00
below = bool ( np . all ( self . bounded_below ) )
above = bool ( np . all ( self . bounded_above ) )
2019-06-28 18:54:31 -04:00
if manner == " both " :
return below and above
elif manner == " below " :
return below
elif manner == " above " :
return above
else :
2022-09-03 22:56:29 +01:00
raise ValueError (
f " manner is not in {{ ' below ' , ' above ' , ' both ' }} , actual value: { manner } "
)
2019-06-28 18:54:31 -04:00
2022-06-26 23:23:15 +01:00
def sample ( self , mask : None = None ) - > np . ndarray :
2022-05-10 17:18:06 +02:00
r """ Generates a single random sample inside the Box.
2019-06-28 18:54:31 -04:00
2022-05-10 17:18:06 +02:00
In creating a sample of the box , each coordinate is sampled ( independently ) from a distribution
that is chosen according to the form of the interval :
2020-04-25 00:24:35 +02:00
2022-05-10 17:18:06 +02:00
* : math : ` [ a , b ] ` : uniform distribution
* : math : ` [ a , \infty ) ` : shifted exponential distribution
* : math : ` ( - \infty , b ] ` : shifted negative exponential distribution
* : math : ` ( - \infty , \infty ) ` : normal distribution
2022-05-25 14:46:41 +01:00
2022-06-26 23:23:15 +01:00
Args :
mask : A mask for sampling values from the Box space , currently unsupported .
2022-05-25 14:46:41 +01:00
Returns :
A sampled value from the Box
2019-06-28 18:54:31 -04:00
"""
2022-06-26 23:23:15 +01:00
if mask is not None :
2022-09-08 10:10:07 +01:00
raise gymnasium . error . Error (
2022-06-26 23:23:15 +01:00
f " Box.sample cannot be provided a mask, actual value: { mask } "
)
2021-07-29 02:26:34 +02:00
high = self . high if self . dtype . kind == " f " else self . high . astype ( " int64 " ) + 1
2019-06-28 18:54:31 -04:00
sample = np . empty ( self . shape )
2018-09-24 20:11:03 +02:00
2019-06-28 18:54:31 -04:00
# Masking arrays which classify the coordinates according to interval
# type
2021-07-29 02:26:34 +02:00
unbounded = ~ self . bounded_below & ~ self . bounded_above
upp_bounded = ~ self . bounded_below & self . bounded_above
low_bounded = self . bounded_below & ~ self . bounded_above
bounded = self . bounded_below & self . bounded_above
2019-06-28 18:54:31 -04:00
# Vectorized sampling by interval type
2021-07-29 02:26:34 +02:00
sample [ unbounded ] = self . np_random . normal ( size = unbounded [ unbounded ] . shape )
2021-07-29 15:39:42 -04:00
sample [ low_bounded ] = (
self . np_random . exponential ( size = low_bounded [ low_bounded ] . shape )
+ self . low [ low_bounded ]
)
2021-07-29 02:26:34 +02:00
2021-07-29 15:39:42 -04:00
sample [ upp_bounded ] = (
- self . np_random . exponential ( size = upp_bounded [ upp_bounded ] . shape )
+ self . high [ upp_bounded ]
)
2021-07-29 02:26:34 +02:00
2021-07-29 15:39:42 -04:00
sample [ bounded ] = self . np_random . uniform (
low = self . low [ bounded ] , high = high [ bounded ] , size = bounded [ bounded ] . shape
)
2021-07-29 02:26:34 +02:00
if self . dtype . kind == " i " :
2019-11-02 04:52:11 +05:30
sample = np . floor ( sample )
2019-06-28 18:54:31 -04:00
return sample . astype ( self . dtype )
2020-04-25 00:24:35 +02:00
2022-01-24 23:22:11 +01:00
def contains ( self , x ) - > bool :
2022-05-10 17:18:06 +02:00
""" Return boolean specifying if x is a valid member of this space. """
2021-09-01 18:14:22 +02:00
if not isinstance ( x , np . ndarray ) :
2021-10-02 08:36:02 +08:00
logger . warn ( " Casting input x to numpy array. " )
2022-09-03 22:56:29 +01:00
try :
x = np . asarray ( x , dtype = self . dtype )
except ( ValueError , TypeError ) :
return False
2021-09-01 18:14:22 +02:00
2022-01-24 23:22:11 +01:00
return bool (
2021-09-01 18:14:22 +02:00
np . can_cast ( x . dtype , self . dtype )
and x . shape == self . shape
2021-09-03 18:28:58 +02:00
and np . all ( x > = self . low )
and np . all ( x < = self . high )
2021-07-29 15:39:42 -04:00
)
2016-04-27 08:00:58 -07:00
def to_jsonable ( self , sample_n ) :
2022-05-10 17:18:06 +02:00
""" Convert a batch of samples from this space to a JSONable data type. """
2016-04-27 08:00:58 -07:00
return np . array ( sample_n ) . tolist ( )
2018-09-24 20:11:03 +02:00
2022-09-03 22:56:29 +01:00
def from_jsonable ( self , sample_n : Sequence [ Union [ float , int ] ] ) - > List [ np . ndarray ] :
2022-05-10 17:18:06 +02:00
""" Convert a JSONable data type to a batch of samples from this space. """
2016-04-27 08:00:58 -07:00
return [ np . asarray ( sample ) for sample in sample_n ]
2022-01-24 23:22:11 +01:00
def __repr__ ( self ) - > str :
2022-05-10 17:18:06 +02:00
""" A string representation of this space.
The representation will include bounds , shape and dtype .
If a bound is uniform , only the corresponding scalar will be given to avoid redundant and ugly strings .
2022-05-25 14:46:41 +01:00
Returns :
A representation of the space
2022-05-10 17:18:06 +02:00
"""
2022-01-13 19:41:53 +01:00
return f " Box( { self . low_repr } , { self . high_repr } , { self . shape } , { self . dtype } ) "
2018-11-29 02:27:27 +01:00
2022-01-24 23:22:11 +01:00
def __eq__ ( self , other ) - > bool :
2022-09-03 22:56:29 +01:00
""" Check whether `other` is equivalent to this instance. Doesn ' t check dtype equivalence. """
2021-07-29 02:26:34 +02:00
return (
isinstance ( other , Box )
and ( self . shape == other . shape )
2022-09-03 22:56:29 +01:00
# and (self.dtype == other.dtype)
2021-07-29 02:26:34 +02:00
and np . allclose ( self . low , other . low )
and np . allclose ( self . high , other . high )
)
2022-01-11 04:45:41 +00:00
2022-05-31 23:53:13 -04:00
def __setstate__ ( self , state : Dict ) :
""" Sets the state of the box for unpickling a box with legacy support. """
super ( ) . __setstate__ ( state )
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
if not hasattr ( self , " low_repr " ) :
self . low_repr = _short_repr ( self . low )
if not hasattr ( self , " high_repr " ) :
self . high_repr = _short_repr ( self . high )
2022-01-11 04:45:41 +00:00
2022-01-24 23:22:11 +01:00
def get_inf ( dtype , sign : str ) - > SupportsFloat :
2022-01-11 04:45:41 +00:00
""" Returns an infinite that doesn ' t break things.
2022-05-10 17:18:06 +02:00
Args :
dtype : An ` np . dtype `
sign ( str ) : must be either ` " + " ` or ` " - " `
2022-05-25 14:46:41 +01:00
Returns :
Gets an infinite value with the sign and dtype
Raises :
TypeError : Unknown sign , use either ' + ' or ' - '
ValueError : Unknown dtype for infinite bounds
2022-01-11 04:45:41 +00:00
"""
if np . dtype ( dtype ) . kind == " f " :
if sign == " + " :
return np . inf
elif sign == " - " :
return - np . inf
else :
raise TypeError ( f " Unknown sign { sign } , use either ' + ' or ' - ' " )
elif np . dtype ( dtype ) . kind == " i " :
if sign == " + " :
return np . iinfo ( dtype ) . max - 2
elif sign == " - " :
return np . iinfo ( dtype ) . min + 2
else :
raise TypeError ( f " Unknown sign { sign } , use either ' + ' or ' - ' " )
else :
raise ValueError ( f " Unknown dtype { dtype } for infinite bounds " )
2022-01-24 23:22:11 +01:00
def get_precision ( dtype ) - > SupportsFloat :
2022-05-10 17:18:06 +02:00
""" Get precision of a data type. """
2022-01-11 04:45:41 +00:00
if np . issubdtype ( dtype , np . floating ) :
return np . finfo ( dtype ) . precision
else :
return np . inf
2022-01-24 23:22:11 +01:00
def _broadcast (
value : Union [ SupportsFloat , np . ndarray ] ,
dtype ,
2022-05-25 15:28:19 +01:00
shape : Tuple [ int , . . . ] ,
2022-01-24 23:22:11 +01:00
inf_sign : str ,
) - > np . ndarray :
2022-05-10 17:18:06 +02:00
""" Handle infinite bounds and broadcast at the same time if needed. """
2022-06-29 16:17:25 +01:00
if is_float_integer ( value ) :
2022-01-24 23:22:11 +01:00
value = get_inf ( dtype , inf_sign ) if np . isinf ( value ) else value # type: ignore
value = np . full ( shape , value , dtype = dtype )
else :
assert isinstance ( value , np . ndarray )
if np . any ( np . isinf ( value ) ) :
# create new array with dtype, but maintain old one to preserve np.inf
temp = value . astype ( dtype )
temp [ np . isinf ( value ) ] = get_inf ( dtype , inf_sign )
value = temp
return value