Fix RandomNumberGenerator pickling (#2639)

* Fix RandomNumberGenerator pickling

* Add test for RNG pickling fix

* Fix Python 3.7 compatibility about tuple unpacking

See: https://bugs.python.org/issue32117

* Fix formatting issue

* Add test for space pickling
This commit is contained in:
JYX
2022-03-02 23:38:26 +08:00
committed by GitHub
parent 8d4dff1b66
commit 15b5c6c29f
3 changed files with 86 additions and 0 deletions

View File

@@ -88,6 +88,31 @@ class RandomNumberGenerator(np.random.Generator):
set_state.__doc__ = np.random.set_state.__doc__
seed.__doc__ = np.random.seed.__doc__
def __reduce__(self):
# np.random.Generator defines __reduce__, but it's hard-coded to
# return a Generator instead of its subclass RandomNumberGenerator.
# We need to override it here, otherwise sampling from a Space will
# be broken after pickling and unpickling, due to using the deprecated
# methods defined above.
# See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223
# And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37
_, init_args, *args = np.random.Generator.__reduce__(self)
return (RandomNumberGenerator._generator_ctor, init_args, *args)
@staticmethod
def _generator_ctor(bit_generator_name="MT19937"):
# Workaround method for RandomNumberGenerator pickling, see __reduce__ above.
# Ported from numpy.random._pickle.__generator_ctor function.
from numpy.random._pickle import BitGenerators
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(
f"{bit_generator_name} is not a known BitGenerator module."
)
return RandomNumberGenerator(bit_generator())
RNG = RandomNumberGenerator