mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-03 22:54:23 +00:00
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset
* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"
This reverts commit 519dfd9117
.
* Remove bad pushed commits
* Fixed spelling in core.py
* Pins pytest to the last py 3.6 version
* Add support for action masking in Space.sample(mask=...)
* Fix action mask
* Fix action_mask
* Fix action_mask
* Added docstrings, fixed bugs and added taxi examples
* Fixed bugs
* Add tests for sample
* Add docstrings and test space sample mask Discrete and MultiBinary
* Add MultiDiscrete sampling and tests
* Remove sample mask from graph
* Update gym/spaces/multi_discrete.py
Co-authored-by: Markus Krimmel <montcyril@gmail.com>
* Updates based on Marcus28 and jjshoots for Graph.py
* Updates based on Marcus28 and jjshoots for Graph.py
* jjshoot review
* jjshoot review
* Update assert check
* Update type hints
Co-authored-by: Markus Krimmel <montcyril@gmail.com>
979 lines
33 KiB
Python
979 lines
33 KiB
Python
import copy
|
|
import json # note: ujson fails this test due to float equality
|
|
import pickle
|
|
import tempfile
|
|
from typing import List, Union
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gym import Space
|
|
from gym.spaces import Box, Dict, Discrete, Graph, MultiBinary, MultiDiscrete, Tuple
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
],
|
|
)
|
|
def test_roundtripping(space):
|
|
sample_1 = space.sample()
|
|
sample_2 = space.sample()
|
|
assert space.contains(sample_1)
|
|
assert space.contains(sample_2)
|
|
json_rep = space.to_jsonable([sample_1, sample_2])
|
|
|
|
json_roundtripped = json.loads(json.dumps(json_rep))
|
|
|
|
samples_after_roundtrip = space.from_jsonable(json_roundtripped)
|
|
sample_1_prime, sample_2_prime = samples_after_roundtrip
|
|
|
|
s1 = space.to_jsonable([sample_1])
|
|
s1p = space.to_jsonable([sample_1_prime])
|
|
s2 = space.to_jsonable([sample_2])
|
|
s2p = space.to_jsonable([sample_2_prime])
|
|
assert s1 == s1p, f"Expected {s1} to equal {s1p}"
|
|
assert s2 == s2p, f"Expected {s2} to equal {s2p}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(1, 3)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2, start=-6))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(6),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
],
|
|
)
|
|
def test_equality(space):
|
|
space1 = space
|
|
space2 = copy.deepcopy(space)
|
|
assert space1 == space2, f"Expected {space1} to equal {space2}"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spaces",
|
|
[
|
|
(Discrete(3), Discrete(4)),
|
|
(Discrete(3), Discrete(3, start=-1)),
|
|
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
|
(MultiBinary(8), MultiBinary(7)),
|
|
(
|
|
Box(
|
|
low=np.array([-10.0, 0.0]),
|
|
high=np.array([10.0, 10.0]),
|
|
dtype=np.float64,
|
|
),
|
|
Box(
|
|
low=np.array([-10.0, 0.0]), high=np.array([10.0, 9.0]), dtype=np.float64
|
|
),
|
|
),
|
|
(
|
|
Box(low=-np.inf, high=0.0, shape=(2, 1)),
|
|
Box(low=0.0, high=np.inf, shape=(2, 1)),
|
|
),
|
|
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
|
(
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple([Discrete(5, start=7), Discrete(10)]),
|
|
),
|
|
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
|
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
|
(
|
|
Graph(
|
|
node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)
|
|
),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
),
|
|
],
|
|
)
|
|
def test_inequality(spaces):
|
|
space1, space2 = spaces
|
|
assert space1 != space2, f"Expected {space1} != {space2}"
|
|
|
|
|
|
# The expected sum of variance for an alpha of 0.05
|
|
# CHI_SQUARED = [0] + [scipy.stats.chi2.isf(0.05, df=df) for df in range(1, 25)]
|
|
CHI_SQUARED = np.array(
|
|
[
|
|
0.01,
|
|
3.8414588206941285,
|
|
5.991464547107983,
|
|
7.814727903251178,
|
|
9.487729036781158,
|
|
11.070497693516355,
|
|
12.59158724374398,
|
|
14.067140449340167,
|
|
15.507313055865454,
|
|
16.91897760462045,
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(1),
|
|
Discrete(5),
|
|
Discrete(8, start=-20),
|
|
Box(low=0, high=255, shape=(2,), dtype=np.uint8),
|
|
Box(low=-np.inf, high=np.inf, shape=(3,)),
|
|
Box(low=1.0, high=np.inf, shape=(3,)),
|
|
Box(low=-np.inf, high=2.0, shape=(3,)),
|
|
Box(low=np.array([0, 2]), high=np.array([10, 4])),
|
|
MultiDiscrete([3, 5]),
|
|
MultiDiscrete(np.array([[3, 5], [2, 1]])),
|
|
MultiBinary([2, 4]),
|
|
],
|
|
)
|
|
def test_sample(space: Space, n_trials: int = 1_000):
|
|
"""Test the space sample has the expected distribution with the chi-squared test and KS test.
|
|
|
|
Example code with scipy.stats.chisquared
|
|
|
|
import scipy.stats
|
|
variance = np.sum(np.square(observed_frequency - expected_frequency) / expected_frequency)
|
|
f'X2 at alpha=0.05 = {scipy.stats.chi2.isf(0.05, df=4)}'
|
|
f'p-value = {scipy.stats.chi2.sf(variance, df=4)}'
|
|
scipy.stats.chisquare(f_obs=observed_frequency)
|
|
"""
|
|
space.seed(0)
|
|
samples = np.array([space.sample() for _ in range(n_trials)])
|
|
assert len(samples) == n_trials
|
|
|
|
# todo add Box space test
|
|
if isinstance(space, Discrete):
|
|
expected_frequency = np.ones(space.n) * n_trials / space.n
|
|
observed_frequency = np.zeros(space.n)
|
|
for sample in samples:
|
|
observed_frequency[sample - space.start] += 1
|
|
degrees_of_freedom = space.n - 1
|
|
|
|
assert observed_frequency.shape == expected_frequency.shape
|
|
assert np.sum(observed_frequency) == n_trials
|
|
|
|
variance = np.sum(
|
|
np.square(expected_frequency - observed_frequency) / expected_frequency
|
|
)
|
|
assert variance < CHI_SQUARED[degrees_of_freedom]
|
|
elif isinstance(space, MultiBinary):
|
|
expected_frequency = n_trials / 2
|
|
observed_frequency = np.sum(samples, axis=0)
|
|
assert observed_frequency.shape == space.shape
|
|
|
|
# As this is a binary space, then we can be lazy in the variance as the np.square is symmetric for the 0 and 1 categories
|
|
variance = (
|
|
2 * np.square(observed_frequency - expected_frequency) / expected_frequency
|
|
)
|
|
assert variance.shape == space.shape
|
|
assert np.all(variance < CHI_SQUARED[1])
|
|
elif isinstance(space, MultiDiscrete):
|
|
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
|
|
def _generate_frequency(dim, func):
|
|
if isinstance(dim, np.ndarray):
|
|
return np.array(
|
|
[_generate_frequency(sub_dim, func) for sub_dim in dim],
|
|
dtype=object,
|
|
)
|
|
else:
|
|
return func(dim)
|
|
|
|
def _update_observed_frequency(obs_sample, obs_freq):
|
|
if isinstance(obs_sample, np.ndarray):
|
|
for sub_sample, sub_freq in zip(obs_sample, obs_freq):
|
|
_update_observed_frequency(sub_sample, sub_freq)
|
|
else:
|
|
obs_freq[obs_sample] += 1
|
|
|
|
expected_frequency = _generate_frequency(
|
|
space.nvec, lambda dim: np.ones(dim) * n_trials / dim
|
|
)
|
|
observed_frequency = _generate_frequency(space.nvec, lambda dim: np.zeros(dim))
|
|
for sample in samples:
|
|
_update_observed_frequency(sample, observed_frequency)
|
|
|
|
def _chi_squared_test(dim, exp_freq, obs_freq):
|
|
if isinstance(dim, np.ndarray):
|
|
for sub_dim, sub_exp_freq, sub_obs_freq in zip(dim, exp_freq, obs_freq):
|
|
_chi_squared_test(sub_dim, sub_exp_freq, sub_obs_freq)
|
|
else:
|
|
assert exp_freq.shape == (dim,) and obs_freq.shape == (dim,)
|
|
assert np.sum(obs_freq) == n_trials
|
|
assert np.sum(exp_freq) == n_trials
|
|
_variance = np.sum(np.square(exp_freq - obs_freq) / exp_freq)
|
|
_degrees_of_freedom = dim - 1
|
|
assert _variance < CHI_SQUARED[_degrees_of_freedom]
|
|
|
|
_chi_squared_test(space.nvec, expected_frequency, observed_frequency)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space,mask",
|
|
[
|
|
(Discrete(5), np.array([0, 1, 1, 0, 1], dtype=np.int8)),
|
|
(Discrete(4, start=-20), np.array([1, 1, 0, 1], dtype=np.int8)),
|
|
(Discrete(4, start=1), np.array([0, 0, 0, 0], dtype=np.int8)),
|
|
(MultiBinary([3, 2]), np.array([[0, 1], [1, 1], [0, 0]], dtype=np.int8)),
|
|
(
|
|
MultiDiscrete([5, 3]),
|
|
(
|
|
np.array([0, 1, 1, 0, 1], dtype=np.int8),
|
|
np.array([0, 1, 1], dtype=np.int8),
|
|
),
|
|
),
|
|
(
|
|
MultiDiscrete(np.array([4, 2])),
|
|
(np.array([0, 0, 0, 0], dtype=np.int8), np.array([1, 1], dtype=np.int8)),
|
|
),
|
|
(
|
|
MultiDiscrete(np.array([[2, 2], [4, 3]])),
|
|
(
|
|
(np.array([0, 1], dtype=np.int8), np.array([1, 1], dtype=np.int8)),
|
|
(
|
|
np.array([0, 1, 1, 0], dtype=np.int8),
|
|
np.array([1, 0, 0], dtype=np.int8),
|
|
),
|
|
),
|
|
),
|
|
],
|
|
)
|
|
def test_space_sample_mask(space, mask, n_trials: int = 100):
|
|
"""Test the space sample with mask works using the pearson chi-squared test."""
|
|
space.seed(1)
|
|
samples = np.array([space.sample(mask) for _ in range(n_trials)])
|
|
|
|
if isinstance(space, Discrete):
|
|
if np.any(mask == 1):
|
|
expected_frequency = np.ones(space.n) * (n_trials / np.sum(mask)) * mask
|
|
else:
|
|
expected_frequency = np.zeros(space.n)
|
|
expected_frequency[0] = n_trials
|
|
observed_frequency = np.zeros(space.n)
|
|
for sample in samples:
|
|
observed_frequency[sample - space.start] += 1
|
|
degrees_of_freedom = max(np.sum(mask) - 1, 0)
|
|
|
|
assert observed_frequency.shape == expected_frequency.shape
|
|
assert np.sum(observed_frequency) == n_trials
|
|
assert np.sum(expected_frequency) == n_trials
|
|
variance = np.sum(
|
|
np.square(expected_frequency - observed_frequency)
|
|
/ np.clip(expected_frequency, 1, None)
|
|
)
|
|
assert variance < CHI_SQUARED[degrees_of_freedom]
|
|
elif isinstance(space, MultiBinary):
|
|
expected_frequency = np.ones(space.shape) * mask * (n_trials / 2)
|
|
observed_frequency = np.sum(samples, axis=0)
|
|
assert space.shape == expected_frequency.shape == observed_frequency.shape
|
|
|
|
variance = (
|
|
2
|
|
* np.square(observed_frequency - expected_frequency)
|
|
/ np.clip(expected_frequency, 1, None)
|
|
)
|
|
assert variance.shape == space.shape
|
|
assert np.all(variance < CHI_SQUARED[1])
|
|
elif isinstance(space, MultiDiscrete):
|
|
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
|
|
def _generate_frequency(
|
|
_dim: Union[np.ndarray, int], _mask, func: callable
|
|
) -> List:
|
|
if isinstance(_dim, np.ndarray):
|
|
return [
|
|
_generate_frequency(sub_dim, sub_mask, func)
|
|
for sub_dim, sub_mask in zip(_dim, _mask)
|
|
]
|
|
else:
|
|
return func(_dim, _mask)
|
|
|
|
def _update_observed_frequency(obs_sample, obs_freq):
|
|
if isinstance(obs_sample, np.ndarray):
|
|
for sub_sample, sub_freq in zip(obs_sample, obs_freq):
|
|
_update_observed_frequency(sub_sample, sub_freq)
|
|
else:
|
|
obs_freq[obs_sample] += 1
|
|
|
|
def _exp_freq_fn(_dim: int, _mask: np.ndarray):
|
|
if np.any(_mask == 1):
|
|
assert _dim == len(_mask)
|
|
return np.ones(_dim) * (n_trials / np.sum(_mask)) * _mask
|
|
else:
|
|
freq = np.zeros(_dim)
|
|
freq[0] = n_trials
|
|
return freq
|
|
|
|
expected_frequency = _generate_frequency(
|
|
space.nvec, mask, lambda dim, _mask: _exp_freq_fn(dim, _mask)
|
|
)
|
|
observed_frequency = _generate_frequency(
|
|
space.nvec, mask, lambda dim, _: np.zeros(dim)
|
|
)
|
|
for sample in samples:
|
|
_update_observed_frequency(sample, observed_frequency)
|
|
|
|
def _chi_squared_test(dim, _mask, exp_freq, obs_freq):
|
|
if isinstance(dim, np.ndarray):
|
|
for sub_dim, sub_mask, sub_exp_freq, sub_obs_freq in zip(
|
|
dim, _mask, exp_freq, obs_freq
|
|
):
|
|
_chi_squared_test(sub_dim, sub_mask, sub_exp_freq, sub_obs_freq)
|
|
else:
|
|
assert exp_freq.shape == (dim,) and obs_freq.shape == (dim,)
|
|
assert np.sum(obs_freq) == n_trials
|
|
assert np.sum(exp_freq) == n_trials
|
|
_variance = np.sum(
|
|
np.square(exp_freq - obs_freq) / np.clip(exp_freq, 1, None)
|
|
)
|
|
_degrees_of_freedom = max(np.sum(_mask) - 1, 0)
|
|
assert _variance < CHI_SQUARED[_degrees_of_freedom]
|
|
|
|
_chi_squared_test(space.nvec, mask, expected_frequency, observed_frequency)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space,mask",
|
|
[
|
|
(
|
|
Dict(a=Discrete(2), b=MultiDiscrete([2, 4])),
|
|
{
|
|
"a": np.array([0, 1], dtype=np.int8),
|
|
"b": (
|
|
np.array([0, 1], dtype=np.int8),
|
|
np.array([1, 1, 0, 0], dtype=np.int8),
|
|
),
|
|
},
|
|
),
|
|
(
|
|
Tuple([Box(0, 1, ()), Discrete(3), MultiBinary([2, 1])]),
|
|
(
|
|
None,
|
|
np.array([0, 1, 0], dtype=np.int8),
|
|
np.array([[0], [1]], dtype=np.int8),
|
|
),
|
|
),
|
|
(
|
|
Dict(a=Tuple([Box(0, 1, ()), Discrete(3)]), b=Discrete(3)),
|
|
{
|
|
"a": (None, np.array([1, 0, 0], dtype=np.int8)),
|
|
"b": np.array([0, 1, 1], dtype=np.int8),
|
|
},
|
|
),
|
|
(Graph(node_space=Discrete(5), edge_space=Discrete(3)), None),
|
|
(
|
|
Graph(node_space=Discrete(3), edge_space=Box(low=0, high=1, shape=(5,))),
|
|
None,
|
|
),
|
|
(
|
|
Graph(
|
|
node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3)
|
|
),
|
|
None,
|
|
),
|
|
],
|
|
)
|
|
def test_composite_space_sample_mask(space, mask):
|
|
"""Test that composite space samples use the mask correctly."""
|
|
space.sample(mask)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spaces",
|
|
[
|
|
(Discrete(5), MultiBinary(5)),
|
|
(
|
|
Box(
|
|
low=np.array([-10.0, 0.0]),
|
|
high=np.array([10.0, 10.0]),
|
|
dtype=np.float64,
|
|
),
|
|
MultiDiscrete([2, 2, 8]),
|
|
),
|
|
(
|
|
Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8),
|
|
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
|
),
|
|
(Dict({"position": Discrete(5)}), Tuple([Discrete(5)])),
|
|
(Dict({"position": Discrete(5)}), Discrete(5)),
|
|
(Tuple((Discrete(5),)), Discrete(5)),
|
|
(
|
|
Box(low=np.array([-np.inf, 0.0]), high=np.array([0.0, np.inf])),
|
|
Box(low=np.array([-np.inf, 1.0]), high=np.array([0.0, np.inf])),
|
|
),
|
|
(
|
|
Graph(
|
|
node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)
|
|
),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
),
|
|
],
|
|
)
|
|
def test_class_inequality(spaces):
|
|
assert spaces[0] == spaces[0]
|
|
assert spaces[1] == spaces[1]
|
|
assert spaces[0] != spaces[1]
|
|
assert spaces[1] != spaces[0]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space_fn",
|
|
[
|
|
lambda: Dict(space1="abc"),
|
|
lambda: Dict({"space1": "abc"}),
|
|
lambda: Tuple(["abc"]),
|
|
],
|
|
)
|
|
def test_bad_space_calls(space_fn):
|
|
with pytest.raises(AssertionError):
|
|
space_fn()
|
|
|
|
|
|
def test_seed_Dict():
|
|
test_space = Dict(
|
|
{
|
|
"a": Box(low=0, high=1, shape=(3, 3)),
|
|
"b": Dict(
|
|
{
|
|
"b_1": Box(low=-100, high=100, shape=(2,)),
|
|
"b_2": Box(low=-1, high=1, shape=(2,)),
|
|
}
|
|
),
|
|
"c": Discrete(5),
|
|
}
|
|
)
|
|
|
|
seed_dict = {
|
|
"a": 0,
|
|
"b": {
|
|
"b_1": 1,
|
|
"b_2": 2,
|
|
},
|
|
"c": 3,
|
|
}
|
|
|
|
test_space.seed(seed_dict)
|
|
|
|
# "Unpack" the dict sub-spaces into individual spaces
|
|
a = Box(low=0, high=1, shape=(3, 3))
|
|
a.seed(0)
|
|
b_1 = Box(low=-100, high=100, shape=(2,))
|
|
b_1.seed(1)
|
|
b_2 = Box(low=-1, high=1, shape=(2,))
|
|
b_2.seed(2)
|
|
c = Discrete(5)
|
|
c.seed(3)
|
|
|
|
for i in range(10):
|
|
test_s = test_space.sample()
|
|
a_s = a.sample()
|
|
assert (test_s["a"] == a_s).all()
|
|
b_1_s = b_1.sample()
|
|
assert (test_s["b"]["b_1"] == b_1_s).all()
|
|
b_2_s = b_2.sample()
|
|
assert (test_s["b"]["b_2"] == b_2_s).all()
|
|
c_s = c.sample()
|
|
assert test_s["c"] == c_s
|
|
|
|
|
|
def test_box_dtype_check():
|
|
# Related Issues:
|
|
# https://github.com/openai/gym/issues/2357
|
|
# https://github.com/openai/gym/issues/2298
|
|
|
|
space = Box(0, 2, tuple(), dtype=np.float32)
|
|
|
|
# casting will match the correct type
|
|
assert space.contains(np.array(0.5, dtype=np.float32))
|
|
|
|
# float64 is not in float32 space
|
|
assert not space.contains(np.array(0.5))
|
|
assert not space.contains(np.array(1))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(3, start=-4),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
],
|
|
)
|
|
def test_seed_returns_list(space):
|
|
def assert_integer_list(seed):
|
|
assert isinstance(seed, list)
|
|
assert len(seed) >= 1
|
|
assert all([isinstance(s, int) for s in seed])
|
|
|
|
assert_integer_list(space.seed(None))
|
|
assert_integer_list(space.seed(0))
|
|
|
|
|
|
def convert_sample_hashable(sample):
|
|
if isinstance(sample, np.ndarray):
|
|
return tuple(sample.tolist())
|
|
if isinstance(sample, (list, tuple)):
|
|
return tuple(convert_sample_hashable(s) for s in sample)
|
|
if isinstance(sample, dict):
|
|
return tuple(
|
|
(key, convert_sample_hashable(value)) for key, value in sample.items()
|
|
)
|
|
|
|
return sample
|
|
|
|
|
|
def sample_equal(sample1, sample2):
|
|
return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(3, start=-4),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
],
|
|
)
|
|
def test_seed_reproducibility(space):
|
|
space1 = space
|
|
space2 = copy.deepcopy(space)
|
|
|
|
space1.seed(None)
|
|
space2.seed(None)
|
|
|
|
assert space1.seed(0) == space2.seed(0)
|
|
assert sample_equal(space1.sample(), space2.sample())
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Tuple([Discrete(100), Discrete(100)]),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple([Discrete(5), Discrete(5, start=10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]),
|
|
high=np.array([1.0, 5.0]),
|
|
dtype=np.float64,
|
|
),
|
|
}
|
|
),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
],
|
|
)
|
|
def test_seed_subspace_incorrelated(space):
|
|
subspaces = []
|
|
if isinstance(space, Tuple):
|
|
subspaces = space.spaces
|
|
elif isinstance(space, Dict):
|
|
subspaces = space.spaces.values()
|
|
elif isinstance(space, Graph):
|
|
if space.edge_space is not None:
|
|
subspaces = [space.node_space, space.edge_space]
|
|
else:
|
|
subspaces = [space.node_space]
|
|
|
|
space.seed(0)
|
|
states = [
|
|
convert_sample_hashable(subspace.np_random.bit_generator.state)
|
|
for subspace in subspaces
|
|
]
|
|
|
|
assert len(states) == len(set(states))
|
|
|
|
|
|
def test_tuple():
|
|
spaces = [Discrete(5), Discrete(10), Discrete(5)]
|
|
space_tuple = Tuple(spaces)
|
|
|
|
assert len(space_tuple) == len(spaces)
|
|
assert space_tuple.count(Discrete(5)) == 2
|
|
assert space_tuple.count(MultiBinary(2)) == 0
|
|
for i, space in enumerate(space_tuple):
|
|
assert space == spaces[i]
|
|
for i, space in enumerate(reversed(space_tuple)):
|
|
assert space == spaces[len(spaces) - 1 - i]
|
|
assert space_tuple.index(Discrete(5)) == 0
|
|
assert space_tuple.index(Discrete(5), 1) == 2
|
|
with pytest.raises(ValueError):
|
|
space_tuple.index(Discrete(10), 0, 1)
|
|
|
|
|
|
def test_multidiscrete_as_tuple():
|
|
# 1D multi-discrete
|
|
space = MultiDiscrete([3, 4, 5])
|
|
|
|
assert space.shape == (3,)
|
|
assert space[0] == Discrete(3)
|
|
assert space[0:1] == MultiDiscrete([3])
|
|
assert space[0:2] == MultiDiscrete([3, 4])
|
|
assert space[:] == space and space[:] is not space
|
|
assert len(space) == 3
|
|
|
|
# 2D multi-discrete
|
|
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
|
|
|
assert space.shape == (2, 3)
|
|
assert space[0, 1] == Discrete(4)
|
|
assert space[0] == MultiDiscrete([3, 4, 5])
|
|
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
|
|
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
|
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
|
|
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
|
|
assert space[:] == space and space[:] is not space
|
|
assert space[:, :] == space and space[:, :] is not space
|
|
|
|
|
|
def test_multidiscrete_subspace_reproducibility():
|
|
# 1D multi-discrete
|
|
space = MultiDiscrete([100, 200, 300])
|
|
space.seed(None)
|
|
|
|
assert sample_equal(space[0].sample(), space[0].sample())
|
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
|
assert sample_equal(space[0:2].sample(), space[0:2].sample())
|
|
assert sample_equal(space[:].sample(), space[:].sample())
|
|
assert sample_equal(space[:].sample(), space.sample())
|
|
|
|
# 2D multi-discrete
|
|
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
|
|
space.seed(None)
|
|
|
|
assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
|
|
assert sample_equal(space[0].sample(), space[0].sample())
|
|
assert sample_equal(space[0:1].sample(), space[0:1].sample())
|
|
assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
|
|
assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
|
|
assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
|
|
assert sample_equal(space[:].sample(), space[:].sample())
|
|
assert sample_equal(space[:, :].sample(), space[:, :].sample())
|
|
assert sample_equal(space[:, :].sample(), space.sample())
|
|
|
|
|
|
def test_space_legacy_state_pickling():
|
|
legacy_state = {
|
|
"shape": (
|
|
1,
|
|
2,
|
|
3,
|
|
),
|
|
"dtype": np.int64,
|
|
"np_random": np.random.default_rng(),
|
|
"n": 3,
|
|
}
|
|
space = Discrete(1)
|
|
space.__setstate__(legacy_state)
|
|
|
|
assert space.shape == legacy_state["shape"]
|
|
assert space._shape == legacy_state["shape"]
|
|
assert space.np_random == legacy_state["np_random"]
|
|
assert space._np_random == legacy_state["np_random"]
|
|
assert space.n == 3
|
|
assert space.dtype == legacy_state["dtype"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
|
|
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
|
|
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
|
|
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
|
|
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
|
|
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
|
|
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
|
|
],
|
|
)
|
|
def test_infinite_space(space):
|
|
# for this test, make sure that spaces that are passed in have only 0 or infinite bounds
|
|
# because space.high and space.low are both modified within the init
|
|
# so we check for infinite when we know it's not 0
|
|
space.seed(0)
|
|
|
|
assert np.all(space.high > space.low), "High bound not higher than low bound"
|
|
|
|
sample = space.sample()
|
|
|
|
# check if space contains sample
|
|
assert space.contains(
|
|
sample
|
|
), "Sample {sample} not inside space according to `space.contains()`"
|
|
|
|
# manually check that the sign of the sample is within the bounds
|
|
assert np.all(
|
|
np.sign(space.high) >= np.sign(sample)
|
|
), f"Sign of sample {sample} is less than space upper bound {space.high}"
|
|
assert np.all(
|
|
np.sign(space.low) <= np.sign(sample)
|
|
), f"Sign of sample {sample} is more than space lower bound {space.low}"
|
|
|
|
# check that int bounds are bounded for everything
|
|
# but floats are unbounded for infinite
|
|
if np.any(space.high != 0):
|
|
assert (
|
|
space.is_bounded("above") is False
|
|
), "inf upper bound supposed to be unbounded"
|
|
else:
|
|
assert (
|
|
space.is_bounded("above") is True
|
|
), "non-inf upper bound supposed to be bounded"
|
|
|
|
if np.any(space.low != 0):
|
|
assert (
|
|
space.is_bounded("below") is False
|
|
), "inf lower bound supposed to be unbounded"
|
|
else:
|
|
assert (
|
|
space.is_bounded("below") is True
|
|
), "non-inf lower bound supposed to be bounded"
|
|
|
|
# check for dtype
|
|
assert (
|
|
space.high.dtype == space.dtype
|
|
), "High's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
|
assert (
|
|
space.low.dtype == space.dtype
|
|
), "Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
|
|
|
|
|
def test_discrete_legacy_state_pickling():
|
|
legacy_state = {
|
|
"n": 3,
|
|
}
|
|
|
|
d = Discrete(1)
|
|
assert "start" in d.__dict__
|
|
del d.__dict__["start"] # legacy did not include start param
|
|
assert "start" not in d.__dict__
|
|
|
|
d.__setstate__(legacy_state)
|
|
|
|
assert d.start == 0
|
|
assert d.n == 3
|
|
|
|
|
|
def test_box_legacy_state_pickling():
|
|
legacy_state = {
|
|
"dtype": np.dtype("float32"),
|
|
"_shape": (5,),
|
|
"low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
"high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
|
"bounded_below": np.array([True, True, True, True, True]),
|
|
"bounded_above": np.array([True, True, True, True, True]),
|
|
"_np_random": None,
|
|
}
|
|
|
|
b = Box(-1, 1, ())
|
|
assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
|
|
del b.__dict__["low_repr"]
|
|
del b.__dict__["high_repr"]
|
|
assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__
|
|
|
|
b.__setstate__(legacy_state)
|
|
assert b.low_repr == "0.0"
|
|
assert b.high_repr == "1.0"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"space",
|
|
[
|
|
Discrete(3),
|
|
Discrete(5, start=-2),
|
|
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
|
Tuple([Discrete(5), Discrete(10)]),
|
|
Tuple(
|
|
[
|
|
Discrete(5),
|
|
Box(low=np.array([0.0, 0.0]), high=np.array([1, 5]), dtype=np.float64),
|
|
]
|
|
),
|
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
|
Tuple((Discrete(5), Discrete(2, start=6), Discrete(2, start=-4))),
|
|
MultiDiscrete([2, 2, 100]),
|
|
MultiBinary(10),
|
|
Dict(
|
|
{
|
|
"position": Discrete(5),
|
|
"velocity": Box(
|
|
low=np.array([0.0, 0.0]), high=np.array([1, 5]), dtype=np.float64
|
|
),
|
|
}
|
|
),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
|
Graph(node_space=Discrete(5), edge_space=None),
|
|
],
|
|
)
|
|
def test_pickle(space):
|
|
space.sample()
|
|
|
|
# Pickle and unpickle with a string
|
|
pickled = pickle.dumps(space)
|
|
space2 = pickle.loads(pickled)
|
|
|
|
# Pickle and unpickle with a file
|
|
with tempfile.TemporaryFile() as f:
|
|
pickle.dump(space, f)
|
|
f.seek(0)
|
|
space3 = pickle.load(f)
|
|
|
|
sample = space.sample()
|
|
sample2 = space2.sample()
|
|
sample3 = space3.sample()
|
|
assert sample_equal(sample, sample2)
|
|
assert sample_equal(sample, sample3)
|