2022-06-29 16:17:25 +01:00
import re
2022-08-30 19:47:26 +01:00
import warnings
2022-06-29 16:17:25 +01:00
import numpy as np
import pytest
2022-09-08 10:10:07 +01:00
import gymnasium . error
from gymnasium . spaces import Box
from gymnasium . spaces . box import get_inf
2022-06-29 16:17:25 +01:00
@pytest.mark.parametrize (
" box,expected_shape " ,
[
2022-09-03 22:56:29 +01:00
( # Test with same 1-dim low and high shape
Box ( low = np . zeros ( 2 ) , high = np . ones ( 2 ) , dtype = np . int32 ) ,
2022-06-29 16:17:25 +01:00
( 2 , ) ,
2022-09-03 22:56:29 +01:00
) ,
( # Test with same multi-dim low and high shape
Box ( low = np . zeros ( ( 2 , 1 ) ) , high = np . ones ( ( 2 , 1 ) ) , dtype = np . int32 ) ,
2022-06-29 16:17:25 +01:00
( 2 , 1 ) ,
2022-09-03 22:56:29 +01:00
) ,
( # Test with scalar low high and different shape
2022-06-29 16:17:25 +01:00
Box ( low = 0 , high = 1 , shape = ( 5 , 2 ) ) ,
( 5 , 2 ) ,
2022-09-03 22:56:29 +01:00
) ,
2022-06-29 16:17:25 +01:00
( Box ( low = 0 , high = 1 ) , ( 1 , ) ) , # Test with int and int
( Box ( low = 0.0 , high = 1.0 ) , ( 1 , ) ) , # Test with float and float
( Box ( low = np . zeros ( 1 ) [ 0 ] , high = np . ones ( 1 ) [ 0 ] ) , ( 1 , ) ) ,
( Box ( low = 0.0 , high = 1 ) , ( 1 , ) ) , # Test with float and int
( Box ( low = 0 , high = np . int32 ( 1 ) ) , ( 1 , ) ) , # Test with python int and numpy int32
( Box ( low = 0 , high = np . ones ( 3 ) ) , ( 3 , ) ) , # Test with array and scalar
( Box ( low = np . zeros ( 3 ) , high = 1.0 ) , ( 3 , ) ) , # Test with array and scalar
] ,
)
2022-09-03 22:56:29 +01:00
def test_shape_inference ( box , expected_shape ) :
""" Test that the shape inference is as expected. """
2022-06-29 16:17:25 +01:00
assert box . shape == expected_shape
assert box . sample ( ) . shape == expected_shape
@pytest.mark.parametrize (
" value,valid " ,
[
( 1 , True ) ,
( 1.0 , True ) ,
( np . int32 ( 1 ) , True ) ,
( np . float32 ( 1.0 ) , True ) ,
( np . zeros ( 2 , dtype = np . float32 ) , True ) ,
( np . zeros ( ( 2 , 2 ) , dtype = np . float32 ) , True ) ,
( np . inf , True ) ,
2022-09-03 22:56:29 +01:00
( np . nan , True ) , # This is a weird case that we allow
2022-06-29 16:17:25 +01:00
( True , False ) ,
( np . bool8 ( True ) , False ) ,
( 1 + 1 j , False ) ,
( np . complex128 ( 1 + 1 j ) , False ) ,
( " string " , False ) ,
] ,
)
2022-09-03 22:56:29 +01:00
def test_low_high_values ( value , valid : bool ) :
""" Test what `low` and `high` values are valid for `Box` space. """
2022-06-29 16:17:25 +01:00
if valid :
2022-08-30 19:47:26 +01:00
with warnings . catch_warnings ( record = True ) as caught_warnings :
2022-06-29 16:17:25 +01:00
Box ( low = value , high = value )
2022-08-30 19:47:26 +01:00
assert len ( caught_warnings ) == 0 , tuple (
warning . message for warning in caught_warnings
)
2022-06-29 16:17:25 +01:00
else :
with pytest . raises (
ValueError ,
2022-09-03 22:56:29 +01:00
match = re . escape (
" expect their types to be np.ndarray, an integer or a float "
) ,
2022-06-29 16:17:25 +01:00
) :
Box ( low = value , high = value )
@pytest.mark.parametrize (
" low,high,kwargs,error,message " ,
[
(
0 ,
1 ,
{ " dtype " : None } ,
AssertionError ,
" Box dtype must be explicitly provided, cannot be None. " ,
) ,
(
0 ,
1 ,
{ " shape " : ( None , ) } ,
AssertionError ,
" Expect all shape elements to be an integer, actual type: (<class ' NoneType ' >,) " ,
) ,
(
0 ,
1 ,
{
" shape " : (
1 ,
None ,
)
} ,
AssertionError ,
" Expect all shape elements to be an integer, actual type: (<class ' int ' >, <class ' NoneType ' >) " ,
) ,
(
0 ,
1 ,
{
" shape " : (
np . int64 ( 1 ) ,
None ,
)
} ,
AssertionError ,
" Expect all shape elements to be an integer, actual type: (<class ' numpy.int64 ' >, <class ' NoneType ' >) " ,
) ,
(
None ,
None ,
{ } ,
ValueError ,
" Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: <class ' NoneType ' >, high: <class ' NoneType ' > " ,
) ,
(
0 ,
None ,
{ } ,
ValueError ,
" Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: <class ' int ' >, high: <class ' NoneType ' > " ,
) ,
(
np . zeros ( 3 ) ,
np . ones ( 2 ) ,
{ } ,
AssertionError ,
" high.shape doesn ' t match provided shape, high.shape: (2,), shape: (3,) " ,
) ,
] ,
)
2022-09-03 22:56:29 +01:00
def test_init_errors ( low , high , kwargs , error , message ) :
""" Test all constructor errors. """
2022-06-29 16:17:25 +01:00
with pytest . raises ( error , match = f " ^ { re . escape ( message ) } $ " ) :
Box ( low = low , high = high , * * kwargs )
2022-09-03 22:56:29 +01:00
def test_dtype_check ( ) :
""" Tests the Box contains function with different dtypes. """
# Related Issues:
# https://github.com/openai/gym/issues/2357
# https://github.com/openai/gym/issues/2298
space = Box ( 0 , 1 , ( ) , dtype = np . float32 )
# casting will match the correct type
assert np . array ( 0.5 , dtype = np . float32 ) in space
# float16 is in float32 space
assert np . array ( 0.5 , dtype = np . float16 ) in space
# float64 is not in float32 space
assert np . array ( 0.5 , dtype = np . float64 ) not in space
@pytest.mark.parametrize (
" space " ,
[
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . int32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . float32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . int64 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , ) , dtype = np . float64 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float32 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int64 ) ,
Box ( low = 0 , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = 0 , shape = ( 2 , 3 ) , dtype = np . float64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float32 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . int64 ) ,
Box ( low = - np . inf , high = np . inf , shape = ( 2 , 3 ) , dtype = np . float64 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . int32 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . float32 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . int64 ) ,
Box ( low = np . array ( [ - np . inf , 0 ] ) , high = np . array ( [ 0.0 , np . inf ] ) , dtype = np . float64 ) ,
] ,
)
def test_infinite_space ( space ) :
"""
To test spaces that are passed in have only 0 or infinite bounds because ` space . high ` and ` space . low `
are both modified within the init , we check for infinite when we know it ' s not 0
"""
assert np . all (
space . low < space . high
) , f " Box low bound ( { space . low } ) is not lower than the high bound ( { space . high } ) "
space . seed ( 0 )
sample = space . sample ( )
# check if space contains sample
assert (
sample in space
) , f " Sample ( { sample } ) not inside space according to `space.contains()` "
# manually check that the sign of the sample is within the bounds
assert np . all (
np . sign ( sample ) < = np . sign ( space . high )
) , f " Sign of sample ( { sample } ) is less than space upper bound ( { space . high } ) "
assert np . all (
np . sign ( space . low ) < = np . sign ( sample )
) , f " Sign of sample ( { sample } ) is more than space lower bound ( { space . low } ) "
# check that int bounds are bounded for everything
# but floats are unbounded for infinite
if np . any ( space . high != 0 ) :
assert (
space . is_bounded ( " above " ) is False
) , " inf upper bound supposed to be unbounded "
else :
assert (
space . is_bounded ( " above " ) is True
) , " non-inf upper bound supposed to be bounded "
if np . any ( space . low != 0 ) :
assert (
space . is_bounded ( " below " ) is False
) , " inf lower bound supposed to be unbounded "
else :
assert (
space . is_bounded ( " below " ) is True
) , " non-inf lower bound supposed to be bounded "
if np . any ( space . low != 0 ) or np . any ( space . high != 0 ) :
assert space . is_bounded ( " both " ) is False
else :
assert space . is_bounded ( " both " ) is True
# check for dtype
assert (
space . high . dtype == space . dtype
) , f " High ' s dtype { space . high . dtype } doesn ' t match `space.dtype` ' "
assert (
space . low . dtype == space . dtype
) , f " Low ' s dtype { space . high . dtype } doesn ' t match `space.dtype` ' "
with pytest . raises (
ValueError , match = " manner is not in { ' below ' , ' above ' , ' both ' }, actual value: "
) :
space . is_bounded ( " test " )
def test_legacy_state_pickling ( ) :
legacy_state = {
" dtype " : np . dtype ( " float32 " ) ,
" _shape " : ( 5 , ) ,
" low " : np . array ( [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ] , dtype = np . float32 ) ,
" high " : np . array ( [ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ] , dtype = np . float32 ) ,
" bounded_below " : np . array ( [ True , True , True , True , True ] ) ,
" bounded_above " : np . array ( [ True , True , True , True , True ] ) ,
" _np_random " : None ,
}
b = Box ( - 1 , 1 , ( ) )
assert " low_repr " in b . __dict__ and " high_repr " in b . __dict__
del b . __dict__ [ " low_repr " ]
del b . __dict__ [ " high_repr " ]
assert " low_repr " not in b . __dict__ and " high_repr " not in b . __dict__
b . __setstate__ ( legacy_state )
assert b . low_repr == " 0.0 "
assert b . high_repr == " 1.0 "
def test_get_inf ( ) :
""" Tests that get inf function works as expected, primarily for coverage. """
assert get_inf ( np . float32 , " + " ) == np . inf
assert get_inf ( np . float16 , " - " ) == - np . inf
with pytest . raises (
TypeError , match = re . escape ( " Unknown sign *, use either ' + ' or ' - ' " )
) :
get_inf ( np . float32 , " * " )
assert get_inf ( np . int16 , " + " ) == 32765
assert get_inf ( np . int8 , " - " ) == - 126
with pytest . raises (
TypeError , match = re . escape ( " Unknown sign *, use either ' + ' or ' - ' " )
) :
get_inf ( np . int32 , " * " )
with pytest . raises (
ValueError ,
match = re . escape ( " Unknown dtype <class ' numpy.complex128 ' > for infinite bounds " ) ,
) :
get_inf ( np . complex_ , " + " )
def test_sample_mask ( ) :
""" Box cannot have a mask applied. """
space = Box ( 0 , 1 )
with pytest . raises (
2022-09-08 10:10:07 +01:00
gymnasium . error . Error ,
2022-09-03 22:56:29 +01:00
match = re . escape ( " Box.sample cannot be provided a mask, actual value: " ) ,
) :
space . sample ( mask = np . array ( [ 0 , 1 , 0 ] , dtype = np . int8 ) )