Modify Space.seed such that the return can be used as seeding values (#1033)

This commit is contained in:
Mark Towers
2024-04-28 16:10:35 +01:00
committed by GitHub
parent d1964978f1
commit 8bf2543e34
13 changed files with 271 additions and 146 deletions

View File

@@ -107,44 +107,45 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" """Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces.values()) return all(space.is_np_flattenable for space in self.spaces.values())
def seed(self, seed: dict[str, Any] | int | None = None) -> list[int]: def seed(self, seed: int | dict[str, Any] | None = None) -> dict[str, int]:
"""Seed the PRNG of this space and all subspaces. """Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed * ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all of the subspaces. * ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all subspaces, though is very unlikely.
* ``Dict`` - Using all the keys in the seed dictionary, the values are used to seed the subspaces. This allows the seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``). * ``Dict`` - A dictionary of seeds for each subspace, requires a seed key for every subspace. This supports seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
Args: Args:
seed: An optional list of ints or int to seed the (sub-)spaces. seed: An optional int or dictionary of subspace keys to int to seed each PRNG. See above for more details.
"""
seeds: list[int] = []
if isinstance(seed, dict): Returns:
assert ( A dictionary for the seed values of the subspaces
seed.keys() == self.spaces.keys() """
), f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}" if seed is None:
for key in seed.keys(): return {key: subspace.seed(None) for (key, subspace) in self.spaces.items()}
seeds += self.spaces[key].seed(seed[key])
elif isinstance(seed, int): elif isinstance(seed, int):
seeds = super().seed(seed) super().seed(seed)
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces # Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
subseeds = self.np_random.integers( subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces) np.iinfo(np.int32).max, size=len(self.spaces)
) )
for subspace, subseed in zip(self.spaces.values(), subseeds): return {
seeds += subspace.seed(int(subseed)) key: subspace.seed(int(subseed))
elif seed is None: for (key, subspace), subseed in zip(self.spaces.items(), subseeds)
for space in self.spaces.values(): }
seeds += space.seed(None) elif isinstance(seed, dict):
if seed.keys() != self.spaces.keys():
raise ValueError(
f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
)
return {key: self.spaces[key].seed(seed[key]) for key in seed.keys()}
else: else:
raise TypeError( raise TypeError(
f"Expected seed type: dict, int or None, actual type: {type(seed)}" f"Expected seed type: dict, int or None, actual type: {type(seed)}"
) )
return seeds
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]: def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
"""Generates a single random sample from this space. """Generates a single random sample from this space.

View File

