Added support for unpickling legacy Box (#2851) (#2854)

Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
Jordan Terry
2022-05-31 23:53:13 -04:00
committed by GitHub
parent 34fba52e4b
commit 0e99e3c624
2 changed files with 34 additions and 1 deletions

View File

@@ -1,5 +1,5 @@
"""Implementation of a space that represents closed boxes in euclidean space."""
from typing import List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
import numpy as np
@@ -232,6 +232,17 @@ class Box(Space[np.ndarray]):
and np.allclose(self.high, other.high)
)
def __setstate__(self, state: Dict):
"""Sets the state of the box for unpickling a box with legacy support."""
super().__setstate__(state)
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
if not hasattr(self, "low_repr"):
self.low_repr = _short_repr(self.low)
if not hasattr(self, "high_repr"):
self.high_repr = _short_repr(self.high)
def get_inf(dtype, sign: str) -> SupportsFloat:
"""Returns an infinite that doesn't break things.

View File

@@ -610,6 +610,28 @@ def test_discrete_legacy_state_pickling():
assert d.n == 3
def test_box_legacy_state_pickling():
legacy_state = {
"dtype": np.dtype("float32"),
"_shape": (5,),
"low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
"high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
"bounded_below": np.array([True, True, True, True, True]),
"bounded_above": np.array([True, True, True, True, True]),
"_np_random": None,
}
b = Box(-1, 1, ())
assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
del b.__dict__["low_repr"]
del b.__dict__["high_repr"]
assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__
b.__setstate__(legacy_state)
assert b.low_repr == "0.0"
assert b.high_repr == "1.0"
@pytest.mark.parametrize(
"space",
[