Modify Space.seed such that the return can be used as seeding values (#1033)

This commit is contained in:
Mark Towers
2024-04-28 16:10:35 +01:00
committed by GitHub
parent d1964978f1
commit 8bf2543e34
13 changed files with 271 additions and 146 deletions

View File

@@ -107,44 +107,45 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces.values())
def seed(self, seed: dict[str, Any] | int | None = None) -> list[int]:
def seed(self, seed: int | dict[str, Any] | None = None) -> dict[str, int]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all of the subspaces.
* ``Dict`` - Using all the keys in the seed dictionary, the values are used to seed the subspaces. This allows the seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
* ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all subspaces, though is very unlikely.
* ``Dict`` - A dictionary of seeds for each subspace, requires a seed key for every subspace. This supports seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
seeds: list[int] = []
seed: An optional int or dictionary of subspace keys to int to seed each PRNG. See above for more details.
if isinstance(seed, dict):
assert (
seed.keys() == self.spaces.keys()
), f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
for key in seed.keys():
seeds += self.spaces[key].seed(seed[key])
Returns:
A dictionary for the seed values of the subspaces
"""
if seed is None:
return {key: subspace.seed(None) for (key, subspace) in self.spaces.items()}
elif isinstance(seed, int):
seeds = super().seed(seed)
super().seed(seed)
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces.values(), subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
for space in self.spaces.values():
seeds += space.seed(None)
return {
key: subspace.seed(int(subseed))
for (key, subspace), subseed in zip(self.spaces.items(), subseeds)
}
elif isinstance(seed, dict):
if seed.keys() != self.spaces.keys():
raise ValueError(
f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
)
return {key: self.spaces[key].seed(seed[key]) for key in seed.keys()}
else:
raise TypeError(
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
)
return seeds
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
"""Generates a single random sample from this space.

View File