@@ -31,25 +31,19 @@ class Graph(Space[GraphInstance]):
Example: Example:
>>> from gymnasium.spaces import Graph, Box, Discrete >>> from gymnasium.spaces import Graph, Box, Discrete
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=42) >>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=123)
>>> observation_space.sample() >>> observation_space.sample(num_nodes=4, num_edges=8)
GraphInstance(nodes=array([[-12.224312 , 71.71958 , 39.473606 ], GraphInstance(nodes=array([[ 36.47037 , -89.235794, -55.928024],
[-81.16453 , 95.12447 , 52.22794 ], [-63.125637, -64.81882 , 62.4189 ],
[ 57.21286 , -74.37727 , -9.922812 ], [ 84.669 , -44.68512 , 63.950912],
[-25.840395 , 85.353 , 28.773024 ], [ 77.97854 , 2.594091, -51.00708 ]], dtype=float32), edges=array([2, 0, 2, 1, 2, 0, 2, 1]), edge_links=array([[3, 0],
[ 64.55232 , -11.317161 , -54.552258 ], [0, 0],
[ 10.916958 , -87.23655 , 65.52624 ], [0, 1],
[ 26.33288 , 51.61755 , -29.094807 ], [0, 2],
[ 94.1396 , 78.62422 , 55.6767 ], [1, 0],
[-61.072258 , -6.6557994, -91.23925 ], [1, 0],
[-69.142105 , 36.60979 , 48.95243 ]], dtype=float32), edges=array([2, 0, 1, 1, 0, 0, 1, 0]), edge_links=array([[7, 5], [0, 1],
[6, 9], [0, 2]], dtype=int32))
[4, 1],
[8, 6],
[7, 0],
[3, 7],
[8, 4],
[8, 8]], dtype=int32))
""" """
def __init__( def __init__(
@@ -110,6 +104,76 @@ class Graph(Space[GraphInstance]):
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}." f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
) )
def seed(
self, seed: int | tuple[int, int] | tuple[int, int, int] | None = None
) -> tuple[int, int] | tuple[int, int, int]:
"""Seeds the PRNG of this space and node / edge subspace.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - The root, node and edge spaces PRNG are randomly initialized
* ``Int`` - The integer is used to seed the :class:`Graph` space that is used to generate seed values for the node and edge subspaces.
* ``Tuple[int, int]`` - Seeds the :class:`Graph` and node subspace with a particular value. Only if edge subspace isn't specified
* ``Tuple[int, int, int]`` - Seeds the :class:`Graph`, node and edge subspaces with a particular value.
Args:
seed: An optional int or tuple of ints for this space and the node / edge subspaces. See above for more details.
Returns:
A tuple of two or three ints depending on if the edge subspace is specified.
"""
if seed is None:
if self.edge_space is None:
return super().seed(None), self.node_space.seed(None)
else:
return (
super().seed(None),
self.node_space.seed(None),
self.edge_space.seed(None),
)
elif isinstance(seed, int):
if self.edge_space is None:
super_seed = super().seed(seed)
node_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
super().seed(seed)
return super_seed, self.node_space.seed(node_seed)
else:
super_seed = super().seed(seed)
node_seed, edge_seed = self.np_random.integers(
np.iinfo(np.int32).max, size=(2,)
)
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
super().seed(seed)
return (
super_seed,
self.node_space.seed(int(node_seed)),
self.edge_space.seed(int(edge_seed)),
)
elif isinstance(seed, (list, tuple)):
if self.edge_space is None:
if len(seed) != 2:
raise ValueError(
f"Expects a tuple of two values for Graph and node space, actual length: {len(seed)}"
)
return super().seed(seed[0]), self.node_space.seed(seed[1])
else:
if len(seed) != 3:
raise ValueError(
f"Expects a tuple of three values for Graph, node and edge space, actual length: {len(seed)}"
)
return (
super().seed(seed[0]),
self.node_space.seed(seed[1]),
self.edge_space.seed(seed[2]),
)
else:
raise TypeError(
f"Expects `None`, int or tuple of ints, actual type: {type(seed)}"
)
def sample( def sample(
self, self,
mask: None mask: None

View File

@@ -1,7 +1,6 @@
"""Implementation of a space that represents the cartesian product of other spaces.""" """Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations from __future__ import annotations
import collections.abc
import typing import typing
from typing import Any, Iterable from typing import Any, Iterable
@@ -17,11 +16,11 @@ class OneOf(Space[Any]):
Example: Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete >>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42) >>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=123)
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box >>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
(1, array([-0.3991573 , 0.21649833], dtype=float32))
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(0, 0) (0, 0)
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(1, array([-0.00711833, -0.7257502 ], dtype=float32))
>>> observation_space[0] >>> observation_space[0]
Discrete(2) Discrete(2)
>>> observation_space[1] >>> observation_space[1]
@@ -57,42 +56,48 @@ class OneOf(Space[Any]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" """Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces) return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]: def seed(self, seed: int | tuple[int, ...] | None = None) -> tuple[int, ...]:
"""Seed the PRNG of this space and all subspaces. """Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed * ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces. * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``. * ``Tuple[int, ...]`` - Values used to seed the subspaces, first value seeds the OneOf and subsequent seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
Args: Args:
seed: An optional list of ints or int to seed the (sub-)spaces. seed: An optional int or tuple of ints to seed the OneOf space and subspaces. See above for more details.
Returns:
A tuple of ints used to seed the OneOf space and subspaces
""" """
if isinstance(seed, collections.abc.Sequence): if seed is None:
assert ( super_seed = super().seed(None)
len(seed) == len(self.spaces) + 1 return (super_seed,) + tuple(space.seed(None) for space in self.spaces)
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
seeds = super().seed(seed[0])
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int): elif isinstance(seed, int):
seeds = super().seed(seed) super_seed = super().seed(seed)
subseeds = self.np_random.integers( subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces) np.iinfo(np.int32).max, size=len(self.spaces)
) )
for subspace, subseed in zip(self.spaces, subseeds): # this is necessary such that after int or list/tuple seeding, the OneOf PRNG are equivalent
seeds += subspace.seed(int(subseed)) super().seed(seed)
elif seed is None: return (super_seed,) + tuple(
seeds = super().seed(None) space.seed(int(subseed))
for space in self.spaces: for space, subseed in zip(self.spaces, subseeds)
seeds += space.seed(None) )
else: elif isinstance(seed, (tuple, list)):
raise TypeError( if len(seed) != len(self.spaces) + 1:
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}" raise ValueError(
f"Expects that the subspaces of seeds equals the number of subspaces + 1. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
) )
return seeds return (super().seed(seed[0]),) + tuple(
space.seed(subseed) for space, subseed in zip(self.spaces, seed[1:])
)
else:
raise TypeError(
f"Expected None, int, or tuple of ints, actual type: {type(seed)}"
)
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]: def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
"""Generates a single random sample inside this space. """Generates a single random sample inside this space.

