2016-04-27 08:00:58 -07:00
|
|
|
import numpy as np
|
2018-11-29 02:27:27 +01:00
|
|
|
|
2019-01-30 22:39:55 +01:00
|
|
|
from .space import Space
|
2019-12-06 14:13:46 +01:00
|
|
|
from gym import logger
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2019-01-30 22:39:55 +01:00
|
|
|
|
|
|
|
class Box(Space):
|
2019-06-28 18:54:31 -04:00
|
|
|
"""
|
|
|
|
A (possibly unbounded) box in R^n. Specifically, a Box represents the
|
|
|
|
Cartesian product of n closed intervals. Each interval has the form of one
|
|
|
|
of [a, b], (-oo, b], [a, oo), or (-oo, oo).
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
There are two common use cases:
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
* Identical bound for each dimension::
|
|
|
|
>>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32)
|
|
|
|
Box(3, 4)
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
* Independent bound for each dimension::
|
|
|
|
>>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
|
|
|
|
Box(2,)
|
2016-06-11 23:10:58 -07:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
def __init__(self, low, high, shape=None, dtype=np.float32):
|
2021-07-29 02:26:34 +02:00
|
|
|
assert dtype is not None, "dtype must be explicitly provided. "
|
2019-03-25 00:39:32 +01:00
|
|
|
self.dtype = np.dtype(dtype)
|
|
|
|
|
2020-05-08 17:56:14 -04:00
|
|
|
# determine shape if it isn't provided directly
|
|
|
|
if shape is not None:
|
|
|
|
shape = tuple(shape)
|
2021-07-29 12:42:48 -04:00
|
|
|
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
|
|
|
|
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
|
2020-05-08 17:56:14 -04:00
|
|
|
elif not np.isscalar(low):
|
|
|
|
shape = low.shape
|
2021-07-29 12:42:48 -04:00
|
|
|
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
|
2020-05-08 17:56:14 -04:00
|
|
|
elif not np.isscalar(high):
|
|
|
|
shape = high.shape
|
2021-07-29 12:42:48 -04:00
|
|
|
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
|
2016-04-27 08:00:58 -07:00
|
|
|
else:
|
2021-07-29 12:42:48 -04:00
|
|
|
raise ValueError("shape must be provided or inferred from the shapes of low or high")
|
2020-05-08 17:56:14 -04:00
|
|
|
|
|
|
|
if np.isscalar(low):
|
2020-05-08 16:25:27 -07:00
|
|
|
low = np.full(shape, low, dtype=dtype)
|
2020-05-08 17:56:14 -04:00
|
|
|
|
|
|
|
if np.isscalar(high):
|
2020-05-08 16:25:27 -07:00
|
|
|
high = np.full(shape, high, dtype=dtype)
|
2020-05-08 17:56:14 -04:00
|
|
|
|
|
|
|
self.shape = shape
|
|
|
|
self.low = low
|
|
|
|
self.high = high
|
2019-06-28 18:54:31 -04:00
|
|
|
|
2019-12-06 14:13:46 +01:00
|
|
|
def _get_precision(dtype):
|
|
|
|
if np.issubdtype(dtype, np.floating):
|
|
|
|
return np.finfo(dtype).precision
|
|
|
|
else:
|
|
|
|
return np.inf
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-12-06 14:13:46 +01:00
|
|
|
low_precision = _get_precision(self.low.dtype)
|
|
|
|
high_precision = _get_precision(self.high.dtype)
|
|
|
|
dtype_precision = _get_precision(self.dtype)
|
|
|
|
if min(low_precision, high_precision) > dtype_precision:
|
2021-07-29 12:42:48 -04:00
|
|
|
logger.warn("Box bound precision lowered by casting to {}".format(self.dtype))
|
2019-03-25 00:39:32 +01:00
|
|
|
self.low = self.low.astype(self.dtype)
|
|
|
|
self.high = self.high.astype(self.dtype)
|
2019-06-28 18:54:31 -04:00
|
|
|
|
|
|
|
# Boolean arrays which indicate the interval type for each coordinate
|
|
|
|
self.bounded_below = -np.inf < self.low
|
|
|
|
self.bounded_above = np.inf > self.high
|
|
|
|
|
2019-03-25 00:39:32 +01:00
|
|
|
super(Box, self).__init__(self.shape, self.dtype)
|
Cleanup, removal of unmaintained code (#836)
* add dtype to Box
* remove board_game, debugging, safety, parameter_tuning environments
* massive set of breaking changes
- remove python logging module
- _step, _reset, _seed, _close => non underscored method
- remove benchmark and scoring folder
* Improve render("human"), now resizable, closable window.
* get rid of default step and reset in wrappers, so it doesn’t silently fail for people with underscore methods
* CubeCrash unit test environment
* followup fixes
* MemorizeDigits unit test envrionment
* refactored spaces a bit
fixed indentation
disabled test_env_semantics
* fix unit tests
* fixes
* CubeCrash, MemorizeDigits tested
* gym backwards compatibility patch
* gym backwards compatibility, followup fixes
* changelist, add spaces to main namespaces
* undo_logger_setup for backwards compat
* remove configuration.py
2018-01-25 18:20:14 -08:00
|
|
|
|
2019-06-28 18:54:31 -04:00
|
|
|
def is_bounded(self, manner="both"):
|
|
|
|
below = np.all(self.bounded_below)
|
|
|
|
above = np.all(self.bounded_above)
|
|
|
|
if manner == "both":
|
|
|
|
return below and above
|
|
|
|
elif manner == "below":
|
|
|
|
return below
|
|
|
|
elif manner == "above":
|
|
|
|
return above
|
|
|
|
else:
|
|
|
|
raise ValueError("manner is not in {'below', 'above', 'both'}")
|
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def sample(self):
|
2019-06-28 18:54:31 -04:00
|
|
|
"""
|
2020-04-25 00:24:35 +02:00
|
|
|
Generates a single random sample inside of the Box.
|
2019-06-28 18:54:31 -04:00
|
|
|
|
|
|
|
In creating a sample of the box, each coordinate is sampled according to
|
|
|
|
the form of the interval:
|
2020-04-25 00:24:35 +02:00
|
|
|
|
|
|
|
* [a, b] : uniform distribution
|
2019-06-28 18:54:31 -04:00
|
|
|
* [a, oo) : shifted exponential distribution
|
|
|
|
* (-oo, b] : shifted negative exponential distribution
|
|
|
|
* (-oo, oo) : normal distribution
|
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
2019-06-28 18:54:31 -04:00
|
|
|
sample = np.empty(self.shape)
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2019-06-28 18:54:31 -04:00
|
|
|
# Masking arrays which classify the coordinates according to interval
|
|
|
|
# type
|
2021-07-29 02:26:34 +02:00
|
|
|
unbounded = ~self.bounded_below & ~self.bounded_above
|
|
|
|
upp_bounded = ~self.bounded_below & self.bounded_above
|
|
|
|
low_bounded = self.bounded_below & ~self.bounded_above
|
|
|
|
bounded = self.bounded_below & self.bounded_above
|
2019-06-28 18:54:31 -04:00
|
|
|
|
|
|
|
# Vectorized sampling by interval type
|
2021-07-29 02:26:34 +02:00
|
|
|
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
|
|
|
|
|
2021-07-29 12:42:48 -04:00
|
|
|
sample[low_bounded] = self.np_random.exponential(size=low_bounded[low_bounded].shape) + self.low[low_bounded]
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-07-29 12:42:48 -04:00
|
|
|
sample[upp_bounded] = -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) + self.high[upp_bounded]
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-07-29 12:42:48 -04:00
|
|
|
sample[bounded] = self.np_random.uniform(low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape)
|
2021-07-29 02:26:34 +02:00
|
|
|
if self.dtype.kind == "i":
|
2019-11-02 04:52:11 +05:30
|
|
|
sample = np.floor(sample)
|
2019-06-28 18:54:31 -04:00
|
|
|
|
|
|
|
return sample.astype(self.dtype)
|
2020-04-25 00:24:35 +02:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def contains(self, x):
|
2019-04-08 19:17:33 -07:00
|
|
|
if isinstance(x, list):
|
|
|
|
x = np.array(x) # Promote list to array for contains check
|
2021-07-29 12:42:48 -04:00
|
|
|
return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
|
2016-04-27 08:00:58 -07:00
|
|
|
|
|
|
|
def to_jsonable(self, sample_n):
|
|
|
|
return np.array(sample_n).tolist()
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def from_jsonable(self, sample_n):
|
|
|
|
return [np.asarray(sample) for sample in sample_n]
|
|
|
|
|
|
|
|
def __repr__(self):
|
2021-07-29 12:42:48 -04:00
|
|
|
return "Box({}, {}, {}, {})".format(self.low.min(), self.high.max(), self.shape, self.dtype)
|
2018-11-29 02:27:27 +01:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def __eq__(self, other):
|
2021-07-29 02:26:34 +02:00
|
|
|
return (
|
|
|
|
isinstance(other, Box)
|
|
|
|
and (self.shape == other.shape)
|
|
|
|
and np.allclose(self.low, other.low)
|
|
|
|
and np.allclose(self.high, other.high)
|
|
|
|
)
|