import numpy as np from gym import Space, spaces class Discrete(Space): """ {0,1,...,n-1} Example usage: self.observation_space = spaces.Discrete(2) """ def __init__(self, n): self.n = n Space.__init__(self, (), np.int64) def sample(self): return spaces.np_random.randint(self.n) def contains(self, x): if isinstance(x, int): as_int = x elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.kind in np.typecodes['AllInteger'] and x.shape == ()): as_int = int(x) else: return False return as_int >= 0 and as_int < self.n def __repr__(self): return "Discrete(%d)" % self.n def __eq__(self, other): return self.n == other.n