View File

@@ -19,12 +19,18 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
Example: Example:
>>> from gymnasium.spaces import Sequence, Box >>> from gymnasium.spaces import Sequence, Box
>>> observation_space = Sequence(Box(0, 1), seed=2)
>>> observation_space.sample()
(array([0.26161215], dtype=float32),)
>>> observation_space = Sequence(Box(0, 1), seed=0) >>> observation_space = Sequence(Box(0, 1), seed=0)
>>> observation_space.sample() >>> observation_space.sample()
(array([0.6369617], dtype=float32), array([0.26978672], dtype=float32), array([0.04097353], dtype=float32)) (array([0.6822636], dtype=float32), array([0.18933342], dtype=float32), array([0.19049619], dtype=float32))
>>> observation_space.sample()
(array([0.83506], dtype=float32), array([0.9053838], dtype=float32), array([0.5836242], dtype=float32), array([0.63214064], dtype=float32))
Example with stacked observations
>>> observation_space = Sequence(Box(0, 1), stack=True, seed=0)
>>> observation_space.sample()
array([[0.6822636 ],
[0.18933342],
[0.19049619]], dtype=float32)
""" """
def __init__( def __init__(
@@ -53,11 +59,39 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
# None for shape and dtype, since it'll require special handling # None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed) super().__init__(None, None, seed)
def seed(self, seed: int | None = None) -> list[int]: def seed(self, seed: int | tuple[int, int] | None = None) -> tuple[int, int]:
"""Seed the PRNG of this space and the feature space.""" """Seed the PRNG of the Sequence space and the feature space.
seeds = super().seed(seed)
seeds += self.feature_space.seed(seed) Depending on the type of seed, the subspaces will be seeded differently
return seeds
* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Sequence` space that is used to generate a seed value for the feature space.
* ``Tuple of ints`` - A tuple for the :class:`Sequence` and feature space.
Args:
seed: An optional int or tuple of ints to seed the PRNG. See above for more details
Returns:
A tuple of the seeding values for the Sequence and feature space
"""
if seed is None:
return super().seed(None), self.feature_space.seed(None)
elif isinstance(seed, int):
super_seed = super().seed(seed)
feature_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
# this is necessary such that after int or list/tuple seeding, the Sequence PRNG are equivalent
super().seed(seed)
return super_seed, self.feature_space.seed(feature_seed)
elif isinstance(seed, (tuple, list)):
if len(seed) != 2:
raise ValueError(
f"Expects the seed to have two elements for the Sequence and feature space, actual length: {len(seed)}"
)
return super().seed(seed[0]), self.feature_space.seed(seed[1])
else:
raise TypeError(
f"Expected None, int, tuple of ints, actual type: {type(seed)}"
)
@property @property
def is_np_flattenable(self): def is_np_flattenable(self):

View File

@@ -102,13 +102,20 @@ class Space(Generic[T_cov]):
""" """
raise NotImplementedError raise NotImplementedError
def seed(self, seed: int | None = None) -> list[int]: def seed(self, seed: int | None = None) -> int | list[int] | dict[str, int]:
"""Seed the PRNG of this space and possibly the PRNGs of subspaces.""" """Seed the pseudorandom number generator (PRNG) of this space and, if applicable, the PRNGs of subspaces.
Args:
seed: The seed value for the space. This is expanded for composite spaces to accept multiple values. For further details, please refer to the space's documentation.
Returns:
The seed values used for all the PRNGs, for composite spaces this can be a tuple or dictionary of values.
"""
self._np_random, np_random_seed = seeding.np_random(seed) self._np_random, np_random_seed = seeding.np_random(seed)
return [np_random_seed] return np_random_seed
def contains(self, x: Any) -> bool: def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space, equivalent to ``sample in space``."""
raise NotImplementedError raise NotImplementedError
def __contains__(self, x: Any) -> bool: def __contains__(self, x: Any) -> bool:

View File

@@ -1,7 +1,6 @@
"""Implementation of a space that represents the cartesian product of other spaces.""" """Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations from __future__ import annotations
import collections.abc
import typing import typing
from typing import Any, Iterable from typing import Any, Iterable
@@ -47,7 +46,7 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" """Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces) return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]: def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
"""Seed the PRNG of this space and all subspaces. """Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently Depending on the type of seed, the subspaces will be seeded differently
@@ -58,32 +57,35 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
Args: Args:
seed: An optional list of ints or int to seed the (sub-)spaces. seed: An optional list of ints or int to seed the (sub-)spaces.
"""
seeds: list[int] = []
if isinstance(seed, collections.abc.Sequence): Returns:
assert len(seed) == len( A tuple of the seed values for all subspaces
self.spaces """
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}" if seed is None:
for subseed, space in zip(seed, self.spaces): return tuple(space.seed(None) for space in self.spaces)
seeds += space.seed(subseed)
elif isinstance(seed, int): elif isinstance(seed, int):
seeds = super().seed(seed) super().seed(seed)
subseeds = self.np_random.integers( subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces) np.iinfo(np.int32).max, size=len(self.spaces)
) )
for subspace, subseed in zip(self.spaces, subseeds): return tuple(
seeds += subspace.seed(int(subseed)) subspace.seed(int(subseed))
elif seed is None: for subspace, subseed in zip(self.spaces, subseeds)
for space in self.spaces: )
seeds += space.seed(seed) elif isinstance(seed, (tuple, list)):
if len(seed) != len(self.spaces):
raise ValueError(
f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
)
return tuple(
space.seed(subseed) for subseed, space in zip(seed, self.spaces)
)
else: else:
raise TypeError( raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}" f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
) )
return seeds
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]: def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
"""Generates a single random sample inside this space. """Generates a single random sample inside this space.

