Files
Gymnasium/tests/spaces/test_discrete.py

35 lines
1.0 KiB
Python
Raw Normal View History

from copy import deepcopy
import numpy as np
2022-09-08 10:10:07 +01:00
from gymnasium.spaces import Discrete
def test_space_legacy_pickling():
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
# Test that start is corrected passed
space = Discrete(1, start=2)
state = space.__dict__
new_space = Discrete(1)
new_space.__setstate__(state)
assert new_space == space
assert new_space.start == 2
legacy_space = Discrete(1)
legacy_state = deepcopy(legacy_space.__dict__)
del legacy_state["start"]
new_legacy_space = Discrete(2)
new_legacy_space.__setstate__(legacy_state)
assert new_legacy_space == legacy_space
assert new_legacy_space.start == 0
def test_sample_mask():
space = Discrete(4, start=2)
assert 2 <= space.sample() < 6
assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]