mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 01:27:29 +00:00
Allow Box more flexibility in how shape is provided (#1873)
The current implementation of Box doesn't allow passing an array as low or high while also providing a shape. The current implementation of Box doesn't allow passing an array as low or high and a constant as the other bound. Co-authored-by: pzhokhov <peterz@openai.com>
This commit is contained in:
@@ -25,16 +25,29 @@ class Box(Space):
|
||||
assert dtype is not None, 'dtype must be explicitly provided. '
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
if shape is None:
|
||||
assert low.shape == high.shape, 'box dimension mismatch. '
|
||||
self.shape = low.shape
|
||||
self.low = low
|
||||
self.high = high
|
||||
# determine shape if it isn't provided directly
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
|
||||
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
|
||||
elif not np.isscalar(low):
|
||||
shape = low.shape
|
||||
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
|
||||
elif not np.isscalar(high):
|
||||
shape = high.shape
|
||||
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
|
||||
else:
|
||||
assert np.isscalar(low) and np.isscalar(high), 'box requires scalar bounds. '
|
||||
self.shape = tuple(shape)
|
||||
self.low = np.full(self.shape, low, dtype=dtype)
|
||||
self.high = np.full(self.shape, high, dtype=dtype)
|
||||
raise ValueError("shape must be provided or inferred from the shapes of low or high")
|
||||
|
||||
if np.isscalar(low):
|
||||
low = np.full(shape, low)
|
||||
|
||||
if np.isscalar(high):
|
||||
high = np.full(shape, high)
|
||||
|
||||
self.shape = shape
|
||||
self.low = low
|
||||
self.high = high
|
||||
|
||||
def _get_precision(dtype):
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
|
Reference in New Issue
Block a user