mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-27 08:47:08 +00:00
Fix RandomNumberGenerator pickling (#2639)
* Fix RandomNumberGenerator pickling * Add test for RNG pickling fix * Fix Python 3.7 compatibility about tuple unpacking See: https://bugs.python.org/issue32117 * Fix formatting issue * Add test for space pickling
This commit is contained in:
@@ -88,6 +88,31 @@ class RandomNumberGenerator(np.random.Generator):
|
|||||||
set_state.__doc__ = np.random.set_state.__doc__
|
set_state.__doc__ = np.random.set_state.__doc__
|
||||||
seed.__doc__ = np.random.seed.__doc__
|
seed.__doc__ = np.random.seed.__doc__
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
# np.random.Generator defines __reduce__, but it's hard-coded to
|
||||||
|
# return a Generator instead of its subclass RandomNumberGenerator.
|
||||||
|
# We need to override it here, otherwise sampling from a Space will
|
||||||
|
# be broken after pickling and unpickling, due to using the deprecated
|
||||||
|
# methods defined above.
|
||||||
|
# See: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_generator.pyx#L221-L223
|
||||||
|
# And: https://github.com/numpy/numpy/blob/41d37b714caa1eef72f984d529f1d40ed48ce535/numpy/random/_pickle.py#L17-L37
|
||||||
|
_, init_args, *args = np.random.Generator.__reduce__(self)
|
||||||
|
return (RandomNumberGenerator._generator_ctor, init_args, *args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generator_ctor(bit_generator_name="MT19937"):
|
||||||
|
# Workaround method for RandomNumberGenerator pickling, see __reduce__ above.
|
||||||
|
# Ported from numpy.random._pickle.__generator_ctor function.
|
||||||
|
from numpy.random._pickle import BitGenerators
|
||||||
|
|
||||||
|
if bit_generator_name in BitGenerators:
|
||||||
|
bit_generator = BitGenerators[bit_generator_name]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{bit_generator_name} is not a known BitGenerator module."
|
||||||
|
)
|
||||||
|
return RandomNumberGenerator(bit_generator())
|
||||||
|
|
||||||
|
|
||||||
RNG = RandomNumberGenerator
|
RNG = RandomNumberGenerator
|
||||||
|
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
import json # note: ujson fails this test due to float equality
|
import json # note: ujson fails this test due to float equality
|
||||||
import copy
|
import copy
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -539,3 +541,50 @@ def test_infinite_space(space):
|
|||||||
assert (
|
assert (
|
||||||
space.low.dtype == space.dtype
|
space.low.dtype == space.dtype
|
||||||
), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"space",
|
||||||
|
[
|
||||||
|
Discrete(3),
|
||||||
|
Discrete(5, start=-2),
|
||||||
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple(
|
||||||
|
[
|
||||||
|
Discrete(5),
|
||||||
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
||||||
|
MultiDiscrete([2, 2, 100]),
|
||||||
|
MultiBinary(10),
|
||||||
|
Dict(
|
||||||
|
{
|
||||||
|
"position": Discrete(5),
|
||||||
|
"velocity": Box(
|
||||||
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_pickle(space):
|
||||||
|
space.sample()
|
||||||
|
|
||||||
|
# Pickle and unpickle with a string
|
||||||
|
pickled = pickle.dumps(space)
|
||||||
|
space2 = pickle.loads(pickled)
|
||||||
|
|
||||||
|
# Pickle and unpickle with a file
|
||||||
|
with tempfile.TemporaryFile() as f:
|
||||||
|
pickle.dump(space, f)
|
||||||
|
f.seek(0)
|
||||||
|
space3 = pickle.load(f)
|
||||||
|
|
||||||
|
sample = space.sample()
|
||||||
|
sample2 = space2.sample()
|
||||||
|
sample3 = space3.sample()
|
||||||
|
assert sample_equal(sample, sample2)
|
||||||
|
assert sample_equal(sample, sample3)
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
from gym import error
|
from gym import error
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
@@ -16,3 +18,13 @@ def test_valid_seeds():
|
|||||||
for seed in [0, 1]:
|
for seed in [0, 1]:
|
||||||
random, seed1 = seeding.np_random(seed)
|
random, seed1 = seeding.np_random(seed)
|
||||||
assert seed == seed1
|
assert seed == seed1
|
||||||
|
|
||||||
|
|
||||||
|
def test_rng_pickle():
|
||||||
|
rng, _ = seeding.np_random(seed=0)
|
||||||
|
pickled = pickle.dumps(rng)
|
||||||
|
rng2 = pickle.loads(pickled)
|
||||||
|
assert isinstance(
|
||||||
|
rng2, seeding.RandomNumberGenerator
|
||||||
|
), "Unpickled object is not a RandomNumberGenerator"
|
||||||
|
assert rng.random() == rng2.random()
|
||||||
|
Reference in New Issue
Block a user