2023-02-18 21:08:08 +00:00
""" Testing `gymnasium.experimental.vector.utils.space_utils` functions. """
import copy
2023-02-22 15:05:58 +00:00
import re
from typing import Iterable
2023-02-18 21:08:08 +00:00
import pytest
2023-02-22 15:05:58 +00:00
from gymnasium import Space
from gymnasium . error import CustomSpaceError
2023-02-18 21:08:08 +00:00
from gymnasium . experimental . vector . utils import (
batch_space ,
concatenate ,
create_empty_array ,
iterate ,
)
2023-02-22 15:05:58 +00:00
from gymnasium . spaces import Tuple
from gymnasium . utils . env_checker import data_equivalence
from tests . experimental . vector . utils . utils import is_rng_equal
from tests . spaces . utils import TESTING_SPACES , TESTING_SPACES_IDS , CustomSpace
2023-02-18 21:08:08 +00:00
2023-02-22 15:05:58 +00:00
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
@pytest.mark.parametrize ( " n " , [ 1 , 4 ] , ids = [ f " n= { n } " for n in [ 1 , 4 ] ] )
def test_batch_space_concatenate_iterate_create_empty_array ( space : Space , n : int ) :
""" Test all space_utils functions using them together. """
# Batch the space and create a sample
2023-02-18 21:08:08 +00:00
batched_space = batch_space ( space , n )
2023-02-22 15:05:58 +00:00
assert isinstance ( batched_space , Space )
2023-02-18 21:08:08 +00:00
batched_sample = batched_space . sample ( )
2023-02-22 15:05:58 +00:00
assert batched_sample in batched_space
# Check the batched samples are within the original space
iterated_samples = iterate ( batched_space , batched_sample )
assert isinstance ( iterated_samples , Iterable )
unbatched_samples = list ( iterated_samples )
assert len ( unbatched_samples ) == n
assert all ( item in space for item in unbatched_samples )
# Create an empty array and check that space is within the batch space
array = create_empty_array ( space , n )
# We do not check that the generated array is within the batched_space.
# assert array in batched_space
unbatched_array = list ( iterate ( batched_space , array ) )
assert len ( unbatched_array ) == n
# assert all(item in space for item in unbatched_array)
# Generate samples from the original space and concatenate using array into a single object
space_samples = [ space . sample ( ) for _ in range ( n ) ]
assert all ( item in space for item in space_samples )
concatenated_samples_array = concatenate ( space , space_samples , array )
# `concatenate` does not necessarily use the out object as the returned object
# assert out is concatenated_samples_array
assert concatenated_samples_array in batched_space
# Iterate over the samples and check that the concatenated samples == original samples
iterated_samples = iterate ( batched_space , concatenated_samples_array )
assert isinstance ( iterated_samples , Iterable )
unbatched_samples = list ( iterated_samples )
assert len ( unbatched_samples ) == n
for unbatched_sample , original_sample in zip ( unbatched_samples , space_samples ) :
assert data_equivalence ( unbatched_sample , original_sample )
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
2023-02-18 21:08:08 +00:00
@pytest.mark.parametrize ( " n " , [ 1 , 2 , 5 ] , ids = [ f " n= { n } " for n in [ 1 , 2 , 5 ] ] )
@pytest.mark.parametrize (
" base_seed " , [ 123 , 456 ] , ids = [ f " seed= { base_seed } " for base_seed in [ 123 , 456 ] ]
)
2023-02-22 15:05:58 +00:00
def test_batch_space_deterministic ( space : Space , n : int , base_seed : int ) :
2023-02-18 21:08:08 +00:00
""" Tests the batched spaces are deterministic by using a copied version. """
# Copy the spaces and check that the np_random are not reference equal
space_a = space
space_a . seed ( base_seed )
space_b = copy . deepcopy ( space_a )
2023-02-22 15:05:58 +00:00
is_rng_equal ( space_a . np_random , space_b . np_random )
2023-02-18 21:08:08 +00:00
assert space_a . np_random is not space_b . np_random
# Batch the spaces and check that the np_random are not reference equal
space_a_batched = batch_space ( space_a , n )
space_b_batched = batch_space ( space_b , n )
2023-02-22 15:05:58 +00:00
is_rng_equal ( space_a_batched . np_random , space_b_batched . np_random )
2023-02-18 21:08:08 +00:00
assert space_a_batched . np_random is not space_b_batched . np_random
# Create that the batched space is not reference equal to the origin spaces
assert space_a . np_random is not space_a_batched . np_random
# Check that batched space a and b random number generator are not effected by the original space
space_a . sample ( )
space_a_batched_sample = space_a_batched . sample ( )
space_b_batched_sample = space_b_batched . sample ( )
for a_sample , b_sample in zip (
iterate ( space_a_batched , space_a_batched_sample ) ,
iterate ( space_b_batched , space_b_batched_sample ) ,
) :
2023-02-22 15:05:58 +00:00
assert data_equivalence ( a_sample , b_sample )
@pytest.mark.parametrize ( " space " , TESTING_SPACES , ids = TESTING_SPACES_IDS )
@pytest.mark.parametrize ( " n " , [ 4 , 5 ] , ids = [ f " n= { n } " for n in [ 4 , 5 ] ] )
@pytest.mark.parametrize (
" base_seed " , [ 123 , 456 ] , ids = [ f " seed= { base_seed } " for base_seed in [ 123 , 456 ] ]
)
def test_batch_space_different_samples ( space : Space , n : int , base_seed : int ) :
""" Tests that the rng values produced at each index are different to prevent if the rng is copied for each subspace. """
space . seed ( base_seed )
batched_space = batch_space ( space , n )
assert space . np_random is not batched_space . np_random
is_rng_equal ( space . np_random , batched_space . np_random )
batched_sample = batched_space . sample ( )
unbatched_samples = list ( iterate ( batched_space , batched_sample ) )
assert len ( unbatched_samples ) == n
assert all ( item in space for item in unbatched_samples )
assert not all (
data_equivalence ( element , unbatched_samples [ 0 ] ) for element in unbatched_samples
) , unbatched_samples
@pytest.mark.parametrize (
" func, n_args " ,
[ ( batch_space , 1 ) , ( concatenate , 2 ) , ( iterate , 1 ) , ( create_empty_array , 2 ) ] ,
)
def test_non_space ( func , n_args ) :
""" Test spaces for vector utility functions on the error produced with unknown spaces. """
args = [ None for _ in range ( n_args ) ]
func_name = func . __name__
with pytest . raises (
TypeError ,
match = re . escape (
f " The space provided to ` { func_name } ` is not a gymnasium Space instance, type: <class ' str ' >, space "
) ,
) :
func ( " space " , * args )
def test_custom_space ( ) :
""" Test custom spaces with space util functions. """
custom_space = CustomSpace ( )
batched_space = batch_space ( custom_space , n = 2 )
assert batched_space == Tuple ( [ custom_space , custom_space ] )
with pytest . raises (
CustomSpaceError ,
match = re . escape (
" Space of type `<class ' tests.spaces.utils.CustomSpace ' >` doesn ' t have an registered `iterate` function. Register `<class ' tests.spaces.utils.CustomSpace ' >` for `iterate` to support it. "
) ,
) :
iterate ( custom_space , None )
concatenated_items = concatenate ( custom_space , ( None , None ) , out = None )
assert concatenated_items == ( None , None )
empty_array = create_empty_array ( custom_space )
assert empty_array is None