mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-31 02:06:08 +00:00
Fix seed
method for spaces.Tuple
and spaces.Dict
(#2365)
* Fix seed method for Tuple and Dict * Improve stochasticity * Update test cases for seed method * Update test cases for seed method Update test cases for seed method Update test cases for seed method
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import numpy as np
|
||||||
from .space import Space
|
from .space import Space
|
||||||
|
|
||||||
|
|
||||||
@@ -52,7 +53,24 @@ class Dict(Space):
|
|||||||
) # None for shape and dtype, since it'll require special handling
|
) # None for shape and dtype, since it'll require special handling
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
[space.seed(seed) for space in self.spaces.values()]
|
seed = super().seed(seed)
|
||||||
|
try:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=False, # unique subseed for each subspace
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=True, # we get more than INT_MAX subspaces
|
||||||
|
)
|
||||||
|
|
||||||
|
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
||||||
|
seed.append(subspace.seed(int(subseed))[0])
|
||||||
|
|
||||||
|
return seed
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import json # note: ujson fails this test due to float equality
|
import json # note: ujson fails this test due to float equality
|
||||||
from copy import copy
|
import copy
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -80,7 +81,7 @@ def test_roundtripping(space):
|
|||||||
)
|
)
|
||||||
def test_equality(space):
|
def test_equality(space):
|
||||||
space1 = space
|
space1 = space
|
||||||
space2 = copy(space)
|
space2 = copy.copy(space)
|
||||||
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
|
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
|
||||||
|
|
||||||
|
|
||||||
@@ -193,3 +194,123 @@ def test_box_dtype_check():
|
|||||||
# float64 is not in float32 space
|
# float64 is not in float32 space
|
||||||
assert not space.contains(np.array(0.5))
|
assert not space.contains(np.array(0.5))
|
||||||
assert not space.contains(np.array(1))
|
assert not space.contains(np.array(1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"space",
|
||||||
|
[
|
||||||
|
Discrete(3),
|
||||||
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple(
|
||||||
|
[
|
||||||
|
Discrete(5),
|
||||||
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
MultiDiscrete([2, 2, 100]),
|
||||||
|
MultiBinary(10),
|
||||||
|
Dict(
|
||||||
|
{
|
||||||
|
"position": Discrete(5),
|
||||||
|
"velocity": Box(
|
||||||
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_seed_returns_list(space):
|
||||||
|
def assert_integer_list(seed):
|
||||||
|
assert isinstance(seed, list)
|
||||||
|
assert len(seed) >= 1
|
||||||
|
assert all([isinstance(s, int) for s in seed])
|
||||||
|
|
||||||
|
assert_integer_list(space.seed(None))
|
||||||
|
assert_integer_list(space.seed(0))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sample_hashable(sample):
|
||||||
|
if isinstance(sample, np.ndarray):
|
||||||
|
return tuple(sample.tolist())
|
||||||
|
if isinstance(sample, (list, tuple)):
|
||||||
|
return tuple(convert_sample_hashable(s) for s in sample)
|
||||||
|
if isinstance(sample, dict):
|
||||||
|
return tuple(
|
||||||
|
(key, convert_sample_hashable(value)) for key, value in sample.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"space",
|
||||||
|
[
|
||||||
|
Discrete(3),
|
||||||
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple(
|
||||||
|
[
|
||||||
|
Discrete(5),
|
||||||
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
MultiDiscrete([2, 2, 100]),
|
||||||
|
MultiBinary(10),
|
||||||
|
Dict(
|
||||||
|
{
|
||||||
|
"position": Discrete(5),
|
||||||
|
"velocity": Box(
|
||||||
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_seed_reproducibility(space):
|
||||||
|
space1 = space
|
||||||
|
space2 = copy.deepcopy(space)
|
||||||
|
|
||||||
|
space1.seed(None)
|
||||||
|
space2.seed(None)
|
||||||
|
|
||||||
|
assert space1.seed(0) == space2.seed(0)
|
||||||
|
|
||||||
|
sample1, sample2 = space1.sample(), space2.sample()
|
||||||
|
assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"space",
|
||||||
|
[
|
||||||
|
Tuple([Discrete(100), Discrete(100)]),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple(
|
||||||
|
[
|
||||||
|
Discrete(5),
|
||||||
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
Dict(
|
||||||
|
{
|
||||||
|
"position": Discrete(5),
|
||||||
|
"velocity": Box(
|
||||||
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_seed_subspace_incorrelated(space):
|
||||||
|
subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()
|
||||||
|
|
||||||
|
space.seed(0)
|
||||||
|
states = [
|
||||||
|
convert_sample_hashable(subspace.np_random.get_state())
|
||||||
|
for subspace in subspaces
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(states) == len(set(states))
|
||||||
|
@@ -19,7 +19,24 @@ class Tuple(Space):
|
|||||||
super(Tuple, self).__init__(None, None)
|
super(Tuple, self).__init__(None, None)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
[space.seed(seed) for space in self.spaces]
|
seed = super().seed(seed)
|
||||||
|
try:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=False, # unique subseed for each subspace
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=True, # we get more than INT_MAX subspaces
|
||||||
|
)
|
||||||
|
|
||||||
|
for subspace, subseed in zip(self.spaces, subseeds):
|
||||||
|
seed.append(subspace.seed(int(subseed))[0])
|
||||||
|
|
||||||
|
return seed
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return tuple([space.sample() for space in self.spaces])
|
return tuple([space.sample() for space in self.spaces])
|
||||||
|
Reference in New Issue
Block a user