Fix Box casting and sampling edge-cases (#774)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
James Mochizuki-Freeman
2024-03-23 12:19:52 -04:00
committed by GitHub
parent 144feb865a
commit 89bedf1f33
3 changed files with 485 additions and 324 deletions

View File

@@ -9,20 +9,54 @@ from gymnasium.spaces import Box
@pytest.mark.parametrize(
"box,expected_shape",
"dtype, error, message",
[
( # Test with same 1-dim low and high shape
Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32),
(2,),
(
None,
ValueError,
"Box dtype must be explicitly provided, cannot be None.",
),
( # Test with same multi-dim low and high shape
Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32),
(2, 1),
(0, TypeError, "Cannot interpret '0' as a data type"),
("unknown", TypeError, "data type 'unknown' not understood"),
(np.zeros(1), TypeError, "Cannot construct a dtype from an array"),
# disabled datatypes
(
np.complex64,
ValueError,
"Invalid Box dtype (complex64), must be an integer, floating, or bool dtype",
),
( # Test with scalar low high and different shape
Box(low=0, high=1, shape=(5, 2)),
(5, 2),
(
complex,
ValueError,
"Invalid Box dtype (complex128), must be an integer, floating, or bool dtype",
),
(
object,
ValueError,
"Invalid Box dtype (object), must be an integer, floating, or bool dtype",
),
(
str,
ValueError,
"Invalid Box dtype (<U0), must be an integer, floating, or bool dtype",
),
],
)
def test_dtype_errors(dtype, error, message):
"""Test errors due to dtype mismatch either to being invalid or disallowed."""
with pytest.raises(error, match=re.escape(message)):
Box(low=0, high=1, dtype=dtype)
@pytest.mark.parametrize(
"box, expected_shape",
[
# Test with same 1-dim low and high shape
(Box(low=np.zeros(2), high=np.ones(2)), (2,)),
# Test with same multi-dim low and high shape
(Box(low=np.zeros((2, 1)), high=np.ones((2, 1))), (2, 1)),
# Test with scalar low high and different shape
(Box(low=0, high=1, shape=(5, 2)), (5, 2)),
(Box(low=0, high=1), (1,)), # Test with int and int
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
@@ -39,112 +73,209 @@ def test_shape_inference(box, expected_shape):
@pytest.mark.parametrize(
"value,valid",
[
(1, True),
(1.0, True),
(np.int32(1), True),
(np.float32(1.0), True),
(np.zeros(2, dtype=np.float32), True),
(np.zeros((2, 2), dtype=np.float32), True),
(np.inf, True),
(np.nan, True), # This is a weird case that we allow
(True, False),
(np.bool_(True), False),
(1 + 1j, False),
(np.complex128(1 + 1j), False),
("string", False),
],
)
def test_low_high_values(value, valid: bool):
"""Test what `low` and `high` values are valid for `Box` space."""
if valid:
with warnings.catch_warnings(record=True) as caught_warnings:
Box(low=-np.inf, high=value)
assert len(caught_warnings) == 0, tuple(
warning.message for warning in caught_warnings
)
else:
with pytest.raises(
ValueError,
match=re.escape(
"expected their types to be np.ndarray, an integer or a float"
),
):
Box(low=-np.inf, high=value)
@pytest.mark.parametrize(
"low,high,kwargs,error,message",
"low, high, shape, error_type, message",
[
(
0,
1,
{"dtype": None},
AssertionError,
"Box dtype must be explicitly provided, cannot be None.",
1,
TypeError,
"Expected Box shape to be an iterable, actual type=<class 'int'>",
),
(
0,
1,
{"shape": (None,)},
AssertionError,
"Expected all shape elements to be an integer, actual type: (<class 'NoneType'>,)",
(None,),
TypeError,
"Expected all Box shape elements to be integer, actual type=(<class 'NoneType'>,)",
),
(
0,
1,
{
"shape": (
1,
None,
)
},
AssertionError,
"Expected all shape elements to be an integer, actual type: (<class 'int'>, <class 'NoneType'>)",
(1, None),
TypeError,
"Expected all Box shape elements to be integer, actual type=(<class 'int'>, <class 'NoneType'>)",
),
(
0,
1,
{
"shape": (
np.int64(1),
None,
)
},
AssertionError,
"Expected all shape elements to be an integer, actual type: (<class 'numpy.int64'>, <class 'NoneType'>)",
),
(
None,
None,
{},
ValueError,
"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class 'NoneType'>, high: <class 'NoneType'>",
),
(
0,
None,
{},
ValueError,
"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class 'int'>, high: <class 'NoneType'>",
(np.int64(1), None),
TypeError,
"Expected all Box shape elements to be integer, actual type=(<class 'numpy.int64'>, <class 'NoneType'>)",
),
(
np.zeros(3),
np.ones(2),
{},
AssertionError,
"high.shape doesn't match provided shape, high.shape: (2,), shape: (3,)",
None,
ValueError,
"Box low.shape and high.shape don't match, low.shape=(3,), high.shape=(2,)",
),
(
np.zeros(2),
np.ones(2),
(3,),
ValueError,
"Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)",
),
(
np.zeros(2),
1,
(3,),
ValueError,
"Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)",
),
(
0,
np.ones(2),
(3,),
ValueError,
"Box high.shape doesn't match provided shape, high.shape=(2,), shape=(3,)",
),
],
)
def test_init_errors(low, high, kwargs, error, message):
"""Test all constructor errors."""
with pytest.raises(error, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, **kwargs)
def test_shape_errors(low, high, shape, error_type, message):
"""Test errors due to shape mismatch."""
with pytest.raises(error_type, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, shape=shape)
def test_dtype_check():
@pytest.mark.parametrize(
"low, high, dtype",
[
# floats
(0, 65505.0, np.float16),
(-65505.0, 0, np.float16),
# signed int
(0, 32768, np.int16),
(-32769, 0, np.int16),
# unsigned int
(-1, 100, np.uint8),
(0, 300, np.uint8),
# boolean
(-1, 1, np.bool_),
(0, 2, np.bool_),
# array inputs
(
np.array([-1, 0]),
np.array([0, 100]),
np.uint8,
),
(
np.array([[-1], [0]]),
np.array([[0], [100]]),
np.uint8,
),
(
np.array([0, 0]),
np.array([0, 300]),
np.uint8,
),
(
np.array([[0], [0]]),
np.array([[0], [300]]),
np.uint8,
),
],
)
def test_out_of_bounds_error(low, high, dtype):
with pytest.raises(
ValueError, match=re.escape("is out of bounds of the dtype range,")
):
Box(low=low, high=high, dtype=dtype)
@pytest.mark.parametrize(
"low, high, dtype",
[
# Floats
(np.nan, 0, np.float32),
(0, np.nan, np.float32),
(np.array([0, np.nan]), np.ones(2), np.float32),
# Signed ints
(np.nan, 0, np.int32),
(0, np.nan, np.int32),
(np.array([0, np.nan]), np.ones(2), np.int32),
# Unsigned ints
# (np.nan, 0, np.uint8),
# (0, np.nan, np.uint8),
# (np.array([0, np.nan]), np.ones(2), np.uint8),
(-np.inf, 1, np.uint8),
(np.array([-np.inf, 0]), 1, np.uint8),
(0, np.inf, np.uint8),
(0, np.array([1, np.inf]), np.uint8),
# boolean
(-np.inf, 1, np.bool_),
(0, np.inf, np.bool_),
],
)
def test_invalid_low_high(low, high, dtype):
if dtype == np.uint8 or dtype == np.bool_:
with pytest.raises(
ValueError, match=re.escape("Box unsigned int dtype don't support")
):
Box(low=low, high=high, dtype=dtype)
else:
with pytest.raises(
ValueError, match=re.escape("value can be equal to `np.nan`,")
):
Box(low=low, high=high, dtype=dtype)
@pytest.mark.parametrize(
"low, high, dtype",
[
# floats
(0, 1, float),
(0, 1, np.float64),
(0, 1, np.float32),
(0, 1, np.float16),
(np.zeros(2), np.ones(2), np.float32),
(np.zeros(2), 1, np.float32),
(-np.inf, 1, np.float32),
(np.array([-np.inf, 0]), 1, np.float32),
(0, np.inf, np.float32),
(0, np.array([np.inf, 1]), np.float32),
(-np.inf, np.inf, np.float32),
(np.full((2,), -np.inf), np.full((2,), np.inf), np.float32),
# signed ints
(0, 1, int),
(0, 1, np.int64),
(0, 1, np.int32),
(0, 1, np.int16),
(0, 1, np.int8),
(np.zeros(2), np.ones(2), np.int32),
(np.zeros(2), 1, np.int32),
(-np.inf, 1, np.int32),
(np.array([-np.inf, 0]), 1, np.int32),
(0, np.inf, np.int32),
(0, np.array([np.inf, 1]), np.int32),
# unsigned ints
(0, 1, np.uint64),
(0, 1, np.uint32),
(0, 1, np.uint16),
(0, 1, np.uint8),
# boolean
(0, 1, np.bool_),
],
)
def test_valid_low_high(low, high, dtype):
with warnings.catch_warnings(record=True) as caught_warnings:
space = Box(low=low, high=high, dtype=dtype)
assert space.dtype == dtype
assert space.low.dtype == dtype
assert space.high.dtype == dtype
space.seed(0)
sample = space.sample()
assert sample.dtype == dtype
assert space.contains(sample)
for warn in caught_warnings:
if "precision lowered by casting to float32" not in warn.message.args[0]:
raise Exception(warn)
def test_contains_dtype():
"""Tests the Box contains function with different dtypes."""
# Related Issues:
# https://github.com/openai/gym/issues/2357
@@ -163,102 +294,54 @@ def test_dtype_check():
@pytest.mark.parametrize(
"space",
"lowhighshape",
[
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
dict(low=0, high=np.inf, shape=(2,)),
dict(low=-np.inf, high=0, shape=(2,)),
dict(low=-np.inf, high=np.inf, shape=(2,)),
dict(low=0, high=np.inf, shape=(2, 3)),
dict(low=-np.inf, high=0, shape=(2, 3)),
dict(low=-np.inf, high=np.inf, shape=(2, 3)),
dict(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf])),
],
)
def test_infinite_space(space):
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
def test_infinite_space(lowhighshape, dtype):
"""
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
are both modified within the init, we check for infinite when we know it's not 0
"""
space = Box(**lowhighshape, dtype=dtype)
assert np.all(
space.low < space.high
), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})"
assert np.all(space.low < space.high)
space.seed(0)
sample = space.sample()
# check if space contains sample
assert (
sample in space
), f"Sample ({sample}) not inside space according to `space.contains()`"
# manually check that the sign of the sample is within the bounds
assert np.all(
np.sign(sample) <= np.sign(space.high)
), f"Sign of sample ({sample}) is less than space upper bound ({space.high})"
assert np.all(
np.sign(space.low) <= np.sign(sample)
), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"
# check that int bounds are bounded for everything
# but floats are unbounded for infinite
if np.any(space.high != 0):
assert (
space.is_bounded("above") is False
), "inf upper bound supposed to be unbounded"
else:
assert (
space.is_bounded("above") is True
), "non-inf upper bound supposed to be bounded"
if np.any(space.low != 0):
assert (
space.is_bounded("below") is False
), "inf lower bound supposed to be unbounded"
else:
assert (
space.is_bounded("below") is True
), "non-inf lower bound supposed to be bounded"
if np.any(space.low != 0) or np.any(space.high != 0):
assert space.is_bounded("both") is False
else:
assert space.is_bounded("both") is True
# check that int bounds are bounded for everything but floats are unbounded for infinite
assert space.is_bounded("above") is not np.any(space.high != 0)
assert space.is_bounded("below") is not np.any(space.low != 0)
assert space.is_bounded("both") is not (
np.any(space.high != 0) | np.any(space.high != 0)
)
# check for dtype
assert (
space.high.dtype == space.dtype
), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'"
assert (
space.low.dtype == space.dtype
), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
assert space.high.dtype == space.dtype
assert space.low.dtype == space.dtype
with pytest.raises(
ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:"
):
space.is_bounded("test")
# Check sample
space.seed(0)
sample = space.sample()
# check if space contains sample
assert sample in space
# manually check that the sign of the sample is within the bounds
assert np.all(np.sign(sample) <= np.sign(space.high))
assert np.all(np.sign(space.low) <= np.sign(sample))
def test_legacy_state_pickling():
legacy_state = {
@@ -290,56 +373,3 @@ def test_sample_mask():
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
):
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))
@pytest.mark.parametrize(
"low, high, shape, dtype, reason",
[
(
5.0,
3.0,
(),
np.float32,
"Some low values are greater than high, low=5.0, high=3.0",
),
(
np.array([5.0, 6.0]),
np.array([1.0, 5.99]),
(2,),
np.float32,
"Some low values are greater than high, low=[5. 6.], high=[1. 5.99]",
),
(
np.inf,
np.inf,
(),
np.float32,
"No low value can be equal to `np.inf`, low=inf",
),
(
np.array([0, np.inf]),
np.array([np.inf, np.inf]),
(2,),
np.float32,
"No low value can be equal to `np.inf`, low=[ 0. inf]",
),
(
-np.inf,
-np.inf,
(),
np.float32,
"No high value can be equal to `-np.inf`, high=-inf",
),
(
np.array([-np.inf, -np.inf]),
np.array([0, -np.inf]),
(2,),
np.float32,
"No high value can be equal to `-np.inf`, high=[ 0. -inf]",
),
],
)
def test_invalid_low_high(low, high, dtype, shape, reason):
"""Tests that we don't allow spaces with degenerate bounds, such as `Box(np.inf, -np.inf)`."""
with pytest.raises(ValueError, match=re.escape(reason)):
Box(low=low, high=high, dtype=dtype, shape=shape)