From 0e99e3c62436aaf7bc36a6e52a542d6bb75edb98 Mon Sep 17 00:00:00 2001 From: Jordan Terry Date: Tue, 31 May 2022 23:53:13 -0400 Subject: [PATCH] Added support for unpickling legacy Box (#2851) (#2854) Co-authored-by: Mark Towers --- gym/spaces/box.py | 13 ++++++++++++- tests/spaces/test_spaces.py | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/gym/spaces/box.py b/gym/spaces/box.py index 012eb5aaf..e9b62c0a2 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -1,5 +1,5 @@ """Implementation of a space that represents closed boxes in euclidean space.""" -from typing import List, Optional, Sequence, SupportsFloat, Tuple, Type, Union +from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union import numpy as np @@ -232,6 +232,17 @@ class Box(Space[np.ndarray]): and np.allclose(self.high, other.high) ) + def __setstate__(self, state: Dict): + """Sets the state of the box for unpickling a box with legacy support.""" + super().__setstate__(state) + + # legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state + if not hasattr(self, "low_repr"): + self.low_repr = _short_repr(self.low) + + if not hasattr(self, "high_repr"): + self.high_repr = _short_repr(self.high) + def get_inf(dtype, sign: str) -> SupportsFloat: """Returns an infinite that doesn't break things. diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index 85739028e..a65217db5 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -610,6 +610,28 @@ def test_discrete_legacy_state_pickling(): assert d.n == 3 +def test_box_legacy_state_pickling(): + legacy_state = { + "dtype": np.dtype("float32"), + "_shape": (5,), + "low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32), + "high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32), + "bounded_below": np.array([True, True, True, True, True]), + "bounded_above": np.array([True, True, True, True, True]), + "_np_random": None, + } + + b = Box(-1, 1, ()) + assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__ + del b.__dict__["low_repr"] + del b.__dict__["high_repr"] + assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__ + + b.__setstate__(legacy_state) + assert b.low_repr == "0.0" + assert b.high_repr == "1.0" + + @pytest.mark.parametrize( "space", [