Fix Box casting and sampling edge-cases (#774)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
James Mochizuki-Freeman
2024-03-23 12:19:52 -04:00
committed by GitHub
parent 144feb865a
commit 89bedf1f33
3 changed files with 485 additions and 324 deletions

View File

@@ -10,7 +10,7 @@ import gymnasium as gym
from gymnasium.spaces.space import Space
def _short_repr(arr: NDArray[Any]) -> str:
def array_short_repr(arr: NDArray[Any]) -> str:
"""Create a shortened string representation of a numpy array.
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
@@ -28,7 +28,7 @@ def _short_repr(arr: NDArray[Any]) -> str:
def is_float_integer(var: Any) -> bool:
"""Checks if a variable is an integer or float."""
"""Checks if a scalar variable is an integer or float (does not include bool)."""
return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
@@ -80,72 +80,232 @@ class Box(Space[NDArray[Any]]):
ValueError: If no shape information is provided (shape is None, low is None and high is None) then a
value error is raised.
"""
assert (
dtype is not None
), "Box dtype must be explicitly provided, cannot be None."
# determine dtype
if dtype is None:
raise ValueError("Box dtype must be explicitly provided, cannot be None.")
self.dtype = np.dtype(dtype)
# determine shape if it isn't provided directly
# * check that dtype is an accepted dtype
if not (
np.issubdtype(self.dtype, np.integer)
or np.issubdtype(self.dtype, np.floating)
or self.dtype == np.bool_
):
raise ValueError(
f"Invalid Box dtype ({self.dtype}), must be an integer, floating, or bool dtype"
)
# determine shape
if shape is not None:
assert all(
np.issubdtype(type(dim), np.integer) for dim in shape
), f"Expected all shape elements to be an integer, actual type: {tuple(type(dim) for dim in shape)}"
shape = tuple(int(dim) for dim in shape) # This changes any np types to int
if not isinstance(shape, Iterable):
raise TypeError(
f"Expected Box shape to be an iterable, actual type={type(shape)}"
)
elif not all(np.issubdtype(type(dim), np.integer) for dim in shape):
raise TypeError(
f"Expected all Box shape elements to be integer, actual type={tuple(type(dim) for dim in shape)}"
)
# Casts the `shape` argument to tuple[int, ...] (otherwise dim can `np.int64`)
shape = tuple(int(dim) for dim in shape)
elif isinstance(low, np.ndarray) and isinstance(high, np.ndarray):
if low.shape != high.shape:
raise ValueError(
f"Box low.shape and high.shape don't match, low.shape={low.shape}, high.shape={high.shape}"
)
shape = low.shape
elif isinstance(low, np.ndarray):
shape = low.shape
elif isinstance(high, np.ndarray):
shape = high.shape
elif is_float_integer(low) and is_float_integer(high):
shape = (1,)
shape = (1,) # low and high are scalars
else:
raise ValueError(
f"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: {type(low)}, high: {type(high)}"
"Box shape is not specified, therefore inferred from low and high. Expected low and high to be np.ndarray, integer, or float."
f"Actual types low={type(low)}, high={type(high)}"
)
# Capture the boundedness information before replacing np.inf with get_inf
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
self.bounded_below: NDArray[np.bool_] = -np.inf < _low
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
self.bounded_above: NDArray[np.bool_] = np.inf > _high
low = _broadcast(low, self.dtype, shape)
high = _broadcast(high, self.dtype, shape)
assert isinstance(low, np.ndarray)
assert (
low.shape == shape
), f"low.shape doesn't match provided shape, low.shape: {low.shape}, shape: {shape}"
assert isinstance(high, np.ndarray)
assert (
high.shape == shape
), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}"
# check that we don't have invalid low or high
if np.any(low > high):
raise ValueError(
f"Some low values are greater than high, low={low}, high={high}"
)
if np.any(np.isposinf(low)):
raise ValueError(f"No low value can be equal to `np.inf`, low={low}")
if np.any(np.isneginf(high)):
raise ValueError(f"No high value can be equal to `-np.inf`, high={high}")
self._shape: tuple[int, ...] = shape
low_precision = get_precision(low.dtype)
high_precision = get_precision(high.dtype)
dtype_precision = get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision:
gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
self.low = low.astype(self.dtype)
self.high = high.astype(self.dtype)
# Cast scalar values to `np.ndarray` and capture the boundedness information
# disallowed cases
# * out of range - this must be done before casting to low and high otherwise, the value is within dtype and cannot be out of range
# * nan - must be done beforehand as int dtype can cast `nan` to another value
# * unsign int inf and -inf - special case that is disallowed
self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high)
if self.dtype == np.bool_:
dtype_min, dtype_max = 0, 1
elif np.issubdtype(self.dtype, np.floating):
dtype_min = float(np.finfo(self.dtype).min)
dtype_max = float(np.finfo(self.dtype).max)
else:
dtype_min = int(np.iinfo(self.dtype).min)
dtype_max = int(np.iinfo(self.dtype).max)
# Cast `low` and `high` to ndarray for the dtype min and max for out of range tests
self.low, self.bounded_below = self._cast_low(low, dtype_min)
self.high, self.bounded_above = self._cast_high(high, dtype_max)
# recheck shape for case where shape and (low or high) are provided
if self.low.shape != shape:
raise ValueError(
f"Box low.shape doesn't match provided shape, low.shape={self.low.shape}, shape={self.shape}"
)
if self.high.shape != shape:
raise ValueError(
f"Box high.shape doesn't match provided shape, high.shape={self.high.shape}, shape={self.shape}"
)
# check that low <= high
if np.any(self.low > self.high):
raise ValueError(
f"Box all low values must be less than or equal to high (some values break this), low={self.low}, high={self.high}"
)
self.low_repr = array_short_repr(self.low)
self.high_repr = array_short_repr(self.high)
super().__init__(self.shape, self.dtype, seed)
def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
"""Casts the input Box low value to ndarray with provided dtype.
Args:
low: The input box low value
dtype_min: The dtype's minimum value
Returns:
The updated low value and for what values the input is bounded (below)
"""
if is_float_integer(low):
bounded_below = -np.inf < np.full(self.shape, low, dtype=float)
if np.isnan(low):
raise ValueError(f"No low value can be equal to `np.nan`, low={low}")
elif np.isneginf(low):
if self.dtype.kind == "i": # signed int
low = dtype_min
elif self.dtype.kind in {"u", "b"}: # unsigned int and bool
raise ValueError(
f"Box unsigned int dtype don't support `-np.inf`, low={low}"
)
elif low < dtype_min:
raise ValueError(
f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}"
)
low = np.full(self.shape, low, dtype=self.dtype)
return low, bounded_below
else: # cast for low - array
if not isinstance(low, np.ndarray):
raise ValueError(
f"Box low must be a np.ndarray, integer, or float, actual type={type(low)}"
)
elif not (
np.issubdtype(low.dtype, np.floating)
or np.issubdtype(low.dtype, np.integer)
or low.dtype == np.bool_
):
raise ValueError(
f"Box low must be a floating, integer, or bool dtype, actual dtype={low.dtype}"
)
elif np.any(np.isnan(low)):
raise ValueError(f"No low value can be equal to `np.nan`, low={low}")
bounded_below = -np.inf < low
if np.any(np.isneginf(low)):
if self.dtype.kind == "i": # signed int
low[np.isneginf(low)] = dtype_min
elif self.dtype.kind in {"u", "b"}: # unsigned int and bool
raise ValueError(
f"Box unsigned int dtype don't support `-np.inf`, low={low}"
)
elif low.dtype != self.dtype and np.any(low < dtype_min):
raise ValueError(
f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}"
)
if (
np.issubdtype(low.dtype, np.floating)
and np.issubdtype(self.dtype, np.floating)
and np.finfo(self.dtype).precision < np.finfo(low.dtype).precision
):
gym.logger.warn(
f"Box low's precision lowered by casting to {self.dtype}, current low.dtype={low.dtype}"
)
return low.astype(self.dtype), bounded_below
def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
"""Casts the input Box high value to ndarray with provided dtype.
Args:
high: The input box high value
dtype_max: The dtype's maximum value
Returns:
The updated high value and for what values the input is bounded (above)
"""
if is_float_integer(high):
bounded_above = np.full(self.shape, high, dtype=float) < np.inf
if np.isnan(high):
raise ValueError(f"No high value can be equal to `np.nan`, high={high}")
elif np.isposinf(high):
if self.dtype.kind == "i": # signed int
high = dtype_max
elif self.dtype.kind in {"u", "b"}: # unsigned int
raise ValueError(
f"Box unsigned int dtype don't support `np.inf`, high={high}"
)
elif high > dtype_max:
raise ValueError(
f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}"
)
high = np.full(self.shape, high, dtype=self.dtype)
return high, bounded_above
else:
if not isinstance(high, np.ndarray):
raise ValueError(
f"Box high must be a np.ndarray, integer, or float, actual type={type(high)}"
)
elif not (
np.issubdtype(high.dtype, np.floating)
or np.issubdtype(high.dtype, np.integer)
or high.dtype == np.bool_
):
raise ValueError(
f"Box high must be a floating or integer dtype, actual dtype={high.dtype}"
)
elif np.any(np.isnan(high)):
raise ValueError(f"No high value can be equal to `np.nan`, high={high}")
bounded_above = high < np.inf
posinf = np.isposinf(high)
if np.any(posinf):
if self.dtype.kind == "i": # signed int
high[posinf] = dtype_max
elif self.dtype.kind in {"u", "b"}: # unsigned int
raise ValueError(
f"Box unsigned int dtype don't support `np.inf`, high={high}"
)
elif high.dtype != self.dtype and np.any(dtype_max < high):
raise ValueError(
f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}"
)
if (
np.issubdtype(high.dtype, np.floating)
and np.issubdtype(self.dtype, np.floating)
and np.finfo(self.dtype).precision < np.finfo(high.dtype).precision
):
gym.logger.warn(
f"Box high's precision lowered by casting to {self.dtype}, current high.dtype={high.dtype}"
)
return high.astype(self.dtype), bounded_above
@property
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
@@ -232,7 +392,24 @@ class Box(Space[NDArray[Any]]):
if self.dtype.kind in ["i", "u", "b"]:
sample = np.floor(sample)
return sample.astype(self.dtype)
# clip values that would underflow/overflow
if np.issubdtype(self.dtype, np.signedinteger):
dtype_min = np.iinfo(self.dtype).min + 2
dtype_max = np.iinfo(self.dtype).max - 2
sample = sample.clip(min=dtype_min, max=dtype_max)
elif np.issubdtype(self.dtype, np.unsignedinteger):
dtype_min = np.iinfo(self.dtype).min
dtype_max = np.iinfo(self.dtype).max
sample = sample.clip(min=dtype_min, max=dtype_max)
sample = sample.astype(self.dtype)
# float64 values have lower than integer precision near int64 min/max, so clip
# again in case something has been cast to an out-of-bounds value
if self.dtype == np.int64:
sample = sample.clip(min=self.low, max=self.high)
return sample
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
@@ -285,53 +462,7 @@ class Box(Space[NDArray[Any]]):
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
if not hasattr(self, "low_repr"):
self.low_repr = _short_repr(self.low)
self.low_repr = array_short_repr(self.low)
if not hasattr(self, "high_repr"):
self.high_repr = _short_repr(self.high)
def get_precision(dtype: np.dtype) -> SupportsFloat:
"""Get precision of a data type."""
if np.issubdtype(dtype, np.floating):
return np.finfo(dtype).precision
else:
return np.inf
def _broadcast(
value: SupportsFloat | NDArray[Any],
dtype: np.dtype,
shape: tuple[int, ...],
) -> NDArray[Any]:
"""Handle infinite bounds and broadcast at the same time if needed.
This is needed primarily because:
>>> import numpy as np
>>> np.full((2,), np.inf, dtype=np.int32)
array([-2147483648, -2147483648], dtype=int32)
"""
if is_float_integer(value):
if np.isneginf(value) and np.dtype(dtype).kind == "i":
value = np.iinfo(dtype).min + 2
elif np.isposinf(value) and np.dtype(dtype).kind == "i":
value = np.iinfo(dtype).max - 2
return np.full(shape, value, dtype=dtype)
elif isinstance(value, np.ndarray):
# this is needed because we can't stuff np.iinfo(int).min into an array of dtype float
casted_value = value.astype(dtype)
# change bounds only if values are negative or positive infinite
if np.dtype(dtype).kind == "i":
casted_value[np.isneginf(value)] = np.iinfo(dtype).min + 2
casted_value[np.isposinf(value)] = np.iinfo(dtype).max - 2
return casted_value
else:
# only np.ndarray allowed beyond this point
raise TypeError(
f"Unknown dtype for `value`, expected `np.ndarray` or float/integer, got {type(value)}"
)
self.high_repr = array_short_repr(self.high)