View File

@@ -31,19 +31,19 @@ from gymnasium.utils.passive_env_checker import (
def data_equivalence(data_1, data_2, exact: bool = False) -> bool: def data_equivalence(data_1, data_2, exact: bool = False) -> bool:
"""Assert equality between data 1 and 2, i.e observations, actions, info. """Assert equality between data 1 and 2, i.e. observations, actions, info.
Args: Args:
data_1: data structure 1 data_1: data structure 1
data_2: data structure 2 data_2: data structure 2
exact: whether to compare array exactly or not if false compares with absolute and realive torrelance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)). exact: whether to compare array exactly or not if false compares with absolute and relative tolerance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
Returns: Returns:
If observation 1 and 2 are equivalent If observation 1 and 2 are equivalent
""" """
if type(data_1) is not type(data_2): if type(data_1) is not type(data_2):
return False return False
if isinstance(data_1, dict): elif isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all( return data_1.keys() == data_2.keys() and all(
data_equivalence(data_1[k], data_2[k], exact) for k in data_1.keys() data_equivalence(data_1[k], data_2[k], exact) for k in data_1.keys()
) )

View File

@@ -6,6 +6,7 @@ import numpy as np
import pytest import pytest
from gymnasium.spaces import Box, Dict, Discrete from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils.env_checker import data_equivalence
def test_dict_init(): def test_dict_init():
@@ -56,8 +57,7 @@ DICT_SPACE = Dict(
def test_dict_seeding(): def test_dict_seeding():
seeds = DICT_SPACE.seed( seeding_values = {
{
"a": 0, "a": 0,
"b": { "b": {
"b_1": 1, "b_1": 1,
@@ -65,8 +65,8 @@ def test_dict_seeding():
}, },
"c": 3, "c": 3,
} }
) seeded_values = DICT_SPACE.seed(seeding_values)
assert all(isinstance(seed, int) for seed in seeds) assert data_equivalence(seeded_values, seeding_values)
# "Unpack" the dict sub-spaces into individual spaces # "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=0) a = Box(low=0, high=1, shape=(3, 3), seed=0)
@@ -84,7 +84,7 @@ def test_dict_seeding():
def test_int_seeding(): def test_int_seeding():
seeds = DICT_SPACE.seed(1) seeds = DICT_SPACE.seed(1)
assert all(isinstance(seed, int) for seed in seeds) assert isinstance(seeds, dict)
# rng, seeds = seeding.np_random(1) # rng, seeds = seeding.np_random(1)
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False) # subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
@@ -92,10 +92,10 @@ def test_int_seeding():
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False) # b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
# "Unpack" the dict sub-spaces into individual spaces # "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=seeds[1]) a = Box(low=0, high=1, shape=(3, 3), seed=seeds["a"])
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds[3]) b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds["b"]["b_1"])
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds[4]) b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds["b"]["b_2"])
c = Discrete(5, seed=seeds[5]) c = Discrete(5, seed=seeds["c"])
for i in range(10): for i in range(10):
dict_sample = DICT_SPACE.sample() dict_sample = DICT_SPACE.sample()
@@ -107,7 +107,7 @@ def test_int_seeding():
def test_none_seeding(): def test_none_seeding():
seeds = DICT_SPACE.seed(None) seeds = DICT_SPACE.seed(None)
assert len(seeds) == 4 and all(isinstance(seed, int) for seed in seeds) assert isinstance(seeds, dict)
def test_bad_seed(): def test_bad_seed():

