mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 17:45:07 +00:00
spaces equality fixes and tests (#1375)
* spaces equality fixes and tests * squash-merged master * added better equality tests and more checks against bad space creation
This commit is contained in:
@@ -57,4 +57,4 @@ class Box(Space):
|
||||
return "Box" + str(self.shape)
|
||||
|
||||
def __eq__(self, other):
|
||||
return np.allclose(self.low, other.low) and np.allclose(self.high, other.high)
|
||||
return isinstance(other, Box) and np.allclose(self.low, other.low) and np.allclose(self.high, other.high)
|
||||
|
@@ -36,6 +36,8 @@ class Dict(Space):
|
||||
assert (spaces is None) or (not spaces_kwargs), 'Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)'
|
||||
if spaces is None:
|
||||
spaces = spaces_kwargs
|
||||
for space in spaces.values():
|
||||
assert isinstance(space, gym.Space), 'Values of the dict should be instances of gym.Space'
|
||||
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
||||
spaces = OrderedDict(sorted(list(spaces.items())))
|
||||
if isinstance(spaces, list):
|
||||
@@ -80,4 +82,4 @@ class Dict(Space):
|
||||
return ret
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.spaces == other.spaces
|
||||
return isinstance(other, Dict) and self.spaces == other.spaces
|
||||
|
@@ -34,4 +34,4 @@ class Discrete(Space):
|
||||
return "Discrete(%d)" % self.n
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.n == other.n
|
||||
return isinstance(other, Discrete) and self.n == other.n
|
||||
|
@@ -28,4 +28,4 @@ class MultiBinary(Space):
|
||||
return "MultiBinary({})".format(self.n)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.n == other.n
|
||||
return isinstance(other, MultiBinary) and self.n == other.n
|
||||
|
@@ -54,4 +54,4 @@ class MultiDiscrete(Space):
|
||||
return "MultiDiscrete({})".format(self.nvec)
|
||||
|
||||
def __eq__(self, other):
|
||||
return np.all(self.nvec == other.nvec)
|
||||
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||
|
@@ -84,3 +84,26 @@ def test_sample(space):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
||||
|
||||
@pytest.mark.parametrize("spaces", [
|
||||
(Discrete(5), MultiBinary(5)),
|
||||
(Box(low=np.array([-10, 0]), high=np.array([10,10]), dtype=np.float32), MultiDiscrete([2, 2, 8])),
|
||||
(Dict({"position": Discrete(5)}), Tuple([Discrete(5)])),
|
||||
(Dict({"position": Discrete(5)}), Discrete(5)),
|
||||
(Tuple((Discrete(5),)), Discrete(5)),
|
||||
])
|
||||
def test_class_inequality(spaces):
|
||||
assert spaces[0] == spaces[0]
|
||||
assert spaces[1] == spaces[1]
|
||||
assert spaces[0] != spaces[1]
|
||||
assert spaces[1] != spaces[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space_fn", [
|
||||
lambda: Dict(space1='abc'),
|
||||
lambda: Dict({'space1': 'abc'}),
|
||||
lambda: Tuple(['abc'])
|
||||
])
|
||||
def test_bad_space_calls(space_fn):
|
||||
with pytest.raises(AssertionError):
|
||||
space_fn()
|
||||
|
@@ -11,6 +11,8 @@ class Tuple(Space):
|
||||
"""
|
||||
def __init__(self, spaces):
|
||||
self.spaces = spaces
|
||||
for space in spaces:
|
||||
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
|
||||
super(Tuple, self).__init__(None, None)
|
||||
|
||||
def seed(self, seed):
|
||||
@@ -43,4 +45,4 @@ class Tuple(Space):
|
||||
return len(self.spaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.spaces == other.spaces
|
||||
return isinstance(other, Tuple) and self.spaces == other.spaces
|
||||
|
Reference in New Issue
Block a user