From 15b5c6c29f86fba62b0573b0c62d0e49fb8b8a37 Mon Sep 17 00:00:00 2001 From: JYX Date: Wed, 2 Mar 2022 23:38:26 +0800 Subject: [PATCH] Fix RandomNumberGenerator pickling (#2639) * Fix RandomNumberGenerator pickling * Add test for RNG pickling fix * Fix Python 3.7 compatibility about tuple unpacking See: https://bugs.python.org/issue32117 * Fix formatting issue * Add test for space pickling --- gym/utils/seeding.py | 25 +++++++++++++++++++ tests/spaces/test_spaces.py | 49 +++++++++++++++++++++++++++++++++++++ tests/utils/test_seeding.py | 12 +++++++++ 3 files changed, 86 insertions(+) diff --git a/gym/utils/seeding.py b/gym/utils/seeding.py index 84571a479..7c3f629d9 100644 --- a/gym/utils/seeding.py +++ b/gym/utils/seeding.py @@ -88,6 +88,31 @@ class RandomNumberGenerator(np.random.Generator): set_state.__doc__ = np.random.set_state.__doc__ seed.__doc__ = np.random.seed.__doc__ + def __reduce__(self): + # np.random.Generator defines __reduce__, but it's hard-coded to + # return a Generator instead of its subclass RandomNumberGenerator. + # We need to override it here, otherwise sampling from a Space will + # be broken after pickling and unpickling, due to using the deprecated + # methods defined above. + # See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223 + # And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37 + _, init_args, *args = np.random.Generator.__reduce__(self) + return (RandomNumberGenerator._generator_ctor, init_args, *args) + + @staticmethod + def _generator_ctor(bit_generator_name="MT19937"): + # Workaround method for RandomNumberGenerator pickling, see __reduce__ above. + # Ported from numpy.random._pickle.__generator_ctor function. + from numpy.random._pickle import BitGenerators + + if bit_generator_name in BitGenerators: + bit_generator = BitGenerators[bit_generator_name] + else: + raise ValueError( + f"{bit_generator_name} is not a known BitGenerator module." + ) + return RandomNumberGenerator(bit_generator()) + RNG = RandomNumberGenerator diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index e0af7e7a3..ad5c4330d 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -1,5 +1,7 @@ import json # note: ujson fails this test due to float equality import copy +import pickle +import tempfile import numpy as np import pytest @@ -539,3 +541,50 @@ def test_infinite_space(space): assert ( space.low.dtype == space.dtype ), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'" + + +@pytest.mark.parametrize( + "space", + [ + Discrete(3), + Discrete(5, start=-2), + Box(low=0.0, high=np.inf, shape=(2, 2)), + Tuple([Discrete(5), Discrete(10)]), + Tuple( + [ + Discrete(5), + Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), + ] + ), + Tuple((Discrete(5), Discrete(2), Discrete(2))), + Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))), + MultiDiscrete([2, 2, 100]), + MultiBinary(10), + Dict( + { + "position": Discrete(5), + "velocity": Box( + low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 + ), + } + ), + ], +) +def test_pickle(space): + space.sample() + + # Pickle and unpickle with a string + pickled = pickle.dumps(space) + space2 = pickle.loads(pickled) + + # Pickle and unpickle with a file + with tempfile.TemporaryFile() as f: + pickle.dump(space, f) + f.seek(0) + space3 = pickle.load(f) + + sample = space.sample() + sample2 = space2.sample() + sample3 = space3.sample() + assert sample_equal(sample, sample2) + assert sample_equal(sample, sample3) diff --git a/tests/utils/test_seeding.py b/tests/utils/test_seeding.py index aa34e6a70..0aa59fdf8 100644 --- a/tests/utils/test_seeding.py +++ b/tests/utils/test_seeding.py @@ -1,3 +1,5 @@ +import pickle + from gym import error from gym.utils import seeding @@ -16,3 +18,13 @@ def test_valid_seeds(): for seed in [0, 1]: random, seed1 = seeding.np_random(seed) assert seed == seed1 + + +def test_rng_pickle(): + rng, _ = seeding.np_random(seed=0) + pickled = pickle.dumps(rng) + rng2 = pickle.loads(pickled) + assert isinstance( + rng2, seeding.RandomNumberGenerator + ), "Unpickled object is not a RandomNumberGenerator" + assert rng.random() == rng2.random()