Readded overwritten changes for offset functionality for Discrete spaces (#2470)

Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
Ishan Manchanda
2021-10-30 21:42:01 +05:30
committed by GitHub
parent 531d4d02db
commit 103b7633f5
4 changed files with 36 additions and 7 deletions

View File

@@ -11,6 +11,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
"space",
[
Discrete(3),
Discrete(5, start=-2),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
@@ -20,6 +21,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
MultiDiscrete([2, 2, 100]),
MultiBinary(10),
Dict(
@@ -56,6 +58,7 @@ def test_roundtripping(space):
"space",
[
Discrete(3),
Discrete(5, start=-2),
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(1, 3)),
Tuple([Discrete(5), Discrete(10)]),
@@ -66,6 +69,7 @@ def test_roundtripping(space):
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
MultiDiscrete([2, 2, 100]),
MultiBinary(6),
Dict(
@@ -88,6 +92,7 @@ def test_equality(space):
"spaces",
[
(Discrete(3), Discrete(4)),
(Discrete(3), Discrete(3, start=-1)),
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
(MultiBinary(8), MultiBinary(7)),
(
@@ -99,6 +104,10 @@ def test_equality(space):
Box(low=0.0, high=np.inf, shape=(2, 1)),
),
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
(
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5, start=7), Discrete(10)]),
),
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
],
@@ -112,6 +121,7 @@ def test_inequality(spaces):
"space",
[
Discrete(5),
Discrete(8, start=-20),
Box(low=0, high=255, shape=(2,), dtype="uint8"),
Box(low=-np.inf, high=np.inf, shape=(3, 3)),
Box(low=1.0, high=np.inf, shape=(3, 3)),
@@ -133,7 +143,7 @@ def test_sample(space):
else:
expected_mean = 0.0
elif isinstance(space, Discrete):
expected_mean = space.n / 2
expected_mean = space.start + space.n / 2
else:
raise NotImplementedError
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
@@ -246,6 +256,7 @@ def test_box_dtype_check():
"space",
[
Discrete(3),
Discrete(3, start=-4),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
@@ -298,6 +309,7 @@ def sample_equal(sample1, sample2):
"space",
[
Discrete(3),
Discrete(3, start=-4),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
@@ -335,6 +347,7 @@ def test_seed_reproducibility(space):
[
Tuple([Discrete(100), Discrete(100)]),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Discrete(5, start=10)]),
Tuple(
[
Discrete(5),