Box Boundedness determined using get_inf instead of np.inf (#2630)

* Box Boundedness determined using get_inf instead of np.inf

* Store original array entries for determining boundedness

* Capture the boundedness before casting away the np.inf

* Removed requirement that integer spaces be bounded above and below

* np full casts away the inf, so using dtype float for boundedness evaluation

* Removed commented code

* But the type ignore hint back in

* Spacing change from black code formatter
This commit is contained in:
Edward Rusu
2022-03-02 07:51:06 -08:00
committed by GitHub
parent 15b5c6c29f
commit d1f35fe587
2 changed files with 23 additions and 27 deletions

View File

@@ -61,8 +61,14 @@ class Box(Space[np.ndarray]):
) )
assert isinstance(shape, tuple) assert isinstance(shape, tuple)
# Capture the boundedness information before replacing np.inf with get_inf
_low = np.full(shape, low, dtype=float) if np.isscalar(low) else low
self.bounded_below = -np.inf < _low
_high = np.full(shape, high, dtype=float) if np.isscalar(high) else high
self.bounded_above = np.inf > _high
low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
high = _broadcast(high, dtype, shape, inf_sign="+") high = _broadcast(high, dtype, shape, inf_sign="+") # type: ignore
assert isinstance(low, np.ndarray) assert isinstance(low, np.ndarray)
assert low.shape == shape, "low.shape doesn't match provided shape" assert low.shape == shape, "low.shape doesn't match provided shape"
@@ -82,10 +88,6 @@ class Box(Space[np.ndarray]):
self.low_repr = _short_repr(self.low) self.low_repr = _short_repr(self.low)
self.high_repr = _short_repr(self.high) self.high_repr = _short_repr(self.high)
# Boolean arrays which indicate the interval type for each coordinate
self.bounded_below = -np.inf < self.low
self.bounded_above = np.inf > self.high
super().__init__(self.shape, self.dtype, seed) super().__init__(self.shape, self.dtype, seed)
@property @property

View File

@@ -510,29 +510,23 @@ def test_infinite_space(space):
# check that int bounds are bounded for everything # check that int bounds are bounded for everything
# but floats are unbounded for infinite # but floats are unbounded for infinite
if space.dtype.kind == "f":
if np.any(space.high != 0): if np.any(space.high != 0):
assert ( assert (
space.is_bounded("above") == False space.is_bounded("above") == False
), "float dtype inf upper bound supposed to be unbounded" ), "inf upper bound supposed to be unbounded"
else: else:
assert ( assert (
space.is_bounded("above") == True space.is_bounded("above") == True
), "float dtype non-inf upper bound supposed to be bounded" ), "non-inf upper bound supposed to be bounded"
if np.any(space.low != 0): if np.any(space.low != 0):
assert ( assert (
space.is_bounded("below") == False space.is_bounded("below") == False
), "float dtype inf lower bound supposed to be unbounded" ), "inf lower bound supposed to be unbounded"
else: else:
assert ( assert (
space.is_bounded("below") == True space.is_bounded("below") == True
), "float dtype non-inf lower bound supposed to be bounded" ), "non-inf lower bound supposed to be bounded"
elif space.dtype.kind == "i":
assert (
space.is_bounded("both") == True
), "int dtypes should be bounded on both ends"
# check for dtype # check for dtype
assert ( assert (