2022-05-10 17:18:06 +02:00
|
|
|
"""Implementation of a space that represents closed boxes in euclidean space."""
|
2022-05-31 23:53:13 -04:00
|
|
|
from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
2022-01-24 23:22:11 +01:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
import numpy as np
|
2018-11-29 02:27:27 +01:00
|
|
|
|
2019-12-06 14:13:46 +01:00
|
|
|
from gym import logger
|
2022-04-24 17:14:33 +01:00
|
|
|
from gym.spaces.space import Space
|
|
|
|
from gym.utils import seeding
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2019-01-30 22:39:55 +01:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def _short_repr(arr: np.ndarray) -> str:
|
2022-01-13 19:41:53 +01:00
|
|
|
"""Create a shortened string representation of a numpy array.
|
|
|
|
|
|
|
|
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
|
|
|
Otherwise, return a string representation of the entire array.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
arr: The array to represent
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A short representation of the array
|
2022-01-13 19:41:53 +01:00
|
|
|
"""
|
|
|
|
if arr.size != 0 and np.min(arr) == np.max(arr):
|
|
|
|
return str(np.min(arr))
|
|
|
|
return str(arr)
|
|
|
|
|
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
class Box(Space[np.ndarray]):
|
2022-05-10 17:18:06 +02:00
|
|
|
r"""A (possibly unbounded) box in :math:`\mathbb{R}^n`.
|
|
|
|
|
|
|
|
Specifically, a Box represents the Cartesian product of n closed intervals.
|
|
|
|
Each interval has the form of one of :math:`[a, b]`, :math:`(-\infty, b]`,
|
|
|
|
:math:`[a, \infty)`, or :math:`(-\infty, \infty)`.
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
There are two common use cases:
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
* Identical bound for each dimension::
|
2022-04-08 03:19:52 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
>>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32)
|
|
|
|
Box(3, 4)
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
* Independent bound for each dimension::
|
2022-04-08 03:19:52 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
>>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
|
|
|
|
Box(2,)
|
2016-04-27 08:00:58 -07:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
low: Union[SupportsFloat, np.ndarray],
|
|
|
|
high: Union[SupportsFloat, np.ndarray],
|
|
|
|
shape: Optional[Sequence[int]] = None,
|
2022-05-25 15:28:19 +01:00
|
|
|
dtype: Type = np.float32,
|
|
|
|
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
2022-01-24 23:22:11 +01:00
|
|
|
):
|
2022-05-10 17:18:06 +02:00
|
|
|
r"""Constructor of :class:`Box`.
|
|
|
|
|
|
|
|
The argument ``low`` specifies the lower bound of each dimension and ``high`` specifies the upper bounds.
|
|
|
|
I.e., the space that is constructed will be the product of the intervals :math:`[\text{low}[i], \text{high}[i]]`.
|
|
|
|
|
|
|
|
If ``low`` (or ``high``) is a scalar, the lower bound (or upper bound, respectively) will be assumed to be
|
|
|
|
this value across all dimensions.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
|
|
|
|
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
|
|
|
|
shape (Optional[Sequence[int]]): This only needs to be specified if both ``low`` and ``high`` are scalars and determines the shape of the space.
|
|
|
|
Otherwise, the shape is inferred from the shape of ``low`` or ``high``.
|
|
|
|
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
|
|
|
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If no shape information is provided (shape is None, low is None and high is None) then a
|
|
|
|
value error is raised.
|
2022-05-10 17:18:06 +02:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
assert dtype is not None, "dtype must be explicitly provided. "
|
2019-03-25 00:39:32 +01:00
|
|
|
self.dtype = np.dtype(dtype)
|
|
|
|
|
2020-05-08 17:56:14 -04:00
|
|
|
# determine shape if it isn't provided directly
|
|
|
|
if shape is not None:
|
|
|
|
shape = tuple(shape)
|
|
|
|
elif not np.isscalar(low):
|
2022-01-24 23:22:11 +01:00
|
|
|
shape = low.shape # type: ignore
|
2020-05-08 17:56:14 -04:00
|
|
|
elif not np.isscalar(high):
|
2022-01-24 23:22:11 +01:00
|
|
|
shape = high.shape # type: ignore
|
2016-04-27 08:00:58 -07:00
|
|
|
else:
|
2021-07-29 15:39:42 -04:00
|
|
|
raise ValueError(
|
|
|
|
"shape must be provided or inferred from the shapes of low or high"
|
|
|
|
)
|
2022-01-24 23:22:11 +01:00
|
|
|
assert isinstance(shape, tuple)
|
2020-05-08 17:56:14 -04:00
|
|
|
|
2022-03-02 07:51:06 -08:00
|
|
|
# Capture the boundedness information before replacing np.inf with get_inf
|
|
|
|
_low = np.full(shape, low, dtype=float) if np.isscalar(low) else low
|
2022-04-08 03:19:52 +02:00
|
|
|
self.bounded_below = -np.inf < _low # type: ignore
|
2022-03-02 07:51:06 -08:00
|
|
|
_high = np.full(shape, high, dtype=float) if np.isscalar(high) else high
|
2022-04-08 03:19:52 +02:00
|
|
|
self.bounded_above = np.inf > _high # type: ignore
|
2022-03-02 07:51:06 -08:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
|
2022-03-02 07:51:06 -08:00
|
|
|
high = _broadcast(high, dtype, shape, inf_sign="+") # type: ignore
|
2022-01-24 23:22:11 +01:00
|
|
|
|
|
|
|
assert isinstance(low, np.ndarray)
|
|
|
|
assert low.shape == shape, "low.shape doesn't match provided shape"
|
|
|
|
assert isinstance(high, np.ndarray)
|
|
|
|
assert high.shape == shape, "high.shape doesn't match provided shape"
|
|
|
|
|
2022-05-25 15:28:19 +01:00
|
|
|
self._shape: Tuple[int, ...] = shape
|
2022-01-24 23:22:11 +01:00
|
|
|
|
|
|
|
low_precision = get_precision(low.dtype)
|
|
|
|
high_precision = get_precision(high.dtype)
|
2022-01-11 04:45:41 +00:00
|
|
|
dtype_precision = get_precision(self.dtype)
|
2022-01-24 23:22:11 +01:00
|
|
|
if min(low_precision, high_precision) > dtype_precision: # type: ignore
|
2021-11-14 14:50:23 +01:00
|
|
|
logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
|
2022-01-24 23:22:11 +01:00
|
|
|
self.low = low.astype(self.dtype)
|
|
|
|
self.high = high.astype(self.dtype)
|
2019-06-28 18:54:31 -04:00
|
|
|
|
2022-01-13 19:41:53 +01:00
|
|
|
self.low_repr = _short_repr(self.low)
|
|
|
|
self.high_repr = _short_repr(self.high)
|
|
|
|
|
2021-11-14 14:50:23 +01:00
|
|
|
super().__init__(self.shape, self.dtype, seed)
|
Cleanup, removal of unmaintained code (#836)
* add dtype to Box
* remove board_game, debugging, safety, parameter_tuning environments
* massive set of breaking changes
- remove python logging module
- _step, _reset, _seed, _close => non underscored method
- remove benchmark and scoring folder
* Improve render("human"), now resizable, closable window.
* get rid of default step and reset in wrappers, so it doesn’t silently fail for people with underscore methods
* CubeCrash unit test environment
* followup fixes
* MemorizeDigits unit test envrionment
* refactored spaces a bit
fixed indentation
disabled test_env_semantics
* fix unit tests
* fixes
* CubeCrash, MemorizeDigits tested
* gym backwards compatibility patch
* gym backwards compatibility, followup fixes
* changelist, add spaces to main namespaces
* undo_logger_setup for backwards compat
* remove configuration.py
2018-01-25 18:20:14 -08:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
@property
|
2022-05-25 15:28:19 +01:00
|
|
|
def shape(self) -> Tuple[int, ...]:
|
2022-01-24 23:22:11 +01:00
|
|
|
"""Has stricter type than gym.Space - never None."""
|
|
|
|
return self._shape
|
|
|
|
|
|
|
|
def is_bounded(self, manner: str = "both") -> bool:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Checks whether the box is bounded in some sense.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
manner (str): One of ``"both"``, ``"below"``, ``"above"``.
|
|
|
|
|
2022-05-25 14:46:41 +01:00
|
|
|
Returns:
|
|
|
|
If the space is bounded
|
|
|
|
|
2022-05-10 17:18:06 +02:00
|
|
|
Raises:
|
2022-05-24 23:09:05 +01:00
|
|
|
ValueError: If `manner` is neither ``"both"`` nor ``"below"`` or ``"above"``
|
2022-05-10 17:18:06 +02:00
|
|
|
"""
|
2022-01-24 23:22:11 +01:00
|
|
|
below = bool(np.all(self.bounded_below))
|
|
|
|
above = bool(np.all(self.bounded_above))
|
2019-06-28 18:54:31 -04:00
|
|
|
if manner == "both":
|
|
|
|
return below and above
|
|
|
|
elif manner == "below":
|
|
|
|
return below
|
|
|
|
elif manner == "above":
|
|
|
|
return above
|
|
|
|
else:
|
|
|
|
raise ValueError("manner is not in {'below', 'above', 'both'}")
|
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def sample(self) -> np.ndarray:
|
2022-05-10 17:18:06 +02:00
|
|
|
r"""Generates a single random sample inside the Box.
|
2019-06-28 18:54:31 -04:00
|
|
|
|
2022-05-10 17:18:06 +02:00
|
|
|
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
|
|
|
|
that is chosen according to the form of the interval:
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2022-05-10 17:18:06 +02:00
|
|
|
* :math:`[a, b]` : uniform distribution
|
|
|
|
* :math:`[a, \infty)` : shifted exponential distribution
|
|
|
|
* :math:`(-\infty, b]` : shifted negative exponential distribution
|
|
|
|
* :math:`(-\infty, \infty)` : normal distribution
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A sampled value from the Box
|
2019-06-28 18:54:31 -04:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
2019-06-28 18:54:31 -04:00
|
|
|
sample = np.empty(self.shape)
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2019-06-28 18:54:31 -04:00
|
|
|
# Masking arrays which classify the coordinates according to interval
|
|
|
|
# type
|
2021-07-29 02:26:34 +02:00
|
|
|
unbounded = ~self.bounded_below & ~self.bounded_above
|
|
|
|
upp_bounded = ~self.bounded_below & self.bounded_above
|
|
|
|
low_bounded = self.bounded_below & ~self.bounded_above
|
|
|
|
bounded = self.bounded_below & self.bounded_above
|
2019-06-28 18:54:31 -04:00
|
|
|
|
|
|
|
# Vectorized sampling by interval type
|
2021-07-29 02:26:34 +02:00
|
|
|
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
|
|
|
|
|
2021-07-29 15:39:42 -04:00
|
|
|
sample[low_bounded] = (
|
|
|
|
self.np_random.exponential(size=low_bounded[low_bounded].shape)
|
|
|
|
+ self.low[low_bounded]
|
|
|
|
)
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-07-29 15:39:42 -04:00
|
|
|
sample[upp_bounded] = (
|
|
|
|
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
|
|
|
|
+ self.high[upp_bounded]
|
|
|
|
)
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-07-29 15:39:42 -04:00
|
|
|
sample[bounded] = self.np_random.uniform(
|
|
|
|
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
|
|
|
|
)
|
2021-07-29 02:26:34 +02:00
|
|
|
if self.dtype.kind == "i":
|
2019-11-02 04:52:11 +05:30
|
|
|
sample = np.floor(sample)
|
2019-06-28 18:54:31 -04:00
|
|
|
|
|
|
|
return sample.astype(self.dtype)
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def contains(self, x) -> bool:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Return boolean specifying if x is a valid member of this space."""
|
2021-09-01 18:14:22 +02:00
|
|
|
if not isinstance(x, np.ndarray):
|
2021-10-02 08:36:02 +08:00
|
|
|
logger.warn("Casting input x to numpy array.")
|
2021-09-01 18:14:22 +02:00
|
|
|
x = np.asarray(x, dtype=self.dtype)
|
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
return bool(
|
2021-09-01 18:14:22 +02:00
|
|
|
np.can_cast(x.dtype, self.dtype)
|
|
|
|
and x.shape == self.shape
|
2021-09-03 18:28:58 +02:00
|
|
|
and np.all(x >= self.low)
|
|
|
|
and np.all(x <= self.high)
|
2021-07-29 15:39:42 -04:00
|
|
|
)
|
2016-04-27 08:00:58 -07:00
|
|
|
|
|
|
|
def to_jsonable(self, sample_n):
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
2016-04-27 08:00:58 -07:00
|
|
|
return np.array(sample_n).tolist()
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-05-25 15:28:19 +01:00
|
|
|
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> List[np.ndarray]:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
2016-04-27 08:00:58 -07:00
|
|
|
return [np.asarray(sample) for sample in sample_n]
|
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __repr__(self) -> str:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""A string representation of this space.
|
|
|
|
|
|
|
|
The representation will include bounds, shape and dtype.
|
|
|
|
If a bound is uniform, only the corresponding scalar will be given to avoid redundant and ugly strings.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A representation of the space
|
2022-05-10 17:18:06 +02:00
|
|
|
"""
|
2022-01-13 19:41:53 +01:00
|
|
|
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
2018-11-29 02:27:27 +01:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __eq__(self, other) -> bool:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Check whether `other` is equivalent to this instance."""
|
2021-07-29 02:26:34 +02:00
|
|
|
return (
|
|
|
|
isinstance(other, Box)
|
|
|
|
and (self.shape == other.shape)
|
|
|
|
and np.allclose(self.low, other.low)
|
|
|
|
and np.allclose(self.high, other.high)
|
|
|
|
)
|
2022-01-11 04:45:41 +00:00
|
|
|
|
2022-05-31 23:53:13 -04:00
|
|
|
def __setstate__(self, state: Dict):
|
|
|
|
"""Sets the state of the box for unpickling a box with legacy support."""
|
|
|
|
super().__setstate__(state)
|
|
|
|
|
|
|
|
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
|
|
|
|
if not hasattr(self, "low_repr"):
|
|
|
|
self.low_repr = _short_repr(self.low)
|
|
|
|
|
|
|
|
if not hasattr(self, "high_repr"):
|
|
|
|
self.high_repr = _short_repr(self.high)
|
|
|
|
|
2022-01-11 04:45:41 +00:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def get_inf(dtype, sign: str) -> SupportsFloat:
|
2022-01-11 04:45:41 +00:00
|
|
|
"""Returns an infinite that doesn't break things.
|
2022-05-10 17:18:06 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
dtype: An `np.dtype`
|
|
|
|
sign (str): must be either `"+"` or `"-"`
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Gets an infinite value with the sign and dtype
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
TypeError: Unknown sign, use either '+' or '-'
|
|
|
|
ValueError: Unknown dtype for infinite bounds
|
2022-01-11 04:45:41 +00:00
|
|
|
"""
|
|
|
|
if np.dtype(dtype).kind == "f":
|
|
|
|
if sign == "+":
|
|
|
|
return np.inf
|
|
|
|
elif sign == "-":
|
|
|
|
return -np.inf
|
|
|
|
else:
|
|
|
|
raise TypeError(f"Unknown sign {sign}, use either '+' or '-'")
|
|
|
|
elif np.dtype(dtype).kind == "i":
|
|
|
|
if sign == "+":
|
|
|
|
return np.iinfo(dtype).max - 2
|
|
|
|
elif sign == "-":
|
|
|
|
return np.iinfo(dtype).min + 2
|
|
|
|
else:
|
|
|
|
raise TypeError(f"Unknown sign {sign}, use either '+' or '-'")
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
|
|
|
|
|
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def get_precision(dtype) -> SupportsFloat:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Get precision of a data type."""
|
2022-01-11 04:45:41 +00:00
|
|
|
if np.issubdtype(dtype, np.floating):
|
|
|
|
return np.finfo(dtype).precision
|
|
|
|
else:
|
|
|
|
return np.inf
|
2022-01-24 23:22:11 +01:00
|
|
|
|
|
|
|
|
|
|
|
def _broadcast(
|
|
|
|
value: Union[SupportsFloat, np.ndarray],
|
|
|
|
dtype,
|
2022-05-25 15:28:19 +01:00
|
|
|
shape: Tuple[int, ...],
|
2022-01-24 23:22:11 +01:00
|
|
|
inf_sign: str,
|
|
|
|
) -> np.ndarray:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Handle infinite bounds and broadcast at the same time if needed."""
|
2022-01-24 23:22:11 +01:00
|
|
|
if np.isscalar(value):
|
|
|
|
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
|
|
|
|
value = np.full(shape, value, dtype=dtype)
|
|
|
|
else:
|
|
|
|
assert isinstance(value, np.ndarray)
|
|
|
|
if np.any(np.isinf(value)):
|
|
|
|
# create new array with dtype, but maintain old one to preserve np.inf
|
|
|
|
temp = value.astype(dtype)
|
|
|
|
temp[np.isinf(value)] = get_inf(dtype, inf_sign)
|
|
|
|
value = temp
|
|
|
|
return value
|