mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 09:37:29 +00:00
Fix Box casting and sampling edge-cases (#774)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
144feb865a
commit
89bedf1f33
@@ -10,7 +10,7 @@ import gymnasium as gym
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
|
||||
def _short_repr(arr: NDArray[Any]) -> str:
|
||||
def array_short_repr(arr: NDArray[Any]) -> str:
|
||||
"""Create a shortened string representation of a numpy array.
|
||||
|
||||
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
||||
@@ -28,7 +28,7 @@ def _short_repr(arr: NDArray[Any]) -> str:
|
||||
|
||||
|
||||
def is_float_integer(var: Any) -> bool:
|
||||
"""Checks if a variable is an integer or float."""
|
||||
"""Checks if a scalar variable is an integer or float (does not include bool)."""
|
||||
return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
|
||||
|
||||
|
||||
@@ -80,72 +80,232 @@ class Box(Space[NDArray[Any]]):
|
||||
ValueError: If no shape information is provided (shape is None, low is None and high is None) then a
|
||||
value error is raised.
|
||||
"""
|
||||
assert (
|
||||
dtype is not None
|
||||
), "Box dtype must be explicitly provided, cannot be None."
|
||||
# determine dtype
|
||||
if dtype is None:
|
||||
raise ValueError("Box dtype must be explicitly provided, cannot be None.")
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
# determine shape if it isn't provided directly
|
||||
# * check that dtype is an accepted dtype
|
||||
if not (
|
||||
np.issubdtype(self.dtype, np.integer)
|
||||
or np.issubdtype(self.dtype, np.floating)
|
||||
or self.dtype == np.bool_
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid Box dtype ({self.dtype}), must be an integer, floating, or bool dtype"
|
||||
)
|
||||
|
||||
# determine shape
|
||||
if shape is not None:
|
||||
assert all(
|
||||
np.issubdtype(type(dim), np.integer) for dim in shape
|
||||
), f"Expected all shape elements to be an integer, actual type: {tuple(type(dim) for dim in shape)}"
|
||||
shape = tuple(int(dim) for dim in shape) # This changes any np types to int
|
||||
if not isinstance(shape, Iterable):
|
||||
raise TypeError(
|
||||
f"Expected Box shape to be an iterable, actual type={type(shape)}"
|
||||
)
|
||||
elif not all(np.issubdtype(type(dim), np.integer) for dim in shape):
|
||||
raise TypeError(
|
||||
f"Expected all Box shape elements to be integer, actual type={tuple(type(dim) for dim in shape)}"
|
||||
)
|
||||
|
||||
# Casts the `shape` argument to tuple[int, ...] (otherwise dim can `np.int64`)
|
||||
shape = tuple(int(dim) for dim in shape)
|
||||
elif isinstance(low, np.ndarray) and isinstance(high, np.ndarray):
|
||||
if low.shape != high.shape:
|
||||
raise ValueError(
|
||||
f"Box low.shape and high.shape don't match, low.shape={low.shape}, high.shape={high.shape}"
|
||||
)
|
||||
shape = low.shape
|
||||
elif isinstance(low, np.ndarray):
|
||||
shape = low.shape
|
||||
elif isinstance(high, np.ndarray):
|
||||
shape = high.shape
|
||||
elif is_float_integer(low) and is_float_integer(high):
|
||||
shape = (1,)
|
||||
shape = (1,) # low and high are scalars
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: {type(low)}, high: {type(high)}"
|
||||
"Box shape is not specified, therefore inferred from low and high. Expected low and high to be np.ndarray, integer, or float."
|
||||
f"Actual types low={type(low)}, high={type(high)}"
|
||||
)
|
||||
|
||||
# Capture the boundedness information before replacing np.inf with get_inf
|
||||
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
|
||||
self.bounded_below: NDArray[np.bool_] = -np.inf < _low
|
||||
|
||||
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
||||
self.bounded_above: NDArray[np.bool_] = np.inf > _high
|
||||
|
||||
low = _broadcast(low, self.dtype, shape)
|
||||
high = _broadcast(high, self.dtype, shape)
|
||||
|
||||
assert isinstance(low, np.ndarray)
|
||||
assert (
|
||||
low.shape == shape
|
||||
), f"low.shape doesn't match provided shape, low.shape: {low.shape}, shape: {shape}"
|
||||
assert isinstance(high, np.ndarray)
|
||||
assert (
|
||||
high.shape == shape
|
||||
), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}"
|
||||
|
||||
# check that we don't have invalid low or high
|
||||
if np.any(low > high):
|
||||
raise ValueError(
|
||||
f"Some low values are greater than high, low={low}, high={high}"
|
||||
)
|
||||
if np.any(np.isposinf(low)):
|
||||
raise ValueError(f"No low value can be equal to `np.inf`, low={low}")
|
||||
if np.any(np.isneginf(high)):
|
||||
raise ValueError(f"No high value can be equal to `-np.inf`, high={high}")
|
||||
|
||||
self._shape: tuple[int, ...] = shape
|
||||
|
||||
low_precision = get_precision(low.dtype)
|
||||
high_precision = get_precision(high.dtype)
|
||||
dtype_precision = get_precision(self.dtype)
|
||||
if min(low_precision, high_precision) > dtype_precision:
|
||||
gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
|
||||
self.low = low.astype(self.dtype)
|
||||
self.high = high.astype(self.dtype)
|
||||
# Cast scalar values to `np.ndarray` and capture the boundedness information
|
||||
# disallowed cases
|
||||
# * out of range - this must be done before casting to low and high otherwise, the value is within dtype and cannot be out of range
|
||||
# * nan - must be done beforehand as int dtype can cast `nan` to another value
|
||||
# * unsign int inf and -inf - special case that is disallowed
|
||||
|
||||
self.low_repr = _short_repr(self.low)
|
||||
self.high_repr = _short_repr(self.high)
|
||||
if self.dtype == np.bool_:
|
||||
dtype_min, dtype_max = 0, 1
|
||||
elif np.issubdtype(self.dtype, np.floating):
|
||||
dtype_min = float(np.finfo(self.dtype).min)
|
||||
dtype_max = float(np.finfo(self.dtype).max)
|
||||
else:
|
||||
dtype_min = int(np.iinfo(self.dtype).min)
|
||||
dtype_max = int(np.iinfo(self.dtype).max)
|
||||
|
||||
# Cast `low` and `high` to ndarray for the dtype min and max for out of range tests
|
||||
self.low, self.bounded_below = self._cast_low(low, dtype_min)
|
||||
self.high, self.bounded_above = self._cast_high(high, dtype_max)
|
||||
|
||||
# recheck shape for case where shape and (low or high) are provided
|
||||
if self.low.shape != shape:
|
||||
raise ValueError(
|
||||
f"Box low.shape doesn't match provided shape, low.shape={self.low.shape}, shape={self.shape}"
|
||||
)
|
||||
if self.high.shape != shape:
|
||||
raise ValueError(
|
||||
f"Box high.shape doesn't match provided shape, high.shape={self.high.shape}, shape={self.shape}"
|
||||
)
|
||||
|
||||
# check that low <= high
|
||||
if np.any(self.low > self.high):
|
||||
raise ValueError(
|
||||
f"Box all low values must be less than or equal to high (some values break this), low={self.low}, high={self.high}"
|
||||
)
|
||||
|
||||
self.low_repr = array_short_repr(self.low)
|
||||
self.high_repr = array_short_repr(self.high)
|
||||
|
||||
super().__init__(self.shape, self.dtype, seed)
|
||||
|
||||
def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Casts the input Box low value to ndarray with provided dtype.
|
||||
|
||||
Args:
|
||||
low: The input box low value
|
||||
dtype_min: The dtype's minimum value
|
||||
|
||||
Returns:
|
||||
The updated low value and for what values the input is bounded (below)
|
||||
"""
|
||||
if is_float_integer(low):
|
||||
bounded_below = -np.inf < np.full(self.shape, low, dtype=float)
|
||||
|
||||
if np.isnan(low):
|
||||
raise ValueError(f"No low value can be equal to `np.nan`, low={low}")
|
||||
elif np.isneginf(low):
|
||||
if self.dtype.kind == "i": # signed int
|
||||
low = dtype_min
|
||||
elif self.dtype.kind in {"u", "b"}: # unsigned int and bool
|
||||
raise ValueError(
|
||||
f"Box unsigned int dtype don't support `-np.inf`, low={low}"
|
||||
)
|
||||
elif low < dtype_min:
|
||||
raise ValueError(
|
||||
f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}"
|
||||
)
|
||||
|
||||
low = np.full(self.shape, low, dtype=self.dtype)
|
||||
return low, bounded_below
|
||||
else: # cast for low - array
|
||||
if not isinstance(low, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Box low must be a np.ndarray, integer, or float, actual type={type(low)}"
|
||||
)
|
||||
elif not (
|
||||
np.issubdtype(low.dtype, np.floating)
|
||||
or np.issubdtype(low.dtype, np.integer)
|
||||
or low.dtype == np.bool_
|
||||
):
|
||||
raise ValueError(
|
||||
f"Box low must be a floating, integer, or bool dtype, actual dtype={low.dtype}"
|
||||
)
|
||||
elif np.any(np.isnan(low)):
|
||||
raise ValueError(f"No low value can be equal to `np.nan`, low={low}")
|
||||
|
||||
bounded_below = -np.inf < low
|
||||
|
||||
if np.any(np.isneginf(low)):
|
||||
if self.dtype.kind == "i": # signed int
|
||||
low[np.isneginf(low)] = dtype_min
|
||||
elif self.dtype.kind in {"u", "b"}: # unsigned int and bool
|
||||
raise ValueError(
|
||||
f"Box unsigned int dtype don't support `-np.inf`, low={low}"
|
||||
)
|
||||
elif low.dtype != self.dtype and np.any(low < dtype_min):
|
||||
raise ValueError(
|
||||
f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}"
|
||||
)
|
||||
|
||||
if (
|
||||
np.issubdtype(low.dtype, np.floating)
|
||||
and np.issubdtype(self.dtype, np.floating)
|
||||
and np.finfo(self.dtype).precision < np.finfo(low.dtype).precision
|
||||
):
|
||||
gym.logger.warn(
|
||||
f"Box low's precision lowered by casting to {self.dtype}, current low.dtype={low.dtype}"
|
||||
)
|
||||
return low.astype(self.dtype), bounded_below
|
||||
|
||||
def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Casts the input Box high value to ndarray with provided dtype.
|
||||
|
||||
Args:
|
||||
high: The input box high value
|
||||
dtype_max: The dtype's maximum value
|
||||
|
||||
Returns:
|
||||
The updated high value and for what values the input is bounded (above)
|
||||
"""
|
||||
if is_float_integer(high):
|
||||
bounded_above = np.full(self.shape, high, dtype=float) < np.inf
|
||||
|
||||
if np.isnan(high):
|
||||
raise ValueError(f"No high value can be equal to `np.nan`, high={high}")
|
||||
elif np.isposinf(high):
|
||||
if self.dtype.kind == "i": # signed int
|
||||
high = dtype_max
|
||||
elif self.dtype.kind in {"u", "b"}: # unsigned int
|
||||
raise ValueError(
|
||||
f"Box unsigned int dtype don't support `np.inf`, high={high}"
|
||||
)
|
||||
elif high > dtype_max:
|
||||
raise ValueError(
|
||||
f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}"
|
||||
)
|
||||
|
||||
high = np.full(self.shape, high, dtype=self.dtype)
|
||||
return high, bounded_above
|
||||
else:
|
||||
if not isinstance(high, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Box high must be a np.ndarray, integer, or float, actual type={type(high)}"
|
||||
)
|
||||
elif not (
|
||||
np.issubdtype(high.dtype, np.floating)
|
||||
or np.issubdtype(high.dtype, np.integer)
|
||||
or high.dtype == np.bool_
|
||||
):
|
||||
raise ValueError(
|
||||
f"Box high must be a floating or integer dtype, actual dtype={high.dtype}"
|
||||
)
|
||||
elif np.any(np.isnan(high)):
|
||||
raise ValueError(f"No high value can be equal to `np.nan`, high={high}")
|
||||
|
||||
bounded_above = high < np.inf
|
||||
|
||||
posinf = np.isposinf(high)
|
||||
if np.any(posinf):
|
||||
if self.dtype.kind == "i": # signed int
|
||||
high[posinf] = dtype_max
|
||||
elif self.dtype.kind in {"u", "b"}: # unsigned int
|
||||
raise ValueError(
|
||||
f"Box unsigned int dtype don't support `np.inf`, high={high}"
|
||||
)
|
||||
elif high.dtype != self.dtype and np.any(dtype_max < high):
|
||||
raise ValueError(
|
||||
f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}"
|
||||
)
|
||||
|
||||
if (
|
||||
np.issubdtype(high.dtype, np.floating)
|
||||
and np.issubdtype(self.dtype, np.floating)
|
||||
and np.finfo(self.dtype).precision < np.finfo(high.dtype).precision
|
||||
):
|
||||
gym.logger.warn(
|
||||
f"Box high's precision lowered by casting to {self.dtype}, current high.dtype={high.dtype}"
|
||||
)
|
||||
return high.astype(self.dtype), bounded_above
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
@@ -232,7 +392,24 @@ class Box(Space[NDArray[Any]]):
|
||||
if self.dtype.kind in ["i", "u", "b"]:
|
||||
sample = np.floor(sample)
|
||||
|
||||
return sample.astype(self.dtype)
|
||||
# clip values that would underflow/overflow
|
||||
if np.issubdtype(self.dtype, np.signedinteger):
|
||||
dtype_min = np.iinfo(self.dtype).min + 2
|
||||
dtype_max = np.iinfo(self.dtype).max - 2
|
||||
sample = sample.clip(min=dtype_min, max=dtype_max)
|
||||
elif np.issubdtype(self.dtype, np.unsignedinteger):
|
||||
dtype_min = np.iinfo(self.dtype).min
|
||||
dtype_max = np.iinfo(self.dtype).max
|
||||
sample = sample.clip(min=dtype_min, max=dtype_max)
|
||||
|
||||
sample = sample.astype(self.dtype)
|
||||
|
||||
# float64 values have lower than integer precision near int64 min/max, so clip
|
||||
# again in case something has been cast to an out-of-bounds value
|
||||
if self.dtype == np.int64:
|
||||
sample = sample.clip(min=self.low, max=self.high)
|
||||
|
||||
return sample
|
||||
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
@@ -285,53 +462,7 @@ class Box(Space[NDArray[Any]]):
|
||||
|
||||
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
|
||||
if not hasattr(self, "low_repr"):
|
||||
self.low_repr = _short_repr(self.low)
|
||||
self.low_repr = array_short_repr(self.low)
|
||||
|
||||
if not hasattr(self, "high_repr"):
|
||||
self.high_repr = _short_repr(self.high)
|
||||
|
||||
|
||||
def get_precision(dtype: np.dtype) -> SupportsFloat:
|
||||
"""Get precision of a data type."""
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
return np.finfo(dtype).precision
|
||||
else:
|
||||
return np.inf
|
||||
|
||||
|
||||
def _broadcast(
|
||||
value: SupportsFloat | NDArray[Any],
|
||||
dtype: np.dtype,
|
||||
shape: tuple[int, ...],
|
||||
) -> NDArray[Any]:
|
||||
"""Handle infinite bounds and broadcast at the same time if needed.
|
||||
|
||||
This is needed primarily because:
|
||||
>>> import numpy as np
|
||||
>>> np.full((2,), np.inf, dtype=np.int32)
|
||||
array([-2147483648, -2147483648], dtype=int32)
|
||||
"""
|
||||
if is_float_integer(value):
|
||||
if np.isneginf(value) and np.dtype(dtype).kind == "i":
|
||||
value = np.iinfo(dtype).min + 2
|
||||
elif np.isposinf(value) and np.dtype(dtype).kind == "i":
|
||||
value = np.iinfo(dtype).max - 2
|
||||
|
||||
return np.full(shape, value, dtype=dtype)
|
||||
|
||||
elif isinstance(value, np.ndarray):
|
||||
# this is needed because we can't stuff np.iinfo(int).min into an array of dtype float
|
||||
casted_value = value.astype(dtype)
|
||||
|
||||
# change bounds only if values are negative or positive infinite
|
||||
if np.dtype(dtype).kind == "i":
|
||||
casted_value[np.isneginf(value)] = np.iinfo(dtype).min + 2
|
||||
casted_value[np.isposinf(value)] = np.iinfo(dtype).max - 2
|
||||
|
||||
return casted_value
|
||||
|
||||
else:
|
||||
# only np.ndarray allowed beyond this point
|
||||
raise TypeError(
|
||||
f"Unknown dtype for `value`, expected `np.ndarray` or float/integer, got {type(value)}"
|
||||
)
|
||||
self.high_repr = array_short_repr(self.high)
|
||||
|
@@ -9,20 +9,54 @@ from gymnasium.spaces import Box
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"box,expected_shape",
|
||||
"dtype, error, message",
|
||||
[
|
||||
( # Test with same 1-dim low and high shape
|
||||
Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32),
|
||||
(2,),
|
||||
(
|
||||
None,
|
||||
ValueError,
|
||||
"Box dtype must be explicitly provided, cannot be None.",
|
||||
),
|
||||
( # Test with same multi-dim low and high shape
|
||||
Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32),
|
||||
(2, 1),
|
||||
(0, TypeError, "Cannot interpret '0' as a data type"),
|
||||
("unknown", TypeError, "data type 'unknown' not understood"),
|
||||
(np.zeros(1), TypeError, "Cannot construct a dtype from an array"),
|
||||
# disabled datatypes
|
||||
(
|
||||
np.complex64,
|
||||
ValueError,
|
||||
"Invalid Box dtype (complex64), must be an integer, floating, or bool dtype",
|
||||
),
|
||||
( # Test with scalar low high and different shape
|
||||
Box(low=0, high=1, shape=(5, 2)),
|
||||
(5, 2),
|
||||
(
|
||||
complex,
|
||||
ValueError,
|
||||
"Invalid Box dtype (complex128), must be an integer, floating, or bool dtype",
|
||||
),
|
||||
(
|
||||
object,
|
||||
ValueError,
|
||||
"Invalid Box dtype (object), must be an integer, floating, or bool dtype",
|
||||
),
|
||||
(
|
||||
str,
|
||||
ValueError,
|
||||
"Invalid Box dtype (<U0), must be an integer, floating, or bool dtype",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dtype_errors(dtype, error, message):
|
||||
"""Test errors due to dtype mismatch either to being invalid or disallowed."""
|
||||
with pytest.raises(error, match=re.escape(message)):
|
||||
Box(low=0, high=1, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"box, expected_shape",
|
||||
[
|
||||
# Test with same 1-dim low and high shape
|
||||
(Box(low=np.zeros(2), high=np.ones(2)), (2,)),
|
||||
# Test with same multi-dim low and high shape
|
||||
(Box(low=np.zeros((2, 1)), high=np.ones((2, 1))), (2, 1)),
|
||||
# Test with scalar low high and different shape
|
||||
(Box(low=0, high=1, shape=(5, 2)), (5, 2)),
|
||||
(Box(low=0, high=1), (1,)), # Test with int and int
|
||||
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
|
||||
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
|
||||
@@ -39,112 +73,209 @@ def test_shape_inference(box, expected_shape):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,valid",
|
||||
[
|
||||
(1, True),
|
||||
(1.0, True),
|
||||
(np.int32(1), True),
|
||||
(np.float32(1.0), True),
|
||||
(np.zeros(2, dtype=np.float32), True),
|
||||
(np.zeros((2, 2), dtype=np.float32), True),
|
||||
(np.inf, True),
|
||||
(np.nan, True), # This is a weird case that we allow
|
||||
(True, False),
|
||||
(np.bool_(True), False),
|
||||
(1 + 1j, False),
|
||||
(np.complex128(1 + 1j), False),
|
||||
("string", False),
|
||||
],
|
||||
)
|
||||
def test_low_high_values(value, valid: bool):
|
||||
"""Test what `low` and `high` values are valid for `Box` space."""
|
||||
if valid:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
Box(low=-np.inf, high=value)
|
||||
assert len(caught_warnings) == 0, tuple(
|
||||
warning.message for warning in caught_warnings
|
||||
)
|
||||
else:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"expected their types to be np.ndarray, an integer or a float"
|
||||
),
|
||||
):
|
||||
Box(low=-np.inf, high=value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"low,high,kwargs,error,message",
|
||||
"low, high, shape, error_type, message",
|
||||
[
|
||||
(
|
||||
0,
|
||||
1,
|
||||
{"dtype": None},
|
||||
AssertionError,
|
||||
"Box dtype must be explicitly provided, cannot be None.",
|
||||
1,
|
||||
TypeError,
|
||||
"Expected Box shape to be an iterable, actual type=<class 'int'>",
|
||||
),
|
||||
(
|
||||
0,
|
||||
1,
|
||||
{"shape": (None,)},
|
||||
AssertionError,
|
||||
"Expected all shape elements to be an integer, actual type: (<class 'NoneType'>,)",
|
||||
(None,),
|
||||
TypeError,
|
||||
"Expected all Box shape elements to be integer, actual type=(<class 'NoneType'>,)",
|
||||
),
|
||||
(
|
||||
0,
|
||||
1,
|
||||
{
|
||||
"shape": (
|
||||
1,
|
||||
None,
|
||||
)
|
||||
},
|
||||
AssertionError,
|
||||
"Expected all shape elements to be an integer, actual type: (<class 'int'>, <class 'NoneType'>)",
|
||||
(1, None),
|
||||
TypeError,
|
||||
"Expected all Box shape elements to be integer, actual type=(<class 'int'>, <class 'NoneType'>)",
|
||||
),
|
||||
(
|
||||
0,
|
||||
1,
|
||||
{
|
||||
"shape": (
|
||||
np.int64(1),
|
||||
None,
|
||||
)
|
||||
},
|
||||
AssertionError,
|
||||
"Expected all shape elements to be an integer, actual type: (<class 'numpy.int64'>, <class 'NoneType'>)",
|
||||
),
|
||||
(
|
||||
None,
|
||||
None,
|
||||
{},
|
||||
ValueError,
|
||||
"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class 'NoneType'>, high: <class 'NoneType'>",
|
||||
),
|
||||
(
|
||||
0,
|
||||
None,
|
||||
{},
|
||||
ValueError,
|
||||
"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: <class 'int'>, high: <class 'NoneType'>",
|
||||
(np.int64(1), None),
|
||||
TypeError,
|
||||
"Expected all Box shape elements to be integer, actual type=(<class 'numpy.int64'>, <class 'NoneType'>)",
|
||||
),
|
||||
(
|
||||
np.zeros(3),
|
||||
np.ones(2),
|
||||
{},
|
||||
AssertionError,
|
||||
"high.shape doesn't match provided shape, high.shape: (2,), shape: (3,)",
|
||||
None,
|
||||
ValueError,
|
||||
"Box low.shape and high.shape don't match, low.shape=(3,), high.shape=(2,)",
|
||||
),
|
||||
(
|
||||
np.zeros(2),
|
||||
np.ones(2),
|
||||
(3,),
|
||||
ValueError,
|
||||
"Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)",
|
||||
),
|
||||
(
|
||||
np.zeros(2),
|
||||
1,
|
||||
(3,),
|
||||
ValueError,
|
||||
"Box low.shape doesn't match provided shape, low.shape=(2,), shape=(3,)",
|
||||
),
|
||||
(
|
||||
0,
|
||||
np.ones(2),
|
||||
(3,),
|
||||
ValueError,
|
||||
"Box high.shape doesn't match provided shape, high.shape=(2,), shape=(3,)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_init_errors(low, high, kwargs, error, message):
|
||||
"""Test all constructor errors."""
|
||||
with pytest.raises(error, match=f"^{re.escape(message)}$"):
|
||||
Box(low=low, high=high, **kwargs)
|
||||
def test_shape_errors(low, high, shape, error_type, message):
|
||||
"""Test errors due to shape mismatch."""
|
||||
with pytest.raises(error_type, match=f"^{re.escape(message)}$"):
|
||||
Box(low=low, high=high, shape=shape)
|
||||
|
||||
|
||||
def test_dtype_check():
|
||||
@pytest.mark.parametrize(
|
||||
"low, high, dtype",
|
||||
[
|
||||
# floats
|
||||
(0, 65505.0, np.float16),
|
||||
(-65505.0, 0, np.float16),
|
||||
# signed int
|
||||
(0, 32768, np.int16),
|
||||
(-32769, 0, np.int16),
|
||||
# unsigned int
|
||||
(-1, 100, np.uint8),
|
||||
(0, 300, np.uint8),
|
||||
# boolean
|
||||
(-1, 1, np.bool_),
|
||||
(0, 2, np.bool_),
|
||||
# array inputs
|
||||
(
|
||||
np.array([-1, 0]),
|
||||
np.array([0, 100]),
|
||||
np.uint8,
|
||||
),
|
||||
(
|
||||
np.array([[-1], [0]]),
|
||||
np.array([[0], [100]]),
|
||||
np.uint8,
|
||||
),
|
||||
(
|
||||
np.array([0, 0]),
|
||||
np.array([0, 300]),
|
||||
np.uint8,
|
||||
),
|
||||
(
|
||||
np.array([[0], [0]]),
|
||||
np.array([[0], [300]]),
|
||||
np.uint8,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_out_of_bounds_error(low, high, dtype):
|
||||
with pytest.raises(
|
||||
ValueError, match=re.escape("is out of bounds of the dtype range,")
|
||||
):
|
||||
Box(low=low, high=high, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"low, high, dtype",
|
||||
[
|
||||
# Floats
|
||||
(np.nan, 0, np.float32),
|
||||
(0, np.nan, np.float32),
|
||||
(np.array([0, np.nan]), np.ones(2), np.float32),
|
||||
# Signed ints
|
||||
(np.nan, 0, np.int32),
|
||||
(0, np.nan, np.int32),
|
||||
(np.array([0, np.nan]), np.ones(2), np.int32),
|
||||
# Unsigned ints
|
||||
# (np.nan, 0, np.uint8),
|
||||
# (0, np.nan, np.uint8),
|
||||
# (np.array([0, np.nan]), np.ones(2), np.uint8),
|
||||
(-np.inf, 1, np.uint8),
|
||||
(np.array([-np.inf, 0]), 1, np.uint8),
|
||||
(0, np.inf, np.uint8),
|
||||
(0, np.array([1, np.inf]), np.uint8),
|
||||
# boolean
|
||||
(-np.inf, 1, np.bool_),
|
||||
(0, np.inf, np.bool_),
|
||||
],
|
||||
)
|
||||
def test_invalid_low_high(low, high, dtype):
|
||||
if dtype == np.uint8 or dtype == np.bool_:
|
||||
with pytest.raises(
|
||||
ValueError, match=re.escape("Box unsigned int dtype don't support")
|
||||
):
|
||||
Box(low=low, high=high, dtype=dtype)
|
||||
else:
|
||||
with pytest.raises(
|
||||
ValueError, match=re.escape("value can be equal to `np.nan`,")
|
||||
):
|
||||
Box(low=low, high=high, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"low, high, dtype",
|
||||
[
|
||||
# floats
|
||||
(0, 1, float),
|
||||
(0, 1, np.float64),
|
||||
(0, 1, np.float32),
|
||||
(0, 1, np.float16),
|
||||
(np.zeros(2), np.ones(2), np.float32),
|
||||
(np.zeros(2), 1, np.float32),
|
||||
(-np.inf, 1, np.float32),
|
||||
(np.array([-np.inf, 0]), 1, np.float32),
|
||||
(0, np.inf, np.float32),
|
||||
(0, np.array([np.inf, 1]), np.float32),
|
||||
(-np.inf, np.inf, np.float32),
|
||||
(np.full((2,), -np.inf), np.full((2,), np.inf), np.float32),
|
||||
# signed ints
|
||||
(0, 1, int),
|
||||
(0, 1, np.int64),
|
||||
(0, 1, np.int32),
|
||||
(0, 1, np.int16),
|
||||
(0, 1, np.int8),
|
||||
(np.zeros(2), np.ones(2), np.int32),
|
||||
(np.zeros(2), 1, np.int32),
|
||||
(-np.inf, 1, np.int32),
|
||||
(np.array([-np.inf, 0]), 1, np.int32),
|
||||
(0, np.inf, np.int32),
|
||||
(0, np.array([np.inf, 1]), np.int32),
|
||||
# unsigned ints
|
||||
(0, 1, np.uint64),
|
||||
(0, 1, np.uint32),
|
||||
(0, 1, np.uint16),
|
||||
(0, 1, np.uint8),
|
||||
# boolean
|
||||
(0, 1, np.bool_),
|
||||
],
|
||||
)
|
||||
def test_valid_low_high(low, high, dtype):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
space = Box(low=low, high=high, dtype=dtype)
|
||||
assert space.dtype == dtype
|
||||
assert space.low.dtype == dtype
|
||||
assert space.high.dtype == dtype
|
||||
|
||||
space.seed(0)
|
||||
sample = space.sample()
|
||||
assert sample.dtype == dtype
|
||||
assert space.contains(sample)
|
||||
|
||||
for warn in caught_warnings:
|
||||
if "precision lowered by casting to float32" not in warn.message.args[0]:
|
||||
raise Exception(warn)
|
||||
|
||||
|
||||
def test_contains_dtype():
|
||||
"""Tests the Box contains function with different dtypes."""
|
||||
# Related Issues:
|
||||
# https://github.com/openai/gym/issues/2357
|
||||
@@ -163,102 +294,54 @@ def test_dtype_check():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
"lowhighshape",
|
||||
[
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
|
||||
dict(low=0, high=np.inf, shape=(2,)),
|
||||
dict(low=-np.inf, high=0, shape=(2,)),
|
||||
dict(low=-np.inf, high=np.inf, shape=(2,)),
|
||||
dict(low=0, high=np.inf, shape=(2, 3)),
|
||||
dict(low=-np.inf, high=0, shape=(2, 3)),
|
||||
dict(low=-np.inf, high=np.inf, shape=(2, 3)),
|
||||
dict(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf])),
|
||||
],
|
||||
)
|
||||
def test_infinite_space(space):
|
||||
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
|
||||
def test_infinite_space(lowhighshape, dtype):
|
||||
"""
|
||||
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
|
||||
are both modified within the init, we check for infinite when we know it's not 0
|
||||
"""
|
||||
space = Box(**lowhighshape, dtype=dtype)
|
||||
|
||||
assert np.all(
|
||||
space.low < space.high
|
||||
), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})"
|
||||
assert np.all(space.low < space.high)
|
||||
|
||||
space.seed(0)
|
||||
sample = space.sample()
|
||||
|
||||
# check if space contains sample
|
||||
assert (
|
||||
sample in space
|
||||
), f"Sample ({sample}) not inside space according to `space.contains()`"
|
||||
|
||||
# manually check that the sign of the sample is within the bounds
|
||||
assert np.all(
|
||||
np.sign(sample) <= np.sign(space.high)
|
||||
), f"Sign of sample ({sample}) is less than space upper bound ({space.high})"
|
||||
assert np.all(
|
||||
np.sign(space.low) <= np.sign(sample)
|
||||
), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"
|
||||
|
||||
# check that int bounds are bounded for everything
|
||||
# but floats are unbounded for infinite
|
||||
if np.any(space.high != 0):
|
||||
assert (
|
||||
space.is_bounded("above") is False
|
||||
), "inf upper bound supposed to be unbounded"
|
||||
else:
|
||||
assert (
|
||||
space.is_bounded("above") is True
|
||||
), "non-inf upper bound supposed to be bounded"
|
||||
|
||||
if np.any(space.low != 0):
|
||||
assert (
|
||||
space.is_bounded("below") is False
|
||||
), "inf lower bound supposed to be unbounded"
|
||||
else:
|
||||
assert (
|
||||
space.is_bounded("below") is True
|
||||
), "non-inf lower bound supposed to be bounded"
|
||||
|
||||
if np.any(space.low != 0) or np.any(space.high != 0):
|
||||
assert space.is_bounded("both") is False
|
||||
else:
|
||||
assert space.is_bounded("both") is True
|
||||
# check that int bounds are bounded for everything but floats are unbounded for infinite
|
||||
assert space.is_bounded("above") is not np.any(space.high != 0)
|
||||
assert space.is_bounded("below") is not np.any(space.low != 0)
|
||||
assert space.is_bounded("both") is not (
|
||||
np.any(space.high != 0) | np.any(space.high != 0)
|
||||
)
|
||||
|
||||
# check for dtype
|
||||
assert (
|
||||
space.high.dtype == space.dtype
|
||||
), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
||||
assert (
|
||||
space.low.dtype == space.dtype
|
||||
), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
||||
assert space.high.dtype == space.dtype
|
||||
assert space.low.dtype == space.dtype
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:"
|
||||
):
|
||||
space.is_bounded("test")
|
||||
|
||||
# Check sample
|
||||
space.seed(0)
|
||||
sample = space.sample()
|
||||
|
||||
# check if space contains sample
|
||||
assert sample in space
|
||||
|
||||
# manually check that the sign of the sample is within the bounds
|
||||
assert np.all(np.sign(sample) <= np.sign(space.high))
|
||||
assert np.all(np.sign(space.low) <= np.sign(sample))
|
||||
|
||||
|
||||
def test_legacy_state_pickling():
|
||||
legacy_state = {
|
||||
@@ -290,56 +373,3 @@ def test_sample_mask():
|
||||
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
|
||||
):
|
||||
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"low, high, shape, dtype, reason",
|
||||
[
|
||||
(
|
||||
5.0,
|
||||
3.0,
|
||||
(),
|
||||
np.float32,
|
||||
"Some low values are greater than high, low=5.0, high=3.0",
|
||||
),
|
||||
(
|
||||
np.array([5.0, 6.0]),
|
||||
np.array([1.0, 5.99]),
|
||||
(2,),
|
||||
np.float32,
|
||||
"Some low values are greater than high, low=[5. 6.], high=[1. 5.99]",
|
||||
),
|
||||
(
|
||||
np.inf,
|
||||
np.inf,
|
||||
(),
|
||||
np.float32,
|
||||
"No low value can be equal to `np.inf`, low=inf",
|
||||
),
|
||||
(
|
||||
np.array([0, np.inf]),
|
||||
np.array([np.inf, np.inf]),
|
||||
(2,),
|
||||
np.float32,
|
||||
"No low value can be equal to `np.inf`, low=[ 0. inf]",
|
||||
),
|
||||
(
|
||||
-np.inf,
|
||||
-np.inf,
|
||||
(),
|
||||
np.float32,
|
||||
"No high value can be equal to `-np.inf`, high=-inf",
|
||||
),
|
||||
(
|
||||
np.array([-np.inf, -np.inf]),
|
||||
np.array([0, -np.inf]),
|
||||
(2,),
|
||||
np.float32,
|
||||
"No high value can be equal to `-np.inf`, high=[ 0. -inf]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_low_high(low, high, dtype, shape, reason):
|
||||
"""Tests that we don't allow spaces with degenerate bounds, such as `Box(np.inf, -np.inf)`."""
|
||||
with pytest.raises(ValueError, match=re.escape(reason)):
|
||||
Box(low=low, high=high, dtype=dtype, shape=shape)
|
||||
|
@@ -45,7 +45,7 @@ def custom_environments():
|
||||
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
|
||||
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
|
||||
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
|
||||
("CartPole-v1", "DtypeObservation", {"dtype": np.int32}),
|
||||
("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}),
|
||||
# ("CartPole-v1", "RenderObservation", {}), # not implemented
|
||||
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
|
||||
# ("CartPole-v1", "FrameStackObservation", {}), # not implemented
|
||||
|
Reference in New Issue
Block a user