Allow Box more flexibility in how shape is provided (#1873)

The current implementation of Box doesn't allow passing an array as low or high while also providing a shape.
The current implementation of Box doesn't allow passing an array as low or high and a constant as the other bound.

Co-authored-by: pzhokhov <peterz@openai.com>
This commit is contained in:
Zach Dwiel
2020-05-08 17:56:14 -04:00
committed by GitHub
parent dfbfab6237
commit 174a27b7fc

View File

@@ -25,16 +25,29 @@ class Box(Space):
assert dtype is not None, 'dtype must be explicitly provided. '
self.dtype = np.dtype(dtype)
if shape is None:
assert low.shape == high.shape, 'box dimension mismatch. '
self.shape = low.shape
self.low = low
self.high = high
# determine shape if it isn't provided directly
if shape is not None:
shape = tuple(shape)
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
elif not np.isscalar(low):
shape = low.shape
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
elif not np.isscalar(high):
shape = high.shape
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
else:
assert np.isscalar(low) and np.isscalar(high), 'box requires scalar bounds. '
self.shape = tuple(shape)
self.low = np.full(self.shape, low, dtype=dtype)
self.high = np.full(self.shape, high, dtype=dtype)
raise ValueError("shape must be provided or inferred from the shapes of low or high")
if np.isscalar(low):
low = np.full(shape, low)
if np.isscalar(high):
high = np.full(shape, high)
self.shape = shape
self.low = low
self.high = high
def _get_precision(dtype):
if np.issubdtype(dtype, np.floating):