mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 09:55:39 +00:00
Fix seed
method for spaces.Tuple
and spaces.Dict
(#2365)
* Fix seed method for Tuple and Dict * Improve stochasticity * Update test cases for seed method * Update test cases for seed method Update test cases for seed method Update test cases for seed method
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
from .space import Space
|
||||
|
||||
|
||||
@@ -52,7 +53,24 @@ class Dict(Space):
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
|
||||
def seed(self, seed=None):
|
||||
[space.seed(seed) for space in self.spaces.values()]
|
||||
seed = super().seed(seed)
|
||||
try:
|
||||
subseeds = self.np_random.choice(
|
||||
np.iinfo(int).max,
|
||||
size=len(self.spaces),
|
||||
replace=False, # unique subseed for each subspace
|
||||
)
|
||||
except ValueError:
|
||||
subseeds = self.np_random.choice(
|
||||
np.iinfo(int).max,
|
||||
size=len(self.spaces),
|
||||
replace=True, # we get more than INT_MAX subspaces
|
||||
)
|
||||
|
||||
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
||||
seed.append(subspace.seed(int(subseed))[0])
|
||||
|
||||
return seed
|
||||
|
||||
def sample(self):
|
||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import json # note: ujson fails this test due to float equality
|
||||
from copy import copy
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -80,7 +81,7 @@ def test_roundtripping(space):
|
||||
)
|
||||
def test_equality(space):
|
||||
space1 = space
|
||||
space2 = copy(space)
|
||||
space2 = copy.copy(space)
|
||||
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
|
||||
|
||||
|
||||
@@ -193,3 +194,123 @@ def test_box_dtype_check():
|
||||
# float64 is not in float32 space
|
||||
assert not space.contains(np.array(0.5))
|
||||
assert not space.contains(np.array(1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_seed_returns_list(space):
|
||||
def assert_integer_list(seed):
|
||||
assert isinstance(seed, list)
|
||||
assert len(seed) >= 1
|
||||
assert all([isinstance(s, int) for s in seed])
|
||||
|
||||
assert_integer_list(space.seed(None))
|
||||
assert_integer_list(space.seed(0))
|
||||
|
||||
|
||||
def convert_sample_hashable(sample):
|
||||
if isinstance(sample, np.ndarray):
|
||||
return tuple(sample.tolist())
|
||||
if isinstance(sample, (list, tuple)):
|
||||
return tuple(convert_sample_hashable(s) for s in sample)
|
||||
if isinstance(sample, dict):
|
||||
return tuple(
|
||||
(key, convert_sample_hashable(value)) for key, value in sample.items()
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_seed_reproducibility(space):
|
||||
space1 = space
|
||||
space2 = copy.deepcopy(space)
|
||||
|
||||
space1.seed(None)
|
||||
space2.seed(None)
|
||||
|
||||
assert space1.seed(0) == space2.seed(0)
|
||||
|
||||
sample1, sample2 = space1.sample(), space2.sample()
|
||||
assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
[
|
||||
Tuple([Discrete(100), Discrete(100)]),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_seed_subspace_incorrelated(space):
|
||||
subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()
|
||||
|
||||
space.seed(0)
|
||||
states = [
|
||||
convert_sample_hashable(subspace.np_random.get_state())
|
||||
for subspace in subspaces
|
||||
]
|
||||
|
||||
assert len(states) == len(set(states))
|
||||
|
@@ -19,7 +19,24 @@ class Tuple(Space):
|
||||
super(Tuple, self).__init__(None, None)
|
||||
|
||||
def seed(self, seed=None):
|
||||
[space.seed(seed) for space in self.spaces]
|
||||
seed = super().seed(seed)
|
||||
try:
|
||||
subseeds = self.np_random.choice(
|
||||
np.iinfo(int).max,
|
||||
size=len(self.spaces),
|
||||
replace=False, # unique subseed for each subspace
|
||||
)
|
||||
except ValueError:
|
||||
subseeds = self.np_random.choice(
|
||||
np.iinfo(int).max,
|
||||
size=len(self.spaces),
|
||||
replace=True, # we get more than INT_MAX subspaces
|
||||
)
|
||||
|
||||
for subspace, subseed in zip(self.spaces, subseeds):
|
||||
seed.append(subspace.seed(int(subseed))[0])
|
||||
|
||||
return seed
|
||||
|
||||
def sample(self):
|
||||
return tuple([space.sample() for space in self.spaces])
|
||||
|
Reference in New Issue
Block a user