View File

@@ -2,7 +2,6 @@ import numpy as np
import pytest import pytest
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
from gymnasium.utils.env_checker import data_equivalence
def test_oneof_inheritance(): def test_oneof_inheritance():
@@ -21,26 +20,18 @@ def test_oneof_inheritance():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"spaces, seed, expected_len", "spaces, seed",
[ [
([Discrete(5), Box(-1, 1, shape=(3,))], None, 3), ([Discrete(5), Box(-1, 1, shape=(3,))], None),
([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3), ([Discrete(5), Box(-1, 1, shape=(3,))], 123),
([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3), ([Discrete(5), Box(-1, 1, shape=(3,))], (123, 456, 789)),
], ],
) )
def test_oneof_seeds(spaces, seed, expected_len): def test_oneof_seeds(spaces, seed):
oneof_space = OneOf(spaces) oneof_space = OneOf(spaces)
seeds = oneof_space.seed(seed) seeds = oneof_space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds) assert isinstance(seeds, tuple)
assert len(seeds) == expected_len assert len(seeds) == len(spaces) + 1
sample1 = oneof_space.sample()
seeds2 = oneof_space.seed(seed)
sample2 = oneof_space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -71,6 +62,6 @@ def test_bad_oneof_seed():
space = OneOf([Box(0, 1), Box(0, 1)]) space = OneOf([Box(0, 1), Box(0, 1)])
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>", match="Expected None, int, or tuple of ints, actual type: <class 'float'>",
): ):
space.seed(0.0) space.seed(0.0)

View File

