Files
Gymnasium/tests/spaces/test_spaces.py
2025-06-07 17:57:58 +01:00

625 lines
23 KiB
Python

import copy
import itertools
import json # note: ujson fails this test due to float equality
import pickle
import tempfile
from collections.abc import Callable
import numpy as np
import pytest
import scipy.stats
from gymnasium.error import Error
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
from gymnasium.utils import seeding
from gymnasium.utils.env_checker import data_equivalence
from tests.spaces.utils import (
TESTING_FUNDAMENTAL_SPACES,
TESTING_FUNDAMENTAL_SPACES_IDS,
TESTING_SPACES,
TESTING_SPACES_IDS,
)
# Due to this test taking a 1ms each then we don't mind generating so many tests
# This generates all pairs of spaces of the same type in TESTING_SPACES
TESTING_SPACES_PERMUTATIONS = list(
itertools.chain(
*[
list(itertools.permutations(list(group), r=2))
for key, group in itertools.groupby(
TESTING_SPACES, key=lambda space: type(space)
)
]
)
)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_roundtripping(space: Space):
"""Tests if space samples passed to `to_jsonable` and `from_jsonable` produce the original samples."""
sample_1 = space.sample()
sample_2 = space.sample()
# Convert the samples to json, dump + load json and convert back to python
sample_json = space.to_jsonable([sample_1, sample_2])
sample_roundtripped = json.loads(json.dumps(sample_json))
sample_1_prime, sample_2_prime = space.from_jsonable(sample_roundtripped)
# Check if the samples are equivalent
assert data_equivalence(
sample_1, sample_1_prime
), f"sample 1: {sample_1}, prime: {sample_1_prime}"
assert data_equivalence(
sample_2, sample_2_prime
), f"sample 2: {sample_2}, prime: {sample_2_prime}"
@pytest.mark.parametrize(
"space_1,space_2",
TESTING_SPACES_PERMUTATIONS,
ids=[f"({s1}, {s2})" for s1, s2 in TESTING_SPACES_PERMUTATIONS],
)
def test_space_equality(space_1, space_2):
"""Check that `space.__eq__` works.
Testing spaces permutations contains all combinations of testing spaces of the same type.
"""
assert space_1 == space_1
assert space_2 == space_2
assert space_1 != space_2
# significance level of chi2 and KS tests
ALPHA = 0.05
@pytest.mark.parametrize(
"space", TESTING_FUNDAMENTAL_SPACES, ids=TESTING_FUNDAMENTAL_SPACES_IDS
)
def test_sample(space: Space, n_trials: int = 1_000):
"""Test the space sample has the expected distribution with the chi-squared test and KS test.
Example code with scipy.stats.chisquared that should have the same
>>> import scipy.stats
>>> variance = np.sum(np.square(observed_frequency - expected_frequency) / expected_frequency)
>>> f'X2 at alpha=0.05 = {scipy.stats.chi2.isf(0.05, df=4)}'
>>> f'p-value = {scipy.stats.chi2.sf(variance, df=4)}'
>>> scipy.stats.chisquare(f_obs=observed_frequency)
"""
space.seed(0)
samples = np.array([space.sample() for _ in range(n_trials)])
assert len(samples) == n_trials
if isinstance(space, Box):
if space.dtype.kind == "f":
test_function = ks_test
elif space.dtype.kind in ["i", "u"]:
test_function = chi2_test
elif space.dtype.kind == "b":
test_function = binary_chi2_test
else:
raise NotImplementedError(f"Unknown test for Box(dtype={space.dtype})")
assert space.shape == space.low.shape == space.high.shape
assert space.shape == samples.shape[1:]
# (n_trials, *space.shape) => (*space.shape, n_trials)
samples = np.moveaxis(samples, 0, -1)
for index in np.ndindex(space.shape):
low = space.low[index]
high = space.high[index]
sample = samples[index]
bounded_below = space.bounded_below[index]
bounded_above = space.bounded_above[index]
test_function(sample, low, high, bounded_below, bounded_above)
elif isinstance(space, Discrete):
expected_frequency = np.ones(space.n) * n_trials / space.n
observed_frequency = np.zeros(space.n)
for sample in samples:
observed_frequency[sample - space.start] += 1
degrees_of_freedom = space.n - 1
assert observed_frequency.shape == expected_frequency.shape
assert np.sum(observed_frequency) == n_trials
variance = np.sum(
np.square(expected_frequency - observed_frequency) / expected_frequency
)
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
elif isinstance(space, MultiBinary):
expected_frequency = n_trials / 2
observed_frequency = np.sum(samples, axis=0)
assert observed_frequency.shape == space.shape
# As this is a binary space, then we can be lazy in the variance as the np.square is symmetric for the 0 and 1 categories
variance = (
2 * np.square(observed_frequency - expected_frequency) / expected_frequency
)
assert variance.shape == space.shape
assert np.all(variance < scipy.stats.chi2.isf(ALPHA, df=1))
elif isinstance(space, MultiDiscrete):
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
def _generate_frequency(dim, func):
if isinstance(dim, np.ndarray):
return np.array(
[_generate_frequency(sub_dim, func) for sub_dim in dim],
dtype=object,
)
else:
return func(dim)
def _update_observed_frequency(obs_sample, obs_freq):
if isinstance(obs_sample, np.ndarray):
for sub_sample, sub_freq in zip(obs_sample, obs_freq):
_update_observed_frequency(sub_sample, sub_freq)
else:
obs_freq[obs_sample] += 1
expected_frequency = _generate_frequency(
space.nvec, lambda dim: np.ones(dim) * n_trials / dim
)
observed_frequency = _generate_frequency(space.nvec, lambda dim: np.zeros(dim))
for sample in samples:
_update_observed_frequency(sample - space.start, observed_frequency)
def _chi_squared_test(dim, exp_freq, obs_freq):
if isinstance(dim, np.ndarray):
for sub_dim, sub_exp_freq, sub_obs_freq in zip(dim, exp_freq, obs_freq):
_chi_squared_test(sub_dim, sub_exp_freq, sub_obs_freq)
else:
assert exp_freq.shape == (dim,) and obs_freq.shape == (dim,)
assert np.sum(obs_freq) == n_trials
assert np.sum(exp_freq) == n_trials
_variance = np.sum(np.square(exp_freq - obs_freq) / exp_freq)
_degrees_of_freedom = dim - 1
assert _variance < scipy.stats.chi2.isf(ALPHA, df=_degrees_of_freedom)
_chi_squared_test(space.nvec, expected_frequency, observed_frequency)
elif isinstance(space, Text):
expected_frequency = (
np.ones(len(space.character_set))
* n_trials
* (space.min_length + (space.max_length - space.min_length) / 2)
/ len(space.character_set)
)
observed_frequency = np.zeros(len(space.character_set))
for sample in samples:
for x in sample:
observed_frequency[space.character_index(x)] += 1
degrees_of_freedom = len(space.character_set) - 1
assert observed_frequency.shape == expected_frequency.shape
assert np.sum(observed_frequency) == sum(len(sample) for sample in samples)
variance = np.sum(
np.square(expected_frequency - observed_frequency) / expected_frequency
)
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
else:
raise NotImplementedError(f"Unknown sample testing for {type(space)}")
def ks_test(sample, low, high, bounded_below, bounded_above):
"""Perform Kolmogorov-Smirnov test on the sample. Automatically picks the
distribution to test against based on the bounds.
"""
if bounded_below and bounded_above:
# X ~ U(low, high)
dist = scipy.stats.uniform(low, high - low)
elif bounded_below and not bounded_above:
# X ~ low + Exp(1.0)
# => X - low ~ Exp(1.0)
dist = scipy.stats.expon
sample = sample - low
elif not bounded_below and bounded_above:
# X ~ high - Exp(1.0)
# => high - X ~ Exp(1.0)
dist = scipy.stats.expon
sample = high - sample
else:
# X ~ N(0.0, 1.0)
dist = scipy.stats.norm
_, p_value = scipy.stats.kstest(sample, dist.cdf)
assert p_value >= ALPHA
def chi2_test(sample, low, high, bounded_below, bounded_above):
"""Perform chi-squared test on the sample. Automatically picks the distribution
to test against based on the bounds.
"""
(n_trials,) = sample.shape
if bounded_below and bounded_above:
# X ~ U(low, high)
degrees_of_freedom = int(high) - int(low) + 1
observed_frequency = np.bincount(sample - low, minlength=degrees_of_freedom)
assert observed_frequency.shape == (degrees_of_freedom,)
expected_frequency = np.ones(degrees_of_freedom) * n_trials / degrees_of_freedom
elif bounded_below and not bounded_above:
# X ~ low + Geom(1 - e^-1)
# => X - low ~ Geom(1 - e^-1)
dist = scipy.stats.geom(1 - 1 / np.e)
observed_frequency = np.bincount(sample - low)
x = np.arange(len(observed_frequency))
expected_frequency = dist.pmf(x + 1) * n_trials
expected_frequency[-1] += n_trials - np.sum(expected_frequency)
elif not bounded_below and bounded_above:
# X ~ high - Geom(1 - e^-1)
# => high - X ~ Geom(1 - e^-1)
dist = scipy.stats.geom(1 - 1 / np.e)
observed_frequency = np.bincount(high - sample)
x = np.arange(len(observed_frequency))
expected_frequency = dist.pmf(x + 1) * n_trials
expected_frequency[-1] += n_trials - np.sum(expected_frequency)
else:
# X ~ floor(N(0.0, 1.0)
# => pmf(x) = cdf(x + 1) - cdf(x)
lowest = np.min(sample)
observed_frequency = np.bincount(sample - lowest)
normal_dist = scipy.stats.norm(0, 1)
x = lowest + np.arange(len(observed_frequency))
expected_frequency = normal_dist.cdf(x + 1) - normal_dist.cdf(x)
expected_frequency[0] += normal_dist.cdf(lowest)
expected_frequency *= n_trials
expected_frequency[-1] += n_trials - np.sum(expected_frequency)
assert observed_frequency.shape == expected_frequency.shape
variance = np.sum(
np.square(expected_frequency - observed_frequency) / expected_frequency
)
degrees_of_freedom = len(observed_frequency) - 1
critical_value = scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
assert variance < critical_value
def binary_chi2_test(sample, low, high, bounded_below, bounded_above):
"""Perform Chi-squared test on boolean samples."""
assert bounded_below
assert bounded_above
(n_trials,) = sample.shape
if low == high == 0:
assert np.all(sample == 0)
elif low == high == 1:
assert np.all(sample == 1)
else:
expected_frequency = n_trials / 2
observed_frequency = np.sum(sample)
# we can be lazy in the variance as the np.square is symmetric for the 0 and 1 categories
variance = (
2 * np.square(observed_frequency - expected_frequency) / expected_frequency
)
critical_value = scipy.stats.chi2.isf(ALPHA, df=1)
assert variance < critical_value
SAMPLE_MASK_RNG, _ = seeding.np_random(1)
@pytest.mark.parametrize(
"space,mask",
itertools.zip_longest(
TESTING_FUNDAMENTAL_SPACES,
[
# Discrete
np.array([1, 1, 0], dtype=np.int8),
np.array([0, 0, 0], dtype=np.int8),
# Box
None,
None,
None,
None,
None,
None,
None,
None,
# Multi-discrete
(np.array([1, 1], dtype=np.int8), np.array([0, 0], dtype=np.int8)),
(
(np.array([1, 0], dtype=np.int8), np.array([0, 1, 1], dtype=np.int8)),
(np.array([1, 1, 0], dtype=np.int8), np.array([0, 1], dtype=np.int8)),
),
(np.array([1, 1], dtype=np.int8), np.array([0, 0], dtype=np.int8)),
(
(np.array([1, 0], dtype=np.int8), np.array([0, 1, 1], dtype=np.int8)),
(np.array([1, 1, 0], dtype=np.int8), np.array([0, 1], dtype=np.int8)),
),
# Multi-binary
np.array([0, 1, 0, 1, 0, 2, 1, 1], dtype=np.int8),
np.array([[0, 1, 2], [0, 2, 1]], dtype=np.int8),
# Text
(None, SAMPLE_MASK_RNG.integers(low=0, high=2, size=62, dtype=np.int8)),
(4, SAMPLE_MASK_RNG.integers(low=0, high=2, size=62, dtype=np.int8)),
(None, np.array([1, 1, 0, 1, 0, 0], dtype=np.int8)),
],
),
ids=TESTING_FUNDAMENTAL_SPACES_IDS,
)
def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
"""Tests that the sampling a space with a mask has the expected distribution.
The implemented code is similar to the `test_space_sample` that considers the mask applied.
"""
if isinstance(space, Box):
# The box space can't have a sample mask
assert mask is None
return
assert mask is not None
space.seed(1)
samples = np.array([space.sample(mask) for _ in range(n_trials)])
if isinstance(space, Discrete):
if np.any(mask == 1):
expected_frequency = np.ones(space.n) * (n_trials / np.sum(mask)) * mask
else:
expected_frequency = np.zeros(space.n)
expected_frequency[0] = n_trials
observed_frequency = np.zeros(space.n)
for sample in samples:
observed_frequency[sample - space.start] += 1
degrees_of_freedom = max(np.sum(mask) - 1, 0)
assert observed_frequency.shape == expected_frequency.shape
assert np.sum(observed_frequency) == n_trials
assert np.sum(expected_frequency) == n_trials
variance = np.sum(
np.square(expected_frequency - observed_frequency)
/ np.clip(expected_frequency, 1, None)
)
if degrees_of_freedom == 0:
assert variance == 0
else:
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
elif isinstance(space, MultiBinary):
expected_frequency = (
np.ones(space.shape) * np.where(mask == 2, 0.5, mask) * n_trials
)
observed_frequency = np.sum(samples, axis=0)
assert space.shape == expected_frequency.shape == observed_frequency.shape
variance = (
2
* np.square(observed_frequency - expected_frequency)
/ np.clip(expected_frequency, 1, None)
)
assert variance.shape == space.shape
assert np.all(variance < scipy.stats.chi2.isf(ALPHA, df=1))
elif isinstance(space, MultiDiscrete):
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
def _generate_frequency(_dim: np.ndarray | int, _mask, func: Callable) -> list:
if isinstance(_dim, np.ndarray):
return [
_generate_frequency(sub_dim, sub_mask, func)
for sub_dim, sub_mask in zip(_dim, _mask)
]
else:
return func(_dim, _mask)
def _update_observed_frequency(obs_sample, obs_freq):
if isinstance(obs_sample, np.ndarray):
for sub_sample, sub_freq in zip(obs_sample, obs_freq):
_update_observed_frequency(sub_sample, sub_freq)
else:
obs_freq[obs_sample] += 1
def _exp_freq_fn(_dim: int, _mask: np.ndarray):
if np.any(_mask == 1):
assert _dim == len(_mask)
return np.ones(_dim) * (n_trials / np.sum(_mask)) * _mask
else:
freq = np.zeros(_dim)
freq[0] = n_trials
return freq
expected_frequency = _generate_frequency(
space.nvec, mask, lambda dim, _mask: _exp_freq_fn(dim, _mask)
)
observed_frequency = _generate_frequency(
space.nvec, mask, lambda dim, _: np.zeros(dim)
)
for sample in samples:
_update_observed_frequency(sample - space.start, observed_frequency)
def _chi_squared_test(dim, _mask, exp_freq, obs_freq):
if isinstance(dim, np.ndarray):
for sub_dim, sub_mask, sub_exp_freq, sub_obs_freq in zip(
dim, _mask, exp_freq, obs_freq
):
_chi_squared_test(sub_dim, sub_mask, sub_exp_freq, sub_obs_freq)
else:
assert exp_freq.shape == (dim,) and obs_freq.shape == (dim,)
assert np.sum(obs_freq) == n_trials
assert np.sum(exp_freq) == n_trials
_variance = np.sum(
np.square(exp_freq - obs_freq) / np.clip(exp_freq, 1, None)
)
_degrees_of_freedom = max(np.sum(_mask) - 1, 0)
if _degrees_of_freedom == 0:
assert _variance == 0
else:
assert _variance < scipy.stats.chi2.isf(
ALPHA, df=_degrees_of_freedom
)
_chi_squared_test(space.nvec, mask, expected_frequency, observed_frequency)
elif isinstance(space, Text):
length, charlist_mask = mask
if length is None:
expected_length = (
space.min_length + (space.max_length - space.min_length) / 2
)
else:
expected_length = length
if np.any(charlist_mask == 1):
expected_frequency = (
np.ones(len(space.character_set))
* n_trials
* expected_length
/ np.sum(charlist_mask)
* charlist_mask
)
else:
expected_frequency = np.zeros(len(space.character_set))
observed_frequency = np.zeros(len(space.character_set))
for sample in samples:
for char in sample:
observed_frequency[space.character_index(char)] += 1
degrees_of_freedom = max(np.sum(charlist_mask) - 1, 0)
assert observed_frequency.shape == expected_frequency.shape
assert np.sum(observed_frequency) == sum(len(sample) for sample in samples)
variance = np.sum(
np.square(expected_frequency - observed_frequency)
/ np.clip(expected_frequency, 1, None)
)
if degrees_of_freedom == 0:
assert variance == 0
else:
assert variance < scipy.stats.chi2.isf(ALPHA, df=degrees_of_freedom)
else:
raise NotImplementedError()
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_seed_reproducibility(space):
"""Test that the set the space seed will reproduce the same samples."""
space_1 = space
space_2 = copy.deepcopy(space)
for seed in range(5):
assert data_equivalence(space_1.seed(seed), space_2.seed(seed))
# With the same seed, the two spaces should be identical
assert all(
data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10)
)
assert not data_equivalence(space_1.seed(123), space_2.seed(456))
# Due to randomness, it is difficult to test that random seeds produce different answers
# Therefore, taking 10 samples and checking that they are not all the same.
assert not all(
data_equivalence(space_1.sample(), space_2.sample()) for _ in range(10)
)
SPACE_CLS = list(dict.fromkeys(type(space) for space in TESTING_SPACES))
SPACE_KWARGS = [
{"n": 3}, # Discrete
{"low": 1, "high": 10}, # Box
{"nvec": [3, 2]}, # MultiDiscrete
{"n": 2}, # MultiBinary
{"max_length": 5}, # Text
{"spaces": (Discrete(3), Discrete(2))}, # Tuple
{"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
{"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
{"space": Discrete(4)}, # Sequence
{"spaces": (Discrete(3), Discrete(5))}, # OneOf
]
assert len(SPACE_CLS) == len(SPACE_KWARGS)
@pytest.mark.parametrize(
"space_cls,kwarg",
list(zip(SPACE_CLS, SPACE_KWARGS)),
ids=[f"{space_cls}" for space_cls in SPACE_CLS],
)
def test_seed_np_random(space_cls, kwarg):
"""During initialisation of a space, a rng instance can be passed to the space.
Test that the space's `np_random` is the rng instance
"""
rng, _ = seeding.np_random(123)
space = space_cls(seed=rng, **kwarg)
assert space.np_random is rng
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_sample_contains(space):
"""Test that samples are contained within the space.
Then test that for all other spaces, we test that an error is not raise with a sample and a bool is returned.
As other spaces can be contained with this space, we cannot test that the contains is always true or false.
"""
for _ in range(10):
sample = space.sample()
assert sample in space
assert space.contains(sample)
for other_space in TESTING_SPACES:
sample = other_space.sample()
space_contains = other_space.contains(sample)
assert isinstance(
space_contains, bool
), f"{space_contains}, {type(space_contains)}, {space}, {other_space}, {sample}"
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_repr(space):
assert isinstance(str(space), str)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_space_pickling(space):
"""Tests the spaces can be pickled with the unpickled version being equivalent to the original."""
space.seed(0)
# Pickle and unpickle with a string
pickled_space = pickle.dumps(space)
unpickled_space = pickle.loads(pickled_space)
assert space == unpickled_space
# Pickle and unpickle with a file
with tempfile.TemporaryFile() as f:
pickle.dump(space, f)
f.seek(0)
file_unpickled_space = pickle.load(f)
assert space == file_unpickled_space
# Check that space samples are the same
space_sample = space.sample()
unpickled_sample = unpickled_space.sample()
file_unpickled_sample = file_unpickled_space.sample()
assert data_equivalence(space_sample, unpickled_sample)
assert data_equivalence(space_sample, file_unpickled_sample)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("initial_seed", [None, 123])
def test_space_seeding_output(space, initial_seed, num_samples=5):
seeding_values = space.seed(initial_seed)
samples = [space.sample() for _ in range(num_samples)]
reseeded_values = space.seed(seeding_values)
resamples = [space.sample() for _ in range(num_samples)]
assert data_equivalence(seeding_values, reseeded_values)
assert data_equivalence(samples, resamples)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_invalid_space_seed(space):
with pytest.raises((ValueError, TypeError, Error)):
space.seed("abc")