diff --git a/gym/spaces/box.py b/gym/spaces/box.py index 7588a0b1f..e593db829 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -1,51 +1,48 @@ import numpy as np -import gym -from gym import logger from .space import Space class Box(Space): - """ - A box in R^n. - I.e., each coordinate is bounded. + """A box in R^n, i.e.each coordinate is bounded. + + There are two common use cases: + + * Identical bound for each dimension:: + >>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32) + Box(3, 4) + + * Independent bound for each dimension:: + >>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32) + Box(2,) - Example usage: - self.action_space = spaces.Box(low=-10, high=10, shape=(1,)) """ - def __init__(self, low=None, high=None, shape=None, dtype=None): - """ - Two kinds of valid input: - Box(low=-1.0, high=1.0, shape=(3,4)) # low and high are scalars, and shape is provided - Box(low=np.array([-1.0,-2.0]), high=np.array([2.0,4.0])) # low and high are arrays of the same shape - """ + def __init__(self, low, high, shape=None, dtype=np.float32): + assert dtype is not None, 'dtype must be explicitly provided. ' + self.dtype = np.dtype(dtype) + if shape is None: assert low.shape == high.shape - shape = low.shape + self.shape = low.shape + self.low = low + self.high = high else: assert np.isscalar(low) and np.isscalar(high) + self.shape = tuple(shape) + self.low = np.full(self.shape, low) + self.high = np.full(self.shape, high) low = low + np.zeros(shape) high = high + np.zeros(shape) - if dtype is None: # Autodetect type - if (high == 255).all(): - dtype = np.uint8 - else: - dtype = np.float32 - logger.warn("gym.spaces.Box autodetected dtype as {}. Please provide explicit dtype.".format(dtype)) - self.low = low.astype(dtype) - self.high = high.astype(dtype) - super(Box, self).__init__(shape, dtype) - self.np_random = np.random.RandomState() - - def seed(self, seed): - self.np_random.seed(seed) + self.low = self.low.astype(self.dtype) + self.high = self.high.astype(self.dtype) + super(Box, self).__init__(self.shape, self.dtype) def sample(self): high = self.high if self.dtype.kind == 'f' else self.high.astype('int64') + 1 - return self.np_random.uniform(low=self.low, high=high, size=self.low.shape).astype(self.dtype) + return self.np_random.uniform(low=self.low, high=high, size=self.shape).astype(self.dtype) def contains(self, x): - return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all() + return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high) def to_jsonable(self, sample_n): return np.array(sample_n).tolist()