2021-09-02 22:15:34 +08:00
import copy
2022-09-03 22:56:29 +01:00
import itertools
2022-03-31 12:50:38 -07:00
import json # note: ujson fails this test due to float equality
2022-03-02 23:38:26 +08:00
import pickle
import tempfile
2025-06-07 17:57:58 +01:00
from collections . abc import Callable
2018-09-24 20:11:03 +02:00
2016-04-27 08:00:58 -07:00
import numpy as np
2017-02-11 22:17:02 -08:00
import pytest
2023-02-06 14:30:09 +03:30
import scipy . stats
2018-09-24 20:11:03 +02:00
2024-04-28 16:10:35 +01:00
from gymnasium . error import Error
2022-09-08 10:10:07 +01:00
from gymnasium . spaces import Box , Discrete , MultiBinary , MultiDiscrete , Space , Text
from gymnasium . utils import seeding
from gymnasium . utils . env_checker import data_equivalence
2022-09-03 22:56:29 +01:00
from tests . spaces . utils import (
TESTING_FUNDAMENTAL_SPACES ,
TESTING_FUNDAMENTAL_SPACES_IDS ,
TESTING_SPACES ,
TESTING_SPACES_IDS ,
2022-07-11 16:39:04 +01:00
)
2016-04-27 08:00:58 -07:00
2022-12-04 22:24:02 +08:00
2022-09-03 22:56:29 +01:00
# Due to this test taking a 1ms each then we don't mind generating so many tests
# This generates all pairs of spaces of the same type in TESTING_SPACES
TESTING_SPACES_PERMUTATIONS = list (
itertools . chain (
* [
list ( itertools . permutations ( list ( group ) , r = 2 ) )
for key , group in itertools . groupby (
TESTING_SPACES , key = lambda space : type ( space )
)
]
)
2021-07-29 02:26:34 +02:00
)
2022-09-03 22:56:29 +01:00
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
def test_roundtripping ( space : Space ) :
""" Tests if space samples passed to `to_jsonable` and `from_jsonable` produce the original samples. """
2016-04-27 08:00:58 -07:00
sample_1 = space . sample ( )
sample_2 = space . sample ( )
2022-09-03 22:56:29 +01:00
# Convert the samples to json, dump + load json and convert back to python
sample_json = space . to_jsonable ( [ sample_1 , sample_2 ] )
sample_roundtripped = json . loads ( json . dumps ( sample_json ) )
sample_1_prime , sample_2_prime = space . from_jsonable ( sample_roundtripped )
2016-04-27 08:00:58 -07:00
2022-09-03 22:56:29 +01:00
# Check if the samples are equivalent
assert data_equivalence (
sample_1 , sample_1_prime
) , f " sample 1: { sample_1 } , prime: { sample_1_prime } "
assert data_equivalence (
sample_2 , sample_2_prime
) , f " sample 2: { sample_2 } , prime: { sample_2_prime } "
2018-09-24 20:11:03 +02:00
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize (
2022-09-03 22:56:29 +01:00
" space_1,space_2 " ,
TESTING_SPACES_PERMUTATIONS ,
ids = [ f " ( { s1 } , { s2 } ) " for s1 , s2 in TESTING_SPACES_PERMUTATIONS ] ,
2021-07-29 02:26:34 +02:00
)
2022-09-03 22:56:29 +01:00
def test_space_equality ( space_1 , space_2 ) :
""" Check that `space.__eq__` works.
2018-09-24 20:11:03 +02:00
2022-09-03 22:56:29 +01:00
Testing spaces permutations contains all combinations of testing spaces of the same type .
"""
assert space_1 == space_1
assert space_2 == space_2
assert space_1 != space_2
2019-02-05 17:49:29 -08:00
2023-02-06 14:30:09 +03:30
# significance level of chi2 and KS tests
ALPHA = 0.05
2022-06-26 23:23:15 +01:00
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize (
2022-09-03 22:56:29 +01:00
" space " , TESTING_FUNDAMENTAL_SPACES , ids = TESTING_FUNDAMENTAL_SPACES_IDS
2021-07-29 02:26:34 +02:00
)
2022-06-26 23:23:15 +01:00
def test_sample ( space : Space , n_trials : int = 1_000 ) :
""" Test the space sample has the expected distribution with the chi-squared test and KS test.
2022-09-03 22:56:29 +01:00
Example code with scipy . stats . chisquared that should have the same
2022-06-26 23:23:15 +01:00
2022-09-03 22:56:29 +01:00
>> > import scipy . stats
>> > variance = np . sum ( np . square ( observed_frequency - expected_frequency ) / expected_frequency )
>> > f ' X2 at alpha=0.05 = { scipy . stats . chi2 . isf ( 0.05 , df = 4 ) } '
>> > f ' p-value = { scipy . stats . chi2 . sf ( variance , df = 4 ) } '
>> > scipy . stats . chisquare ( f_obs = observed_frequency )
2022-06-26 23:23:15 +01:00
"""
2019-02-05 17:49:29 -08:00
space . seed ( 0 )
samples = np . array ( [ space . sample ( ) for _ in range ( n_trials ) ] )
2022-06-26 23:23:15 +01:00
assert len ( samples ) == n_trials
2022-09-03 22:56:29 +01:00
if isinstance ( space , Box ) :
2023-02-06 14:30:09 +03:30
if space . dtype . kind == " f " :
test_function = ks_test
elif space . dtype . kind in [ " i " , " u " ] :
test_function = chi2_test
elif space . dtype . kind == " b " :
test_function = binary_chi2_test
else :
raise NotImplementedError ( f " Unknown test for Box(dtype= { space . dtype } ) " )
assert space . shape == space . low . shape == space . high . shape
assert space . shape == samples . shape [ 1 : ]
# (n_trials, *space.shape) => (*space.shape, n_trials)
samples = np . moveaxis ( samples , 0 , - 1 )
for index in np . ndindex ( space . shape ) :
low = space . low [ index ]
high = space . high [ index ]
sample = samples [ index ]
bounded_below = space . bounded_below [ index ]
bounded_above = space . bounded_above [ index ]
test_function ( sample , low , high , bounded_below , bounded_above )
2022-09-03 22:56:29 +01:00
elif isinstance ( space , Discrete ) :
2022-06-26 23:23:15 +01:00
expected_frequency = np . ones ( space . n ) * n_trials / space . n
observed_frequency = np . zeros ( space . n )
for sample in samples :
observed_frequency [ sample - space . start ] + = 1
degrees_of_freedom = space . n - 1
assert observed_frequency . shape == expected_frequency . shape
assert np . sum ( observed_frequency ) == n_trials
variance = np . sum (
np . square ( expected_frequency - observed_frequency ) / expected_frequency
)
2023-02-06 14:30:09 +03:30
assert variance < scipy . stats . chi2 . isf ( ALPHA , df = degrees_of_freedom )
2022-06-26 23:23:15 +01:00
elif isinstance ( space , MultiBinary ) :
expected_frequency = n_trials / 2
observed_frequency = np . sum ( samples , axis = 0 )
assert observed_frequency . shape == space . shape
# As this is a binary space, then we can be lazy in the variance as the np.square is symmetric for the 0 and 1 categories
variance = (
2 * np . square ( observed_frequency - expected_frequency ) / expected_frequency
)
assert variance . shape == space . shape
2023-02-06 14:30:09 +03:30
assert np . all ( variance < scipy . stats . chi2 . isf ( ALPHA , df = 1 ) )
2022-06-26 23:23:15 +01:00
elif isinstance ( space , MultiDiscrete ) :
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
def _generate_frequency ( dim , func ) :
if isinstance ( dim , np . ndarray ) :
return np . array (
[ _generate_frequency ( sub_dim , func ) for sub_dim in dim ] ,
dtype = object ,
)
else :
return func ( dim )
def _update_observed_frequency ( obs_sample , obs_freq ) :
if isinstance ( obs_sample , np . ndarray ) :
for sub_sample , sub_freq in zip ( obs_sample , obs_freq ) :
_update_observed_frequency ( sub_sample , sub_freq )
else :
obs_freq [ obs_sample ] + = 1
expected_frequency = _generate_frequency (
space . nvec , lambda dim : np . ones ( dim ) * n_trials / dim
)
observed_frequency = _generate_frequency ( space . nvec , lambda dim : np . zeros ( dim ) )
for sample in samples :
2023-06-16 16:36:42 +02:00
_update_observed_frequency ( sample - space . start , observed_frequency )
2022-06-26 23:23:15 +01:00
def _chi_squared_test ( dim , exp_freq , obs_freq ) :
if isinstance ( dim , np . ndarray ) :
for sub_dim , sub_exp_freq , sub_obs_freq in zip ( dim , exp_freq , obs_freq ) :
_chi_squared_test ( sub_dim , sub_exp_freq , sub_obs_freq )
else :
assert exp_freq . shape == ( dim , ) and obs_freq . shape == ( dim , )
assert np . sum ( obs_freq ) == n_trials
assert np . sum ( exp_freq ) == n_trials
_variance = np . sum ( np . square ( exp_freq - obs_freq ) / exp_freq )
_degrees_of_freedom = dim - 1
2023-02-06 14:30:09 +03:30
assert _variance < scipy . stats . chi2 . isf ( ALPHA , df = _degrees_of_freedom )
2022-06-26 23:23:15 +01:00
_chi_squared_test ( space . nvec , expected_frequency , observed_frequency )
2022-09-03 22:56:29 +01:00
elif isinstance ( space , Text ) :
expected_frequency = (
np . ones ( len ( space . character_set ) )
* n_trials
* ( space . min_length + ( space . max_length - space . min_length ) / 2 )
/ len ( space . character_set )
)
observed_frequency = np . zeros ( len ( space . character_set ) )
for sample in samples :
for x in sample :
observed_frequency [ space . character_index ( x ) ] + = 1
degrees_of_freedom = len ( space . character_set ) - 1
assert observed_frequency . shape == expected_frequency . shape
assert np . sum ( observed_frequency ) == sum ( len ( sample ) for sample in samples )
variance = np . sum (
np . square ( expected_frequency - observed_frequency ) / expected_frequency
)
2023-02-06 14:30:09 +03:30
assert variance < scipy . stats . chi2 . isf ( ALPHA , df = degrees_of_freedom )
2022-09-03 22:56:29 +01:00
else :
raise NotImplementedError ( f " Unknown sample testing for { type ( space ) } " )
2023-02-06 14:30:09 +03:30
def ks_test ( sample , low , high , bounded_below , bounded_above ) :
""" Perform Kolmogorov-Smirnov test on the sample. Automatically picks the
distribution to test against based on the bounds .
"""
if bounded_below and bounded_above :
# X ~ U(low, high)
dist = scipy . stats . uniform ( low , high - low )
elif bounded_below and not bounded_above :
# X ~ low + Exp(1.0)
# => X - low ~ Exp(1.0)
dist = scipy . stats . expon
sample = sample - low
elif not bounded_below and bounded_above :
# X ~ high - Exp(1.0)
# => high - X ~ Exp(1.0)
dist = scipy . stats . expon
sample = high - sample
else :
# X ~ N(0.0, 1.0)
dist = scipy . stats . norm
_ , p_value = scipy . stats . kstest ( sample , dist . cdf )
assert p_value > = ALPHA
def chi2_test ( sample , low , high , bounded_below , bounded_above ) :
""" Perform chi-squared test on the sample. Automatically picks the distribution
to test against based on the bounds .
"""
( n_trials , ) = sample . shape
if bounded_below and bounded_above :
# X ~ U(low, high)
2024-06-28 18:48:01 +02:00
degrees_of_freedom = int ( high ) - int ( low ) + 1
2023-02-06 14:30:09 +03:30
observed_frequency = np . bincount ( sample - low , minlength = degrees_of_freedom )
assert observed_frequency . shape == ( degrees_of_freedom , )
expected_frequency = np . ones ( degrees_of_freedom ) * n_trials / degrees_of_freedom
elif bounded_below and not bounded_above :
# X ~ low + Geom(1 - e^-1)
# => X - low ~ Geom(1 - e^-1)
dist = scipy . stats . geom ( 1 - 1 / np . e )
observed_frequency = np . bincount ( sample - low )
x = np . arange ( len ( observed_frequency ) )
expected_frequency = dist . pmf ( x + 1 ) * n_trials
expected_frequency [ - 1 ] + = n_trials - np . sum ( expected_frequency )
elif not bounded_below and bounded_above :
# X ~ high - Geom(1 - e^-1)
# => high - X ~ Geom(1 - e^-1)
dist = scipy . stats . geom ( 1 - 1 / np . e )
observed_frequency = np . bincount ( high - sample )
x = np . arange ( len ( observed_frequency ) )
expected_frequency = dist . pmf ( x + 1 ) * n_trials
expected_frequency [ - 1 ] + = n_trials - np . sum ( expected_frequency )
else :
# X ~ floor(N(0.0, 1.0)
# => pmf(x) = cdf(x + 1) - cdf(x)
lowest = np . min ( sample )
observed_frequency = np . bincount ( sample - lowest )
normal_dist = scipy . stats . norm ( 0 , 1 )
x = lowest + np . arange ( len ( observed_frequency ) )
expected_frequency = normal_dist . cdf ( x + 1 ) - normal_dist . cdf ( x )
expected_frequency [ 0 ] + = normal_dist . cdf ( lowest )
expected_frequency * = n_trials
expected_frequency [ - 1 ] + = n_trials - np . sum ( expected_frequency )
assert observed_frequency . shape == expected_frequency . shape
variance = np . sum (
np . square ( expected_frequency - observed_frequency ) / expected_frequency
)
degrees_of_freedom = len ( observed_frequency ) - 1
critical_value = scipy . stats . chi2 . isf ( ALPHA , df = degrees_of_freedom )
assert variance < critical_value
def binary_chi2_test ( sample , low , high , bounded_below , bounded_above ) :
""" Perform Chi-squared test on boolean samples. """
assert bounded_below
assert bounded_above
( n_trials , ) = sample . shape
if low == high == 0 :
assert np . all ( sample == 0 )
elif low == high == 1 :
assert np . all ( sample == 1 )
else :
expected_frequency = n_trials / 2
observed_frequency = np . sum ( sample )
# we can be lazy in the variance as the np.square is symmetric for the 0 and 1 categories
variance = (
2 * np . square ( observed_frequency - expected_frequency ) / expected_frequency
)
critical_value = scipy . stats . chi2 . isf ( ALPHA , df = 1 )
assert variance < critical_value
2022-09-03 22:56:29 +01:00
SAMPLE_MASK_RNG , _ = seeding . np_random ( 1 )
2022-06-26 23:23:15 +01:00
@pytest.mark.parametrize (
" space,mask " ,
2022-09-03 22:56:29 +01:00
itertools . zip_longest (
TESTING_FUNDAMENTAL_SPACES ,
[
# Discrete
np . array ( [ 1 , 1 , 0 ] , dtype = np . int8 ) ,
np . array ( [ 0 , 0 , 0 ] , dtype = np . int8 ) ,
# Box
None ,
None ,
None ,
None ,
None ,
2023-02-06 14:30:09 +03:30
None ,
None ,
None ,
2022-09-03 22:56:29 +01:00
# Multi-discrete
( np . array ( [ 1 , 1 ] , dtype = np . int8 ) , np . array ( [ 0 , 0 ] , dtype = np . int8 ) ) ,
2022-06-26 23:23:15 +01:00
(
2022-09-03 22:56:29 +01:00
( np . array ( [ 1 , 0 ] , dtype = np . int8 ) , np . array ( [ 0 , 1 , 1 ] , dtype = np . int8 ) ) ,
( np . array ( [ 1 , 1 , 0 ] , dtype = np . int8 ) , np . array ( [ 0 , 1 ] , dtype = np . int8 ) ) ,
2022-06-26 23:23:15 +01:00
) ,
2023-06-16 16:36:42 +02:00
( np . array ( [ 1 , 1 ] , dtype = np . int8 ) , np . array ( [ 0 , 0 ] , dtype = np . int8 ) ) ,
(
( np . array ( [ 1 , 0 ] , dtype = np . int8 ) , np . array ( [ 0 , 1 , 1 ] , dtype = np . int8 ) ) ,
( np . array ( [ 1 , 1 , 0 ] , dtype = np . int8 ) , np . array ( [ 0 , 1 ] , dtype = np . int8 ) ) ,
) ,
2022-09-03 22:56:29 +01:00
# Multi-binary
np . array ( [ 0 , 1 , 0 , 1 , 0 , 2 , 1 , 1 ] , dtype = np . int8 ) ,
np . array ( [ [ 0 , 1 , 2 ] , [ 0 , 2 , 1 ] ] , dtype = np . int8 ) ,
# Text
( None , SAMPLE_MASK_RNG . integers ( low = 0 , high = 2 , size = 62 , dtype = np . int8 ) ) ,
( 4 , SAMPLE_MASK_RNG . integers ( low = 0 , high = 2 , size = 62 , dtype = np . int8 ) ) ,
( None , np . array ( [ 1 , 1 , 0 , 1 , 0 , 0 ] , dtype = np . int8 ) ) ,
] ,
) ,
ids = TESTING_FUNDAMENTAL_SPACES_IDS ,
2022-06-26 23:23:15 +01:00
)
2022-09-03 22:56:29 +01:00
def test_space_sample_mask ( space : Space , mask , n_trials : int = 100 ) :
""" Tests that the sampling a space with a mask has the expected distribution.
The implemented code is similar to the ` test_space_sample ` that considers the mask applied .
"""
if isinstance ( space , Box ) :
# The box space can't have a sample mask
assert mask is None
return
assert mask is not None
2022-06-26 23:23:15 +01:00
space . seed ( 1 )
samples = np . array ( [ space . sample ( mask ) for _ in range ( n_trials ) ] )
if isinstance ( space , Discrete ) :
if np . any ( mask == 1 ) :
expected_frequency = np . ones ( space . n ) * ( n_trials / np . sum ( mask ) ) * mask
2019-06-28 18:54:31 -04:00
else :
2022-06-26 23:23:15 +01:00
expected_frequency = np . zeros ( space . n )
expected_frequency [ 0 ] = n_trials
observed_frequency = np . zeros ( space . n )
for sample in samples :
observed_frequency [ sample - space . start ] + = 1
degrees_of_freedom = max ( np . sum ( mask ) - 1 , 0 )
assert observed_frequency . shape == expected_frequency . shape
assert np . sum ( observed_frequency ) == n_trials
assert np . sum ( expected_frequency ) == n_trials
variance = np . sum (
np . square ( expected_frequency - observed_frequency )
/ np . clip ( expected_frequency , 1 , None )
)
2023-02-06 14:30:09 +03:30
if degrees_of_freedom == 0 :
assert variance == 0
else :
assert variance < scipy . stats . chi2 . isf ( ALPHA , df = degrees_of_freedom )
2022-06-26 23:23:15 +01:00
elif isinstance ( space , MultiBinary ) :
2022-09-03 22:56:29 +01:00
expected_frequency = (
np . ones ( space . shape ) * np . where ( mask == 2 , 0.5 , mask ) * n_trials
)
2022-06-26 23:23:15 +01:00
observed_frequency = np . sum ( samples , axis = 0 )
assert space . shape == expected_frequency . shape == observed_frequency . shape
variance = (
2
* np . square ( observed_frequency - expected_frequency )
/ np . clip ( expected_frequency , 1 , None )
)
assert variance . shape == space . shape
2023-02-06 14:30:09 +03:30
assert np . all ( variance < scipy . stats . chi2 . isf ( ALPHA , df = 1 ) )
2022-06-26 23:23:15 +01:00
elif isinstance ( space , MultiDiscrete ) :
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
2025-06-07 17:57:58 +01:00
def _generate_frequency ( _dim : np . ndarray | int , _mask , func : Callable ) - > list :
2022-06-26 23:23:15 +01:00
if isinstance ( _dim , np . ndarray ) :
return [
_generate_frequency ( sub_dim , sub_mask , func )
for sub_dim , sub_mask in zip ( _dim , _mask )
]
else :
return func ( _dim , _mask )
def _update_observed_frequency ( obs_sample , obs_freq ) :
if isinstance ( obs_sample , np . ndarray ) :
for sub_sample , sub_freq in zip ( obs_sample , obs_freq ) :
_update_observed_frequency ( sub_sample , sub_freq )
else :
obs_freq [ obs_sample ] + = 1
def _exp_freq_fn ( _dim : int , _mask : np . ndarray ) :
if np . any ( _mask == 1 ) :
assert _dim == len ( _mask )
return np . ones ( _dim ) * ( n_trials / np . sum ( _mask ) ) * _mask
else :
freq = np . zeros ( _dim )
freq [ 0 ] = n_trials
return freq
expected_frequency = _generate_frequency (
space . nvec , mask , lambda dim , _mask : _exp_freq_fn ( dim , _mask )
)
observed_frequency = _generate_frequency (
space . nvec , mask , lambda dim , _ : np . zeros ( dim )
)
for sample in samples :
2023-06-16 16:36:42 +02:00
_update_observed_frequency ( sample - space . start , observed_frequency )
2022-06-26 23:23:15 +01:00
def _chi_squared_test ( dim , _mask , exp_freq , obs_freq ) :
if isinstance ( dim , np . ndarray ) :
for sub_dim , sub_mask , sub_exp_freq , sub_obs_freq in zip (
dim , _mask , exp_freq , obs_freq
) :
_chi_squared_test ( sub_dim , sub_mask , sub_exp_freq , sub_obs_freq )
else :
assert exp_freq . shape == ( dim , ) and obs_freq . shape == ( dim , )
assert np . sum ( obs_freq ) == n_trials
assert np . sum ( exp_freq ) == n_trials
_variance = np . sum (
np . square ( exp_freq - obs_freq ) / np . clip ( exp_freq , 1 , None )
)
_degrees_of_freedom = max ( np . sum ( _mask ) - 1 , 0 )
2023-02-06 14:30:09 +03:30
if _degrees_of_freedom == 0 :
assert _variance == 0
else :
assert _variance < scipy . stats . chi2 . isf (
ALPHA , df = _degrees_of_freedom
)
2022-06-26 23:23:15 +01:00
_chi_squared_test ( space . nvec , mask , expected_frequency , observed_frequency )
2022-09-03 22:56:29 +01:00
elif isinstance ( space , Text ) :
length , charlist_mask = mask
2019-03-23 23:18:19 -07:00
2022-09-03 22:56:29 +01:00
if length is None :
expected_length = (
space . min_length + ( space . max_length - space . min_length ) / 2
)
else :
expected_length = length
if np . any ( charlist_mask == 1 ) :
expected_frequency = (
np . ones ( len ( space . character_set ) )
* n_trials
* expected_length
/ np . sum ( charlist_mask )
* charlist_mask
)
else :
expected_frequency = np . zeros ( len ( space . character_set ) )
2019-03-23 23:18:19 -07:00
2022-09-03 22:56:29 +01:00
observed_frequency = np . zeros ( len ( space . character_set ) )
for sample in samples :
for char in sample :
observed_frequency [ space . character_index ( char ) ] + = 1
2021-09-13 20:08:01 +02:00
2022-09-03 22:56:29 +01:00
degrees_of_freedom = max ( np . sum ( charlist_mask ) - 1 , 0 )
2021-09-02 22:15:34 +08:00
2022-09-03 22:56:29 +01:00
assert observed_frequency . shape == expected_frequency . shape
assert np . sum ( observed_frequency ) == sum ( len ( sample ) for sample in samples )
2021-09-02 22:15:34 +08:00
2022-09-03 22:56:29 +01:00
variance = np . sum (
np . square ( expected_frequency - observed_frequency )
/ np . clip ( expected_frequency , 1 , None )
2021-09-02 22:15:34 +08:00
)
2023-02-06 14:30:09 +03:30
if degrees_of_freedom == 0 :
assert variance == 0
2022-09-03 22:56:29 +01:00
else :
2023-02-06 14:30:09 +03:30
assert variance < scipy . stats . chi2 . isf ( ALPHA , df = degrees_of_freedom )
2022-09-03 22:56:29 +01:00
else :
raise NotImplementedError ( )
2021-09-12 00:54:52 +08:00
2022-09-03 22:56:29 +01:00
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
2021-09-02 22:15:34 +08:00
def test_seed_reproducibility ( space ) :
2022-09-03 22:56:29 +01:00
""" Test that the set the space seed will reproduce the same samples. """
space_1 = space
space_2 = copy . deepcopy ( space )
for seed in range ( 5 ) :
2024-04-28 16:10:35 +01:00
assert data_equivalence ( space_1 . seed ( seed ) , space_2 . seed ( seed ) )
2022-09-03 22:56:29 +01:00
# With the same seed, the two spaces should be identical
assert all (
data_equivalence ( space_1 . sample ( ) , space_2 . sample ( ) ) for _ in range ( 10 )
)
2021-09-02 22:15:34 +08:00
2024-04-28 16:10:35 +01:00
assert not data_equivalence ( space_1 . seed ( 123 ) , space_2 . seed ( 456 ) )
2022-09-03 22:56:29 +01:00
# Due to randomness, it is difficult to test that random seeds produce different answers
# Therefore, taking 10 samples and checking that they are not all the same.
assert not all (
data_equivalence ( space_1 . sample ( ) , space_2 . sample ( ) ) for _ in range ( 10 )
)
2021-09-02 22:15:34 +08:00
2022-09-03 22:56:29 +01:00
SPACE_CLS = list ( dict . fromkeys ( type ( space ) for space in TESTING_SPACES ) )
SPACE_KWARGS = [
{ " n " : 3 } , # Discrete
{ " low " : 1 , " high " : 10 } , # Box
{ " nvec " : [ 3 , 2 ] } , # MultiDiscrete
{ " n " : 2 } , # MultiBinary
{ " max_length " : 5 } , # Text
2022-09-03 23:39:23 +01:00
{ " spaces " : ( Discrete ( 3 ) , Discrete ( 2 ) ) } , # Tuple
{ " spaces " : { " a " : Discrete ( 3 ) , " b " : Discrete ( 2 ) } } , # Dict
{ " node_space " : Discrete ( 4 ) , " edge_space " : Discrete ( 3 ) } , # Graph
{ " space " : Discrete ( 4 ) } , # Sequence
2024-03-11 13:30:50 +01:00
{ " spaces " : ( Discrete ( 3 ) , Discrete ( 5 ) ) } , # OneOf
2022-09-03 22:56:29 +01:00
]
assert len ( SPACE_CLS ) == len ( SPACE_KWARGS )
2022-01-11 04:45:41 +00:00
@pytest.mark.parametrize (
2022-09-03 22:56:29 +01:00
" space_cls,kwarg " ,
list ( zip ( SPACE_CLS , SPACE_KWARGS ) ) ,
ids = [ f " { space_cls } " for space_cls in SPACE_CLS ] ,
2022-01-11 04:45:41 +00:00
)
2022-09-03 22:56:29 +01:00
def test_seed_np_random ( space_cls , kwarg ) :
""" During initialisation of a space, a rng instance can be passed to the space.
2022-01-11 04:45:41 +00:00
2022-09-03 22:56:29 +01:00
Test that the space ' s `np_random` is the rng instance
"""
rng , _ = seeding . np_random ( 123 )
2022-03-02 11:14:59 -05:00
2022-09-03 22:56:29 +01:00
space = space_cls ( seed = rng , * * kwarg )
assert space . np_random is rng
2022-03-02 11:14:59 -05:00
2022-09-03 22:56:29 +01:00
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
def test_sample_contains ( space ) :
""" Test that samples are contained within the space.
2022-03-02 11:14:59 -05:00
2022-09-03 22:56:29 +01:00
Then test that for all other spaces , we test that an error is not raise with a sample and a bool is returned .
As other spaces can be contained with this space , we cannot test that the contains is always true or false .
"""
for _ in range ( 10 ) :
sample = space . sample ( )
assert sample in space
assert space . contains ( sample )
2022-03-02 11:14:59 -05:00
2022-09-03 22:56:29 +01:00
for other_space in TESTING_SPACES :
2022-12-19 12:53:06 +00:00
sample = other_space . sample ( )
space_contains = other_space . contains ( sample )
assert isinstance (
space_contains , bool
) , f " { space_contains } , { type ( space_contains ) } , { space } , { other_space } , { sample } "
2022-05-31 23:53:13 -04:00
2022-09-03 22:56:29 +01:00
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
def test_repr ( space ) :
assert isinstance ( str ( space ) , str )
2022-05-31 23:53:13 -04:00
2022-09-03 22:56:29 +01:00
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
def test_space_pickling ( space ) :
""" Tests the spaces can be pickled with the unpickled version being equivalent to the original. """
space . seed ( 0 )
2022-03-02 23:38:26 +08:00
# Pickle and unpickle with a string
2022-09-03 22:56:29 +01:00
pickled_space = pickle . dumps ( space )
unpickled_space = pickle . loads ( pickled_space )
assert space == unpickled_space
2022-03-02 23:38:26 +08:00
# Pickle and unpickle with a file
with tempfile . TemporaryFile ( ) as f :
pickle . dump ( space , f )
f . seek ( 0 )
2022-09-03 22:56:29 +01:00
file_unpickled_space = pickle . load ( f )
assert space == file_unpickled_space
2022-03-02 23:38:26 +08:00
2022-09-03 22:56:29 +01:00
# Check that space samples are the same
space_sample = space . sample ( )
unpickled_sample = unpickled_space . sample ( )
file_unpickled_sample = file_unpickled_space . sample ( )
assert data_equivalence ( space_sample , unpickled_sample )
assert data_equivalence ( space_sample , file_unpickled_sample )
2024-04-28 16:10:35 +01:00
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
@pytest.mark.parametrize ( " initial_seed " , [ None , 123 ] )
def test_space_seeding_output ( space , initial_seed , num_samples = 5 ) :
seeding_values = space . seed ( initial_seed )
samples = [ space . sample ( ) for _ in range ( num_samples ) ]
reseeded_values = space . seed ( seeding_values )
resamples = [ space . sample ( ) for _ in range ( num_samples ) ]
assert data_equivalence ( seeding_values , reseeded_values )
assert data_equivalence ( samples , resamples )
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
def test_invalid_space_seed ( space ) :
with pytest . raises ( ( ValueError , TypeError , Error ) ) :
space . seed ( " abc " )