diff --git a/gym/spaces/dict.py b/gym/spaces/dict.py index 6c65a5ee9..2a9d6ecc6 100644 --- a/gym/spaces/dict.py +++ b/gym/spaces/dict.py @@ -1,4 +1,5 @@ from collections import OrderedDict +import numpy as np from .space import Space @@ -52,7 +53,24 @@ class Dict(Space): ) # None for shape and dtype, since it'll require special handling def seed(self, seed=None): - [space.seed(seed) for space in self.spaces.values()] + seed = super().seed(seed) + try: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=False, # unique subseed for each subspace + ) + except ValueError: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=True, # we get more than INT_MAX subspaces + ) + + for subspace, subseed in zip(self.spaces.values(), subseeds): + seed.append(subspace.seed(int(subseed))[0]) + + return seed def sample(self): return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()]) diff --git a/gym/spaces/tests/test_spaces.py b/gym/spaces/tests/test_spaces.py index 2694f3091..5a43f154d 100644 --- a/gym/spaces/tests/test_spaces.py +++ b/gym/spaces/tests/test_spaces.py @@ -1,5 +1,6 @@ import json # note: ujson fails this test due to float equality -from copy import copy +import copy +from collections import OrderedDict import numpy as np import pytest @@ -80,7 +81,7 @@ def test_roundtripping(space): ) def test_equality(space): space1 = space - space2 = copy(space) + space2 = copy.copy(space) assert space1 == space2, "Expected {} to equal {}".format(space1, space2) @@ -193,3 +194,123 @@ def test_box_dtype_check(): # float64 is not in float32 space assert not space.contains(np.array(0.5)) assert not space.contains(np.array(1)) + + +@pytest.mark.parametrize( + "space", + [ + Discrete(3), + Box(low=0.0, high=np.inf, shape=(2, 2)), + Tuple([Discrete(5), Discrete(10)]), + Tuple( + [ + Discrete(5), + Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), + ] + ), + Tuple((Discrete(5), Discrete(2), Discrete(2))), + MultiDiscrete([2, 2, 100]), + MultiBinary(10), + Dict( + { + "position": Discrete(5), + "velocity": Box( + low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 + ), + } + ), + ], +) +def test_seed_returns_list(space): + def assert_integer_list(seed): + assert isinstance(seed, list) + assert len(seed) >= 1 + assert all([isinstance(s, int) for s in seed]) + + assert_integer_list(space.seed(None)) + assert_integer_list(space.seed(0)) + + +def convert_sample_hashable(sample): + if isinstance(sample, np.ndarray): + return tuple(sample.tolist()) + if isinstance(sample, (list, tuple)): + return tuple(convert_sample_hashable(s) for s in sample) + if isinstance(sample, dict): + return tuple( + (key, convert_sample_hashable(value)) for key, value in sample.items() + ) + + return sample + + +@pytest.mark.parametrize( + "space", + [ + Discrete(3), + Box(low=0.0, high=np.inf, shape=(2, 2)), + Tuple([Discrete(5), Discrete(10)]), + Tuple( + [ + Discrete(5), + Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), + ] + ), + Tuple((Discrete(5), Discrete(2), Discrete(2))), + MultiDiscrete([2, 2, 100]), + MultiBinary(10), + Dict( + { + "position": Discrete(5), + "velocity": Box( + low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 + ), + } + ), + ], +) +def test_seed_reproducibility(space): + space1 = space + space2 = copy.deepcopy(space) + + space1.seed(None) + space2.seed(None) + + assert space1.seed(0) == space2.seed(0) + + sample1, sample2 = space1.sample(), space2.sample() + assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2) + + +@pytest.mark.parametrize( + "space", + [ + Tuple([Discrete(100), Discrete(100)]), + Tuple([Discrete(5), Discrete(10)]), + Tuple( + [ + Discrete(5), + Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), + ] + ), + Tuple((Discrete(5), Discrete(2), Discrete(2))), + Dict( + { + "position": Discrete(5), + "velocity": Box( + low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 + ), + } + ), + ], +) +def test_seed_subspace_incorrelated(space): + subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values() + + space.seed(0) + states = [ + convert_sample_hashable(subspace.np_random.get_state()) + for subspace in subspaces + ] + + assert len(states) == len(set(states)) diff --git a/gym/spaces/tuple.py b/gym/spaces/tuple.py index fec18c41b..c00935ff2 100644 --- a/gym/spaces/tuple.py +++ b/gym/spaces/tuple.py @@ -19,7 +19,24 @@ class Tuple(Space): super(Tuple, self).__init__(None, None) def seed(self, seed=None): - [space.seed(seed) for space in self.spaces] + seed = super().seed(seed) + try: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=False, # unique subseed for each subspace + ) + except ValueError: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=True, # we get more than INT_MAX subspaces + ) + + for subspace, subseed in zip(self.spaces, subseeds): + seed.append(subspace.seed(int(subseed))[0]) + + return seed def sample(self): return tuple([space.sample() for space in self.spaces])