Add missing __eq__ and __repr__ methods (#1178)

* Add missing equality + repr methods

* Update gym.spaces tests
This commit is contained in:
Antonin RAFFIN
2018-09-24 20:11:03 +02:00
committed by pzhokhov
parent 42f9e14c00
commit 2234f94e7b
7 changed files with 66 additions and 3 deletions

View File

@@ -35,15 +35,18 @@ class Box(gym.Space):
def sample(self):
return gym.spaces.np_random.uniform(low=self.low, high=self.high + (0 if self.dtype.kind == 'f' else 1), size=self.low.shape).astype(self.dtype)
def contains(self, x):
return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all()
def to_jsonable(self, sample_n):
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n):
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
return "Box" + str(self.shape)
def __eq__(self, other):
return np.allclose(self.low, other.low) and np.allclose(self.high, other.high)

View File

@@ -71,3 +71,5 @@ class Dict(gym.Space):
ret.append(entry)
return ret
def __eq__(self, other):
return self.spaces == other.spaces

View File

@@ -11,8 +11,10 @@ class Discrete(gym.Space):
def __init__(self, n):
self.n = n
gym.Space.__init__(self, (), np.int64)
def sample(self):
return gym.spaces.np_random.randint(self.n)
def contains(self, x):
if isinstance(x, int):
as_int = x
@@ -24,5 +26,6 @@ class Discrete(gym.Space):
def __repr__(self):
return "Discrete(%d)" % self.n
def __eq__(self, other):
return self.n == other.n

View File

@@ -5,12 +5,21 @@ class MultiBinary(gym.Space):
def __init__(self, n):
self.n = n
gym.Space.__init__(self, (self.n,), np.int8)
def sample(self):
return gym.spaces.np_random.randint(low=0, high=2, size=self.n).astype(self.dtype)
def contains(self, x):
return ((x==0) | (x==1)).all()
def to_jsonable(self, sample_n):
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n):
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
return "MultiBinary({})".format(self.n)
def __eq__(self, other):
return self.n == other.n

View File

@@ -8,12 +8,21 @@ class MultiDiscrete(gym.Space):
"""
self.nvec = np.asarray(nvec, dtype=np.int32)
gym.Space.__init__(self, (self.nvec.shape,), np.int8)
def sample(self):
return (gym.spaces.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x):
return (0 <= x).all() and (x < self.nvec).all() and x.dtype.kind in 'ui'
def to_jsonable(self, sample_n):
return [sample.tolist() for sample in sample_n]
def from_jsonable(self, sample_n):
return np.array(sample_n)
def __repr__(self):
return "MultiDiscrete({})".format(self.nvec)
def __eq__(self, other):
return np.all(self.nvec == other.nvec)

View File

@@ -1,6 +1,9 @@
import json # note: ujson fails this test due to float equality
from copy import copy
import numpy as np
import pytest
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
@@ -30,3 +33,34 @@ def test_roundtripping(space):
s2p = space.to_jsonable([sample_2_prime])
assert s1 == s1p, "Expected {} to equal {}".format(s1, s1p)
assert s2 == s2p, "Expected {} to equal {}".format(s2, s2p)
@pytest.mark.parametrize("space", [
Discrete(3),
Box(low=np.array([-10, 0]),high=np.array([10, 10])),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(low=np.array([0, 0]),high=np.array([1, 5]))]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 100]),
MultiBinary(6),
Dict({"position": Discrete(5), "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]))}),
])
def test_equality(space):
space1 = space
space2 = copy(space)
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
@pytest.mark.parametrize("spaces", [
(Discrete(3), Discrete(4)),
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
(MultiBinary(8), MultiBinary(7)),
(Box(low=np.array([-10, 0]),high=np.array([10, 10])),
Box(low=np.array([-10, 0]),high=np.array([10, 9]))),
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
])
def test_inequality(spaces):
space1, space2 = spaces
assert space1 != space2, "Expected {} != {}".format(space1, space2)

View File

@@ -30,3 +30,6 @@ class Tuple(gym.Space):
def from_jsonable(self, sample_n):
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
def __eq__(self, other):
return self.spaces == other.spaces