mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-27 00:37:19 +00:00
Fix RandomNumberGenerator pickling (#2639)
* Fix RandomNumberGenerator pickling * Add test for RNG pickling fix * Fix Python 3.7 compatibility about tuple unpacking See: https://bugs.python.org/issue32117 * Fix formatting issue * Add test for space pickling
This commit is contained in:
@@ -88,6 +88,31 @@ class RandomNumberGenerator(np.random.Generator):
|
||||
set_state.__doc__ = np.random.set_state.__doc__
|
||||
seed.__doc__ = np.random.seed.__doc__
|
||||
|
||||
def __reduce__(self):
|
||||
# np.random.Generator defines __reduce__, but it's hard-coded to
|
||||
# return a Generator instead of its subclass RandomNumberGenerator.
|
||||
# We need to override it here, otherwise sampling from a Space will
|
||||
# be broken after pickling and unpickling, due to using the deprecated
|
||||
# methods defined above.
|
||||
# See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223
|
||||
# And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37
|
||||
_, init_args, *args = np.random.Generator.__reduce__(self)
|
||||
return (RandomNumberGenerator._generator_ctor, init_args, *args)
|
||||
|
||||
@staticmethod
|
||||
def _generator_ctor(bit_generator_name="MT19937"):
|
||||
# Workaround method for RandomNumberGenerator pickling, see __reduce__ above.
|
||||
# Ported from numpy.random._pickle.__generator_ctor function.
|
||||
from numpy.random._pickle import BitGenerators
|
||||
|
||||
if bit_generator_name in BitGenerators:
|
||||
bit_generator = BitGenerators[bit_generator_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{bit_generator_name} is not a known BitGenerator module."
|
||||
)
|
||||
return RandomNumberGenerator(bit_generator())
|
||||
|
||||
|
||||
RNG = RandomNumberGenerator
|
||||
|
||||
|
@@ -1,5 +1,7 @@
|
||||
import json # note: ujson fails this test due to float equality
|
||||
import copy
|
||||
import pickle
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -539,3 +541,50 @@ def test_infinite_space(space):
|
||||
assert (
|
||||
space.low.dtype == space.dtype
|
||||
), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
[
|
||||
Discrete(3),
|
||||
Discrete(5, start=-2),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_pickle(space):
|
||||
space.sample()
|
||||
|
||||
# Pickle and unpickle with a string
|
||||
pickled = pickle.dumps(space)
|
||||
space2 = pickle.loads(pickled)
|
||||
|
||||
# Pickle and unpickle with a file
|
||||
with tempfile.TemporaryFile() as f:
|
||||
pickle.dump(space, f)
|
||||
f.seek(0)
|
||||
space3 = pickle.load(f)
|
||||
|
||||
sample = space.sample()
|
||||
sample2 = space2.sample()
|
||||
sample3 = space3.sample()
|
||||
assert sample_equal(sample, sample2)
|
||||
assert sample_equal(sample, sample3)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
import pickle
|
||||
|
||||
from gym import error
|
||||
from gym.utils import seeding
|
||||
|
||||
@@ -16,3 +18,13 @@ def test_valid_seeds():
|
||||
for seed in [0, 1]:
|
||||
random, seed1 = seeding.np_random(seed)
|
||||
assert seed == seed1
|
||||
|
||||
|
||||
def test_rng_pickle():
|
||||
rng, _ = seeding.np_random(seed=0)
|
||||
pickled = pickle.dumps(rng)
|
||||
rng2 = pickle.loads(pickled)
|
||||
assert isinstance(
|
||||
rng2, seeding.RandomNumberGenerator
|
||||
), "Unpickled object is not a RandomNumberGenerator"
|
||||
assert rng.random() == rng2.random()
|
||||
|
Reference in New Issue
Block a user