@@ -31,25 +31,19 @@ class Graph(Space[GraphInstance]):
Example:
>>> from gymnasium.spaces import Graph, Box, Discrete
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=42)
>>> observation_space.sample()
GraphInstance(nodes=array([[-12.224312 , 71.71958 , 39.473606 ],
[-81.16453 , 95.12447 , 52.22794 ],
[ 57.21286 , -74.37727 , -9.922812 ],
[-25.840395 , 85.353 , 28.773024 ],
[ 64.55232 , -11.317161 , -54.552258 ],
[ 10.916958 , -87.23655 , 65.52624 ],
[ 26.33288 , 51.61755 , -29.094807 ],
[ 94.1396 , 78.62422 , 55.6767 ],
[-61.072258 , -6.6557994, -91.23925 ],
[-69.142105 , 36.60979 , 48.95243 ]], dtype=float32), edges=array([2, 0, 1, 1, 0, 0, 1, 0]), edge_links=array([[7, 5],
[6, 9],
[4, 1],
[8, 6],
[7, 0],
[3, 7],
[8, 4],
[8, 8]], dtype=int32))
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=123)
>>> observation_space.sample(num_nodes=4, num_edges=8)
GraphInstance(nodes=array([[ 36.47037 , -89.235794, -55.928024],
[-63.125637, -64.81882 , 62.4189 ],
[ 84.669 , -44.68512 , 63.950912],
[ 77.97854 , 2.594091, -51.00708 ]], dtype=float32), edges=array([2, 0, 2, 1, 2, 0, 2, 1]), edge_links=array([[3, 0],
[0, 0],
[0, 1],
[0, 2],
[1, 0],
[1, 0],
[0, 1],
[0, 2]], dtype=int32))
"""
def __init__(
@@ -110,6 +104,76 @@ class Graph(Space[GraphInstance]):
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
)
def seed(
self, seed: int | tuple[int, int] | tuple[int, int, int] | None = None
) -> tuple[int, int] | tuple[int, int, int]:
"""Seeds the PRNG of this space and node / edge subspace.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - The root, node and edge spaces PRNG are randomly initialized
* ``Int`` - The integer is used to seed the :class:`Graph` space that is used to generate seed values for the node and edge subspaces.
* ``Tuple[int, int]`` - Seeds the :class:`Graph` and node subspace with a particular value. Only if edge subspace isn't specified
* ``Tuple[int, int, int]`` - Seeds the :class:`Graph`, node and edge subspaces with a particular value.
Args:
seed: An optional int or tuple of ints for this space and the node / edge subspaces. See above for more details.
Returns:
A tuple of two or three ints depending on if the edge subspace is specified.
"""
if seed is None:
if self.edge_space is None:
return super().seed(None), self.node_space.seed(None)
else:
return (
super().seed(None),
self.node_space.seed(None),
self.edge_space.seed(None),
)
elif isinstance(seed, int):
if self.edge_space is None:
super_seed = super().seed(seed)
node_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
super().seed(seed)
return super_seed, self.node_space.seed(node_seed)
else:
super_seed = super().seed(seed)
node_seed, edge_seed = self.np_random.integers(
np.iinfo(np.int32).max, size=(2,)
)
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
super().seed(seed)
return (
super_seed,
self.node_space.seed(int(node_seed)),
self.edge_space.seed(int(edge_seed)),
)
elif isinstance(seed, (list, tuple)):
if self.edge_space is None:
if len(seed) != 2:
raise ValueError(
f"Expects a tuple of two values for Graph and node space, actual length: {len(seed)}"
)
return super().seed(seed[0]), self.node_space.seed(seed[1])
else:
if len(seed) != 3:
raise ValueError(
f"Expects a tuple of three values for Graph, node and edge space, actual length: {len(seed)}"
)
return (
super().seed(seed[0]),
self.node_space.seed(seed[1]),
self.edge_space.seed(seed[2]),
)
else:
raise TypeError(
f"Expects `None`, int or tuple of ints, actual type: {type(seed)}"
)
def sample(
self,
mask: None

View File

@@ -1,7 +1,6 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations
import collections.abc
import typing
from typing import Any, Iterable
@@ -17,11 +16,11 @@ class OneOf(Space[Any]):
Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=123)
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
(1, array([-0.3991573 , 0.21649833], dtype=float32))
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(0, 0)
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(1, array([-0.00711833, -0.7257502 ], dtype=float32))
>>> observation_space[0]
Discrete(2)
>>> observation_space[1]
@@ -57,42 +56,48 @@ class OneOf(Space[Any]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
def seed(self, seed: int | tuple[int, ...] | None = None) -> tuple[int, ...]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
* ``Tuple[int, ...]`` - Values used to seed the subspaces, first value seeds the OneOf and subsequent seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
seed: An optional int or tuple of ints to seed the OneOf space and subspaces. See above for more details.
Returns:
A tuple of ints used to seed the OneOf space and subspaces
"""
if isinstance(seed, collections.abc.Sequence):
assert (
len(seed) == len(self.spaces) + 1
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
seeds = super().seed(seed[0])
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
if seed is None:
super_seed = super().seed(None)
return (super_seed,) + tuple(space.seed(None) for space in self.spaces)
elif isinstance(seed, int):
seeds = super().seed(seed)
super_seed = super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
seeds = super().seed(None)
for space in self.spaces:
seeds += space.seed(None)
else:
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
# this is necessary such that after int or list/tuple seeding, the OneOf PRNG are equivalent
super().seed(seed)
return (super_seed,) + tuple(
space.seed(int(subseed))
for space, subseed in zip(self.spaces, subseeds)
)
elif isinstance(seed, (tuple, list)):
if len(seed) != len(self.spaces) + 1:
raise ValueError(
f"Expects that the subspaces of seeds equals the number of subspaces + 1. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
)
return seeds
return (super().seed(seed[0]),) + tuple(
space.seed(subseed) for space, subseed in zip(self.spaces, seed[1:])
)
else:
raise TypeError(
f"Expected None, int, or tuple of ints, actual type: {type(seed)}"
)
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
"""Generates a single random sample inside this space.

View File

@@ -19,12 +19,18 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
Example:
>>> from gymnasium.spaces import Sequence, Box
>>> observation_space = Sequence(Box(0, 1), seed=2)
>>> observation_space.sample()
(array([0.26161215], dtype=float32),)
>>> observation_space = Sequence(Box(0, 1), seed=0)
>>> observation_space.sample()
(array([0.6369617], dtype=float32), array([0.26978672], dtype=float32), array([0.04097353], dtype=float32))
(array([0.6822636], dtype=float32), array([0.18933342], dtype=float32), array([0.19049619], dtype=float32))
>>> observation_space.sample()
(array([0.83506], dtype=float32), array([0.9053838], dtype=float32), array([0.5836242], dtype=float32), array([0.63214064], dtype=float32))
Example with stacked observations
>>> observation_space = Sequence(Box(0, 1), stack=True, seed=0)
>>> observation_space.sample()
array([[0.6822636 ],
[0.18933342],
[0.19049619]], dtype=float32)
"""
def __init__(
@@ -53,11 +59,39 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
# None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed)
def seed(self, seed: int | None = None) -> list[int]:
"""Seed the PRNG of this space and the feature space."""
seeds = super().seed(seed)
seeds += self.feature_space.seed(seed)
return seeds
def seed(self, seed: int | tuple[int, int] | None = None) -> tuple[int, int]:
"""Seed the PRNG of the Sequence space and the feature space.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Sequence` space that is used to generate a seed value for the feature space.
* ``Tuple of ints`` - A tuple for the :class:`Sequence` and feature space.
Args:
seed: An optional int or tuple of ints to seed the PRNG. See above for more details
Returns:
A tuple of the seeding values for the Sequence and feature space
"""
if seed is None:
return super().seed(None), self.feature_space.seed(None)
elif isinstance(seed, int):
super_seed = super().seed(seed)
feature_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
# this is necessary such that after int or list/tuple seeding, the Sequence PRNG are equivalent
super().seed(seed)
return super_seed, self.feature_space.seed(feature_seed)
elif isinstance(seed, (tuple, list)):
if len(seed) != 2:
raise ValueError(
f"Expects the seed to have two elements for the Sequence and feature space, actual length: {len(seed)}"
)
return super().seed(seed[0]), self.feature_space.seed(seed[1])
else:
raise TypeError(
f"Expected None, int, tuple of ints, actual type: {type(seed)}"
)
@property
def is_np_flattenable(self):

View File

@@ -102,13 +102,20 @@ class Space(Generic[T_cov]):
"""
raise NotImplementedError
def seed(self, seed: int | None = None) -> list[int]:
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
def seed(self, seed: int | None = None) -> int | list[int] | dict[str, int]:
"""Seed the pseudorandom number generator (PRNG) of this space and, if applicable, the PRNGs of subspaces.
Args:
seed: The seed value for the space. This is expanded for composite spaces to accept multiple values. For further details, please refer to the space's documentation.
Returns:
The seed values used for all the PRNGs, for composite spaces this can be a tuple or dictionary of values.
"""
self._np_random, np_random_seed = seeding.np_random(seed)
return [np_random_seed]
return np_random_seed
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
"""Return boolean specifying if x is a valid member of this space, equivalent to ``sample in space``."""
raise NotImplementedError
def __contains__(self, x: Any) -> bool:

View File

@@ -1,7 +1,6 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations
import collections.abc
import typing
from typing import Any, Iterable
@@ -47,7 +46,7 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
@@ -58,32 +57,35 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
seeds: list[int] = []
if isinstance(seed, collections.abc.Sequence):
assert len(seed) == len(
self.spaces
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
Returns:
A tuple of the seed values for all subspaces
"""
if seed is None:
return tuple(space.seed(None) for space in self.spaces)
elif isinstance(seed, int):
seeds = super().seed(seed)
super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
for space in self.spaces:
seeds += space.seed(seed)
return tuple(
subspace.seed(int(subseed))
for subspace, subseed in zip(self.spaces, subseeds)
)
elif isinstance(seed, (tuple, list)):
if len(seed) != len(self.spaces):
raise ValueError(
f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
)
return tuple(
space.seed(subseed) for subseed, space in zip(seed, self.spaces)
)
else:
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)
return seeds
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
"""Generates a single random sample inside this space.

View File

@@ -31,19 +31,19 @@ from gymnasium.utils.passive_env_checker import (
def data_equivalence(data_1, data_2, exact: bool = False) -> bool:
"""Assert equality between data 1 and 2, i.e observations, actions, info.
"""Assert equality between data 1 and 2, i.e. observations, actions, info.
Args:
data_1: data structure 1
data_2: data structure 2
exact: whether to compare array exactly or not if false compares with absolute and realive torrelance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
exact: whether to compare array exactly or not if false compares with absolute and relative tolerance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
Returns:
If observation 1 and 2 are equivalent
"""
if type(data_1) is not type(data_2):
return False
if isinstance(data_1, dict):
elif isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all(
data_equivalence(data_1[k], data_2[k], exact) for k in data_1.keys()
)

View File

@@ -6,6 +6,7 @@ import numpy as np
import pytest
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils.env_checker import data_equivalence
def test_dict_init():
@@ -56,8 +57,7 @@ DICT_SPACE = Dict(
def test_dict_seeding():
seeds = DICT_SPACE.seed(
{
seeding_values = {
"a": 0,
"b": {
"b_1": 1,
@@ -65,8 +65,8 @@ def test_dict_seeding():
},
"c": 3,
}
)
assert all(isinstance(seed, int) for seed in seeds)
seeded_values = DICT_SPACE.seed(seeding_values)
assert data_equivalence(seeded_values, seeding_values)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=0)
@@ -84,7 +84,7 @@ def test_dict_seeding():
def test_int_seeding():
seeds = DICT_SPACE.seed(1)
assert all(isinstance(seed, int) for seed in seeds)
assert isinstance(seeds, dict)
# rng, seeds = seeding.np_random(1)
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
@@ -92,10 +92,10 @@ def test_int_seeding():
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=seeds[1])
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds[3])
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds[4])
c = Discrete(5, seed=seeds[5])
a = Box(low=0, high=1, shape=(3, 3), seed=seeds["a"])
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds["b"]["b_1"])
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds["b"]["b_2"])
c = Discrete(5, seed=seeds["c"])
for i in range(10):
dict_sample = DICT_SPACE.sample()
@@ -107,7 +107,7 @@ def test_int_seeding():
def test_none_seeding():
seeds = DICT_SPACE.seed(None)
assert len(seeds) == 4 and all(isinstance(seed, int) for seed in seeds)
assert isinstance(seeds, dict)
def test_bad_seed():