@@ -9,6 +9,7 @@ import numpy as np
import pytest import pytest
import scipy.stats import scipy.stats
from gymnasium.error import Error
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
from gymnasium.utils import seeding from gymnasium.utils import seeding
from gymnasium.utils.env_checker import data_equivalence from gymnasium.utils.env_checker import data_equivalence
@@ -509,13 +510,13 @@ def test_seed_reproducibility(space):
space_2 = copy.deepcopy(space) space_2 = copy.deepcopy(space)
for seed in range(5): for seed in range(5):
assert space_1.seed(seed) == space_2.seed(seed) assert data_equivalence(space_1.seed(seed), space_2.seed(seed))
# With the same seed, the two spaces should be identical # With the same seed, the two spaces should be identical
assert all( assert all(
data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10) data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10)
) )
assert space_1.seed(123) != space_2.seed(456) assert not data_equivalence(space_1.seed(123), space_2.seed(456))
# Due to randomness, it is difficult to test that random seeds produce different answers # Due to randomness, it is difficult to test that random seeds produce different answers
# Therefore, taking 10 samples and checking that they are not all the same. # Therefore, taking 10 samples and checking that they are not all the same.
assert not all( assert not all(
@@ -604,3 +605,22 @@ def test_space_pickling(space):
file_unpickled_sample = file_unpickled_space.sample() file_unpickled_sample = file_unpickled_space.sample()
assert data_equivalence(space_sample, unpickled_sample) assert data_equivalence(space_sample, unpickled_sample)
assert data_equivalence(space_sample, file_unpickled_sample) assert data_equivalence(space_sample, file_unpickled_sample)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("initial_seed", [None, 123])
def test_space_seeding_output(space, initial_seed, num_samples=5):
seeding_values = space.seed(initial_seed)
samples = [space.sample() for _ in range(num_samples)]
reseeded_values = space.seed(seeding_values)
resamples = [space.sample() for _ in range(num_samples)]
assert data_equivalence(seeding_values, reseeded_values)
assert data_equivalence(samples, resamples)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_invalid_space_seed(space):
with pytest.raises((ValueError, TypeError, Error)):
space.seed("abc")

View File

@@ -37,17 +37,16 @@ def test_sequence_inheritance():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"space, seed, expected_len", "space, seed",
[ [
(Tuple([Discrete(5), Discrete(4)]), None, 2), (Tuple([Discrete(5), Discrete(4)]), None),
(Tuple([Discrete(5), Discrete(4)]), 123, 3), (Tuple([Discrete(5), Discrete(4)]), 123),
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2), (Tuple([Discrete(5), Discrete(4)]), (123, 456)),
( (
Tuple( Tuple(
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2)))) (Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
), ),
(123, (456, 789)), (123, (456, 789)),
3,
), ),
( (
Tuple( Tuple(
@@ -57,22 +56,21 @@ def test_sequence_inheritance():
) )
), ),
(123, {"position": 456, "velocity": 789}), (123, {"position": 456, "velocity": 789}),
3,
), ),
], ],
) )
def test_seeds(space, seed, expected_len): def test_seeds(space, seed):
seeds = space.seed(seed) seeds1 = space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds) assert isinstance(seeds1, tuple)
assert len(seeds) == expected_len assert len(seeds1) == len(space)
sample1 = space.sample() sample1 = space.sample()
seeds2 = space.seed(seed) seeds2 = space.seed(seeds1)
sample2 = space.sample() sample2 = space.sample()
data_equivalence(seeds, seeds2) assert data_equivalence(seeds1, seeds2)
data_equivalence(sample1, sample2) assert data_equivalence(sample1, sample2)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -50,6 +50,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
None, None,
None, None,
None, None,
None,
# Sequence # Sequence
None, None,
None, None,
@@ -60,6 +61,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
4, 4,
5, 5,
] ]
assert len(TESTING_SPACES) == len(TESTING_SPACES_EXPECTED_FLATDIMS)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -100,6 +100,7 @@ TESTING_COMPOSITE_SPACES = [
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))), b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
), ),
# Graph spaces # Graph spaces
Graph(node_space=Box(-1, 1, shape=(2,)), edge_space=None),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)), Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(3), edge_space=Discrete(4)), Graph(node_space=Discrete(3), edge_space=Discrete(4)),