"""Testing `gymnasium.experimental.vector.utils.space_utils` functions.""" import copy import re from typing import Iterable import pytest from gymnasium import Space from gymnasium.error import CustomSpaceError from gymnasium.experimental.vector.utils import ( batch_space, concatenate, create_empty_array, iterate, ) from gymnasium.spaces import Tuple from gymnasium.utils.env_checker import data_equivalence from tests.experimental.vector.utils.utils import is_rng_equal from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) @pytest.mark.parametrize("n", [1, 4], ids=[f"n={n}" for n in [1, 4]]) def test_batch_space_concatenate_iterate_create_empty_array(space: Space, n: int): """Test all space_utils functions using them together.""" # Batch the space and create a sample batched_space = batch_space(space, n) assert isinstance(batched_space, Space) batched_sample = batched_space.sample() assert batched_sample in batched_space # Check the batched samples are within the original space iterated_samples = iterate(batched_space, batched_sample) assert isinstance(iterated_samples, Iterable) unbatched_samples = list(iterated_samples) assert len(unbatched_samples) == n assert all(item in space for item in unbatched_samples) # Create an empty array and check that space is within the batch space array = create_empty_array(space, n) # We do not check that the generated array is within the batched_space. # assert array in batched_space unbatched_array = list(iterate(batched_space, array)) assert len(unbatched_array) == n # assert all(item in space for item in unbatched_array) # Generate samples from the original space and concatenate using array into a single object space_samples = [space.sample() for _ in range(n)] assert all(item in space for item in space_samples) concatenated_samples_array = concatenate(space, space_samples, array) # `concatenate` does not necessarily use the out object as the returned object # assert out is concatenated_samples_array assert concatenated_samples_array in batched_space # Iterate over the samples and check that the concatenated samples == original samples iterated_samples = iterate(batched_space, concatenated_samples_array) assert isinstance(iterated_samples, Iterable) unbatched_samples = list(iterated_samples) assert len(unbatched_samples) == n for unbatched_sample, original_sample in zip(unbatched_samples, space_samples): assert data_equivalence(unbatched_sample, original_sample) @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) @pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]]) @pytest.mark.parametrize( "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] ) def test_batch_space_deterministic(space: Space, n: int, base_seed: int): """Tests the batched spaces are deterministic by using a copied version.""" # Copy the spaces and check that the np_random are not reference equal space_a = space space_a.seed(base_seed) space_b = copy.deepcopy(space_a) is_rng_equal(space_a.np_random, space_b.np_random) assert space_a.np_random is not space_b.np_random # Batch the spaces and check that the np_random are not reference equal space_a_batched = batch_space(space_a, n) space_b_batched = batch_space(space_b, n) is_rng_equal(space_a_batched.np_random, space_b_batched.np_random) assert space_a_batched.np_random is not space_b_batched.np_random # Create that the batched space is not reference equal to the origin spaces assert space_a.np_random is not space_a_batched.np_random # Check that batched space a and b random number generator are not effected by the original space space_a.sample() space_a_batched_sample = space_a_batched.sample() space_b_batched_sample = space_b_batched.sample() for a_sample, b_sample in zip( iterate(space_a_batched, space_a_batched_sample), iterate(space_b_batched, space_b_batched_sample), ): assert data_equivalence(a_sample, b_sample) @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) @pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]]) @pytest.mark.parametrize( "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] ) def test_batch_space_different_samples(space: Space, n: int, base_seed: int): """Tests that the rng values produced at each index are different to prevent if the rng is copied for each subspace.""" space.seed(base_seed) batched_space = batch_space(space, n) assert space.np_random is not batched_space.np_random is_rng_equal(space.np_random, batched_space.np_random) batched_sample = batched_space.sample() unbatched_samples = list(iterate(batched_space, batched_sample)) assert len(unbatched_samples) == n assert all(item in space for item in unbatched_samples) assert not all( data_equivalence(element, unbatched_samples[0]) for element in unbatched_samples ), unbatched_samples @pytest.mark.parametrize( "func, n_args", [(batch_space, 1), (concatenate, 2), (iterate, 1), (create_empty_array, 2)], ) def test_non_space(func, n_args): """Test spaces for vector utility functions on the error produced with unknown spaces.""" args = [None for _ in range(n_args)] func_name = func.__name__ with pytest.raises( TypeError, match=re.escape( f"The space provided to `{func_name}` is not a gymnasium Space instance, type: , space" ), ): func("space", *args) def test_custom_space(): """Test custom spaces with space util functions.""" custom_space = CustomSpace() batched_space = batch_space(custom_space, n=2) assert batched_space == Tuple([custom_space, custom_space]) with pytest.raises( CustomSpaceError, match=re.escape( "Space of type `` doesn't have an registered `iterate` function. Register `` for `iterate` to support it." ), ): iterate(custom_space, None) concatenated_items = concatenate(custom_space, (None, None), out=None) assert concatenated_items == (None, None) empty_array = create_empty_array(custom_space) assert empty_array is None