mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 21:06:59 +00:00
Readded overwritten changes for offset functionality for Discrete spaces (#2470)
Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(5, start=-2),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
@@ -20,6 +21,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict(
|
||||
@@ -56,6 +58,7 @@ def test_roundtripping(space):
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(5, start=-2),
|
||||
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(1, 3)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
@@ -66,6 +69,7 @@ def test_roundtripping(space):
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(6),
|
||||
Dict(
|
||||
@@ -88,6 +92,7 @@ def test_equality(space):
|
||||
"spaces",
|
||||
[
|
||||
(Discrete(3), Discrete(4)),
|
||||
(Discrete(3), Discrete(3, start=-1)),
|
||||
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
||||
(MultiBinary(8), MultiBinary(7)),
|
||||
(
|
||||
@@ -99,6 +104,10 @@ def test_equality(space):
|
||||
Box(low=0.0, high=np.inf, shape=(2, 1)),
|
||||
),
|
||||
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
||||
(
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5, start=7), Discrete(10)]),
|
||||
),
|
||||
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
||||
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
||||
],
|
||||
@@ -112,6 +121,7 @@ def test_inequality(spaces):
|
||||
"space",
|
||||
[
|
||||
Discrete(5),
|
||||
Discrete(8, start=-20),
|
||||
Box(low=0, high=255, shape=(2,), dtype="uint8"),
|
||||
Box(low=-np.inf, high=np.inf, shape=(3, 3)),
|
||||
Box(low=1.0, high=np.inf, shape=(3, 3)),
|
||||
@@ -133,7 +143,7 @@ def test_sample(space):
|
||||
else:
|
||||
expected_mean = 0.0
|
||||
elif isinstance(space, Discrete):
|
||||
expected_mean = space.n / 2
|
||||
expected_mean = space.start + space.n / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
||||
@@ -246,6 +256,7 @@ def test_box_dtype_check():
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(3, start=-4),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
@@ -298,6 +309,7 @@ def sample_equal(sample1, sample2):
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(3, start=-4),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
@@ -335,6 +347,7 @@ def test_seed_reproducibility(space):
|
||||
[
|
||||
Tuple([Discrete(100), Discrete(100)]),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Discrete(5, start=10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
|
Reference in New Issue
Block a user