2022-05-10 17:18:06 +02:00
""" Implementation of a space that represents closed boxes in euclidean space. """
2022-11-15 14:09:22 +00:00
from __future__ import annotations
from typing import Any , Iterable , Mapping , Sequence , SupportsFloat
2022-01-24 23:22:11 +01:00
2016-04-27 08:00:58 -07:00
import numpy as np
2022-11-15 14:09:22 +00:00
from numpy . typing import NDArray
2018-11-29 02:27:27 +01:00
2022-09-16 23:41:27 +01:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium . spaces . space import Space
2022-03-31 12:50:38 -07:00
2019-01-30 22:39:55 +01:00
2022-11-15 14:09:22 +00:00
def _short_repr ( arr : NDArray [ Any ] ) - > str :
2022-01-13 19:41:53 +01:00
""" Create a shortened string representation of a numpy array.
If arr is a multiple of the all - ones vector , return a string representation of the multiplier .
Otherwise , return a string representation of the entire array .
2022-05-25 14:46:41 +01:00
Args :
arr : The array to represent
Returns :
A short representation of the array
2022-01-13 19:41:53 +01:00
"""
if arr . size != 0 and np . min ( arr ) == np . max ( arr ) :
return str ( np . min ( arr ) )
return str ( arr )
2022-11-15 14:09:22 +00:00
def is_float_integer ( var : Any ) - > bool :
2022-06-29 16:17:25 +01:00
""" Checks if a variable is an integer or float. """
return np . issubdtype ( type ( var ) , np . integer ) or np . issubdtype ( type ( var ) , np . floating )
2022-11-15 14:09:22 +00:00
class Box ( Space [ NDArray [ Any ] ] ) :
2022-05-10 17:18:06 +02:00
r """ A (possibly unbounded) box in :math:` \ mathbb {R} ^n`.
Specifically , a Box represents the Cartesian product of n closed intervals .
Each interval has the form of one of : math : ` [ a , b ] ` , : math : ` ( - \infty , b ] ` ,
: math : ` [ a , \infty ) ` , or : math : ` ( - \infty , \infty ) ` .
2020-04-25 00:24:35 +02:00
2019-03-25 00:39:32 +01:00
There are two common use cases :
2020-04-25 00:24:35 +02:00
2019-03-25 00:39:32 +01:00
* Identical bound for each dimension : :
2022-04-08 03:19:52 +02:00
2019-03-25 00:39:32 +01:00
>> > Box ( low = - 1.0 , high = 2.0 , shape = ( 3 , 4 ) , dtype = np . float32 )
2023-01-20 14:28:09 +01:00
Box ( - 1.0 , 2.0 , ( 3 , 4 ) , float32 )
2020-04-25 00:24:35 +02:00
2019-03-25 00:39:32 +01:00
* Independent bound for each dimension : :
2022-04-08 03:19:52 +02:00
2019-03-25 00:39:32 +01:00
>> > Box ( low = np . array ( [ - 1.0 , - 2.0 ] ) , high = np . array ( [ 2.0 , 4.0 ] ) , dtype = np . float32 )
2023-01-20 14:28:09 +01:00
Box ( [ - 1. - 2. ] , [ 2. 4. ] , ( 2 , ) , float32 )
2016-04-27 08:00:58 -07:00
"""
2021-07-29 02:26:34 +02:00
2022-01-24 23:22:11 +01:00
def __init__ (
self ,
2022-11-15 14:09:22 +00:00
low : SupportsFloat | NDArray [ Any ] ,
high : SupportsFloat | NDArray [ Any ] ,
shape : Sequence [ int ] | None = None ,
dtype : type [ np . floating [ Any ] ] | type [ np . integer [ Any ] ] = np . float32 ,
seed : int | np . random . Generator | None = None ,
2022-01-24 23:22:11 +01:00
) :
2022-05-10 17:18:06 +02:00
r """ Constructor of :class:`Box`.
The argument ` ` low ` ` specifies the lower bound of each dimension and ` ` high ` ` specifies the upper bounds .
I . e . , the space that is constructed will be the product of the intervals : math : ` [ \text { low } [ i ] , \text { high } [ i ] ] ` .
If ` ` low ` ` ( or ` ` high ` ` ) is a scalar , the lower bound ( or upper bound , respectively ) will be assumed to be
this value across all dimensions .
Args :
2023-02-14 17:31:37 -08:00
low ( SupportsFloat | np . ndarray ) : Lower bounds of the intervals . If integer , must be at least ` ` - 2 * * 63 ` ` .
high ( SupportsFloat | np . ndarray ] ) : Upper bounds of the intervals . If integer , must be at most ` ` 2 * * 63 - 2 ` ` .
2022-06-29 16:17:25 +01:00
shape ( Optional [ Sequence [ int ] ] ) : The shape is inferred from the shape of ` low ` or ` high ` ` np . ndarray ` s with
` low ` and ` high ` scalars defaulting to a shape of ( 1 , )
2022-05-10 17:18:06 +02:00
dtype : The dtype of the elements of the space . If this is an integer type , the : class : ` Box ` is essentially a discrete space .
seed : Optionally , you can use this argument to seed the RNG that is used to sample from the space .
2022-05-25 14:46:41 +01:00
Raises :
ValueError : If no shape information is provided ( shape is None , low is None and high is None ) then a
value error is raised .
2022-05-10 17:18:06 +02:00
"""
2022-06-29 16:17:25 +01:00
assert (
dtype is not None
) , " Box dtype must be explicitly provided, cannot be None. "
2019-03-25 00:39:32 +01:00
self . dtype = np . dtype ( dtype )
2020-05-08 17:56:14 -04:00
# determine shape if it isn't provided directly
if shape is not None :
2022-06-29 16:17:25 +01:00
assert all (
np . issubdtype ( type ( dim ) , np . integer ) for dim in shape
2023-05-15 14:07:57 +01:00
) , f " Expected all shape elements to be an integer, actual type: { tuple ( type ( dim ) for dim in shape ) } "
2022-06-29 16:17:25 +01:00
shape = tuple ( int ( dim ) for dim in shape ) # This changes any np types to int
elif isinstance ( low , np . ndarray ) :
shape = low . shape
elif isinstance ( high , np . ndarray ) :
shape = high . shape
elif is_float_integer ( low ) and is_float_integer ( high ) :
shape = ( 1 , )
2016-04-27 08:00:58 -07:00
else :
2021-07-29 15:39:42 -04:00
raise ValueError (
2023-05-15 14:07:57 +01:00
f " Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: { type ( low ) } , high: { type ( high ) } "
2021-07-29 15:39:42 -04:00
)
2020-05-08 17:56:14 -04:00
2022-03-02 07:51:06 -08:00
# Capture the boundedness information before replacing np.inf with get_inf
2022-06-29 16:17:25 +01:00
_low = np . full ( shape , low , dtype = float ) if is_float_integer ( low ) else low
2023-02-13 18:18:40 +01:00
self . bounded_below : NDArray [ np . bool_ ] = - np . inf < _low
2022-06-29 16:17:25 +01:00
_high = np . full ( shape , high , dtype = float ) if is_float_integer ( high ) else high
2023-02-13 18:18:40 +01:00
self . bounded_above : NDArray [ np . bool_ ] = np . inf > _high
2022-03-02 07:51:06 -08:00
2023-05-15 14:07:57 +01:00
low = _broadcast ( low , self . dtype , shape )
high = _broadcast ( high , self . dtype , shape )
2022-01-24 23:22:11 +01:00
assert isinstance ( low , np . ndarray )
2022-06-29 16:17:25 +01:00
assert (
low . shape == shape
) , f " low.shape doesn ' t match provided shape, low.shape: { low . shape } , shape: { shape } "
2022-01-24 23:22:11 +01:00
assert isinstance ( high , np . ndarray )
2022-06-29 16:17:25 +01:00
assert (
high . shape == shape
) , f " high.shape doesn ' t match provided shape, high.shape: { high . shape } , shape: { shape } "
2022-01-24 23:22:11 +01:00
2023-05-15 14:07:57 +01:00
# check that we don't have invalid low or high
if np . any ( low > high ) :
raise ValueError (
f " Some low values are greater than high, low= { low } , high= { high } "
)
if np . any ( np . isposinf ( low ) ) :
raise ValueError ( f " No low value can be equal to `np.inf`, low= { low } " )
if np . any ( np . isneginf ( high ) ) :
raise ValueError ( f " No high value can be equal to `-np.inf`, high= { high } " )
2022-11-15 14:09:22 +00:00
self . _shape : tuple [ int , . . . ] = shape
2022-01-24 23:22:11 +01:00
low_precision = get_precision ( low . dtype )
high_precision = get_precision ( high . dtype )
2022-01-11 04:45:41 +00:00
dtype_precision = get_precision ( self . dtype )
2022-11-15 14:09:22 +00:00
if min ( low_precision , high_precision ) > dtype_precision :
gym . logger . warn ( f " Box bound precision lowered by casting to { self . dtype } " )
2022-01-24 23:22:11 +01:00
self . low = low . astype ( self . dtype )
self . high = high . astype ( self . dtype )
2019-06-28 18:54:31 -04:00
2022-01-13 19:41:53 +01:00
self . low_repr = _short_repr ( self . low )
self . high_repr = _short_repr ( self . high )
2021-11-14 14:50:23 +01:00
super ( ) . __init__ ( self . shape , self . dtype , seed )
Cleanup, removal of unmaintained code (#836)
* add dtype to Box
* remove board_game, debugging, safety, parameter_tuning environments
* massive set of breaking changes
- remove python logging module
- _step, _reset, _seed, _close => non underscored method
- remove benchmark and scoring folder
* Improve render("human"), now resizable, closable window.
* get rid of default step and reset in wrappers, so it doesn’t silently fail for people with underscore methods
* CubeCrash unit test environment
* followup fixes
* MemorizeDigits unit test envrionment
* refactored spaces a bit
fixed indentation
disabled test_env_semantics
* fix unit tests
* fixes
* CubeCrash, MemorizeDigits tested
* gym backwards compatibility patch
* gym backwards compatibility, followup fixes
* changelist, add spaces to main namespaces
* undo_logger_setup for backwards compat
* remove configuration.py
2018-01-25 18:20:14 -08:00
2022-01-24 23:22:11 +01:00
@property
2022-11-15 14:09:22 +00:00
def shape ( self ) - > tuple [ int , . . . ] :
""" Has stricter type than gym.Space - never None. """
2022-01-24 23:22:11 +01:00
return self . _shape
2022-08-15 17:11:32 +02:00
@property
def is_np_flattenable ( self ) :
""" Checks whether this space can be flattened to a :class:`spaces.Box`. """
return True
2022-01-24 23:22:11 +01:00
def is_bounded ( self , manner : str = " both " ) - > bool :
2022-05-10 17:18:06 +02:00
""" Checks whether the box is bounded in some sense.
Args :
manner ( str ) : One of ` ` " both " ` ` , ` ` " below " ` ` , ` ` " above " ` ` .
2022-05-25 14:46:41 +01:00
Returns :
If the space is bounded
2022-05-10 17:18:06 +02:00
Raises :
2022-05-24 23:09:05 +01:00
ValueError : If ` manner ` is neither ` ` " both " ` ` nor ` ` " below " ` ` or ` ` " above " ` `
2022-05-10 17:18:06 +02:00
"""
2022-01-24 23:22:11 +01:00
below = bool ( np . all ( self . bounded_below ) )
above = bool ( np . all ( self . bounded_above ) )
2019-06-28 18:54:31 -04:00
if manner == " both " :
return below and above
elif manner == " below " :
return below
elif manner == " above " :
return above
else :
2022-09-03 22:56:29 +01:00
raise ValueError (
f " manner is not in {{ ' below ' , ' above ' , ' both ' }} , actual value: { manner } "
)
2019-06-28 18:54:31 -04:00
2022-11-15 14:09:22 +00:00
def sample ( self , mask : None = None ) - > NDArray [ Any ] :
2022-05-10 17:18:06 +02:00
r """ Generates a single random sample inside the Box.
2019-06-28 18:54:31 -04:00
2022-05-10 17:18:06 +02:00
In creating a sample of the box , each coordinate is sampled ( independently ) from a distribution
that is chosen according to the form of the interval :
2020-04-25 00:24:35 +02:00
2022-05-10 17:18:06 +02:00
* : math : ` [ a , b ] ` : uniform distribution
* : math : ` [ a , \infty ) ` : shifted exponential distribution
* : math : ` ( - \infty , b ] ` : shifted negative exponential distribution
* : math : ` ( - \infty , \infty ) ` : normal distribution
2022-05-25 14:46:41 +01:00
2022-06-26 23:23:15 +01:00
Args :
mask : A mask for sampling values from the Box space , currently unsupported .
2022-05-25 14:46:41 +01:00
Returns :
A sampled value from the Box
2019-06-28 18:54:31 -04:00
"""
2022-06-26 23:23:15 +01:00
if mask is not None :
2022-09-16 23:41:27 +01:00
raise gym . error . Error (
2022-06-26 23:23:15 +01:00
f " Box.sample cannot be provided a mask, actual value: { mask } "
)
2021-07-29 02:26:34 +02:00
high = self . high if self . dtype . kind == " f " else self . high . astype ( " int64 " ) + 1
2019-06-28 18:54:31 -04:00
sample = np . empty ( self . shape )
2018-09-24 20:11:03 +02:00
2022-11-15 14:09:22 +00:00
# Masking arrays which classify the coordinates according to interval type
2021-07-29 02:26:34 +02:00
unbounded = ~ self . bounded_below & ~ self . bounded_above
upp_bounded = ~ self . bounded_below & self . bounded_above
low_bounded = self . bounded_below & ~ self . bounded_above
bounded = self . bounded_below & self . bounded_above
2019-06-28 18:54:31 -04:00
# Vectorized sampling by interval type
2021-07-29 02:26:34 +02:00
sample [ unbounded ] = self . np_random . normal ( size = unbounded [ unbounded ] . shape )
2021-07-29 15:39:42 -04:00
sample [ low_bounded ] = (
self . np_random . exponential ( size = low_bounded [ low_bounded ] . shape )
+ self . low [ low_bounded ]
)
2021-07-29 02:26:34 +02:00
2021-07-29 15:39:42 -04:00
sample [ upp_bounded ] = (
- self . np_random . exponential ( size = upp_bounded [ upp_bounded ] . shape )
2023-02-06 14:30:09 +03:30
+ high [ upp_bounded ]
2021-07-29 15:39:42 -04:00
)
2021-07-29 02:26:34 +02:00
2021-07-29 15:39:42 -04:00
sample [ bounded ] = self . np_random . uniform (
low = self . low [ bounded ] , high = high [ bounded ] , size = bounded [ bounded ] . shape
)
2023-02-06 14:30:09 +03:30
if self . dtype . kind in [ " i " , " u " , " b " ] :
2019-11-02 04:52:11 +05:30
sample = np . floor ( sample )
2019-06-28 18:54:31 -04:00
return sample . astype ( self . dtype )
2020-04-25 00:24:35 +02:00
2022-11-15 14:09:22 +00:00
def contains ( self , x : Any ) - > bool :
2022-05-10 17:18:06 +02:00
""" Return boolean specifying if x is a valid member of this space. """
2021-09-01 18:14:22 +02:00
if not isinstance ( x , np . ndarray ) :
2022-11-15 14:09:22 +00:00
gym . logger . warn ( " Casting input x to numpy array. " )
2022-09-03 22:56:29 +01:00
try :
x = np . asarray ( x , dtype = self . dtype )
except ( ValueError , TypeError ) :
return False
2021-09-01 18:14:22 +02:00
2022-01-24 23:22:11 +01:00
return bool (
2021-09-01 18:14:22 +02:00
np . can_cast ( x . dtype , self . dtype )
and x . shape == self . shape
2021-09-03 18:28:58 +02:00
and np . all ( x > = self . low )
and np . all ( x < = self . high )
2021-07-29 15:39:42 -04:00
)
2016-04-27 08:00:58 -07:00
2023-01-30 18:27:32 +05:30
def to_jsonable ( self , sample_n : Sequence [ NDArray [ Any ] ] ) - > list [ list ] :
2022-05-10 17:18:06 +02:00
""" Convert a batch of samples from this space to a JSONable data type. """
2023-01-30 18:27:32 +05:30
return [ sample . tolist ( ) for sample in sample_n ]
2018-09-24 20:11:03 +02:00
2022-11-15 14:09:22 +00:00
def from_jsonable ( self , sample_n : Sequence [ float | int ] ) - > list [ NDArray [ Any ] ] :
2022-05-10 17:18:06 +02:00
""" Convert a JSONable data type to a batch of samples from this space. """
2023-05-23 17:03:25 +01:00
return [ np . asarray ( sample , dtype = self . dtype ) for sample in sample_n ]
2016-04-27 08:00:58 -07:00
2022-01-24 23:22:11 +01:00
def __repr__ ( self ) - > str :
2022-05-10 17:18:06 +02:00
""" A string representation of this space.
The representation will include bounds , shape and dtype .
If a bound is uniform , only the corresponding scalar will be given to avoid redundant and ugly strings .
2022-05-25 14:46:41 +01:00
Returns :
A representation of the space
2022-05-10 17:18:06 +02:00
"""
2022-01-13 19:41:53 +01:00
return f " Box( { self . low_repr } , { self . high_repr } , { self . shape } , { self . dtype } ) "
2018-11-29 02:27:27 +01:00
2022-11-15 14:09:22 +00:00
def __eq__ ( self , other : Any ) - > bool :
2022-09-03 22:56:29 +01:00
""" Check whether `other` is equivalent to this instance. Doesn ' t check dtype equivalence. """
2021-07-29 02:26:34 +02:00
return (
isinstance ( other , Box )
and ( self . shape == other . shape )
2023-11-07 13:27:25 +00:00
and ( self . dtype == other . dtype )
2021-07-29 02:26:34 +02:00
and np . allclose ( self . low , other . low )
and np . allclose ( self . high , other . high )
)
2022-01-11 04:45:41 +00:00
2022-11-15 14:09:22 +00:00
def __setstate__ ( self , state : Iterable [ tuple [ str , Any ] ] | Mapping [ str , Any ] ) :
2022-05-31 23:53:13 -04:00
""" Sets the state of the box for unpickling a box with legacy support. """
super ( ) . __setstate__ ( state )
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
if not hasattr ( self , " low_repr " ) :
self . low_repr = _short_repr ( self . low )
if not hasattr ( self , " high_repr " ) :
self . high_repr = _short_repr ( self . high )
2022-01-11 04:45:41 +00:00
2022-11-15 14:09:22 +00:00
def get_precision ( dtype : np . dtype ) - > SupportsFloat :
2022-05-10 17:18:06 +02:00
""" Get precision of a data type. """
2022-01-11 04:45:41 +00:00
if np . issubdtype ( dtype , np . floating ) :
return np . finfo ( dtype ) . precision
else :
return np . inf
2022-01-24 23:22:11 +01:00
def _broadcast (
2022-11-15 14:09:22 +00:00
value : SupportsFloat | NDArray [ Any ] ,
dtype : np . dtype ,
shape : tuple [ int , . . . ] ,
) - > NDArray [ Any ] :
2023-05-15 14:07:57 +01:00
""" Handle infinite bounds and broadcast at the same time if needed.
This is needed primarily because :
>> > import numpy as np
>> > np . full ( ( 2 , ) , np . inf , dtype = np . int32 )
array ( [ - 2147483648 , - 2147483648 ] , dtype = int32 )
"""
2022-06-29 16:17:25 +01:00
if is_float_integer ( value ) :
2023-05-15 14:07:57 +01:00
if np . isneginf ( value ) and np . dtype ( dtype ) . kind == " i " :
value = np . iinfo ( dtype ) . min + 2
elif np . isposinf ( value ) and np . dtype ( dtype ) . kind == " i " :
value = np . iinfo ( dtype ) . max - 2
return np . full ( shape , value , dtype = dtype )
elif isinstance ( value , np . ndarray ) :
# this is needed because we can't stuff np.iinfo(int).min into an array of dtype float
casted_value = value . astype ( dtype )
# change bounds only if values are negative or positive infinite
if np . dtype ( dtype ) . kind == " i " :
casted_value [ np . isneginf ( value ) ] = np . iinfo ( dtype ) . min + 2
casted_value [ np . isposinf ( value ) ] = np . iinfo ( dtype ) . max - 2
return casted_value
2022-01-24 23:22:11 +01:00
else :
2023-05-15 14:07:57 +01:00
# only np.ndarray allowed beyond this point
raise TypeError (
f " Unknown dtype for `value`, expected `np.ndarray` or float/integer, got { type ( value ) } "
)