mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""Implementation of a space that represents closed boxes in euclidean space."""
|
"""Implementation of a space that represents closed boxes in euclidean space."""
|
||||||
from typing import List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -232,6 +232,17 @@ class Box(Space[np.ndarray]):
|
|||||||
and np.allclose(self.high, other.high)
|
and np.allclose(self.high, other.high)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict):
|
||||||
|
"""Sets the state of the box for unpickling a box with legacy support."""
|
||||||
|
super().__setstate__(state)
|
||||||
|
|
||||||
|
# legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
|
||||||
|
if not hasattr(self, "low_repr"):
|
||||||
|
self.low_repr = _short_repr(self.low)
|
||||||
|
|
||||||
|
if not hasattr(self, "high_repr"):
|
||||||
|
self.high_repr = _short_repr(self.high)
|
||||||
|
|
||||||
|
|
||||||
def get_inf(dtype, sign: str) -> SupportsFloat:
|
def get_inf(dtype, sign: str) -> SupportsFloat:
|
||||||
"""Returns an infinite that doesn't break things.
|
"""Returns an infinite that doesn't break things.
|
||||||
|
@@ -610,6 +610,28 @@ def test_discrete_legacy_state_pickling():
|
|||||||
assert d.n == 3
|
assert d.n == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_box_legacy_state_pickling():
|
||||||
|
legacy_state = {
|
||||||
|
"dtype": np.dtype("float32"),
|
||||||
|
"_shape": (5,),
|
||||||
|
"low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||||
|
"high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
||||||
|
"bounded_below": np.array([True, True, True, True, True]),
|
||||||
|
"bounded_above": np.array([True, True, True, True, True]),
|
||||||
|
"_np_random": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
b = Box(-1, 1, ())
|
||||||
|
assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
|
||||||
|
del b.__dict__["low_repr"]
|
||||||
|
del b.__dict__["high_repr"]
|
||||||
|
assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__
|
||||||
|
|
||||||
|
b.__setstate__(legacy_state)
|
||||||
|
assert b.low_repr == "0.0"
|
||||||
|
assert b.high_repr == "1.0"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"space",
|
"space",
|
||||||
[
|
[
|
||||||
|
Reference in New Issue
Block a user