mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 17:45:07 +00:00
Modify Space.seed
such that the return can be used as seeding values (#1033)
This commit is contained in:
@@ -107,44 +107,45 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return all(space.is_np_flattenable for space in self.spaces.values())
|
return all(space.is_np_flattenable for space in self.spaces.values())
|
||||||
|
|
||||||
def seed(self, seed: dict[str, Any] | int | None = None) -> list[int]:
|
def seed(self, seed: int | dict[str, Any] | None = None) -> dict[str, int]:
|
||||||
"""Seed the PRNG of this space and all subspaces.
|
"""Seed the PRNG of this space and all subspaces.
|
||||||
|
|
||||||
Depending on the type of seed, the subspaces will be seeded differently
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
|
|
||||||
* ``None`` - All the subspaces will use a random initial seed
|
* ``None`` - All the subspaces will use a random initial seed
|
||||||
* ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all of the subspaces.
|
* ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all subspaces, though is very unlikely.
|
||||||
* ``Dict`` - Using all the keys in the seed dictionary, the values are used to seed the subspaces. This allows the seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
|
* ``Dict`` - A dictionary of seeds for each subspace, requires a seed key for every subspace. This supports seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
seed: An optional int or dictionary of subspace keys to int to seed each PRNG. See above for more details.
|
||||||
"""
|
|
||||||
seeds: list[int] = []
|
|
||||||
|
|
||||||
if isinstance(seed, dict):
|
Returns:
|
||||||
assert (
|
A dictionary for the seed values of the subspaces
|
||||||
seed.keys() == self.spaces.keys()
|
"""
|
||||||
), f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
|
if seed is None:
|
||||||
for key in seed.keys():
|
return {key: subspace.seed(None) for (key, subspace) in self.spaces.items()}
|
||||||
seeds += self.spaces[key].seed(seed[key])
|
|
||||||
elif isinstance(seed, int):
|
elif isinstance(seed, int):
|
||||||
seeds = super().seed(seed)
|
super().seed(seed)
|
||||||
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
|
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
|
||||||
subseeds = self.np_random.integers(
|
subseeds = self.np_random.integers(
|
||||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||||
)
|
)
|
||||||
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
return {
|
||||||
seeds += subspace.seed(int(subseed))
|
key: subspace.seed(int(subseed))
|
||||||
elif seed is None:
|
for (key, subspace), subseed in zip(self.spaces.items(), subseeds)
|
||||||
for space in self.spaces.values():
|
}
|
||||||
seeds += space.seed(None)
|
elif isinstance(seed, dict):
|
||||||
|
if seed.keys() != self.spaces.keys():
|
||||||
|
raise ValueError(
|
||||||
|
f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {key: self.spaces[key].seed(seed[key]) for key in seed.keys()}
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
|
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return seeds
|
|
||||||
|
|
||||||
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
|
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
|
@@ -31,25 +31,19 @@ class Graph(Space[GraphInstance]):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from gymnasium.spaces import Graph, Box, Discrete
|
>>> from gymnasium.spaces import Graph, Box, Discrete
|
||||||
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=42)
|
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=123)
|
||||||
>>> observation_space.sample()
|
>>> observation_space.sample(num_nodes=4, num_edges=8)
|
||||||
GraphInstance(nodes=array([[-12.224312 , 71.71958 , 39.473606 ],
|
GraphInstance(nodes=array([[ 36.47037 , -89.235794, -55.928024],
|
||||||
[-81.16453 , 95.12447 , 52.22794 ],
|
[-63.125637, -64.81882 , 62.4189 ],
|
||||||
[ 57.21286 , -74.37727 , -9.922812 ],
|
[ 84.669 , -44.68512 , 63.950912],
|
||||||
[-25.840395 , 85.353 , 28.773024 ],
|
[ 77.97854 , 2.594091, -51.00708 ]], dtype=float32), edges=array([2, 0, 2, 1, 2, 0, 2, 1]), edge_links=array([[3, 0],
|
||||||
[ 64.55232 , -11.317161 , -54.552258 ],
|
[0, 0],
|
||||||
[ 10.916958 , -87.23655 , 65.52624 ],
|
[0, 1],
|
||||||
[ 26.33288 , 51.61755 , -29.094807 ],
|
[0, 2],
|
||||||
[ 94.1396 , 78.62422 , 55.6767 ],
|
[1, 0],
|
||||||
[-61.072258 , -6.6557994, -91.23925 ],
|
[1, 0],
|
||||||
[-69.142105 , 36.60979 , 48.95243 ]], dtype=float32), edges=array([2, 0, 1, 1, 0, 0, 1, 0]), edge_links=array([[7, 5],
|
[0, 1],
|
||||||
[6, 9],
|
[0, 2]], dtype=int32))
|
||||||
[4, 1],
|
|
||||||
[8, 6],
|
|
||||||
[7, 0],
|
|
||||||
[3, 7],
|
|
||||||
[8, 4],
|
|
||||||
[8, 8]], dtype=int32))
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -110,6 +104,76 @@ class Graph(Space[GraphInstance]):
|
|||||||
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
|
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def seed(
|
||||||
|
self, seed: int | tuple[int, int] | tuple[int, int, int] | None = None
|
||||||
|
) -> tuple[int, int] | tuple[int, int, int]:
|
||||||
|
"""Seeds the PRNG of this space and node / edge subspace.
|
||||||
|
|
||||||
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
|
|
||||||
|
* ``None`` - The root, node and edge spaces PRNG are randomly initialized
|
||||||
|
* ``Int`` - The integer is used to seed the :class:`Graph` space that is used to generate seed values for the node and edge subspaces.
|
||||||
|
* ``Tuple[int, int]`` - Seeds the :class:`Graph` and node subspace with a particular value. Only if edge subspace isn't specified
|
||||||
|
* ``Tuple[int, int, int]`` - Seeds the :class:`Graph`, node and edge subspaces with a particular value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: An optional int or tuple of ints for this space and the node / edge subspaces. See above for more details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of two or three ints depending on if the edge subspace is specified.
|
||||||
|
"""
|
||||||
|
if seed is None:
|
||||||
|
if self.edge_space is None:
|
||||||
|
return super().seed(None), self.node_space.seed(None)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
super().seed(None),
|
||||||
|
self.node_space.seed(None),
|
||||||
|
self.edge_space.seed(None),
|
||||||
|
)
|
||||||
|
elif isinstance(seed, int):
|
||||||
|
if self.edge_space is None:
|
||||||
|
super_seed = super().seed(seed)
|
||||||
|
node_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
|
||||||
|
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
|
||||||
|
super().seed(seed)
|
||||||
|
return super_seed, self.node_space.seed(node_seed)
|
||||||
|
else:
|
||||||
|
super_seed = super().seed(seed)
|
||||||
|
node_seed, edge_seed = self.np_random.integers(
|
||||||
|
np.iinfo(np.int32).max, size=(2,)
|
||||||
|
)
|
||||||
|
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
|
||||||
|
super().seed(seed)
|
||||||
|
return (
|
||||||
|
super_seed,
|
||||||
|
self.node_space.seed(int(node_seed)),
|
||||||
|
self.edge_space.seed(int(edge_seed)),
|
||||||
|
)
|
||||||
|
elif isinstance(seed, (list, tuple)):
|
||||||
|
if self.edge_space is None:
|
||||||
|
if len(seed) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expects a tuple of two values for Graph and node space, actual length: {len(seed)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().seed(seed[0]), self.node_space.seed(seed[1])
|
||||||
|
else:
|
||||||
|
if len(seed) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expects a tuple of three values for Graph, node and edge space, actual length: {len(seed)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
super().seed(seed[0]),
|
||||||
|
self.node_space.seed(seed[1]),
|
||||||
|
self.edge_space.seed(seed[2]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expects `None`, int or tuple of ints, actual type: {type(seed)}"
|
||||||
|
)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
mask: None
|
mask: None
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections.abc
|
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable
|
||||||
|
|
||||||
@@ -17,11 +16,11 @@ class OneOf(Space[Any]):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from gymnasium.spaces import OneOf, Box, Discrete
|
>>> from gymnasium.spaces import OneOf, Box, Discrete
|
||||||
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
|
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=123)
|
||||||
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
|
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
|
||||||
(1, array([-0.3991573 , 0.21649833], dtype=float32))
|
|
||||||
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
|
|
||||||
(0, 0)
|
(0, 0)
|
||||||
|
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
|
||||||
|
(1, array([-0.00711833, -0.7257502 ], dtype=float32))
|
||||||
>>> observation_space[0]
|
>>> observation_space[0]
|
||||||
Discrete(2)
|
Discrete(2)
|
||||||
>>> observation_space[1]
|
>>> observation_space[1]
|
||||||
@@ -57,42 +56,48 @@ class OneOf(Space[Any]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return all(space.is_np_flattenable for space in self.spaces)
|
return all(space.is_np_flattenable for space in self.spaces)
|
||||||
|
|
||||||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
|
def seed(self, seed: int | tuple[int, ...] | None = None) -> tuple[int, ...]:
|
||||||
"""Seed the PRNG of this space and all subspaces.
|
"""Seed the PRNG of this space and all subspaces.
|
||||||
|
|
||||||
Depending on the type of seed, the subspaces will be seeded differently
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
|
|
||||||
* ``None`` - All the subspaces will use a random initial seed
|
* ``None`` - All the subspaces will use a random initial seed
|
||||||
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
||||||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
* ``Tuple[int, ...]`` - Values used to seed the subspaces, first value seeds the OneOf and subsequent seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
seed: An optional int or tuple of ints to seed the OneOf space and subspaces. See above for more details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of ints used to seed the OneOf space and subspaces
|
||||||
"""
|
"""
|
||||||
if isinstance(seed, collections.abc.Sequence):
|
if seed is None:
|
||||||
assert (
|
super_seed = super().seed(None)
|
||||||
len(seed) == len(self.spaces) + 1
|
return (super_seed,) + tuple(space.seed(None) for space in self.spaces)
|
||||||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
|
||||||
seeds = super().seed(seed[0])
|
|
||||||
for subseed, space in zip(seed, self.spaces):
|
|
||||||
seeds += space.seed(subseed)
|
|
||||||
elif isinstance(seed, int):
|
elif isinstance(seed, int):
|
||||||
seeds = super().seed(seed)
|
super_seed = super().seed(seed)
|
||||||
subseeds = self.np_random.integers(
|
subseeds = self.np_random.integers(
|
||||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||||
)
|
)
|
||||||
for subspace, subseed in zip(self.spaces, subseeds):
|
# this is necessary such that after int or list/tuple seeding, the OneOf PRNG are equivalent
|
||||||
seeds += subspace.seed(int(subseed))
|
super().seed(seed)
|
||||||
elif seed is None:
|
return (super_seed,) + tuple(
|
||||||
seeds = super().seed(None)
|
space.seed(int(subseed))
|
||||||
for space in self.spaces:
|
for space, subseed in zip(self.spaces, subseeds)
|
||||||
seeds += space.seed(None)
|
)
|
||||||
else:
|
elif isinstance(seed, (tuple, list)):
|
||||||
raise TypeError(
|
if len(seed) != len(self.spaces) + 1:
|
||||||
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
raise ValueError(
|
||||||
|
f"Expects that the subspaces of seeds equals the number of subspaces + 1. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return seeds
|
return (super().seed(seed[0]),) + tuple(
|
||||||
|
space.seed(subseed) for space, subseed in zip(self.spaces, seed[1:])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected None, int, or tuple of ints, actual type: {type(seed)}"
|
||||||
|
)
|
||||||
|
|
||||||
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
|
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
|
||||||
"""Generates a single random sample inside this space.
|
"""Generates a single random sample inside this space.
|
||||||
|
@@ -19,12 +19,18 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from gymnasium.spaces import Sequence, Box
|
>>> from gymnasium.spaces import Sequence, Box
|
||||||
>>> observation_space = Sequence(Box(0, 1), seed=2)
|
|
||||||
>>> observation_space.sample()
|
|
||||||
(array([0.26161215], dtype=float32),)
|
|
||||||
>>> observation_space = Sequence(Box(0, 1), seed=0)
|
>>> observation_space = Sequence(Box(0, 1), seed=0)
|
||||||
>>> observation_space.sample()
|
>>> observation_space.sample()
|
||||||
(array([0.6369617], dtype=float32), array([0.26978672], dtype=float32), array([0.04097353], dtype=float32))
|
(array([0.6822636], dtype=float32), array([0.18933342], dtype=float32), array([0.19049619], dtype=float32))
|
||||||
|
>>> observation_space.sample()
|
||||||
|
(array([0.83506], dtype=float32), array([0.9053838], dtype=float32), array([0.5836242], dtype=float32), array([0.63214064], dtype=float32))
|
||||||
|
|
||||||
|
Example with stacked observations
|
||||||
|
>>> observation_space = Sequence(Box(0, 1), stack=True, seed=0)
|
||||||
|
>>> observation_space.sample()
|
||||||
|
array([[0.6822636 ],
|
||||||
|
[0.18933342],
|
||||||
|
[0.19049619]], dtype=float32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -53,11 +59,39 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
|||||||
# None for shape and dtype, since it'll require special handling
|
# None for shape and dtype, since it'll require special handling
|
||||||
super().__init__(None, None, seed)
|
super().__init__(None, None, seed)
|
||||||
|
|
||||||
def seed(self, seed: int | None = None) -> list[int]:
|
def seed(self, seed: int | tuple[int, int] | None = None) -> tuple[int, int]:
|
||||||
"""Seed the PRNG of this space and the feature space."""
|
"""Seed the PRNG of the Sequence space and the feature space.
|
||||||
seeds = super().seed(seed)
|
|
||||||
seeds += self.feature_space.seed(seed)
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
return seeds
|
|
||||||
|
* ``None`` - All the subspaces will use a random initial seed
|
||||||
|
* ``Int`` - The integer is used to seed the :class:`Sequence` space that is used to generate a seed value for the feature space.
|
||||||
|
* ``Tuple of ints`` - A tuple for the :class:`Sequence` and feature space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: An optional int or tuple of ints to seed the PRNG. See above for more details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the seeding values for the Sequence and feature space
|
||||||
|
"""
|
||||||
|
if seed is None:
|
||||||
|
return super().seed(None), self.feature_space.seed(None)
|
||||||
|
elif isinstance(seed, int):
|
||||||
|
super_seed = super().seed(seed)
|
||||||
|
feature_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
|
||||||
|
# this is necessary such that after int or list/tuple seeding, the Sequence PRNG are equivalent
|
||||||
|
super().seed(seed)
|
||||||
|
return super_seed, self.feature_space.seed(feature_seed)
|
||||||
|
elif isinstance(seed, (tuple, list)):
|
||||||
|
if len(seed) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expects the seed to have two elements for the Sequence and feature space, actual length: {len(seed)}"
|
||||||
|
)
|
||||||
|
return super().seed(seed[0]), self.feature_space.seed(seed[1])
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected None, int, tuple of ints, actual type: {type(seed)}"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_np_flattenable(self):
|
def is_np_flattenable(self):
|
||||||
|
@@ -102,13 +102,20 @@ class Space(Generic[T_cov]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def seed(self, seed: int | None = None) -> list[int]:
|
def seed(self, seed: int | None = None) -> int | list[int] | dict[str, int]:
|
||||||
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
"""Seed the pseudorandom number generator (PRNG) of this space and, if applicable, the PRNGs of subspaces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed value for the space. This is expanded for composite spaces to accept multiple values. For further details, please refer to the space's documentation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The seed values used for all the PRNGs, for composite spaces this can be a tuple or dictionary of values.
|
||||||
|
"""
|
||||||
self._np_random, np_random_seed = seeding.np_random(seed)
|
self._np_random, np_random_seed = seeding.np_random(seed)
|
||||||
return [np_random_seed]
|
return np_random_seed
|
||||||
|
|
||||||
def contains(self, x: Any) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space, equivalent to ``sample in space``."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __contains__(self, x: Any) -> bool:
|
def __contains__(self, x: Any) -> bool:
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections.abc
|
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable
|
||||||
|
|
||||||
@@ -47,7 +46,7 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return all(space.is_np_flattenable for space in self.spaces)
|
return all(space.is_np_flattenable for space in self.spaces)
|
||||||
|
|
||||||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
|
def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
|
||||||
"""Seed the PRNG of this space and all subspaces.
|
"""Seed the PRNG of this space and all subspaces.
|
||||||
|
|
||||||
Depending on the type of seed, the subspaces will be seeded differently
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
@@ -58,32 +57,35 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||||
"""
|
|
||||||
seeds: list[int] = []
|
|
||||||
|
|
||||||
if isinstance(seed, collections.abc.Sequence):
|
Returns:
|
||||||
assert len(seed) == len(
|
A tuple of the seed values for all subspaces
|
||||||
self.spaces
|
"""
|
||||||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
|
if seed is None:
|
||||||
for subseed, space in zip(seed, self.spaces):
|
return tuple(space.seed(None) for space in self.spaces)
|
||||||
seeds += space.seed(subseed)
|
|
||||||
elif isinstance(seed, int):
|
elif isinstance(seed, int):
|
||||||
seeds = super().seed(seed)
|
super().seed(seed)
|
||||||
subseeds = self.np_random.integers(
|
subseeds = self.np_random.integers(
|
||||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||||
)
|
)
|
||||||
for subspace, subseed in zip(self.spaces, subseeds):
|
return tuple(
|
||||||
seeds += subspace.seed(int(subseed))
|
subspace.seed(int(subseed))
|
||||||
elif seed is None:
|
for subspace, subseed in zip(self.spaces, subseeds)
|
||||||
for space in self.spaces:
|
)
|
||||||
seeds += space.seed(seed)
|
elif isinstance(seed, (tuple, list)):
|
||||||
|
if len(seed) != len(self.spaces):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(
|
||||||
|
space.seed(subseed) for subseed, space in zip(seed, self.spaces)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return seeds
|
|
||||||
|
|
||||||
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
|
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
|
||||||
"""Generates a single random sample inside this space.
|
"""Generates a single random sample inside this space.
|
||||||
|
|
||||||
|
@@ -31,19 +31,19 @@ from gymnasium.utils.passive_env_checker import (
|
|||||||
|
|
||||||
|
|
||||||
def data_equivalence(data_1, data_2, exact: bool = False) -> bool:
|
def data_equivalence(data_1, data_2, exact: bool = False) -> bool:
|
||||||
"""Assert equality between data 1 and 2, i.e observations, actions, info.
|
"""Assert equality between data 1 and 2, i.e. observations, actions, info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_1: data structure 1
|
data_1: data structure 1
|
||||||
data_2: data structure 2
|
data_2: data structure 2
|
||||||
exact: whether to compare array exactly or not if false compares with absolute and realive torrelance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
|
exact: whether to compare array exactly or not if false compares with absolute and relative tolerance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If observation 1 and 2 are equivalent
|
If observation 1 and 2 are equivalent
|
||||||
"""
|
"""
|
||||||
if type(data_1) is not type(data_2):
|
if type(data_1) is not type(data_2):
|
||||||
return False
|
return False
|
||||||
if isinstance(data_1, dict):
|
elif isinstance(data_1, dict):
|
||||||
return data_1.keys() == data_2.keys() and all(
|
return data_1.keys() == data_2.keys() and all(
|
||||||
data_equivalence(data_1[k], data_2[k], exact) for k in data_1.keys()
|
data_equivalence(data_1[k], data_2[k], exact) for k in data_1.keys()
|
||||||
)
|
)
|
||||||
|
@@ -6,6 +6,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.spaces import Box, Dict, Discrete
|
from gymnasium.spaces import Box, Dict, Discrete
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
|
|
||||||
|
|
||||||
def test_dict_init():
|
def test_dict_init():
|
||||||
@@ -56,8 +57,7 @@ DICT_SPACE = Dict(
|
|||||||
|
|
||||||
|
|
||||||
def test_dict_seeding():
|
def test_dict_seeding():
|
||||||
seeds = DICT_SPACE.seed(
|
seeding_values = {
|
||||||
{
|
|
||||||
"a": 0,
|
"a": 0,
|
||||||
"b": {
|
"b": {
|
||||||
"b_1": 1,
|
"b_1": 1,
|
||||||
@@ -65,8 +65,8 @@ def test_dict_seeding():
|
|||||||
},
|
},
|
||||||
"c": 3,
|
"c": 3,
|
||||||
}
|
}
|
||||||
)
|
seeded_values = DICT_SPACE.seed(seeding_values)
|
||||||
assert all(isinstance(seed, int) for seed in seeds)
|
assert data_equivalence(seeded_values, seeding_values)
|
||||||
|
|
||||||
# "Unpack" the dict sub-spaces into individual spaces
|
# "Unpack" the dict sub-spaces into individual spaces
|
||||||
a = Box(low=0, high=1, shape=(3, 3), seed=0)
|
a = Box(low=0, high=1, shape=(3, 3), seed=0)
|
||||||
@@ -84,7 +84,7 @@ def test_dict_seeding():
|
|||||||
|
|
||||||
def test_int_seeding():
|
def test_int_seeding():
|
||||||
seeds = DICT_SPACE.seed(1)
|
seeds = DICT_SPACE.seed(1)
|
||||||
assert all(isinstance(seed, int) for seed in seeds)
|
assert isinstance(seeds, dict)
|
||||||
|
|
||||||
# rng, seeds = seeding.np_random(1)
|
# rng, seeds = seeding.np_random(1)
|
||||||
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
|
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
|
||||||
@@ -92,10 +92,10 @@ def test_int_seeding():
|
|||||||
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
|
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
|
||||||
|
|
||||||
# "Unpack" the dict sub-spaces into individual spaces
|
# "Unpack" the dict sub-spaces into individual spaces
|
||||||
a = Box(low=0, high=1, shape=(3, 3), seed=seeds[1])
|
a = Box(low=0, high=1, shape=(3, 3), seed=seeds["a"])
|
||||||
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds[3])
|
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds["b"]["b_1"])
|
||||||
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds[4])
|
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds["b"]["b_2"])
|
||||||
c = Discrete(5, seed=seeds[5])
|
c = Discrete(5, seed=seeds["c"])
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
dict_sample = DICT_SPACE.sample()
|
dict_sample = DICT_SPACE.sample()
|
||||||
@@ -107,7 +107,7 @@ def test_int_seeding():
|
|||||||
|
|
||||||
def test_none_seeding():
|
def test_none_seeding():
|
||||||
seeds = DICT_SPACE.seed(None)
|
seeds = DICT_SPACE.seed(None)
|
||||||
assert len(seeds) == 4 and all(isinstance(seed, int) for seed in seeds)
|
assert isinstance(seeds, dict)
|
||||||
|
|
||||||
|
|
||||||
def test_bad_seed():
|
def test_bad_seed():
|
||||||
|
@@ -2,7 +2,6 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
|
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
|
||||||
from gymnasium.utils.env_checker import data_equivalence
|
|
||||||
|
|
||||||
|
|
||||||
def test_oneof_inheritance():
|
def test_oneof_inheritance():
|
||||||
@@ -21,26 +20,18 @@ def test_oneof_inheritance():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"spaces, seed, expected_len",
|
"spaces, seed",
|
||||||
[
|
[
|
||||||
([Discrete(5), Box(-1, 1, shape=(3,))], None, 3),
|
([Discrete(5), Box(-1, 1, shape=(3,))], None),
|
||||||
([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3),
|
([Discrete(5), Box(-1, 1, shape=(3,))], 123),
|
||||||
([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3),
|
([Discrete(5), Box(-1, 1, shape=(3,))], (123, 456, 789)),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_oneof_seeds(spaces, seed, expected_len):
|
def test_oneof_seeds(spaces, seed):
|
||||||
oneof_space = OneOf(spaces)
|
oneof_space = OneOf(spaces)
|
||||||
seeds = oneof_space.seed(seed)
|
seeds = oneof_space.seed(seed)
|
||||||
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
|
assert isinstance(seeds, tuple)
|
||||||
assert len(seeds) == expected_len
|
assert len(seeds) == len(spaces) + 1
|
||||||
|
|
||||||
sample1 = oneof_space.sample()
|
|
||||||
|
|
||||||
seeds2 = oneof_space.seed(seed)
|
|
||||||
sample2 = oneof_space.sample()
|
|
||||||
|
|
||||||
data_equivalence(seeds, seeds2)
|
|
||||||
data_equivalence(sample1, sample2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -71,6 +62,6 @@ def test_bad_oneof_seed():
|
|||||||
space = OneOf([Box(0, 1), Box(0, 1)])
|
space = OneOf([Box(0, 1), Box(0, 1)])
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
|
match="Expected None, int, or tuple of ints, actual type: <class 'float'>",
|
||||||
):
|
):
|
||||||
space.seed(0.0)
|
space.seed(0.0)
|
||||||
|
@@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
|
|
||||||
|
from gymnasium.error import Error
|
||||||
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
|
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
|
||||||
from gymnasium.utils import seeding
|
from gymnasium.utils import seeding
|
||||||
from gymnasium.utils.env_checker import data_equivalence
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
@@ -509,13 +510,13 @@ def test_seed_reproducibility(space):
|
|||||||
space_2 = copy.deepcopy(space)
|
space_2 = copy.deepcopy(space)
|
||||||
|
|
||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
assert space_1.seed(seed) == space_2.seed(seed)
|
assert data_equivalence(space_1.seed(seed), space_2.seed(seed))
|
||||||
# With the same seed, the two spaces should be identical
|
# With the same seed, the two spaces should be identical
|
||||||
assert all(
|
assert all(
|
||||||
data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10)
|
data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert space_1.seed(123) != space_2.seed(456)
|
assert not data_equivalence(space_1.seed(123), space_2.seed(456))
|
||||||
# Due to randomness, it is difficult to test that random seeds produce different answers
|
# Due to randomness, it is difficult to test that random seeds produce different answers
|
||||||
# Therefore, taking 10 samples and checking that they are not all the same.
|
# Therefore, taking 10 samples and checking that they are not all the same.
|
||||||
assert not all(
|
assert not all(
|
||||||
@@ -604,3 +605,22 @@ def test_space_pickling(space):
|
|||||||
file_unpickled_sample = file_unpickled_space.sample()
|
file_unpickled_sample = file_unpickled_space.sample()
|
||||||
assert data_equivalence(space_sample, unpickled_sample)
|
assert data_equivalence(space_sample, unpickled_sample)
|
||||||
assert data_equivalence(space_sample, file_unpickled_sample)
|
assert data_equivalence(space_sample, file_unpickled_sample)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||||
|
@pytest.mark.parametrize("initial_seed", [None, 123])
|
||||||
|
def test_space_seeding_output(space, initial_seed, num_samples=5):
|
||||||
|
seeding_values = space.seed(initial_seed)
|
||||||
|
samples = [space.sample() for _ in range(num_samples)]
|
||||||
|
|
||||||
|
reseeded_values = space.seed(seeding_values)
|
||||||
|
resamples = [space.sample() for _ in range(num_samples)]
|
||||||
|
|
||||||
|
assert data_equivalence(seeding_values, reseeded_values)
|
||||||
|
assert data_equivalence(samples, resamples)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||||
|
def test_invalid_space_seed(space):
|
||||||
|
with pytest.raises((ValueError, TypeError, Error)):
|
||||||
|
space.seed("abc")
|
||||||
|
@@ -37,17 +37,16 @@ def test_sequence_inheritance():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"space, seed, expected_len",
|
"space, seed",
|
||||||
[
|
[
|
||||||
(Tuple([Discrete(5), Discrete(4)]), None, 2),
|
(Tuple([Discrete(5), Discrete(4)]), None),
|
||||||
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
|
(Tuple([Discrete(5), Discrete(4)]), 123),
|
||||||
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
|
(Tuple([Discrete(5), Discrete(4)]), (123, 456)),
|
||||||
(
|
(
|
||||||
Tuple(
|
Tuple(
|
||||||
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
|
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
|
||||||
),
|
),
|
||||||
(123, (456, 789)),
|
(123, (456, 789)),
|
||||||
3,
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
Tuple(
|
Tuple(
|
||||||
@@ -57,22 +56,21 @@ def test_sequence_inheritance():
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
(123, {"position": 456, "velocity": 789}),
|
(123, {"position": 456, "velocity": 789}),
|
||||||
3,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_seeds(space, seed, expected_len):
|
def test_seeds(space, seed):
|
||||||
seeds = space.seed(seed)
|
seeds1 = space.seed(seed)
|
||||||
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
|
assert isinstance(seeds1, tuple)
|
||||||
assert len(seeds) == expected_len
|
assert len(seeds1) == len(space)
|
||||||
|
|
||||||
sample1 = space.sample()
|
sample1 = space.sample()
|
||||||
|
|
||||||
seeds2 = space.seed(seed)
|
seeds2 = space.seed(seeds1)
|
||||||
sample2 = space.sample()
|
sample2 = space.sample()
|
||||||
|
|
||||||
data_equivalence(seeds, seeds2)
|
assert data_equivalence(seeds1, seeds2)
|
||||||
data_equivalence(sample1, sample2)
|
assert data_equivalence(sample1, sample2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@@ -50,6 +50,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
# Sequence
|
# Sequence
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
@@ -60,6 +61,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
|
|||||||
4,
|
4,
|
||||||
5,
|
5,
|
||||||
]
|
]
|
||||||
|
assert len(TESTING_SPACES) == len(TESTING_SPACES_EXPECTED_FLATDIMS)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@@ -100,6 +100,7 @@ TESTING_COMPOSITE_SPACES = [
|
|||||||
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
|
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
|
||||||
),
|
),
|
||||||
# Graph spaces
|
# Graph spaces
|
||||||
|
Graph(node_space=Box(-1, 1, shape=(2,)), edge_space=None),
|
||||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||||
Graph(node_space=Discrete(3), edge_space=Discrete(4)),
|
Graph(node_space=Discrete(3), edge_space=Discrete(4)),
|
||||||
|
Reference in New Issue
Block a user