mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-06 16:01:45 +00:00
41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
![]() |
import numpy as np
|
||
|
|
||
|
from gym.spaces import Discrete
|
||
|
|
||
|
|
||
|
def test_space_legacy_pickling():
|
||
|
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
|
||
|
legacy_state = {
|
||
|
"shape": (
|
||
|
1,
|
||
|
2,
|
||
|
3,
|
||
|
),
|
||
|
"dtype": np.int64,
|
||
|
"np_random": np.random.default_rng(),
|
||
|
"n": 3,
|
||
|
}
|
||
|
space = Discrete(1)
|
||
|
space.__setstate__(legacy_state)
|
||
|
|
||
|
assert space.shape == legacy_state["shape"]
|
||
|
assert space.np_random == legacy_state["np_random"]
|
||
|
assert space.n == 3
|
||
|
assert space.dtype == legacy_state["dtype"]
|
||
|
|
||
|
# Test that start is missing
|
||
|
assert "start" in space.__dict__
|
||
|
del space.__dict__["start"] # legacy did not include start param
|
||
|
assert "start" not in space.__dict__
|
||
|
|
||
|
space.__setstate__(legacy_state)
|
||
|
assert space.start == 0
|
||
|
|
||
|
|
||
|
def test_sample_mask():
|
||
|
space = Discrete(4, start=2)
|
||
|
assert 2 <= space.sample() < 6
|
||
|
assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
|
||
|
assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
|
||
|
assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]
|