mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""Implementation of a space that represents closed boxes in euclidean space."""
|
||||
from typing import List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
||||
from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -232,6 +232,17 @@ class Box(Space[np.ndarray]):
|
||||
and np.allclose(self.high, other.high)
|
||||
)
|
||||
|
||||
def __setstate__(self, state: Dict):
|
||||
"""Sets the state of the box for unpickling a box with legacy support."""
|
||||
super().__setstate__(state)
|
||||
|
||||
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
|
||||
if not hasattr(self, "low_repr"):
|
||||
self.low_repr = _short_repr(self.low)
|
||||
|
||||
if not hasattr(self, "high_repr"):
|
||||
self.high_repr = _short_repr(self.high)
|
||||
|
||||
|
||||
def get_inf(dtype, sign: str) -> SupportsFloat:
|
||||
"""Returns an infinite that doesn't break things.
|
||||
|
@@ -610,6 +610,28 @@ def test_discrete_legacy_state_pickling():
|
||||
assert d.n == 3
|
||||
|
||||
|
||||
def test_box_legacy_state_pickling():
|
||||
legacy_state = {
|
||||
"dtype": np.dtype("float32"),
|
||||
"_shape": (5,),
|
||||
"low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
||||
"bounded_below": np.array([True, True, True, True, True]),
|
||||
"bounded_above": np.array([True, True, True, True, True]),
|
||||
"_np_random": None,
|
||||
}
|
||||
|
||||
b = Box(-1, 1, ())
|
||||
assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
|
||||
del b.__dict__["low_repr"]
|
||||
del b.__dict__["high_repr"]
|
||||
assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__
|
||||
|
||||
b.__setstate__(legacy_state)
|
||||
assert b.low_repr == "0.0"
|
||||
assert b.high_repr == "1.0"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
[
|
||||
|
Reference in New Issue
Block a user