Modify Space.seed such that the return can be used as seeding values (#1033)

This commit is contained in:
Mark Towers
2024-04-28 16:10:35 +01:00
committed by GitHub
parent d1964978f1
commit 8bf2543e34
13 changed files with 271 additions and 146 deletions

View File

@@ -37,17 +37,16 @@ def test_sequence_inheritance():
@pytest.mark.parametrize(
"space, seed, expected_len",
"space, seed",
[
(Tuple([Discrete(5), Discrete(4)]), None, 2),
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
(Tuple([Discrete(5), Discrete(4)]), None),
(Tuple([Discrete(5), Discrete(4)]), 123),
(Tuple([Discrete(5), Discrete(4)]), (123, 456)),
(
Tuple(
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
),
(123, (456, 789)),
3,
),
(
Tuple(
@@ -57,22 +56,21 @@ def test_sequence_inheritance():
)
),
(123, {"position": 456, "velocity": 789}),
3,
),
],
)
def test_seeds(space, seed, expected_len):
seeds = space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
def test_seeds(space, seed):
seeds1 = space.seed(seed)
assert isinstance(seeds1, tuple)
assert len(seeds1) == len(space)
sample1 = space.sample()
seeds2 = space.seed(seed)
seeds2 = space.seed(seeds1)
sample2 = space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
assert data_equivalence(seeds1, seeds2)
assert data_equivalence(sample1, sample2)
@pytest.mark.parametrize(