spaces equality fixes and tests (#1375)

* spaces equality fixes and tests

* squash-merged master

* added better equality tests and more checks against bad space creation
This commit is contained in:
pzhokhov
2019-03-23 23:18:19 -07:00
committed by GitHub
parent b219d36441
commit 07645bd11e
7 changed files with 33 additions and 6 deletions

View File

@@ -57,4 +57,4 @@ class Box(Space):
return "Box" + str(self.shape) return "Box" + str(self.shape)
def __eq__(self, other): def __eq__(self, other):
return np.allclose(self.low, other.low) and np.allclose(self.high, other.high) return isinstance(other, Box) and np.allclose(self.low, other.low) and np.allclose(self.high, other.high)

View File

@@ -36,6 +36,8 @@ class Dict(Space):
assert (spaces is None) or (not spaces_kwargs), 'Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)' assert (spaces is None) or (not spaces_kwargs), 'Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)'
if spaces is None: if spaces is None:
spaces = spaces_kwargs spaces = spaces_kwargs
for space in spaces.values():
assert isinstance(space, gym.Space), 'Values of the dict should be instances of gym.Space'
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict): if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
spaces = OrderedDict(sorted(list(spaces.items()))) spaces = OrderedDict(sorted(list(spaces.items())))
if isinstance(spaces, list): if isinstance(spaces, list):
@@ -80,4 +82,4 @@ class Dict(Space):
return ret return ret
def __eq__(self, other): def __eq__(self, other):
return self.spaces == other.spaces return isinstance(other, Dict) and self.spaces == other.spaces

View File

@@ -34,4 +34,4 @@ class Discrete(Space):
return "Discrete(%d)" % self.n return "Discrete(%d)" % self.n
def __eq__(self, other): def __eq__(self, other):
return self.n == other.n return isinstance(other, Discrete) and self.n == other.n

View File

@@ -28,4 +28,4 @@ class MultiBinary(Space):
return "MultiBinary({})".format(self.n) return "MultiBinary({})".format(self.n)
def __eq__(self, other): def __eq__(self, other):
return self.n == other.n return isinstance(other, MultiBinary) and self.n == other.n

View File

@@ -54,4 +54,4 @@ class MultiDiscrete(Space):
return "MultiDiscrete({})".format(self.nvec) return "MultiDiscrete({})".format(self.nvec)
def __eq__(self, other): def __eq__(self, other):
return np.all(self.nvec == other.nvec) return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)

View File

@@ -84,3 +84,26 @@ def test_sample(space):
else: else:
raise NotImplementedError raise NotImplementedError
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std()) np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
@pytest.mark.parametrize("spaces", [
(Discrete(5), MultiBinary(5)),
(Box(low=np.array([-10, 0]), high=np.array([10,10]), dtype=np.float32), MultiDiscrete([2, 2, 8])),
(Dict({"position": Discrete(5)}), Tuple([Discrete(5)])),
(Dict({"position": Discrete(5)}), Discrete(5)),
(Tuple((Discrete(5),)), Discrete(5)),
])
def test_class_inequality(spaces):
assert spaces[0] == spaces[0]
assert spaces[1] == spaces[1]
assert spaces[0] != spaces[1]
assert spaces[1] != spaces[0]
@pytest.mark.parametrize("space_fn", [
lambda: Dict(space1='abc'),
lambda: Dict({'space1': 'abc'}),
lambda: Tuple(['abc'])
])
def test_bad_space_calls(space_fn):
with pytest.raises(AssertionError):
space_fn()

View File

@@ -11,6 +11,8 @@ class Tuple(Space):
""" """
def __init__(self, spaces): def __init__(self, spaces):
self.spaces = spaces self.spaces = spaces
for space in spaces:
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
super(Tuple, self).__init__(None, None) super(Tuple, self).__init__(None, None)
def seed(self, seed): def seed(self, seed):
@@ -43,4 +45,4 @@ class Tuple(Space):
return len(self.spaces) return len(self.spaces)
def __eq__(self, other): def __eq__(self, other):
return self.spaces == other.spaces return isinstance(other, Tuple) and self.spaces == other.spaces