mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-06 16:01:45 +00:00
* Updated cartpole-v0 to v1 to prevent warning and added pytest.mark.filterwarnings for tests where warnings are unavoidable * Change np.bool to bool as numpy raises a warning and bool is the suggested solution * Seeding randint is deprecated in the future, integers is new solution * Fixed errors thrown when the video recorder is deleted but not closed * spaces.Box expects a floating array, updated all cases where this was not true and modified float32 to float64 as float array default to float64. Otherwise space.Box raises warning that dtype precision (float32) is lower than array precision (float64). * Added pytest.mark.filterwarnings to preventing the raising of an intended warning * Added comment to explain why a warning is raised that can't be prevented without version update to the environment * Added comment to explain why warning is raised * Changed values to float as expected by the box which default to float64 * Removed --forked from pytest as the pytest-forked project is no being maintained and was not raising warnings as expected * When AsyncVectorEnv has shared_memory=True then a ValueError is raised before _state is initialised. Therefore, on the destruction on the env an error is thrown in .close_extra as _state does not exist * Possible fix that was causing an error in test_call_async_vector_env by ensuring that pygame resources are released * Pygame throws an error with ALSA when closed, using a fix from PettingZoo (https://github.com/Farama-Foundation/PettingZoo/blob/master/pettingzoo/__init__.py). We use the dsp audiodriver to prevent this issue * Modification due to running pre-commit locally * Updated cartpole-v0 to v1 to prevent warning and added pytest.mark.filterwarnings for tests where warnings are unavoidable * Change np.bool to bool as numpy raises a warning and bool is the suggested solution * Seeding randint is deprecated in the future, integers is new solution * Fixed errors thrown when the video recorder is deleted but not closed * spaces.Box expects a floating array, updated all cases where this was not true and modified float32 to float64 as float array default to float64. Otherwise space.Box raises warning that dtype precision (float32) is lower than array precision (float64). * Added pytest.mark.filterwarnings to preventing the raising of an intended warning * Added comment to explain why a warning is raised that can't be prevented without version update to the environment * Added comment to explain why warning is raised * Changed values to float as expected by the box which default to float64 * Removed --forked from pytest as the pytest-forked project is no being maintained and was not raising warnings as expected * When AsyncVectorEnv has shared_memory=True then a ValueError is raised before _state is initialised. Therefore, on the destruction on the env an error is thrown in .close_extra as _state does not exist * Possible fix that was causing an error in test_call_async_vector_env by ensuring that pygame resources are released * Pygame throws an error with ALSA when closed, using a fix from PettingZoo (https://github.com/Farama-Foundation/PettingZoo/blob/master/pettingzoo/__init__.py). We use the dsp audiodriver to prevent this issue * Modification due to running pre-commit locally
658 lines
20 KiB
Python
658 lines
20 KiB
Python
import json # note: ujson fails this test due to float equality
|
|
import copy
|
|
import pickle
|
|
import tempfile
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_roundtripping(space):
|
|
sample_1 = space.sample()
|
|
sample_2 = space.sample()
|
|
assert space.contains(sample_1)
|
|
assert space.contains(sample_2)
|
|
json_rep = space.to_jsonable([sample_1, sample_2])
|
|
|
|
json_roundtripped = json.loads(json.dumps(json_rep))
|
|
|
|
samples_after_roundtrip = space.from_jsonable(json_roundtripped)
|
|
sample_1_prime, sample_2_prime = samples_after_roundtrip
|
|
|
|
s1 = space.to_jsonable([sample_1])
|
|
s1p = space.to_jsonable([sample_1_prime])
|
|
s2 = space.to_jsonable([sample_2])
|
|
s2p = space.to_jsonable([sample_2_prime])
|
|
assert s1 == s1p, f"Expected {s1} to equal {s1p}"
|
|
assert s2 == s2p, f"Expected {s2} to equal {s2p}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(1, 3)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(6),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_equality(space):
|
|
space1 = space
|
|
space2 = copy.deepcopy(space)
|
|
assert space1 == space2, f"Expected {space1} to equal {space2}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spaces",
|
|
[
|
|
(Discrete(3), Discrete(4)),
|
|
(Discrete(3), Discrete(3, start=-1)),
|
|
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
|
(MultiBinary(8), MultiBinary(7)),
|
|
(
|
|
Box(
|
|
low=np.array([-10.0, 0.0]),
|
|
high=np.array([10.0, 10.0]),
|
|
dtype=np.float64,
|
|
),
|
|
Box(
|
|
low=np.array([-10.0, 0.0]), high=np.array([10.0, 9.0]), dtype=np.float64
|
|
),
|
|
),
|
|
(
|
|
Box(low=-np.inf, high=0.0, shape=(2, 1)),
|
|
Box(low=0.0, high=np.inf, shape=(2, 1)),
|
|
),
|
|
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
|
(
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple([Discrete(5, start=7), Discrete(10)]),
|
|
),
|
|
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
|
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
|
],
|
|
)
|
|
def test_inequality(spaces):
|
|
space1, space2 = spaces
|
|
assert space1 != space2, f"Expected {space1} != {space2}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(5),
|
|
Discrete(8, start=-20),
|
|
Box(low=0, high=255, shape=(2,), dtype="uint8"),
|
|
Box(low=-np.inf, high=np.inf, shape=(3, 3)),
|
|
Box(low=1.0, high=np.inf, shape=(3, 3)),
|
|
Box(low=-np.inf, high=2.0, shape=(3, 3)),
|
|
],
|
|
)
|
|
def test_sample(space):
|
|
space.seed(0)
|
|
n_trials = 100
|
|
samples = np.array([space.sample() for _ in range(n_trials)])
|
|
expected_mean = 0.0
|
|
if isinstance(space, Box):
|
|
if space.is_bounded():
|
|
expected_mean = (space.high + space.low) / 2
|
|
elif space.is_bounded("below"):
|
|
expected_mean = 1 + space.low
|
|
elif space.is_bounded("above"):
|
|
expected_mean = -1 + space.high
|
|
else:
|
|
expected_mean = 0.0
|
|
elif isinstance(space, Discrete):
|
|
expected_mean = space.start + space.n / 2
|
|
else:
|
|
raise NotImplementedError
|
|
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spaces",
|
|
[
|
|
(Discrete(5), MultiBinary(5)),
|
|
(
|
|
Box(
|
|
low=np.array([-10.0, 0.0]),
|
|
high=np.array([10.0, 10.0]),
|
|
dtype=np.float64,
|
|
),
|
|
MultiDiscrete([2, 2, 8]),
|
|
),
|
|
(
|
|
Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8),
|
|
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
|
),
|
|
(Dict({"position": Discrete(5)}), Tuple([Discrete(5)])),
|
|
(Dict({"position": Discrete(5)}), Discrete(5)),
|
|
(Tuple((Discrete(5),)), Discrete(5)),
|
|
(
|
|
Box(low=np.array([-np.inf, 0.0]), high=np.array([0.0, np.inf])),
|
|
Box(low=np.array([-np.inf, 1.0]), high=np.array([0.0, np.inf])),
|
|
),
|
|
],
|
|
)
|
|
def test_class_inequality(spaces):
|
|
assert spaces[0] == spaces[0]
|
|
assert spaces[1] == spaces[1]
|
|
assert spaces[0] != spaces[1]
|
|
assert spaces[1] != spaces[0]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space_fn",
|
|
[
|
|
lambda: Dict(space1="abc"),
|
|
lambda: Dict({"space1": "abc"}),
|
|
lambda: Tuple(["abc"]),
|
|
],
|
|
)
|
|
def test_bad_space_calls(space_fn):
|
|
with pytest.raises(AssertionError):
|
|
space_fn()
|
|
|
|
|
|
def test_seed_Dict():
|
|
test_space = Dict(
|
|
{
|
|
"a": Box(low=0, high=1, shape=(3, 3)),
|
|
"b": Dict(
|
|
{
|
|
"b_1": Box(low=-100, high=100, shape=(2,)),
|
|
"b_2": Box(low=-1, high=1, shape=(2,)),
|
|
}
|
|
),
|
|
"c": Discrete(5),
|
|
}
|
|
)
|
|
|
|
seed_dict = {
|
|
"a": 0,
|
|
"b": {
|
|
"b_1": 1,
|
|
"b_2": 2,
|
|
},
|
|
"c": 3,
|
|
}
|
|
|
|
test_space.seed(seed_dict)
|
|
|
|
# "Unpack" the dict sub-spaces into individual spaces
|
|
a = Box(low=0, high=1, shape=(3, 3))
|
|
a.seed(0)
|
|
b_1 = Box(low=-100, high=100, shape=(2,))
|
|
b_1.seed(1)
|
|
b_2 = Box(low=-1, high=1, shape=(2,))
|
|
b_2.seed(2)
|
|
c = Discrete(5)
|
|
c.seed(3)
|
|
|
|
for i in range(10):
|
|
test_s = test_space.sample()
|
|
a_s = a.sample()
|
|
assert (test_s["a"] == a_s).all()
|
|
b_1_s = b_1.sample()
|
|
assert (test_s["b"]["b_1"] == b_1_s).all()
|
|
b_2_s = b_2.sample()
|
|
assert (test_s["b"]["b_2"] == b_2_s).all()
|
|
c_s = c.sample()
|
|
assert test_s["c"] == c_s
|
|
|
|
|
|
def test_box_dtype_check():
|
|
# Related Issues:
|
|
# https://github.com/openai/gym/issues/2357
|
|
# https://github.com/openai/gym/issues/2298
|
|
|
|
space = Box(0, 2, tuple(), dtype=np.float32)
|
|
|
|
# casting will match the correct type
|
|
assert space.contains(np.array(0.5, dtype=np.float32))
|
|
|
|
# float64 is not in float32 space
|
|
assert not space.contains(np.array(0.5))
|
|
assert not space.contains(np.array(1))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(3, start=-4),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_seed_returns_list(space):
|
|
def assert_integer_list(seed):
|
|
assert isinstance(seed, list)
|
|
assert len(seed) >= 1
|
|
assert all([isinstance(s, int) for s in seed])
|
|
|
|
assert_integer_list(space.seed(None))
|
|
assert_integer_list(space.seed(0))
|
|
|
|
|
|
def convert_sample_hashable(sample):
|
|
if isinstance(sample, np.ndarray):
|
|
return tuple(sample.tolist())
|
|
if isinstance(sample, (list, tuple)):
|
|
return tuple(convert_sample_hashable(s) for s in sample)
|
|
if isinstance(sample, dict):
|
|
return tuple(
|
|
(key, convert_sample_hashable(value)) for key, value in sample.items()
|
|
)
|
|
|
|
return sample
|
|
|
|
|
|
def sample_equal(sample1, sample2):
|
|
return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(3, start=-4),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_seed_reproducibility(space):
|
|
space1 = space
|
|
space2 = copy.deepcopy(space)
|
|
|
|
space1.seed(None)
|
|
space2.seed(None)
|
|
|
|
assert space1.seed(0) == space2.seed(0)
|
|
assert sample_equal(space1.sample(), space2.sample())
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Tuple([Discrete(100), Discrete(100)]),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple([Discrete(5), Discrete(5, start=10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_seed_subspace_incorrelated(space):
|
|
subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()
|
|
|
|
space.seed(0)
|
|
states = [
|
|
convert_sample_hashable(subspace.np_random.bit_generator.state)
|
|
for subspace in subspaces
|
|
]
|
|
|
|
assert len(states) == len(set(states))
|
|
|
|
|
|
def test_tuple():
|
|
spaces = [Discrete(5), Discrete(10), Discrete(5)]
|
|
space_tuple = Tuple(spaces)
|
|
|
|
assert len(space_tuple) == len(spaces)
|
|
assert space_tuple.count(Discrete(5)) == 2
|
|
assert space_tuple.count(MultiBinary(2)) == 0
|
|
for i, space in enumerate(space_tuple):
|
|
assert space == spaces[i]
|
|
for i, space in enumerate(reversed(space_tuple)):
|
|
assert space == spaces[len(spaces) - 1 - i]
|
|
assert space_tuple.index(Discrete(5)) == 0
|
|
assert space_tuple.index(Discrete(5), 1) == 2
|
|
with pytest.raises(ValueError):
|
|
space_tuple.index(Discrete(10), 0, 1)
|
|
|
|
|
|
def test_multidiscrete_as_tuple():
|
|
# 1D multi-discrete
|
|
space = MultiDiscrete([3, 4, 5])
|
|
|
|
assert space.shape == (3,)
|
|
assert space[0] == Discrete(3)
|
|
assert space[0:1] == MultiDiscrete([3])
|
|
assert space[0:2] == MultiDiscrete([3, 4])
|
|
assert space[:] == space and space[:] is not space
|
|
assert len(space) == 3
|
|
|
|
# 2D multi-discrete
|
|
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
|
|
|
assert space.shape == (2, 3)
|
|
assert space[0, 1] == Discrete(4)
|
|
assert space[0] == MultiDiscrete([3, 4, 5])
|
|
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
|
|
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
|
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
|
|
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
|
|
assert space[:] == space and space[:] is not space
|
|
assert space[:, :] == space and space[:, :] is not space
|
|
|
|
|
|
def test_multidiscrete_subspace_reproducibility():
|
|
# 1D multi-discrete
|
|
space = MultiDiscrete([100, 200, 300])
|
|
space.seed(None)
|
|
|
|
assert sample_equal(space[0].sample(), space[0].sample())
|
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
|
assert sample_equal(space[0:2].sample(), space[0:2].sample())
|
|
assert sample_equal(space[:].sample(), space[:].sample())
|
|
assert sample_equal(space[:].sample(), space.sample())
|
|
|
|
# 2D multi-discrete
|
|
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
|
|
space.seed(None)
|
|
|
|
assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
|
|
assert sample_equal(space[0].sample(), space[0].sample())
|
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
|
assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
|
|
assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
|
|
assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
|
|
assert sample_equal(space[:].sample(), space[:].sample())
|
|
assert sample_equal(space[:, :].sample(), space[:, :].sample())
|
|
assert sample_equal(space[:, :].sample(), space.sample())
|
|
|
|
|
|
def test_space_legacy_state_pickling():
|
|
legacy_state = {
|
|
"shape": (
|
|
1,
|
|
2,
|
|
3,
|
|
),
|
|
"dtype": np.int64,
|
|
"np_random": np.random.default_rng(),
|
|
"n": 3,
|
|
}
|
|
space = Discrete(1)
|
|
space.__setstate__(legacy_state)
|
|
|
|
assert space.shape == legacy_state["shape"]
|
|
assert space._shape == legacy_state["shape"]
|
|
assert space.np_random == legacy_state["np_random"]
|
|
assert space._np_random == legacy_state["np_random"]
|
|
assert space.n == 3
|
|
assert space.dtype == legacy_state["dtype"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
|
|
],
|
|
)
|
|
def test_infinite_space(space):
|
|
# for this test, make sure that spaces that are passed in have only 0 or infinite bounds
|
|
# because space.high and space.low are both modified within the init
|
|
# so we check for infinite when we know it's not 0
|
|
space.seed(0)
|
|
|
|
assert np.all(space.high > space.low), "High bound not higher than low bound"
|
|
|
|
sample = space.sample()
|
|
|
|
# check if space contains sample
|
|
assert space.contains(
|
|
sample
|
|
), "Sample {sample} not inside space according to `space.contains()`"
|
|
|
|
# manually check that the sign of the sample is within the bounds
|
|
assert np.all(
|
|
np.sign(space.high) >= np.sign(sample)
|
|
), f"Sign of sample {sample} is less than space upper bound {space.high}"
|
|
assert np.all(
|
|
np.sign(space.low) <= np.sign(sample)
|
|
), f"Sign of sample {sample} is more than space lower bound {space.low}"
|
|
|
|
# check that int bounds are bounded for everything
|
|
# but floats are unbounded for infinite
|
|
if np.any(space.high != 0):
|
|
assert (
|
|
space.is_bounded("above") == False
|
|
), "inf upper bound supposed to be unbounded"
|
|
else:
|
|
assert (
|
|
space.is_bounded("above") == True
|
|
), "non-inf upper bound supposed to be bounded"
|
|
|
|
if np.any(space.low != 0):
|
|
assert (
|
|
space.is_bounded("below") == False
|
|
), "inf lower bound supposed to be unbounded"
|
|
else:
|
|
assert (
|
|
space.is_bounded("below") == True
|
|
), "non-inf lower bound supposed to be bounded"
|
|
|
|
# check for dtype
|
|
assert (
|
|
space.high.dtype == space.dtype
|
|
), "High's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
|
assert (
|
|
space.low.dtype == space.dtype
|
|
), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
|
|
|
|
|
def test_discrete_legacy_state_pickling():
|
|
legacy_state = {
|
|
"n": 3,
|
|
}
|
|
|
|
d = Discrete(1)
|
|
assert "start" in d.__dict__
|
|
del d.__dict__["start"] # legacy did not include start param
|
|
assert "start" not in d.__dict__
|
|
|
|
d.__setstate__(legacy_state)
|
|
|
|
assert d.start == 0
|
|
assert d.n == 3
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(low=np.array([0.0, 0.0]), high=np.array([1, 5]), dtype=np.float64),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]), high=np.array([1, 5]), dtype=np.float64
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_pickle(space):
|
|
space.sample()
|
|
|
|
# Pickle and unpickle with a string
|
|
pickled = pickle.dumps(space)
|
|
space2 = pickle.loads(pickled)
|
|
|
|
# Pickle and unpickle with a file
|
|
with tempfile.TemporaryFile() as f:
|
|
pickle.dump(space, f)
|
|
f.seek(0)
|
|
space3 = pickle.load(f)
|
|
|
|
sample = space.sample()
|
|
sample2 = space2.sample()
|
|
sample3 = space3.sample()
|
|
assert sample_equal(sample, sample2)
|
|
assert sample_equal(sample, sample3)
|