Files
Gymnasium/gym/spaces/box.py

219 lines
7.4 KiB
Python
Raw Normal View History

2016-04-27 08:00:58 -07:00
import numpy as np
from .space import Space
from gym import logger
2016-04-27 08:00:58 -07:00
def _short_repr(arr):
"""Create a shortened string representation of a numpy array.
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
Otherwise, return a string representation of the entire array.
"""
if arr.size != 0 and np.min(arr) == np.max(arr):
return str(np.min(arr))
return str(arr)
class Box(Space):
"""
A (possibly unbounded) box in R^n. Specifically, a Box represents the
Cartesian product of n closed intervals. Each interval has the form of one
of [a, b], (-oo, b], [a, oo), or (-oo, oo).
2019-03-25 00:39:32 +01:00
There are two common use cases:
2019-03-25 00:39:32 +01:00
* Identical bound for each dimension::
>>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32)
Box(3, 4)
2019-03-25 00:39:32 +01:00
* Independent bound for each dimension::
>>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
Box(2,)
2016-06-11 23:10:58 -07:00
2016-04-27 08:00:58 -07:00
"""
2021-07-29 02:26:34 +02:00
def __init__(self, low, high, shape=None, dtype=np.float32, seed=None):
2021-07-29 02:26:34 +02:00
assert dtype is not None, "dtype must be explicitly provided. "
2019-03-25 00:39:32 +01:00
self.dtype = np.dtype(dtype)
# determine shape if it isn't provided directly
if shape is not None:
shape = tuple(shape)
2021-07-29 15:39:42 -04:00
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match provided shape"
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match provided shape"
elif not np.isscalar(low):
shape = low.shape
2021-07-29 15:39:42 -04:00
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match low.shape"
elif not np.isscalar(high):
shape = high.shape
2021-07-29 15:39:42 -04:00
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match high.shape"
2016-04-27 08:00:58 -07:00
else:
2021-07-29 15:39:42 -04:00
raise ValueError(
"shape must be provided or inferred from the shapes of low or high"
)
# handle infinite bounds and broadcast at the same time if needed
if np.isscalar(low):
low = get_inf(dtype, "-") if np.isinf(low) else low
2020-05-08 16:25:27 -07:00
low = np.full(shape, low, dtype=dtype)
else:
if np.any(np.isinf(low)):
# create new array with dtype, but maintain old one to preserve np.inf
temp_low = low.astype(dtype)
temp_low[np.isinf(low)] = get_inf(dtype, "-")
low = temp_low
if np.isscalar(high):
high = get_inf(dtype, "+") if np.isinf(high) else high
2020-05-08 16:25:27 -07:00
high = np.full(shape, high, dtype=dtype)
else:
if np.any(np.isinf(high)):
# create new array with dtype, but maintain old one to preserve np.inf
temp_high = high.astype(dtype)
temp_high[np.isinf(high)] = get_inf(dtype, "+")
high = temp_high
self._shape = shape
self.low = low
self.high = high
low_precision = get_precision(self.low.dtype)
high_precision = get_precision(self.high.dtype)
dtype_precision = get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision:
logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
2019-03-25 00:39:32 +01:00
self.low = self.low.astype(self.dtype)
self.high = self.high.astype(self.dtype)
self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high)
# Boolean arrays which indicate the interval type for each coordinate
self.bounded_below = -np.inf < self.low
self.bounded_above = np.inf > self.high
super().__init__(self.shape, self.dtype, seed)
def is_bounded(self, manner="both"):
below = np.all(self.bounded_below)
above = np.all(self.bounded_above)
if manner == "both":
return below and above
elif manner == "below":
return below
elif manner == "above":
return above
else:
raise ValueError("manner is not in {'below', 'above', 'both'}")
2016-04-27 08:00:58 -07:00
def sample(self):
"""
Generates a single random sample inside of the Box.
In creating a sample of the box, each coordinate is sampled according to
the form of the interval:
* [a, b] : uniform distribution
* [a, oo) : shifted exponential distribution
* (-oo, b] : shifted negative exponential distribution
* (-oo, oo) : normal distribution
"""
2021-07-29 02:26:34 +02:00
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape)
# Masking arrays which classify the coordinates according to interval
# type
2021-07-29 02:26:34 +02:00
unbounded = ~self.bounded_below & ~self.bounded_above
upp_bounded = ~self.bounded_below & self.bounded_above
low_bounded = self.bounded_below & ~self.bounded_above
bounded = self.bounded_below & self.bounded_above
# Vectorized sampling by interval type
2021-07-29 02:26:34 +02:00
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
2021-07-29 15:39:42 -04:00
sample[low_bounded] = (
self.np_random.exponential(size=low_bounded[low_bounded].shape)
+ self.low[low_bounded]
)
2021-07-29 02:26:34 +02:00
2021-07-29 15:39:42 -04:00
sample[upp_bounded] = (
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
+ self.high[upp_bounded]
)
2021-07-29 02:26:34 +02:00
2021-07-29 15:39:42 -04:00
sample[bounded] = self.np_random.uniform(
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
)
2021-07-29 02:26:34 +02:00
if self.dtype.kind == "i":
sample = np.floor(sample)
return sample.astype(self.dtype)
2016-04-27 08:00:58 -07:00
def contains(self, x):
if not isinstance(x, np.ndarray):
logger.warn("Casting input x to numpy array.")
x = np.asarray(x, dtype=self.dtype)
2021-07-29 15:39:42 -04:00
return (
np.can_cast(x.dtype, self.dtype)
and x.shape == self.shape
and np.all(x >= self.low)
and np.all(x <= self.high)
2021-07-29 15:39:42 -04:00
)
2016-04-27 08:00:58 -07:00
def to_jsonable(self, sample_n):
return np.array(sample_n).tolist()
2016-04-27 08:00:58 -07:00
def from_jsonable(self, sample_n):
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
2016-04-27 08:00:58 -07:00
def __eq__(self, other):
2021-07-29 02:26:34 +02:00
return (
isinstance(other, Box)
and (self.shape == other.shape)
and np.allclose(self.low, other.low)
and np.allclose(self.high, other.high)
)
def get_inf(dtype, sign):
"""Returns an infinite that doesn't break things.
`dtype` must be an `np.dtype`
`bound` must be either `min` or `max`
"""
if np.dtype(dtype).kind == "f":
if sign == "+":
return np.inf
elif sign == "-":
return -np.inf
else:
raise TypeError(f"Unknown sign {sign}, use either '+' or '-'")
elif np.dtype(dtype).kind == "i":
if sign == "+":
return np.iinfo(dtype).max - 2
elif sign == "-":
return np.iinfo(dtype).min + 2
else:
raise TypeError(f"Unknown sign {sign}, use either '+' or '-'")
else:
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
def get_precision(dtype):
if np.issubdtype(dtype, np.floating):
return np.finfo(dtype).precision
else:
return np.inf