mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-28 09:17:18 +00:00
Fix RandomNumberGenerator pickling (#2639)
* Fix RandomNumberGenerator pickling * Add test for RNG pickling fix * Fix Python 3.7 compatibility about tuple unpacking See: https://bugs.python.org/issue32117 * Fix formatting issue * Add test for space pickling
This commit is contained in:
@@ -88,6 +88,31 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
set_state.__doc__ = np.random.set_state.__doc__
|
||||
seed.__doc__ = np.random.seed.__doc__
|
||||
|
||||
def __reduce__(self):
|
||||
# np.random.Generator defines __reduce__, but it's hard-coded to
|
||||
# return a Generator instead of its subclass RandomNumberGenerator.
|
||||
# We need to override it here, otherwise sampling from a Space will
|
||||
# be broken after pickling and unpickling, due to using the deprecated
|
||||
# methods defined above.
|
||||
# See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223
|
||||
# And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37
|
||||
_, init_args, *args = np.random.Generator.__reduce__(self)
|
||||
return (RandomNumberGenerator._generator_ctor, init_args, *args)
|
||||
|
||||
@staticmethod
|
||||
def _generator_ctor(bit_generator_name="MT19937"):
|
||||
# Workaround method for RandomNumberGenerator pickling, see __reduce__ above.
|
||||
# Ported from numpy.random._pickle.__generator_ctor function.
|
||||
from numpy.random._pickle import BitGenerators
|
||||
|
||||
if bit_generator_name in BitGenerators:
|
||||
bit_generator = BitGenerators[bit_generator_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{bit_generator_name} is not a known BitGenerator module."
|
||||
)
|
||||
return RandomNumberGenerator(bit_generator())
|
||||
|
||||
|
||||
RNG = RandomNumberGenerator
|
||||
|
||||
|
Reference in New Issue
Block a user