diff --git a/gym/spaces/box.py b/gym/spaces/box.py index a2981861c..a4af9c2f8 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -25,16 +25,29 @@ class Box(Space): assert dtype is not None, 'dtype must be explicitly provided. ' self.dtype = np.dtype(dtype) - if shape is None: - assert low.shape == high.shape, 'box dimension mismatch. ' - self.shape = low.shape - self.low = low - self.high = high + # determine shape if it isn't provided directly + if shape is not None: + shape = tuple(shape) + assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape" + assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape" + elif not np.isscalar(low): + shape = low.shape + assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape" + elif not np.isscalar(high): + shape = high.shape + assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape" else: - assert np.isscalar(low) and np.isscalar(high), 'box requires scalar bounds. ' - self.shape = tuple(shape) - self.low = np.full(self.shape, low, dtype=dtype) - self.high = np.full(self.shape, high, dtype=dtype) + raise ValueError("shape must be provided or inferred from the shapes of low or high") + + if np.isscalar(low): + low = np.full(shape, low) + + if np.isscalar(high): + high = np.full(shape, high) + + self.shape = shape + self.low = low + self.high = high def _get_precision(dtype): if np.issubdtype(dtype, np.floating):