View File

@@ -9,20 +9,54 @@ from gymnasium.spaces import Box
@pytest.mark.parametrize(
"box,expected_shape",
"dtype, error, message",
[
( # Test with same 1-dim low and high shape
Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32),
(2,),
(
None,
ValueError,
"Box dtype must be explicitly provided, cannot be None.",
),
( # Test with same multi-dim low and high shape
Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32),
(2, 1),
(0, TypeError, "Cannot interpret '0' as a data type"),
("unknown", TypeError, "data type 'unknown' not understood"),
(np.zeros(1), TypeError, "Cannot construct a dtype from an array"),
# disabled datatypes
(
np.complex64,
ValueError,
"Invalid Box dtype (complex64), must be an integer, floating, or bool dtype",
),
( # Test with scalar low high and different shape
Box(low=0, high=1, shape=(5, 2)),
(5, 2),
(
complex,
ValueError,
"Invalid Box dtype (complex128), must be an integer, floating, or bool dtype",
),
(
object,
ValueError,
"Invalid Box dtype (object), must be an integer, floating, or bool dtype",
),
(
str,
ValueError,
"Invalid Box dtype (<U0), must be an integer, floating, or bool dtype",
),
],
)
def test_dtype_errors(dtype, error, message):
"""Test errors due to dtype mismatch either to being invalid or disallowed."""
with pytest.raises(error, match=re.escape(message)):
Box(low=0, high=1, dtype=dtype)
@pytest.mark.parametrize(
"box, expected_shape",
[
# Test with same 1-dim low and high shape
(Box(low=np.zeros(2), high=np.ones(2)), (2,)),
# Test with same multi-dim low and high shape
(Box(low=np.zeros((2, 1)), high=np.ones((2, 1))), (2, 1)),
# Test with scalar low high and different shape
(Box(low=0, high=1, shape=(5, 2)), (5, 2)),
(Box(low=0, high=1), (1,)), # Test with int and int
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
@@ -39,112 +73,209 @@ def test_shape_inference(box, expected_shape):
@pytest.mark.parametrize(
"value,valid",
[
(1, True),
(1.0, True),
(np.int32(1), True),
(np.float32(1.0), True),
(np.zeros(2, dtype=np.float32), True),
(np.zeros((2, 2), dtype=np.float32), True),
(np.inf, True),
(np.nan, True), # This is a weird case that we allow
(True, False),
(np.bool_(True), False),
(1 + 1j, False),
(np.complex128(1 + 1j), False),
("string", False),
],
)
def test_low_high_values(value, valid: bool):
"""Test what `low` and `high` values are valid for `Box` space."""
if valid:
with warnings.catch_warnings(record=True) as caught_warnings:
Box(low=-np.inf, high=value)
assert len(caught_warnings) == 0, tuple(
warning.message for warning in caught_warnings
)
else:
with pytest.raises(
ValueError,
match=re.escape(
"expected their types to be np.ndarray, an integer or a float"
),
):
Box(low=-np.inf, high=value)
@pytest.mark.parametrize(
"low,high,kwargs,error,message",
"low, high, shape, error_type, message",
[
(
0,
1,
{"dtype": None},
AssertionError,
"Box dtype must be explicitly provided, cannot be None.",
1,
TypeError,
"Expected Box shape to be an iterable, actual type=<class 'int'>",
),
(
0,
1,
{"shape": (None,)},
AssertionError,
"Expected all shape elements to be an integer, actual type: (<class 'NoneType'>,)",
(None,),
TypeError,
"Expected all Box shape elements to be integer, actual type=(<class 'NoneType'>,)",
),
(
0,
1,
{
"shape": (
1,
None,
)
},
AssertionError,
"Expected all shape elements to be an integer, actual type: (<class 'int'>, <class 'NoneType'>)",
(1, None),
TypeError,
"Expected all Box shape elements to be integer, actual type=(<class 'int'>, <class 'NoneType'>)",
),
(
0,
1,
{
"shape": (
np.int64(1),
None,
)
},
AssertionError,
"Expected all shape elements to be an integer, actual type: (<class 'numpy.int64'>, <class 'NoneType'>)",
),
(
None,
None,
{},
ValueError,
"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class 'NoneType'>, high: <class 'NoneType'>",
),
(
0,
None,
{},
ValueError,
"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class 'int'>, high: <class 'NoneType'>",
(np.int64(1), None),
TypeError,
"Expected all Box shape elements to be integer, actual type=(<class 'numpy.int64'>, <class 'NoneType'>)",
),
(
np.zeros(3),
np.ones(2),
{},
AssertionError,
"high.shape doesn't match provided shape, high.shape: (2,), shape: (3,)",
None,
ValueError,
"Box low.shape and high.shape don't match, low.shape=(3,), high.shape=(2,)",
),
(
np.zeros(2),
np.ones(2),
(3,),
ValueError,
"Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)",
),
(
np.zeros(2),
1,
(3,),
ValueError,
"Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)",
),
(
0,
np.ones(2),
(3,),
ValueError,
"Box high.shape doesn't match provided shape, high.shape=(2,), shape=(3,)",
),
],
)
def test_init_errors(low, high, kwargs, error, message):
"""Test all constructor errors."""
with pytest.raises(error, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, **kwargs)
def test_shape_errors(low, high, shape, error_type, message):
"""Test errors due to shape mismatch."""
with pytest.raises(error_type, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, shape=shape)
def test_dtype_check():
@pytest.mark.parametrize(
"low, high, dtype",
[
# floats
(0, 65505.0, np.float16),
(-65505.0, 0, np.float16),
# signed int
(0, 32768, np.int16),
(-32769, 0, np.int16),
# unsigned int
(-1, 100, np.uint8),
(0, 300, np.uint8),
# boolean
(-1, 1, np.bool_),
(0, 2, np.bool_),
# array inputs
(
np.array([-1, 0]),
np.array([0, 100]),
np.uint8,
),
(
np.array([[-1], [0]]),
np.array([[0], [100]]),
np.uint8,
),
(
np.array([0, 0]),
np.array([0, 300]),
np.uint8,
),
(
np.array([[0], [0]]),
np.array([[0], [300]]),
np.uint8,
),
],
)
def test_out_of_bounds_error(low, high, dtype):
with pytest.raises(
ValueError, match=re.escape("is out of bounds of the dtype range,")
):
Box(low=low, high=high, dtype=dtype)
@pytest.mark.parametrize(
"low, high, dtype",
[
# Floats
(np.nan, 0, np.float32),
(0, np.nan, np.float32),
(np.array([0, np.nan]), np.ones(2), np.float32),
# Signed ints
(np.nan, 0, np.int32),
(0, np.nan, np.int32),
(np.array([0, np.nan]), np.ones(2), np.int32),
# Unsigned ints
# (np.nan, 0, np.uint8),
# (0, np.nan, np.uint8),
# (np.array([0, np.nan]), np.ones(2), np.uint8),
(-np.inf, 1, np.uint8),
(np.array([-np.inf, 0]), 1, np.uint8),
(0, np.inf, np.uint8),
(0, np.array([1, np.inf]), np.uint8),
# boolean
(-np.inf, 1, np.bool_),
(0, np.inf, np.bool_),
],
)
def test_invalid_low_high(low, high, dtype):
if dtype == np.uint8 or dtype == np.bool_:
with pytest.raises(
ValueError, match=re.escape("Box unsigned int dtype don't support")
):
Box(low=low, high=high, dtype=dtype)
else:
with pytest.raises(
ValueError, match=re.escape("value can be equal to `np.nan`,")
):
Box(low=low, high=high, dtype=dtype)
@pytest.mark.parametrize(
"low, high, dtype",
[
# floats
(0, 1, float),
(0, 1, np.float64),
(0, 1, np.float32),
(0, 1, np.float16),
(np.zeros(2), np.ones(2), np.float32),
(np.zeros(2), 1, np.float32),
(-np.inf, 1, np.float32),
(np.array([-np.inf, 0]), 1, np.float32),
(0, np.inf, np.float32),
(0, np.array([np.inf, 1]), np.float32),
(-np.inf, np.inf, np.float32),
(np.full((2,), -np.inf), np.full((2,), np.inf), np.float32),
# signed ints
(0, 1, int),
(0, 1, np.int64),
(0, 1, np.int32),
(0, 1, np.int16),
(0, 1, np.int8),
(np.zeros(2), np.ones(2), np.int32),
(np.zeros(2), 1, np.int32),
(-np.inf, 1, np.int32),
(np.array([-np.inf, 0]), 1, np.int32),
(0, np.inf, np.int32),
(0, np.array([np.inf, 1]), np.int32),
# unsigned ints
(0, 1, np.uint64),
(0, 1, np.uint32),
(0, 1, np.uint16),
(0, 1, np.uint8),
# boolean
(0, 1, np.bool_),
],
)
def test_valid_low_high(low, high, dtype):
with warnings.catch_warnings(record=True) as caught_warnings:
space = Box(low=low, high=high, dtype=dtype)
assert space.dtype == dtype
assert space.low.dtype == dtype
assert space.high.dtype == dtype
space.seed(0)
sample = space.sample()
assert sample.dtype == dtype
assert space.contains(sample)
for warn in caught_warnings:
if "precision lowered by casting to float32" not in warn.message.args[0]:
raise Exception(warn)
def test_contains_dtype():
"""Tests the Box contains function with different dtypes."""
# Related Issues:
# https://github.com/openai/gym/issues/2357
@@ -163,102 +294,54 @@ def test_dtype_check():
@pytest.mark.parametrize(
"space",
"lowhighshape",
[
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
dict(low=0, high=np.inf, shape=(2,)),
dict(low=-np.inf, high=0, shape=(2,)),
dict(low=-np.inf, high=np.inf, shape=(2,)),
dict(low=0, high=np.inf, shape=(2, 3)),
dict(low=-np.inf, high=0, shape=(2, 3)),
dict(low=-np.inf, high=np.inf, shape=(2, 3)),
dict(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf])),
],
)
def test_infinite_space(space):
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
def test_infinite_space(lowhighshape, dtype):
"""
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
are both modified within the init, we check for infinite when we know it's not 0
"""
space = Box(**lowhighshape, dtype=dtype)
assert np.all(
space.low < space.high
), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})"
assert np.all(space.low < space.high)
space.seed(0)
sample = space.sample()
# check if space contains sample
assert (
sample in space
), f"Sample ({sample}) not inside space according to `space.contains()`"
# manually check that the sign of the sample is within the bounds
assert np.all(
np.sign(sample) <= np.sign(space.high)
), f"Sign of sample ({sample}) is less than space upper bound ({space.high})"
assert np.all(
np.sign(space.low) <= np.sign(sample)
), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"
# check that int bounds are bounded for everything
# but floats are unbounded for infinite
if np.any(space.high != 0):
assert (
space.is_bounded("above") is False
), "inf upper bound supposed to be unbounded"
else:
assert (
space.is_bounded("above") is True
), "non-inf upper bound supposed to be bounded"
if np.any(space.low != 0):
assert (
space.is_bounded("below") is False
), "inf lower bound supposed to be unbounded"
else:
assert (
space.is_bounded("below") is True
), "non-inf lower bound supposed to be bounded"
if np.any(space.low != 0) or np.any(space.high != 0):
assert space.is_bounded("both") is False
else:
assert space.is_bounded("both") is True
# check that int bounds are bounded for everything but floats are unbounded for infinite
assert space.is_bounded("above") is not np.any(space.high != 0)
assert space.is_bounded("below") is not np.any(space.low != 0)
assert space.is_bounded("both") is not (
np.any(space.high != 0) | np.any(space.high != 0)
)
# check for dtype
assert (
space.high.dtype == space.dtype
), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'"
assert (
space.low.dtype == space.dtype
), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
assert space.high.dtype == space.dtype
assert space.low.dtype == space.dtype
with pytest.raises(
ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:"
):
space.is_bounded("test")
# Check sample
space.seed(0)
sample = space.sample()
# check if space contains sample
assert sample in space
# manually check that the sign of the sample is within the bounds
assert np.all(np.sign(sample) <= np.sign(space.high))
assert np.all(np.sign(space.low) <= np.sign(sample))
def test_legacy_state_pickling():
legacy_state = {
@@ -290,56 +373,3 @@ def test_sample_mask():
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
):
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))
@pytest.mark.parametrize(
"low, high, shape, dtype, reason",
[
(
5.0,
3.0,
(),
np.float32,
"Some low values are greater than high, low=5.0, high=3.0",
),
(
np.array([5.0, 6.0]),
np.array([1.0, 5.99]),
(2,),
np.float32,
"Some low values are greater than high, low=[5. 6.], high=[1. 5.99]",
),
(
np.inf,
np.inf,
(),
np.float32,
"No low value can be equal to `np.inf`, low=inf",
),
(
np.array([0, np.inf]),
np.array([np.inf, np.inf]),
(2,),
np.float32,
"No low value can be equal to `np.inf`, low=[ 0. inf]",
),
(
-np.inf,
-np.inf,
(),
np.float32,
"No high value can be equal to `-np.inf`, high=-inf",
),
(
np.array([-np.inf, -np.inf]),
np.array([0, -np.inf]),
(2,),
np.float32,
"No high value can be equal to `-np.inf`, high=[ 0. -inf]",
),
],
)
def test_invalid_low_high(low, high, dtype, shape, reason):
"""Tests that we don't allow spaces with degenerate bounds, such as `Box(np.inf, -np.inf)`."""
with pytest.raises(ValueError, match=re.escape(reason)):
Box(low=low, high=high, dtype=dtype, shape=shape)

View File

@@ -45,7 +45,7 @@ def custom_environments():
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
("CartPole-v1", "DtypeObservation", {"dtype": np.int32}),
("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}),
# ("CartPole-v1", "RenderObservation", {}), # not implemented
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
# ("CartPole-v1", "FrameStackObservation", {}), # not implemented