Fix seed method for spaces.Tuple and spaces.Dict (#2365)

* Fix seed method for Tuple and Dict

* Improve stochasticity

* Update test cases for seed method

* Update test cases for seed method

Update test cases for seed method

Update test cases for seed method
This commit is contained in:
Xuehai Pan
2021-09-02 22:15:34 +08:00
committed by GitHub
parent c00c1babb9
commit bb8e8063e9
3 changed files with 160 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
from collections import OrderedDict
import numpy as np
from .space import Space
@@ -52,7 +53,24 @@ class Dict(Space):
) # None for shape and dtype, since it'll require special handling
def seed(self, seed=None):
[space.seed(seed) for space in self.spaces.values()]
seed = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)
for subspace, subseed in zip(self.spaces.values(), subseeds):
seed.append(subspace.seed(int(subseed))[0])
return seed
def sample(self):
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])

View File

@@ -1,5 +1,6 @@
import json # note: ujson fails this test due to float equality
from copy import copy
import copy
from collections import OrderedDict
import numpy as np
import pytest
@@ -80,7 +81,7 @@ def test_roundtripping(space):
)
def test_equality(space):
space1 = space
space2 = copy(space)
space2 = copy.copy(space)
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
@@ -193,3 +194,123 @@ def test_box_dtype_check():
# float64 is not in float32 space
assert not space.contains(np.array(0.5))
assert not space.contains(np.array(1))
@pytest.mark.parametrize(
"space",
[
Discrete(3),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
[
Discrete(5),
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 100]),
MultiBinary(10),
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
)
def test_seed_returns_list(space):
def assert_integer_list(seed):
assert isinstance(seed, list)
assert len(seed) >= 1
assert all([isinstance(s, int) for s in seed])
assert_integer_list(space.seed(None))
assert_integer_list(space.seed(0))
def convert_sample_hashable(sample):
if isinstance(sample, np.ndarray):
return tuple(sample.tolist())
if isinstance(sample, (list, tuple)):
return tuple(convert_sample_hashable(s) for s in sample)
if isinstance(sample, dict):
return tuple(
(key, convert_sample_hashable(value)) for key, value in sample.items()
)
return sample
@pytest.mark.parametrize(
"space",
[
Discrete(3),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
[
Discrete(5),
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 100]),
MultiBinary(10),
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
)
def test_seed_reproducibility(space):
space1 = space
space2 = copy.deepcopy(space)
space1.seed(None)
space2.seed(None)
assert space1.seed(0) == space2.seed(0)
sample1, sample2 = space1.sample(), space2.sample()
assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
@pytest.mark.parametrize(
"space",
[
Tuple([Discrete(100), Discrete(100)]),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
[
Discrete(5),
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
)
def test_seed_subspace_incorrelated(space):
subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()
space.seed(0)
states = [
convert_sample_hashable(subspace.np_random.get_state())
for subspace in subspaces
]
assert len(states) == len(set(states))

View File

@@ -19,7 +19,24 @@ class Tuple(Space):
super(Tuple, self).__init__(None, None)
def seed(self, seed=None):
[space.seed(seed) for space in self.spaces]
seed = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)
for subspace, subseed in zip(self.spaces, subseeds):
seed.append(subspace.seed(int(subseed))[0])
return seed
def sample(self):
return tuple([space.sample() for space in self.spaces])