View File

@@ -2,7 +2,6 @@ import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
from gymnasium.utils.env_checker import data_equivalence
def test_oneof_inheritance():
@@ -21,26 +20,18 @@ def test_oneof_inheritance():
@pytest.mark.parametrize(
"spaces, seed, expected_len",
"spaces, seed",
[
([Discrete(5), Box(-1, 1, shape=(3,))], None, 3),
([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3),
([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3),
([Discrete(5), Box(-1, 1, shape=(3,))], None),
([Discrete(5), Box(-1, 1, shape=(3,))], 123),
([Discrete(5), Box(-1, 1, shape=(3,))], (123, 456, 789)),
],
)
def test_oneof_seeds(spaces, seed, expected_len):
def test_oneof_seeds(spaces, seed):
oneof_space = OneOf(spaces)
seeds = oneof_space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
sample1 = oneof_space.sample()
seeds2 = oneof_space.seed(seed)
sample2 = oneof_space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
assert isinstance(seeds, tuple)
assert len(seeds) == len(spaces) + 1
@pytest.mark.parametrize(
@@ -71,6 +62,6 @@ def test_bad_oneof_seed():
space = OneOf([Box(0, 1), Box(0, 1)])
with pytest.raises(
TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
match="Expected None, int, or tuple of ints, actual type: <class 'float'>",
):
space.seed(0.0)

View File

@@ -9,6 +9,7 @@ import numpy as np
import pytest
import scipy.stats
from gymnasium.error import Error
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
from gymnasium.utils import seeding
from gymnasium.utils.env_checker import data_equivalence
@@ -509,13 +510,13 @@ def test_seed_reproducibility(space):
space_2 = copy.deepcopy(space)
for seed in range(5):
assert space_1.seed(seed) == space_2.seed(seed)
assert data_equivalence(space_1.seed(seed), space_2.seed(seed))
# With the same seed, the two spaces should be identical
assert all(
data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10)
)
assert space_1.seed(123) != space_2.seed(456)
assert not data_equivalence(space_1.seed(123), space_2.seed(456))
# Due to randomness, it is difficult to test that random seeds produce different answers
# Therefore, taking 10 samples and checking that they are not all the same.
assert not all(
@@ -604,3 +605,22 @@ def test_space_pickling(space):
file_unpickled_sample = file_unpickled_space.sample()
assert data_equivalence(space_sample, unpickled_sample)
assert data_equivalence(space_sample, file_unpickled_sample)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("initial_seed", [None, 123])
def test_space_seeding_output(space, initial_seed, num_samples=5):
seeding_values = space.seed(initial_seed)
samples = [space.sample() for _ in range(num_samples)]
reseeded_values = space.seed(seeding_values)
resamples = [space.sample() for _ in range(num_samples)]
assert data_equivalence(seeding_values, reseeded_values)
assert data_equivalence(samples, resamples)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_invalid_space_seed(space):
with pytest.raises((ValueError, TypeError, Error)):
space.seed("abc")

View File

@@ -37,17 +37,16 @@ def test_sequence_inheritance():
@pytest.mark.parametrize(
"space, seed, expected_len",
"space, seed",
[
(Tuple([Discrete(5), Discrete(4)]), None, 2),
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
(Tuple([Discrete(5), Discrete(4)]), None),
(Tuple([Discrete(5), Discrete(4)]), 123),
(Tuple([Discrete(5), Discrete(4)]), (123, 456)),
(
Tuple(
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
),
(123, (456, 789)),
3,
),
(
Tuple(
@@ -57,22 +56,21 @@ def test_sequence_inheritance():
)
),
(123, {"position": 456, "velocity": 789}),
3,
),
],
)
def test_seeds(space, seed, expected_len):
seeds = space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
def test_seeds(space, seed):
seeds1 = space.seed(seed)
assert isinstance(seeds1, tuple)
assert len(seeds1) == len(space)
sample1 = space.sample()
seeds2 = space.seed(seed)
seeds2 = space.seed(seeds1)
sample2 = space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
assert data_equivalence(seeds1, seeds2)
assert data_equivalence(sample1, sample2)
@pytest.mark.parametrize(

View File

@@ -50,6 +50,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
None,
None,
None,
None,
# Sequence
None,
None,
@@ -60,6 +61,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
4,
5,
]
assert len(TESTING_SPACES) == len(TESTING_SPACES_EXPECTED_FLATDIMS)
@pytest.mark.parametrize(

View File

@@ -100,6 +100,7 @@ TESTING_COMPOSITE_SPACES = [
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
),
# Graph spaces
Graph(node_space=Box(-1, 1, shape=(2,)), edge_space=None),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(3), edge_space=Discrete(4)),