mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-15 19:31:27 +00:00
* fix box infinite bounds * Revert "fix box infinite bounds" This reverts commit 8e27a01fda839f522fc5a0855350e5a543359c47. * fix box space infinite bounds * experiencing unit test failures with master build on windows, shifting to Linux to test * box space infinite bounds fixed, all unit test pass, though the solution is still less elegant than I'd like * bracket fix * black formatting * remove redundant casting to np.array for array spaces * previous changes unintentionally changed the sampling method for float dtype infinite, this commit fixes that. Float infinite bounds are sampled either using shifted exponential or normal depending on circumstance, but int infinite bounds are always sampled uniformly. * added more tests as per #2517, made some error messages more descriptive, changed logic for calculating inf for ints * accidentally commented out something I shouldn't have commented out * moved get_inf and get_precision to non-static functions * Move get_precision and get_inf outside of class * Don't edit code within github's editor lol:wq * black
542 lines
17 KiB
Python
542 lines
17 KiB
Python
import json # note: ujson fails this test due to float equality
|
|
import copy
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_roundtripping(space):
|
|
sample_1 = space.sample()
|
|
sample_2 = space.sample()
|
|
assert space.contains(sample_1)
|
|
assert space.contains(sample_2)
|
|
json_rep = space.to_jsonable([sample_1, sample_2])
|
|
|
|
json_roundtripped = json.loads(json.dumps(json_rep))
|
|
|
|
samples_after_roundtrip = space.from_jsonable(json_roundtripped)
|
|
sample_1_prime, sample_2_prime = samples_after_roundtrip
|
|
|
|
s1 = space.to_jsonable([sample_1])
|
|
s1p = space.to_jsonable([sample_1_prime])
|
|
s2 = space.to_jsonable([sample_2])
|
|
s2p = space.to_jsonable([sample_2_prime])
|
|
assert s1 == s1p, "Expected {} to equal {}".format(s1, s1p)
|
|
assert s2 == s2p, "Expected {} to equal {}".format(s2, s2p)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
|
Box(low=-np.inf, high=np.inf, shape=(1, 3)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(6),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_equality(space):
|
|
space1 = space
|
|
space2 = copy.copy(space)
|
|
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spaces",
|
|
[
|
|
(Discrete(3), Discrete(4)),
|
|
(Discrete(3), Discrete(3, start=-1)),
|
|
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
|
(MultiBinary(8), MultiBinary(7)),
|
|
(
|
|
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
|
Box(low=np.array([-10, 0]), high=np.array([10, 9]), dtype=np.float32),
|
|
),
|
|
(
|
|
Box(low=-np.inf, high=0.0, shape=(2, 1)),
|
|
Box(low=0.0, high=np.inf, shape=(2, 1)),
|
|
),
|
|
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
|
(
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple([Discrete(5, start=7), Discrete(10)]),
|
|
),
|
|
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
|
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
|
],
|
|
)
|
|
def test_inequality(spaces):
|
|
space1, space2 = spaces
|
|
assert space1 != space2, "Expected {} != {}".format(space1, space2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(5),
|
|
Discrete(8, start=-20),
|
|
Box(low=0, high=255, shape=(2,), dtype="uint8"),
|
|
Box(low=-np.inf, high=np.inf, shape=(3, 3)),
|
|
Box(low=1.0, high=np.inf, shape=(3, 3)),
|
|
Box(low=-np.inf, high=2.0, shape=(3, 3)),
|
|
],
|
|
)
|
|
def test_sample(space):
|
|
space.seed(0)
|
|
n_trials = 100
|
|
samples = np.array([space.sample() for _ in range(n_trials)])
|
|
expected_mean = 0.0
|
|
if isinstance(space, Box):
|
|
if space.is_bounded():
|
|
expected_mean = (space.high + space.low) / 2
|
|
elif space.is_bounded("below"):
|
|
expected_mean = 1 + space.low
|
|
elif space.is_bounded("above"):
|
|
expected_mean = -1 + space.high
|
|
else:
|
|
expected_mean = 0.0
|
|
elif isinstance(space, Discrete):
|
|
expected_mean = space.start + space.n / 2
|
|
else:
|
|
raise NotImplementedError
|
|
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spaces",
|
|
[
|
|
(Discrete(5), MultiBinary(5)),
|
|
(
|
|
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
|
MultiDiscrete([2, 2, 8]),
|
|
),
|
|
(
|
|
Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8),
|
|
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
|
),
|
|
(Dict({"position": Discrete(5)}), Tuple([Discrete(5)])),
|
|
(Dict({"position": Discrete(5)}), Discrete(5)),
|
|
(Tuple((Discrete(5),)), Discrete(5)),
|
|
(
|
|
Box(low=np.array([-np.inf, 0.0]), high=np.array([0.0, np.inf])),
|
|
Box(low=np.array([-np.inf, 1.0]), high=np.array([0.0, np.inf])),
|
|
),
|
|
],
|
|
)
|
|
def test_class_inequality(spaces):
|
|
assert spaces[0] == spaces[0]
|
|
assert spaces[1] == spaces[1]
|
|
assert spaces[0] != spaces[1]
|
|
assert spaces[1] != spaces[0]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space_fn",
|
|
[
|
|
lambda: Dict(space1="abc"),
|
|
lambda: Dict({"space1": "abc"}),
|
|
lambda: Tuple(["abc"]),
|
|
],
|
|
)
|
|
def test_bad_space_calls(space_fn):
|
|
with pytest.raises(AssertionError):
|
|
space_fn()
|
|
|
|
|
|
def test_seed_Dict():
|
|
test_space = Dict(
|
|
{
|
|
"a": Box(low=0, high=1, shape=(3, 3)),
|
|
"b": Dict(
|
|
{
|
|
"b_1": Box(low=-100, high=100, shape=(2,)),
|
|
"b_2": Box(low=-1, high=1, shape=(2,)),
|
|
}
|
|
),
|
|
"c": Discrete(5),
|
|
}
|
|
)
|
|
|
|
seed_dict = {
|
|
"a": 0,
|
|
"b": {
|
|
"b_1": 1,
|
|
"b_2": 2,
|
|
},
|
|
"c": 3,
|
|
}
|
|
|
|
test_space.seed(seed_dict)
|
|
|
|
# "Unpack" the dict sub-spaces into individual spaces
|
|
a = Box(low=0, high=1, shape=(3, 3))
|
|
a.seed(0)
|
|
b_1 = Box(low=-100, high=100, shape=(2,))
|
|
b_1.seed(1)
|
|
b_2 = Box(low=-1, high=1, shape=(2,))
|
|
b_2.seed(2)
|
|
c = Discrete(5)
|
|
c.seed(3)
|
|
|
|
for i in range(10):
|
|
test_s = test_space.sample()
|
|
a_s = a.sample()
|
|
assert (test_s["a"] == a_s).all()
|
|
b_1_s = b_1.sample()
|
|
assert (test_s["b"]["b_1"] == b_1_s).all()
|
|
b_2_s = b_2.sample()
|
|
assert (test_s["b"]["b_2"] == b_2_s).all()
|
|
c_s = c.sample()
|
|
assert test_s["c"] == c_s
|
|
|
|
|
|
def test_box_dtype_check():
|
|
# Related Issues:
|
|
# https://github.com/openai/gym/issues/2357
|
|
# https://github.com/openai/gym/issues/2298
|
|
|
|
space = Box(0, 2, tuple(), dtype=np.float32)
|
|
|
|
# casting will match the correct type
|
|
assert space.contains(0.5)
|
|
|
|
# float64 is not in float32 space
|
|
assert not space.contains(np.array(0.5))
|
|
assert not space.contains(np.array(1))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(3, start=-4),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_seed_returns_list(space):
|
|
def assert_integer_list(seed):
|
|
assert isinstance(seed, list)
|
|
assert len(seed) >= 1
|
|
assert all([isinstance(s, int) for s in seed])
|
|
|
|
assert_integer_list(space.seed(None))
|
|
assert_integer_list(space.seed(0))
|
|
|
|
|
|
def convert_sample_hashable(sample):
|
|
if isinstance(sample, np.ndarray):
|
|
return tuple(sample.tolist())
|
|
if isinstance(sample, (list, tuple)):
|
|
return tuple(convert_sample_hashable(s) for s in sample)
|
|
if isinstance(sample, dict):
|
|
return tuple(
|
|
(key, convert_sample_hashable(value)) for key, value in sample.items()
|
|
)
|
|
|
|
return sample
|
|
|
|
|
|
def sample_equal(sample1, sample2):
|
|
return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(3, start=-4),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_seed_reproducibility(space):
|
|
space1 = space
|
|
space2 = copy.deepcopy(space)
|
|
|
|
space1.seed(None)
|
|
space2.seed(None)
|
|
|
|
assert space1.seed(0) == space2.seed(0)
|
|
assert sample_equal(space1.sample(), space2.sample())
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Tuple([Discrete(100), Discrete(100)]),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple([Discrete(5), Discrete(5, start=10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
),
|
|
}
|
|
),
|
|
],
|
|
)
|
|
def test_seed_subspace_incorrelated(space):
|
|
subspaces = space.spaces if isinstance(space, Tuple) else space.spaces.values()
|
|
|
|
space.seed(0)
|
|
states = [
|
|
convert_sample_hashable(subspace.np_random.bit_generator.state)
|
|
for subspace in subspaces
|
|
]
|
|
|
|
assert len(states) == len(set(states))
|
|
|
|
|
|
def test_multidiscrete_as_tuple():
|
|
# 1D multi-discrete
|
|
space = MultiDiscrete([3, 4, 5])
|
|
|
|
assert space.shape == (3,)
|
|
assert space[0] == Discrete(3)
|
|
assert space[0:1] == MultiDiscrete([3])
|
|
assert space[0:2] == MultiDiscrete([3, 4])
|
|
assert space[:] == space and space[:] is not space
|
|
assert len(space) == 3
|
|
|
|
# 2D multi-discrete
|
|
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
|
|
|
assert space.shape == (2, 3)
|
|
assert space[0, 1] == Discrete(4)
|
|
assert space[0] == MultiDiscrete([3, 4, 5])
|
|
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
|
|
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
|
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
|
|
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
|
|
assert space[:] == space and space[:] is not space
|
|
assert space[:, :] == space and space[:, :] is not space
|
|
|
|
|
|
def test_multidiscrete_subspace_reproducibility():
|
|
# 1D multi-discrete
|
|
space = MultiDiscrete([100, 200, 300])
|
|
space.seed(None)
|
|
|
|
assert sample_equal(space[0].sample(), space[0].sample())
|
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
|
assert sample_equal(space[0:2].sample(), space[0:2].sample())
|
|
assert sample_equal(space[:].sample(), space[:].sample())
|
|
assert sample_equal(space[:].sample(), space.sample())
|
|
|
|
# 2D multi-discrete
|
|
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
|
|
space.seed(None)
|
|
|
|
assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
|
|
assert sample_equal(space[0].sample(), space[0].sample())
|
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
|
assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
|
|
assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
|
|
assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
|
|
assert sample_equal(space[:].sample(), space[:].sample())
|
|
assert sample_equal(space[:, :].sample(), space[:, :].sample())
|
|
assert sample_equal(space[:, :].sample(), space.sample())
|
|
|
|
|
|
def test_space_legacy_state_pickling():
|
|
legacy_state = {
|
|
"shape": (
|
|
1,
|
|
2,
|
|
3,
|
|
),
|
|
"dtype": np.int64,
|
|
"np_random": np.random.default_rng(),
|
|
"n": 3,
|
|
}
|
|
space = Discrete(1)
|
|
space.__setstate__(legacy_state)
|
|
|
|
assert space.shape == legacy_state["shape"]
|
|
assert space._shape == legacy_state["shape"]
|
|
assert space.np_random == legacy_state["np_random"]
|
|
assert space._np_random == legacy_state["np_random"]
|
|
assert space.n == 3
|
|
assert space.dtype == legacy_state["dtype"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
|
|
],
|
|
)
|
|
def test_infinite_space(space):
|
|
# for this test, make sure that spaces that are passed in have only 0 or infinite bounds
|
|
# because space.high and space.low are both modified within the init
|
|
# so we check for infinite when we know it's not 0
|
|
space.seed(0)
|
|
|
|
assert np.all(space.high > space.low), "High bound not higher than low bound"
|
|
|
|
sample = space.sample()
|
|
|
|
# check if space contains sample
|
|
assert space.contains(
|
|
sample
|
|
), "Sample {sample} not inside space according to `space.contains()`"
|
|
|
|
# manually check that the sign of the sample is within the bounds
|
|
assert np.all(
|
|
np.sign(space.high) >= np.sign(sample)
|
|
), f"Sign of sample {sample} is less than space upper bound {space.high}"
|
|
assert np.all(
|
|
np.sign(space.low) <= np.sign(sample)
|
|
), f"Sign of sample {sample} is more than space lower bound {space.low}"
|
|
|
|
# check that int bounds are bounded for everything
|
|
# but floats are unbounded for infinite
|
|
if space.dtype.kind == "f":
|
|
if np.any(space.high != 0):
|
|
assert (
|
|
space.is_bounded("above") == False
|
|
), "float dtype inf upper bound supposed to be unbounded"
|
|
else:
|
|
assert (
|
|
space.is_bounded("above") == True
|
|
), "float dtype non-inf upper bound supposed to be bounded"
|
|
|
|
if np.any(space.low != 0):
|
|
assert (
|
|
space.is_bounded("below") == False
|
|
), "float dtype inf lower bound supposed to be unbounded"
|
|
else:
|
|
assert (
|
|
space.is_bounded("below") == True
|
|
), "float dtype non-inf lower bound supposed to be bounded"
|
|
|
|
elif space.dtype.kind == "i":
|
|
assert (
|
|
space.is_bounded("both") == True
|
|
), "int dtypes should be bounded on both ends"
|
|
|
|
# check for dtype
|
|
assert (
|
|
space.high.dtype == space.dtype
|
|
), "High's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
|
assert (
|
|
space.low.dtype == space.dtype
|
|
), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|