mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 17:45:07 +00:00
Modify Space.seed
such that the return can be used as seeding values (#1033)
This commit is contained in:
@@ -107,44 +107,45 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces.values())
|
||||
|
||||
def seed(self, seed: dict[str, Any] | int | None = None) -> list[int]:
|
||||
def seed(self, seed: int | dict[str, Any] | None = None) -> dict[str, int]:
|
||||
"""Seed the PRNG of this space and all subspaces.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
|
||||
* ``None`` - All the subspaces will use a random initial seed
|
||||
* ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all of the subspaces.
|
||||
* ``Dict`` - Using all the keys in the seed dictionary, the values are used to seed the subspaces. This allows the seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
|
||||
* ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all subspaces, though is very unlikely.
|
||||
* ``Dict`` - A dictionary of seeds for each subspace, requires a seed key for every subspace. This supports seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
|
||||
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
"""
|
||||
seeds: list[int] = []
|
||||
seed: An optional int or dictionary of subspace keys to int to seed each PRNG. See above for more details.
|
||||
|
||||
if isinstance(seed, dict):
|
||||
assert (
|
||||
seed.keys() == self.spaces.keys()
|
||||
), f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
|
||||
for key in seed.keys():
|
||||
seeds += self.spaces[key].seed(seed[key])
|
||||
Returns:
|
||||
A dictionary for the seed values of the subspaces
|
||||
"""
|
||||
if seed is None:
|
||||
return {key: subspace.seed(None) for (key, subspace) in self.spaces.items()}
|
||||
elif isinstance(seed, int):
|
||||
seeds = super().seed(seed)
|
||||
super().seed(seed)
|
||||
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
|
||||
subseeds = self.np_random.integers(
|
||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||
)
|
||||
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
||||
seeds += subspace.seed(int(subseed))
|
||||
elif seed is None:
|
||||
for space in self.spaces.values():
|
||||
seeds += space.seed(None)
|
||||
return {
|
||||
key: subspace.seed(int(subseed))
|
||||
for (key, subspace), subseed in zip(self.spaces.items(), subseeds)
|
||||
}
|
||||
elif isinstance(seed, dict):
|
||||
if seed.keys() != self.spaces.keys():
|
||||
raise ValueError(
|
||||
f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
|
||||
)
|
||||
|
||||
return {key: self.spaces[key].seed(seed[key]) for key in seed.keys()}
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
|
||||
)
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
|
@@ -31,25 +31,19 @@ class Graph(Space[GraphInstance]):
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Graph, Box, Discrete
|
||||
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=42)
|
||||
>>> observation_space.sample()
|
||||
GraphInstance(nodes=array([[-12.224312 , 71.71958 , 39.473606 ],
|
||||
[-81.16453 , 95.12447 , 52.22794 ],
|
||||
[ 57.21286 , -74.37727 , -9.922812 ],
|
||||
[-25.840395 , 85.353 , 28.773024 ],
|
||||
[ 64.55232 , -11.317161 , -54.552258 ],
|
||||
[ 10.916958 , -87.23655 , 65.52624 ],
|
||||
[ 26.33288 , 51.61755 , -29.094807 ],
|
||||
[ 94.1396 , 78.62422 , 55.6767 ],
|
||||
[-61.072258 , -6.6557994, -91.23925 ],
|
||||
[-69.142105 , 36.60979 , 48.95243 ]], dtype=float32), edges=array([2, 0, 1, 1, 0, 0, 1, 0]), edge_links=array([[7, 5],
|
||||
[6, 9],
|
||||
[4, 1],
|
||||
[8, 6],
|
||||
[7, 0],
|
||||
[3, 7],
|
||||
[8, 4],
|
||||
[8, 8]], dtype=int32))
|
||||
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=123)
|
||||
>>> observation_space.sample(num_nodes=4, num_edges=8)
|
||||
GraphInstance(nodes=array([[ 36.47037 , -89.235794, -55.928024],
|
||||
[-63.125637, -64.81882 , 62.4189 ],
|
||||
[ 84.669 , -44.68512 , 63.950912],
|
||||
[ 77.97854 , 2.594091, -51.00708 ]], dtype=float32), edges=array([2, 0, 2, 1, 2, 0, 2, 1]), edge_links=array([[3, 0],
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[1, 0],
|
||||
[1, 0],
|
||||
[0, 1],
|
||||
[0, 2]], dtype=int32))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -110,6 +104,76 @@ class Graph(Space[GraphInstance]):
|
||||
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
|
||||
)
|
||||
|
||||
def seed(
|
||||
self, seed: int | tuple[int, int] | tuple[int, int, int] | None = None
|
||||
) -> tuple[int, int] | tuple[int, int, int]:
|
||||
"""Seeds the PRNG of this space and node / edge subspace.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
|
||||
* ``None`` - The root, node and edge spaces PRNG are randomly initialized
|
||||
* ``Int`` - The integer is used to seed the :class:`Graph` space that is used to generate seed values for the node and edge subspaces.
|
||||
* ``Tuple[int, int]`` - Seeds the :class:`Graph` and node subspace with a particular value. Only if edge subspace isn't specified
|
||||
* ``Tuple[int, int, int]`` - Seeds the :class:`Graph`, node and edge subspaces with a particular value.
|
||||
|
||||
Args:
|
||||
seed: An optional int or tuple of ints for this space and the node / edge subspaces. See above for more details.
|
||||
|
||||
Returns:
|
||||
A tuple of two or three ints depending on if the edge subspace is specified.
|
||||
"""
|
||||
if seed is None:
|
||||
if self.edge_space is None:
|
||||
return super().seed(None), self.node_space.seed(None)
|
||||
else:
|
||||
return (
|
||||
super().seed(None),
|
||||
self.node_space.seed(None),
|
||||
self.edge_space.seed(None),
|
||||
)
|
||||
elif isinstance(seed, int):
|
||||
if self.edge_space is None:
|
||||
super_seed = super().seed(seed)
|
||||
node_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
|
||||
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
|
||||
super().seed(seed)
|
||||
return super_seed, self.node_space.seed(node_seed)
|
||||
else:
|
||||
super_seed = super().seed(seed)
|
||||
node_seed, edge_seed = self.np_random.integers(
|
||||
np.iinfo(np.int32).max, size=(2,)
|
||||
)
|
||||
# this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
|
||||
super().seed(seed)
|
||||
return (
|
||||
super_seed,
|
||||
self.node_space.seed(int(node_seed)),
|
||||
self.edge_space.seed(int(edge_seed)),
|
||||
)
|
||||
elif isinstance(seed, (list, tuple)):
|
||||
if self.edge_space is None:
|
||||
if len(seed) != 2:
|
||||
raise ValueError(
|
||||
f"Expects a tuple of two values for Graph and node space, actual length: {len(seed)}"
|
||||
)
|
||||
|
||||
return super().seed(seed[0]), self.node_space.seed(seed[1])
|
||||
else:
|
||||
if len(seed) != 3:
|
||||
raise ValueError(
|
||||
f"Expects a tuple of three values for Graph, node and edge space, actual length: {len(seed)}"
|
||||
)
|
||||
|
||||
return (
|
||||
super().seed(seed[0]),
|
||||
self.node_space.seed(seed[1]),
|
||||
self.edge_space.seed(seed[2]),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expects `None`, int or tuple of ints, actual type: {type(seed)}"
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
mask: None
|
||||
|
@@ -1,7 +1,6 @@
|
||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any, Iterable
|
||||
|
||||
@@ -17,11 +16,11 @@ class OneOf(Space[Any]):
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import OneOf, Box, Discrete
|
||||
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
|
||||
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=123)
|
||||
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
|
||||
(1, array([-0.3991573 , 0.21649833], dtype=float32))
|
||||
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
|
||||
(0, 0)
|
||||
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
|
||||
(1, array([-0.00711833, -0.7257502 ], dtype=float32))
|
||||
>>> observation_space[0]
|
||||
Discrete(2)
|
||||
>>> observation_space[1]
|
||||
@@ -57,42 +56,48 @@ class OneOf(Space[Any]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces)
|
||||
|
||||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
|
||||
def seed(self, seed: int | tuple[int, ...] | None = None) -> tuple[int, ...]:
|
||||
"""Seed the PRNG of this space and all subspaces.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
|
||||
* ``None`` - All the subspaces will use a random initial seed
|
||||
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
||||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||
* ``Tuple[int, ...]`` - Values used to seed the subspaces, first value seeds the OneOf and subsequent seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
seed: An optional int or tuple of ints to seed the OneOf space and subspaces. See above for more details.
|
||||
|
||||
Returns:
|
||||
A tuple of ints used to seed the OneOf space and subspaces
|
||||
"""
|
||||
if isinstance(seed, collections.abc.Sequence):
|
||||
assert (
|
||||
len(seed) == len(self.spaces) + 1
|
||||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
||||
seeds = super().seed(seed[0])
|
||||
for subseed, space in zip(seed, self.spaces):
|
||||
seeds += space.seed(subseed)
|
||||
if seed is None:
|
||||
super_seed = super().seed(None)
|
||||
return (super_seed,) + tuple(space.seed(None) for space in self.spaces)
|
||||
elif isinstance(seed, int):
|
||||
seeds = super().seed(seed)
|
||||
super_seed = super().seed(seed)
|
||||
subseeds = self.np_random.integers(
|
||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||
)
|
||||
for subspace, subseed in zip(self.spaces, subseeds):
|
||||
seeds += subspace.seed(int(subseed))
|
||||
elif seed is None:
|
||||
seeds = super().seed(None)
|
||||
for space in self.spaces:
|
||||
seeds += space.seed(None)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
||||
# this is necessary such that after int or list/tuple seeding, the OneOf PRNG are equivalent
|
||||
super().seed(seed)
|
||||
return (super_seed,) + tuple(
|
||||
space.seed(int(subseed))
|
||||
for space, subseed in zip(self.spaces, subseeds)
|
||||
)
|
||||
elif isinstance(seed, (tuple, list)):
|
||||
if len(seed) != len(self.spaces) + 1:
|
||||
raise ValueError(
|
||||
f"Expects that the subspaces of seeds equals the number of subspaces + 1. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
||||
)
|
||||
|
||||
return seeds
|
||||
return (super().seed(seed[0]),) + tuple(
|
||||
space.seed(subseed) for space, subseed in zip(self.spaces, seed[1:])
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected None, int, or tuple of ints, actual type: {type(seed)}"
|
||||
)
|
||||
|
||||
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
|
||||
"""Generates a single random sample inside this space.
|
||||
|
@@ -19,12 +19,18 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Sequence, Box
|
||||
>>> observation_space = Sequence(Box(0, 1), seed=2)
|
||||
>>> observation_space.sample()
|
||||
(array([0.26161215], dtype=float32),)
|
||||
>>> observation_space = Sequence(Box(0, 1), seed=0)
|
||||
>>> observation_space.sample()
|
||||
(array([0.6369617], dtype=float32), array([0.26978672], dtype=float32), array([0.04097353], dtype=float32))
|
||||
(array([0.6822636], dtype=float32), array([0.18933342], dtype=float32), array([0.19049619], dtype=float32))
|
||||
>>> observation_space.sample()
|
||||
(array([0.83506], dtype=float32), array([0.9053838], dtype=float32), array([0.5836242], dtype=float32), array([0.63214064], dtype=float32))
|
||||
|
||||
Example with stacked observations
|
||||
>>> observation_space = Sequence(Box(0, 1), stack=True, seed=0)
|
||||
>>> observation_space.sample()
|
||||
array([[0.6822636 ],
|
||||
[0.18933342],
|
||||
[0.19049619]], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -53,11 +59,39 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
# None for shape and dtype, since it'll require special handling
|
||||
super().__init__(None, None, seed)
|
||||
|
||||
def seed(self, seed: int | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and the feature space."""
|
||||
seeds = super().seed(seed)
|
||||
seeds += self.feature_space.seed(seed)
|
||||
return seeds
|
||||
def seed(self, seed: int | tuple[int, int] | None = None) -> tuple[int, int]:
|
||||
"""Seed the PRNG of the Sequence space and the feature space.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
|
||||
* ``None`` - All the subspaces will use a random initial seed
|
||||
* ``Int`` - The integer is used to seed the :class:`Sequence` space that is used to generate a seed value for the feature space.
|
||||
* ``Tuple of ints`` - A tuple for the :class:`Sequence` and feature space.
|
||||
|
||||
Args:
|
||||
seed: An optional int or tuple of ints to seed the PRNG. See above for more details
|
||||
|
||||
Returns:
|
||||
A tuple of the seeding values for the Sequence and feature space
|
||||
"""
|
||||
if seed is None:
|
||||
return super().seed(None), self.feature_space.seed(None)
|
||||
elif isinstance(seed, int):
|
||||
super_seed = super().seed(seed)
|
||||
feature_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
|
||||
# this is necessary such that after int or list/tuple seeding, the Sequence PRNG are equivalent
|
||||
super().seed(seed)
|
||||
return super_seed, self.feature_space.seed(feature_seed)
|
||||
elif isinstance(seed, (tuple, list)):
|
||||
if len(seed) != 2:
|
||||
raise ValueError(
|
||||
f"Expects the seed to have two elements for the Sequence and feature space, actual length: {len(seed)}"
|
||||
)
|
||||
return super().seed(seed[0]), self.feature_space.seed(seed[1])
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected None, int, tuple of ints, actual type: {type(seed)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
|
@@ -102,13 +102,20 @@ class Space(Generic[T_cov]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def seed(self, seed: int | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
||||
def seed(self, seed: int | None = None) -> int | list[int] | dict[str, int]:
|
||||
"""Seed the pseudorandom number generator (PRNG) of this space and, if applicable, the PRNGs of subspaces.
|
||||
|
||||
Args:
|
||||
seed: The seed value for the space. This is expanded for composite spaces to accept multiple values. For further details, please refer to the space's documentation.
|
||||
|
||||
Returns:
|
||||
The seed values used for all the PRNGs, for composite spaces this can be a tuple or dictionary of values.
|
||||
"""
|
||||
self._np_random, np_random_seed = seeding.np_random(seed)
|
||||
return [np_random_seed]
|
||||
return np_random_seed
|
||||
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
"""Return boolean specifying if x is a valid member of this space, equivalent to ``sample in space``."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __contains__(self, x: Any) -> bool:
|
||||
|
@@ -1,7 +1,6 @@
|
||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any, Iterable
|
||||
|
||||
@@ -47,7 +46,7 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces)
|
||||
|
||||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
|
||||
def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
|
||||
"""Seed the PRNG of this space and all subspaces.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
@@ -58,32 +57,35 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
||||
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
"""
|
||||
seeds: list[int] = []
|
||||
|
||||
if isinstance(seed, collections.abc.Sequence):
|
||||
assert len(seed) == len(
|
||||
self.spaces
|
||||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
|
||||
for subseed, space in zip(seed, self.spaces):
|
||||
seeds += space.seed(subseed)
|
||||
Returns:
|
||||
A tuple of the seed values for all subspaces
|
||||
"""
|
||||
if seed is None:
|
||||
return tuple(space.seed(None) for space in self.spaces)
|
||||
elif isinstance(seed, int):
|
||||
seeds = super().seed(seed)
|
||||
super().seed(seed)
|
||||
subseeds = self.np_random.integers(
|
||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||
)
|
||||
for subspace, subseed in zip(self.spaces, subseeds):
|
||||
seeds += subspace.seed(int(subseed))
|
||||
elif seed is None:
|
||||
for space in self.spaces:
|
||||
seeds += space.seed(seed)
|
||||
return tuple(
|
||||
subspace.seed(int(subseed))
|
||||
for subspace, subseed in zip(self.spaces, subseeds)
|
||||
)
|
||||
elif isinstance(seed, (tuple, list)):
|
||||
if len(seed) != len(self.spaces):
|
||||
raise ValueError(
|
||||
f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
||||
)
|
||||
|
||||
return tuple(
|
||||
space.seed(subseed) for subseed, space in zip(seed, self.spaces)
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
||||
)
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
|
||||
"""Generates a single random sample inside this space.
|
||||
|
||||
|
@@ -31,19 +31,19 @@ from gymnasium.utils.passive_env_checker import (
|
||||
|
||||
|
||||
def data_equivalence(data_1, data_2, exact: bool = False) -> bool:
|
||||
"""Assert equality between data 1 and 2, i.e observations, actions, info.
|
||||
"""Assert equality between data 1 and 2, i.e. observations, actions, info.
|
||||
|
||||
Args:
|
||||
data_1: data structure 1
|
||||
data_2: data structure 2
|
||||
exact: whether to compare array exactly or not if false compares with absolute and realive torrelance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
|
||||
exact: whether to compare array exactly or not if false compares with absolute and relative tolerance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
|
||||
|
||||
Returns:
|
||||
If observation 1 and 2 are equivalent
|
||||
"""
|
||||
if type(data_1) is not type(data_2):
|
||||
return False
|
||||
if isinstance(data_1, dict):
|
||||
elif isinstance(data_1, dict):
|
||||
return data_1.keys() == data_2.keys() and all(
|
||||
data_equivalence(data_1[k], data_2[k], exact) for k in data_1.keys()
|
||||
)
|
||||
|
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.spaces import Box, Dict, Discrete
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
|
||||
|
||||
def test_dict_init():
|
||||
@@ -56,8 +57,7 @@ DICT_SPACE = Dict(
|
||||
|
||||
|
||||
def test_dict_seeding():
|
||||
seeds = DICT_SPACE.seed(
|
||||
{
|
||||
seeding_values = {
|
||||
"a": 0,
|
||||
"b": {
|
||||
"b_1": 1,
|
||||
@@ -65,8 +65,8 @@ def test_dict_seeding():
|
||||
},
|
||||
"c": 3,
|
||||
}
|
||||
)
|
||||
assert all(isinstance(seed, int) for seed in seeds)
|
||||
seeded_values = DICT_SPACE.seed(seeding_values)
|
||||
assert data_equivalence(seeded_values, seeding_values)
|
||||
|
||||
# "Unpack" the dict sub-spaces into individual spaces
|
||||
a = Box(low=0, high=1, shape=(3, 3), seed=0)
|
||||
@@ -84,7 +84,7 @@ def test_dict_seeding():
|
||||
|
||||
def test_int_seeding():
|
||||
seeds = DICT_SPACE.seed(1)
|
||||
assert all(isinstance(seed, int) for seed in seeds)
|
||||
assert isinstance(seeds, dict)
|
||||
|
||||
# rng, seeds = seeding.np_random(1)
|
||||
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
|
||||
@@ -92,10 +92,10 @@ def test_int_seeding():
|
||||
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
|
||||
|
||||
# "Unpack" the dict sub-spaces into individual spaces
|
||||
a = Box(low=0, high=1, shape=(3, 3), seed=seeds[1])
|
||||
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds[3])
|
||||
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds[4])
|
||||
c = Discrete(5, seed=seeds[5])
|
||||
a = Box(low=0, high=1, shape=(3, 3), seed=seeds["a"])
|
||||
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds["b"]["b_1"])
|
||||
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds["b"]["b_2"])
|
||||
c = Discrete(5, seed=seeds["c"])
|
||||
|
||||
for i in range(10):
|
||||
dict_sample = DICT_SPACE.sample()
|
||||
@@ -107,7 +107,7 @@ def test_int_seeding():
|
||||
|
||||
def test_none_seeding():
|
||||
seeds = DICT_SPACE.seed(None)
|
||||
assert len(seeds) == 4 and all(isinstance(seed, int) for seed in seeds)
|
||||
assert isinstance(seeds, dict)
|
||||
|
||||
|
||||
def test_bad_seed():
|
||||
|
@@ -2,7 +2,6 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
|
||||
|
||||
def test_oneof_inheritance():
|
||||
@@ -21,26 +20,18 @@ def test_oneof_inheritance():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spaces, seed, expected_len",
|
||||
"spaces, seed",
|
||||
[
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], None, 3),
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3),
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3),
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], None),
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], 123),
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], (123, 456, 789)),
|
||||
],
|
||||
)
|
||||
def test_oneof_seeds(spaces, seed, expected_len):
|
||||
def test_oneof_seeds(spaces, seed):
|
||||
oneof_space = OneOf(spaces)
|
||||
seeds = oneof_space.seed(seed)
|
||||
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
|
||||
assert len(seeds) == expected_len
|
||||
|
||||
sample1 = oneof_space.sample()
|
||||
|
||||
seeds2 = oneof_space.seed(seed)
|
||||
sample2 = oneof_space.sample()
|
||||
|
||||
data_equivalence(seeds, seeds2)
|
||||
data_equivalence(sample1, sample2)
|
||||
assert isinstance(seeds, tuple)
|
||||
assert len(seeds) == len(spaces) + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -71,6 +62,6 @@ def test_bad_oneof_seed():
|
||||
space = OneOf([Box(0, 1), Box(0, 1)])
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
|
||||
match="Expected None, int, or tuple of ints, actual type: <class 'float'>",
|
||||
):
|
||||
space.seed(0.0)
|
||||
|
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
|
||||
from gymnasium.error import Error
|
||||
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
@@ -509,13 +510,13 @@ def test_seed_reproducibility(space):
|
||||
space_2 = copy.deepcopy(space)
|
||||
|
||||
for seed in range(5):
|
||||
assert space_1.seed(seed) == space_2.seed(seed)
|
||||
assert data_equivalence(space_1.seed(seed), space_2.seed(seed))
|
||||
# With the same seed, the two spaces should be identical
|
||||
assert all(
|
||||
data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10)
|
||||
)
|
||||
|
||||
assert space_1.seed(123) != space_2.seed(456)
|
||||
assert not data_equivalence(space_1.seed(123), space_2.seed(456))
|
||||
# Due to randomness, it is difficult to test that random seeds produce different answers
|
||||
# Therefore, taking 10 samples and checking that they are not all the same.
|
||||
assert not all(
|
||||
@@ -604,3 +605,22 @@ def test_space_pickling(space):
|
||||
file_unpickled_sample = file_unpickled_space.sample()
|
||||
assert data_equivalence(space_sample, unpickled_sample)
|
||||
assert data_equivalence(space_sample, file_unpickled_sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
@pytest.mark.parametrize("initial_seed", [None, 123])
|
||||
def test_space_seeding_output(space, initial_seed, num_samples=5):
|
||||
seeding_values = space.seed(initial_seed)
|
||||
samples = [space.sample() for _ in range(num_samples)]
|
||||
|
||||
reseeded_values = space.seed(seeding_values)
|
||||
resamples = [space.sample() for _ in range(num_samples)]
|
||||
|
||||
assert data_equivalence(seeding_values, reseeded_values)
|
||||
assert data_equivalence(samples, resamples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_invalid_space_seed(space):
|
||||
with pytest.raises((ValueError, TypeError, Error)):
|
||||
space.seed("abc")
|
||||
|
@@ -37,17 +37,16 @@ def test_sequence_inheritance():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space, seed, expected_len",
|
||||
"space, seed",
|
||||
[
|
||||
(Tuple([Discrete(5), Discrete(4)]), None, 2),
|
||||
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
|
||||
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
|
||||
(Tuple([Discrete(5), Discrete(4)]), None),
|
||||
(Tuple([Discrete(5), Discrete(4)]), 123),
|
||||
(Tuple([Discrete(5), Discrete(4)]), (123, 456)),
|
||||
(
|
||||
Tuple(
|
||||
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
|
||||
),
|
||||
(123, (456, 789)),
|
||||
3,
|
||||
),
|
||||
(
|
||||
Tuple(
|
||||
@@ -57,22 +56,21 @@ def test_sequence_inheritance():
|
||||
)
|
||||
),
|
||||
(123, {"position": 456, "velocity": 789}),
|
||||
3,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_seeds(space, seed, expected_len):
|
||||
seeds = space.seed(seed)
|
||||
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
|
||||
assert len(seeds) == expected_len
|
||||
def test_seeds(space, seed):
|
||||
seeds1 = space.seed(seed)
|
||||
assert isinstance(seeds1, tuple)
|
||||
assert len(seeds1) == len(space)
|
||||
|
||||
sample1 = space.sample()
|
||||
|
||||
seeds2 = space.seed(seed)
|
||||
seeds2 = space.seed(seeds1)
|
||||
sample2 = space.sample()
|
||||
|
||||
data_equivalence(seeds, seeds2)
|
||||
data_equivalence(sample1, sample2)
|
||||
assert data_equivalence(seeds1, seeds2)
|
||||
assert data_equivalence(sample1, sample2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@@ -50,6 +50,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
# Sequence
|
||||
None,
|
||||
None,
|
||||
@@ -60,6 +61,7 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
|
||||
4,
|
||||
5,
|
||||
]
|
||||
assert len(TESTING_SPACES) == len(TESTING_SPACES_EXPECTED_FLATDIMS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@@ -100,6 +100,7 @@ TESTING_COMPOSITE_SPACES = [
|
||||
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
|
||||
),
|
||||
# Graph spaces
|
||||
Graph(node_space=Box(-1, 1, shape=(2,)), edge_space=None),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Discrete(3), edge_space=Discrete(4)),
|
||||
|
Reference in New Issue
Block a user