mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 17:45:07 +00:00
spaces equality fixes and tests (#1375)
* spaces equality fixes and tests * squash-merged master * added better equality tests and more checks against bad space creation
This commit is contained in:
@@ -57,4 +57,4 @@ class Box(Space):
|
|||||||
return "Box" + str(self.shape)
|
return "Box" + str(self.shape)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return np.allclose(self.low, other.low) and np.allclose(self.high, other.high)
|
return isinstance(other, Box) and np.allclose(self.low, other.low) and np.allclose(self.high, other.high)
|
||||||
|
@@ -36,6 +36,8 @@ class Dict(Space):
|
|||||||
assert (spaces is None) or (not spaces_kwargs), 'Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)'
|
assert (spaces is None) or (not spaces_kwargs), 'Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)'
|
||||||
if spaces is None:
|
if spaces is None:
|
||||||
spaces = spaces_kwargs
|
spaces = spaces_kwargs
|
||||||
|
for space in spaces.values():
|
||||||
|
assert isinstance(space, gym.Space), 'Values of the dict should be instances of gym.Space'
|
||||||
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
||||||
spaces = OrderedDict(sorted(list(spaces.items())))
|
spaces = OrderedDict(sorted(list(spaces.items())))
|
||||||
if isinstance(spaces, list):
|
if isinstance(spaces, list):
|
||||||
@@ -80,4 +82,4 @@ class Dict(Space):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.spaces == other.spaces
|
return isinstance(other, Dict) and self.spaces == other.spaces
|
||||||
|
@@ -34,4 +34,4 @@ class Discrete(Space):
|
|||||||
return "Discrete(%d)" % self.n
|
return "Discrete(%d)" % self.n
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.n == other.n
|
return isinstance(other, Discrete) and self.n == other.n
|
||||||
|
@@ -28,4 +28,4 @@ class MultiBinary(Space):
|
|||||||
return "MultiBinary({})".format(self.n)
|
return "MultiBinary({})".format(self.n)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.n == other.n
|
return isinstance(other, MultiBinary) and self.n == other.n
|
||||||
|
@@ -54,4 +54,4 @@ class MultiDiscrete(Space):
|
|||||||
return "MultiDiscrete({})".format(self.nvec)
|
return "MultiDiscrete({})".format(self.nvec)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return np.all(self.nvec == other.nvec)
|
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||||
|
@@ -84,3 +84,26 @@ def test_sample(space):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("spaces", [
|
||||||
|
(Discrete(5), MultiBinary(5)),
|
||||||
|
(Box(low=np.array([-10, 0]), high=np.array([10,10]), dtype=np.float32), MultiDiscrete([2, 2, 8])),
|
||||||
|
(Dict({"position": Discrete(5)}), Tuple([Discrete(5)])),
|
||||||
|
(Dict({"position": Discrete(5)}), Discrete(5)),
|
||||||
|
(Tuple((Discrete(5),)), Discrete(5)),
|
||||||
|
])
|
||||||
|
def test_class_inequality(spaces):
|
||||||
|
assert spaces[0] == spaces[0]
|
||||||
|
assert spaces[1] == spaces[1]
|
||||||
|
assert spaces[0] != spaces[1]
|
||||||
|
assert spaces[1] != spaces[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space_fn", [
|
||||||
|
lambda: Dict(space1='abc'),
|
||||||
|
lambda: Dict({'space1': 'abc'}),
|
||||||
|
lambda: Tuple(['abc'])
|
||||||
|
])
|
||||||
|
def test_bad_space_calls(space_fn):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
space_fn()
|
||||||
|
@@ -11,6 +11,8 @@ class Tuple(Space):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, spaces):
|
def __init__(self, spaces):
|
||||||
self.spaces = spaces
|
self.spaces = spaces
|
||||||
|
for space in spaces:
|
||||||
|
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
|
||||||
super(Tuple, self).__init__(None, None)
|
super(Tuple, self).__init__(None, None)
|
||||||
|
|
||||||
def seed(self, seed):
|
def seed(self, seed):
|
||||||
@@ -43,4 +45,4 @@ class Tuple(Space):
|
|||||||
return len(self.spaces)
|
return len(self.spaces)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.spaces == other.spaces
|
return isinstance(other, Tuple) and self.spaces == other.spaces
|
||||||
|
Reference in New Issue
Block a user