2022-06-29 16:17:25 +01:00
import re
2022-08-30 19:47:26 +01:00
import warnings
2022-06-29 16:17:25 +01:00
import numpy as np
import pytest
2022-09-16 23:41:27 +01:00
import gymnasium as gym
2022-09-08 10:10:07 +01:00
from gymnasium . spaces import Box
2022-06-29 16:17:25 +01:00
@pytest.mark.parametrize (
" box,expected_shape " ,
[
2022-09-03 22:56:29 +01:00
( # Test with same 1-dim low and high shape
Box ( low = np . zeros ( 2 ) , high = np . ones ( 2 ) , dtype = np . int32 ) ,
2022-06-29 16:17:25 +01:00
( 2 , ) ,
2022-09-03 22:56:29 +01:00
) ,
( # Test with same multi-dim low and high shape
Box ( low = np . zeros ( ( 2 , 1 ) ) , high = np . ones ( ( 2 , 1 ) ) , dtype = np . int32 ) ,
2022-06-29 16:17:25 +01:00
( 2 , 1 ) ,
2022-09-03 22:56:29 +01:00
) ,
( # Test with scalar low high and different shape
2022-06-29 16:17:25 +01:00
Box ( low = 0 , high = 1 , shape = ( 5 , 2 ) ) ,
( 5 , 2 ) ,
2022-09-03 22:56:29 +01:00
) ,
2022-06-29 16:17:25 +01:00
( Box ( low = 0 , high = 1 ) , ( 1 , ) ) , # Test with int and int
( Box ( low = 0.0 , high = 1.0 ) , ( 1 , ) ) , # Test with float and float
( Box ( low = np . zeros ( 1 ) [ 0 ] , high = np . ones ( 1 ) [ 0 ] ) , ( 1 , ) ) ,
( Box ( low = 0.0 , high = 1 ) , ( 1 , ) ) , # Test with float and int
( Box ( low = 0 , high = np . int32 ( 1 ) ) , ( 1 , ) ) , # Test with python int and numpy int32
( Box ( low = 0 , high = np . ones ( 3 ) ) , ( 3 , ) ) , # Test with array and scalar
( Box ( low = np . zeros ( 3 ) , high = 1.0 ) , ( 3 , ) ) , # Test with array and scalar
] ,
)
2022-09-03 22:56:29 +01:00
def test_shape_inference ( box , expected_shape ) :
""" Test that the shape inference is as expected. """
2022-06-29 16:17:25 +01:00
assert box . shape == expected_shape
assert box . sample ( ) . shape == expected_shape
@pytest.mark.parametrize (
" value,valid " ,
[
( 1 , True ) ,
( 1.0 , True ) ,
( np . int32 ( 1 ) , True ) ,
( np . float32 ( 1.0 ) , True ) ,
( np . zeros ( 2 , dtype = np . float32 ) , True ) ,
( np . zeros ( ( 2 , 2 ) , dtype = np . float32 ) , True ) ,
( np . inf , True ) ,
2022-09-03 22:56:29 +01:00
( np . nan , True ) , # This is a weird case that we allow
2022-06-29 16:17:25 +01:00
( True , False ) ,
2022-12-19 12:53:06 +00:00
( np . bool_ ( True ) , False ) ,
2022-06-29 16:17:25 +01:00
( 1 + 1 j , False ) ,
( np . complex128 ( 1 + 1 j ) , False ) ,
( " string " , False ) ,
] ,
)
2022-09-03 22:56:29 +01:00
def test_low_high_values ( value , valid : bool ) :
""" Test what `low` and `high` values are valid for `Box` space. """
2022-06-29 16:17:25 +01:00
if valid :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2023-05-15 14:07:57 +01:00
Box ( low = - np . inf , high = value )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0 , tuple (
warning . message for warning in caught_warnings
)
2022-06-29 16:17:25 +01:00
else :
with pytest . raises (
ValueError ,
2022-09-03 22:56:29 +01:00
match = re . escape (
2023-05-15 14:07:57 +01:00
" expected their types to be np.ndarray, an integer or a float "
2022-09-03 22:56:29 +01:00
) ,
2022-06-29 16:17:25 +01:00
) :
2023-05-15 14:07:57 +01:00
Box ( low = - np . inf , high = value )
2022-06-29 16:17:25 +01:00
@pytest.mark.parametrize (
" low,high,kwargs,error,message " ,
[
(
0 ,
1 ,
{ " dtype " : None } ,
AssertionError ,
" Box dtype must be explicitly provided, cannot be None. " ,
) ,
(
0 ,
1 ,
{ " shape " : ( None , ) } ,
AssertionError ,
2023-05-15 14:07:57 +01:00
" Expected all shape elements to be an integer, actual type: (<class ' NoneType ' >,) " ,
2022-06-29 16:17:25 +01:00
) ,
(
0 ,
1 ,
{
" shape " : (
1 ,
None ,
)
} ,
AssertionError ,
2023-05-15 14:07:57 +01:00
" Expected all shape elements to be an integer, actual type: (<class ' int ' >, <class ' NoneType ' >) " ,
2022-06-29 16:17:25 +01:00
) ,
(
0 ,
1 ,
{
" shape " : (
np . int64 ( 1 ) ,
None ,
)
} ,
AssertionError ,
2023-05-15 14:07:57 +01:00
" Expected all shape elements to be an integer, actual type: (<class ' numpy.int64 ' >, <class ' NoneType ' >) " ,
2022-06-29 16:17:25 +01:00
) ,
(
None ,
None ,
{ } ,
ValueError ,
2023-05-15 14:07:57 +01:00
" Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class ' NoneType ' >, high: <class ' NoneType ' > " ,
2022-06-29 16:17:25 +01:00
) ,
(
0 ,
None ,
{ } ,
ValueError ,
2023-05-15 14:07:57 +01:00
" Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class ' int ' >, high: <class ' NoneType ' > " ,
2022-06-29 16:17:25 +01:00
) ,
(
np . zeros ( 3 ) ,
np . ones ( 2 ) ,
{ } ,
AssertionError ,
" high.shape doesn ' t match provided shape, high.shape: (2,), shape: (3,) " ,
) ,
] ,
)
2022-09-03 22:56:29 +01:00
def test_init_errors ( low , high , kwargs , error , message ) :
""" Test all constructor errors. """
2022-06-29 16:17:25 +01:00
with pytest . raises ( error , match = f " ^ { re . escape ( message ) } $ " ) :
Box ( low = low , high = high , * * kwargs )
2022-09-03 22:56:29 +01:00
def test_dtype_check ( ) :
""" Tests the Box contains function with different dtypes. """
# Related Issues:
# https://github.com/openai/gym/issues/2357
# https://github.com/openai/gym/issues/2298
space = Box ( 0 , 1 , ( ) , dtype = np . float32 )
# casting will match the correct type
assert np . array ( 0.5 , dtype = np . float32 ) in space
# float16 is in float32 space
assert np . array ( 0.5 , dtype = np . float16 ) in space
# float64 is not in float32 space
assert np . array ( 0.5 , dtype = np . float64 ) not in space
@pytest.mark.parametrize (
" space " ,
[
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . int32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . float32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . int64 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . float64 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int64 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float64 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . int32 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . float32 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . int64 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . float64 ) ,
] ,
)
def test_infinite_space ( space ) :
"""
To test spaces that are passed in have only 0 or infinite bounds because ` space . high ` and ` space . low `
are both modified within the init , we check for infinite when we know it ' s not 0
"""
assert np . all (
space . low < space . high
) , f " Box low bound ( { space . low } ) is not lower than the high bound ( { space . high } ) "
space . seed ( 0 )
sample = space . sample ( )
# check if space contains sample
assert (
sample in space
) , f " Sample ( { sample } ) not inside space according to `space.contains()` "
# manually check that the sign of the sample is within the bounds
assert np . all (
np . sign ( sample ) < = np . sign ( space . high )
) , f " Sign of sample ( { sample } ) is less than space upper bound ( { space . high } ) "
assert np . all (
np . sign ( space . low ) < = np . sign ( sample )
) , f " Sign of sample ( { sample } ) is more than space lower bound ( { space . low } ) "
# check that int bounds are bounded for everything
# but floats are unbounded for infinite
if np . any ( space . high != 0 ) :
assert (
space . is_bounded ( " above " ) is False
) , " inf upper bound supposed to be unbounded "
else :
assert (
space . is_bounded ( " above " ) is True
) , " non-inf upper bound supposed to be bounded "
if np . any ( space . low != 0 ) :
assert (
space . is_bounded ( " below " ) is False
) , " inf lower bound supposed to be unbounded "
else :
assert (
space . is_bounded ( " below " ) is True
) , " non-inf lower bound supposed to be bounded "
if np . any ( space . low != 0 ) or np . any ( space . high != 0 ) :
assert space . is_bounded ( " both " ) is False
else :
assert space . is_bounded ( " both " ) is True
# check for dtype
assert (
space . high . dtype == space . dtype
) , f " High ' s dtype { space . high . dtype } doesn ' t match `space.dtype` ' "
assert (
space . low . dtype == space . dtype
) , f " Low ' s dtype { space . high . dtype } doesn ' t match `space.dtype` ' "
with pytest . raises (
ValueError , match = " manner is not in { ' below ' , ' above ' , ' both ' }, actual value: "
) :
space . is_bounded ( " test " )
def test_legacy_state_pickling ( ) :
legacy_state = {
" dtype " : np . dtype ( " float32 " ) ,
" _shape " : ( 5 , ) ,
" low " : np . array ( [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ] , dtype = np . float32 ) ,
" high " : np . array ( [ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ] , dtype = np . float32 ) ,
" bounded_below " : np . array ( [ True , True , True , True , True ] ) ,
" bounded_above " : np . array ( [ True , True , True , True , True ] ) ,
" _np_random " : None ,
}
b = Box ( - 1 , 1 , ( ) )
assert " low_repr " in b . __dict__ and " high_repr " in b . __dict__
del b . __dict__ [ " low_repr " ]
del b . __dict__ [ " high_repr " ]
assert " low_repr " not in b . __dict__ and " high_repr " not in b . __dict__
b . __setstate__ ( legacy_state )
assert b . low_repr == " 0.0 "
assert b . high_repr == " 1.0 "
def test_sample_mask ( ) :
""" Box cannot have a mask applied. """
space = Box ( 0 , 1 )
with pytest . raises (
2022-09-16 23:41:27 +01:00
gym . error . Error ,
2022-09-03 22:56:29 +01:00
match = re . escape ( " Box.sample cannot be provided a mask, actual value: " ) ,
) :
space . sample ( mask = np . array ( [ 0 , 1 , 0 ] , dtype = np . int8 ) )
2023-05-15 14:07:57 +01:00
@pytest.mark.parametrize (
" low, high, shape, dtype, reason " ,
[
(
5.0 ,
3.0 ,
( ) ,
np . float32 ,
" Some low values are greater than high, low=5.0, high=3.0 " ,
) ,
(
np . array ( [ 5.0 , 6.0 ] ) ,
np . array ( [ 1.0 , 5.99 ] ) ,
( 2 , ) ,
np . float32 ,
" Some low values are greater than high, low=[5. 6.], high=[1. 5.99] " ,
) ,
(
np . inf ,
np . inf ,
( ) ,
np . float32 ,
" No low value can be equal to `np.inf`, low=inf " ,
) ,
(
np . array ( [ 0 , np . inf ] ) ,
np . array ( [ np . inf , np . inf ] ) ,
( 2 , ) ,
np . float32 ,
" No low value can be equal to `np.inf`, low=[ 0. inf] " ,
) ,
(
- np . inf ,
- np . inf ,
( ) ,
np . float32 ,
" No high value can be equal to `-np.inf`, high=-inf " ,
) ,
(
np . array ( [ - np . inf , - np . inf ] ) ,
np . array ( [ 0 , - np . inf ] ) ,
( 2 , ) ,
np . float32 ,
" No high value can be equal to `-np.inf`, high=[ 0. -inf] " ,
) ,
] ,
)
def test_invalid_low_high ( low , high , dtype , shape , reason ) :
""" Tests that we don ' t allow spaces with degenerate bounds, such as `Box(np.inf, -np.inf)`. """
with pytest . raises ( ValueError , match = re . escape ( reason ) ) :
Box ( low = low , high = high , dtype = dtype , shape = shape )