diff --git a/gym/spaces/box.py b/gym/spaces/box.py index f3ff2c73f..d0d41f27a 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -35,15 +35,18 @@ class Box(gym.Space): def sample(self): return gym.spaces.np_random.uniform(low=self.low, high=self.high + (0 if self.dtype.kind == 'f' else 1), size=self.low.shape).astype(self.dtype) + def contains(self, x): return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all() def to_jsonable(self, sample_n): return np.array(sample_n).tolist() + def from_jsonable(self, sample_n): return [np.asarray(sample) for sample in sample_n] def __repr__(self): return "Box" + str(self.shape) + def __eq__(self, other): return np.allclose(self.low, other.low) and np.allclose(self.high, other.high) diff --git a/gym/spaces/dict_space.py b/gym/spaces/dict_space.py index 669f2f848..56cb91a4e 100644 --- a/gym/spaces/dict_space.py +++ b/gym/spaces/dict_space.py @@ -71,3 +71,5 @@ class Dict(gym.Space): ret.append(entry) return ret + def __eq__(self, other): + return self.spaces == other.spaces diff --git a/gym/spaces/discrete.py b/gym/spaces/discrete.py index 72c2afa9c..c737640d8 100644 --- a/gym/spaces/discrete.py +++ b/gym/spaces/discrete.py @@ -11,8 +11,10 @@ class Discrete(gym.Space): def __init__(self, n): self.n = n gym.Space.__init__(self, (), np.int64) + def sample(self): return gym.spaces.np_random.randint(self.n) + def contains(self, x): if isinstance(x, int): as_int = x @@ -24,5 +26,6 @@ class Discrete(gym.Space): def __repr__(self): return "Discrete(%d)" % self.n + def __eq__(self, other): return self.n == other.n diff --git a/gym/spaces/multi_binary.py b/gym/spaces/multi_binary.py index dd3f1d364..cfa3364c3 100644 --- a/gym/spaces/multi_binary.py +++ b/gym/spaces/multi_binary.py @@ -5,12 +5,21 @@ class MultiBinary(gym.Space): def __init__(self, n): self.n = n gym.Space.__init__(self, (self.n,), np.int8) + def sample(self): return gym.spaces.np_random.randint(low=0, high=2, size=self.n).astype(self.dtype) + def contains(self, x): return ((x==0) | (x==1)).all() def to_jsonable(self, sample_n): return np.array(sample_n).tolist() + def from_jsonable(self, sample_n): return [np.asarray(sample) for sample in sample_n] + + def __repr__(self): + return "MultiBinary({})".format(self.n) + + def __eq__(self, other): + return self.n == other.n diff --git a/gym/spaces/multi_discrete.py b/gym/spaces/multi_discrete.py index 16d79c5de..ef9213125 100644 --- a/gym/spaces/multi_discrete.py +++ b/gym/spaces/multi_discrete.py @@ -8,12 +8,21 @@ class MultiDiscrete(gym.Space): """ self.nvec = np.asarray(nvec, dtype=np.int32) gym.Space.__init__(self, (self.nvec.shape,), np.int8) + def sample(self): return (gym.spaces.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype) + def contains(self, x): return (0 <= x).all() and (x < self.nvec).all() and x.dtype.kind in 'ui' def to_jsonable(self, sample_n): return [sample.tolist() for sample in sample_n] + def from_jsonable(self, sample_n): return np.array(sample_n) + + def __repr__(self): + return "MultiDiscrete({})".format(self.nvec) + + def __eq__(self, other): + return np.all(self.nvec == other.nvec) diff --git a/gym/spaces/tests/test_spaces.py b/gym/spaces/tests/test_spaces.py index dec8ecdf1..cadc86ff0 100644 --- a/gym/spaces/tests/test_spaces.py +++ b/gym/spaces/tests/test_spaces.py @@ -1,16 +1,19 @@ import json # note: ujson fails this test due to float equality +from copy import copy + import numpy as np import pytest + from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict @pytest.mark.parametrize("space", [ Discrete(3), Tuple([Discrete(5), Discrete(10)]), - Tuple([Discrete(5), Box(low=np.array([0,0]),high=np.array([1,5]))]), + Tuple([Discrete(5), Box(low=np.array([0, 0]),high=np.array([1, 5]))]), Tuple((Discrete(5), Discrete(2), Discrete(2))), - MultiDiscrete([ 2, 2, 100]), - Dict({"position": Discrete(5), "velocity": Box(low=np.array([0,0]),high=np.array([1,5]))}), + MultiDiscrete([2, 2, 100]), + Dict({"position": Discrete(5), "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]))}), ]) def test_roundtripping(space): sample_1 = space.sample() @@ -30,3 +33,34 @@ def test_roundtripping(space): s2p = space.to_jsonable([sample_2_prime]) assert s1 == s1p, "Expected {} to equal {}".format(s1, s1p) assert s2 == s2p, "Expected {} to equal {}".format(s2, s2p) + + +@pytest.mark.parametrize("space", [ + Discrete(3), + Box(low=np.array([-10, 0]),high=np.array([10, 10])), + Tuple([Discrete(5), Discrete(10)]), + Tuple([Discrete(5), Box(low=np.array([0, 0]),high=np.array([1, 5]))]), + Tuple((Discrete(5), Discrete(2), Discrete(2))), + MultiDiscrete([2, 2, 100]), + MultiBinary(6), + Dict({"position": Discrete(5), "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]))}), + ]) +def test_equality(space): + space1 = space + space2 = copy(space) + assert space1 == space2, "Expected {} to equal {}".format(space1, space2) + + +@pytest.mark.parametrize("spaces", [ + (Discrete(3), Discrete(4)), + (MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])), + (MultiBinary(8), MultiBinary(7)), + (Box(low=np.array([-10, 0]),high=np.array([10, 10])), + Box(low=np.array([-10, 0]),high=np.array([10, 9]))), + (Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])), + (Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})), + (Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})), + ]) +def test_inequality(spaces): + space1, space2 = spaces + assert space1 != space2, "Expected {} != {}".format(space1, space2) diff --git a/gym/spaces/tuple_space.py b/gym/spaces/tuple_space.py index 453663781..473aa6529 100644 --- a/gym/spaces/tuple_space.py +++ b/gym/spaces/tuple_space.py @@ -30,3 +30,6 @@ class Tuple(gym.Space): def from_jsonable(self, sample_n): return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])] + + def __eq__(self, other): + return self.spaces == other.spaces