Fix RandomNumberGenerator pickling (#2639)

* Fix RandomNumberGenerator pickling

* Add test for RNG pickling fix

* Fix Python 3.7 compatibility about tuple unpacking

See: https://bugs.python.org/issue32117

* Fix formatting issue

* Add test for space pickling
This commit is contained in:
JYX
2022-03-02 23:38:26 +08:00
committed by GitHub
parent 8d4dff1b66
commit 15b5c6c29f
3 changed files with 86 additions and 0 deletions

View File

@@ -88,6 +88,31 @@ class RandomNumberGenerator(np.random.Generator):
set_state.__doc__ = np.random.set_state.__doc__
seed.__doc__ = np.random.seed.__doc__
def __reduce__(self):
# np.random.Generator defines __reduce__, but it's hard-coded to
# return a Generator instead of its subclass RandomNumberGenerator.
# We need to override it here, otherwise sampling from a Space will
# be broken after pickling and unpickling, due to using the deprecated
# methods defined above.
# See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223
# And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37
_, init_args, *args = np.random.Generator.__reduce__(self)
return (RandomNumberGenerator._generator_ctor, init_args, *args)
@staticmethod
def _generator_ctor(bit_generator_name="MT19937"):
# Workaround method for RandomNumberGenerator pickling, see __reduce__ above.
# Ported from numpy.random._pickle.__generator_ctor function.
from numpy.random._pickle import BitGenerators
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(
f"{bit_generator_name} is not a known BitGenerator module."
)
return RandomNumberGenerator(bit_generator())
RNG = RandomNumberGenerator

View File

@@ -1,5 +1,7 @@
import json # note: ujson fails this test due to float equality
import copy
import pickle
import tempfile
import numpy as np
import pytest
@@ -539,3 +541,50 @@ def test_infinite_space(space):
assert (
space.low.dtype == space.dtype
), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
@pytest.mark.parametrize(
"space",
[
Discrete(3),
Discrete(5, start=-2),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
[
Discrete(5),
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
MultiDiscrete([2, 2, 100]),
MultiBinary(10),
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
)
def test_pickle(space):
space.sample()
# Pickle and unpickle with a string
pickled = pickle.dumps(space)
space2 = pickle.loads(pickled)
# Pickle and unpickle with a file
with tempfile.TemporaryFile() as f:
pickle.dump(space, f)
f.seek(0)
space3 = pickle.load(f)
sample = space.sample()
sample2 = space2.sample()
sample3 = space3.sample()
assert sample_equal(sample, sample2)
assert sample_equal(sample, sample3)

View File

@@ -1,3 +1,5 @@
import pickle
from gym import error
from gym.utils import seeding
@@ -16,3 +18,13 @@ def test_valid_seeds():
for seed in [0, 1]:
random, seed1 = seeding.np_random(seed)
assert seed == seed1
def test_rng_pickle():
rng, _ = seeding.np_random(seed=0)
pickled = pickle.dumps(rng)
rng2 = pickle.loads(pickled)
assert isinstance(
rng2, seeding.RandomNumberGenerator
), "Unpickled object is not a RandomNumberGenerator"
assert rng.random() == rng2.random()