mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-16 19:49:13 +00:00
Add missing __eq__
and __repr__
methods (#1178)
* Add missing equality + repr methods * Update gym.spaces tests
This commit is contained in:
@@ -35,15 +35,18 @@ class Box(gym.Space):
|
||||
|
||||
def sample(self):
|
||||
return gym.spaces.np_random.uniform(low=self.low, high=self.high + (0 if self.dtype.kind == 'f' else 1), size=self.low.shape).astype(self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all()
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
|
||||
def __repr__(self):
|
||||
return "Box" + str(self.shape)
|
||||
|
||||
def __eq__(self, other):
|
||||
return np.allclose(self.low, other.low) and np.allclose(self.high, other.high)
|
||||
|
@@ -71,3 +71,5 @@ class Dict(gym.Space):
|
||||
ret.append(entry)
|
||||
return ret
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.spaces == other.spaces
|
||||
|
@@ -11,8 +11,10 @@ class Discrete(gym.Space):
|
||||
def __init__(self, n):
|
||||
self.n = n
|
||||
gym.Space.__init__(self, (), np.int64)
|
||||
|
||||
def sample(self):
|
||||
return gym.spaces.np_random.randint(self.n)
|
||||
|
||||
def contains(self, x):
|
||||
if isinstance(x, int):
|
||||
as_int = x
|
||||
@@ -24,5 +26,6 @@ class Discrete(gym.Space):
|
||||
|
||||
def __repr__(self):
|
||||
return "Discrete(%d)" % self.n
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.n == other.n
|
||||
|
@@ -5,12 +5,21 @@ class MultiBinary(gym.Space):
|
||||
def __init__(self, n):
|
||||
self.n = n
|
||||
gym.Space.__init__(self, (self.n,), np.int8)
|
||||
|
||||
def sample(self):
|
||||
return gym.spaces.np_random.randint(low=0, high=2, size=self.n).astype(self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
return ((x==0) | (x==1)).all()
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
|
||||
def __repr__(self):
|
||||
return "MultiBinary({})".format(self.n)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.n == other.n
|
||||
|
@@ -8,12 +8,21 @@ class MultiDiscrete(gym.Space):
|
||||
"""
|
||||
self.nvec = np.asarray(nvec, dtype=np.int32)
|
||||
gym.Space.__init__(self, (self.nvec.shape,), np.int8)
|
||||
|
||||
def sample(self):
|
||||
return (gym.spaces.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
return (0 <= x).all() and (x < self.nvec).all() and x.dtype.kind in 'ui'
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
return [sample.tolist() for sample in sample_n]
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
return np.array(sample_n)
|
||||
|
||||
def __repr__(self):
|
||||
return "MultiDiscrete({})".format(self.nvec)
|
||||
|
||||
def __eq__(self, other):
|
||||
return np.all(self.nvec == other.nvec)
|
||||
|
@@ -1,6 +1,9 @@
|
||||
import json # note: ujson fails this test due to float equality
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
|
||||
|
||||
@@ -30,3 +33,34 @@ def test_roundtripping(space):
|
||||
s2p = space.to_jsonable([sample_2_prime])
|
||||
assert s1 == s1p, "Expected {} to equal {}".format(s1, s1p)
|
||||
assert s2 == s2p, "Expected {} to equal {}".format(s2, s2p)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [
|
||||
Discrete(3),
|
||||
Box(low=np.array([-10, 0]),high=np.array([10, 10])),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Box(low=np.array([0, 0]),high=np.array([1, 5]))]),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(6),
|
||||
Dict({"position": Discrete(5), "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]))}),
|
||||
])
|
||||
def test_equality(space):
|
||||
space1 = space
|
||||
space2 = copy(space)
|
||||
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spaces", [
|
||||
(Discrete(3), Discrete(4)),
|
||||
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
||||
(MultiBinary(8), MultiBinary(7)),
|
||||
(Box(low=np.array([-10, 0]),high=np.array([10, 10])),
|
||||
Box(low=np.array([-10, 0]),high=np.array([10, 9]))),
|
||||
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
||||
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
||||
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
||||
])
|
||||
def test_inequality(spaces):
|
||||
space1, space2 = spaces
|
||||
assert space1 != space2, "Expected {} != {}".format(space1, space2)
|
||||
|
@@ -30,3 +30,6 @@ class Tuple(gym.Space):
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.spaces == other.spaces
|
||||
|
Reference in New Issue
Block a user