Files
Gymnasium/tests/experimental/vector/utils/test_space_utils.py

157 lines
6.4 KiB
Python

"""Testing `gymnasium.experimental.vector.utils.space_utils` functions."""
import copy
import re
from typing import Iterable
import pytest
from gymnasium import Space
from gymnasium.error import CustomSpaceError
from gymnasium.experimental.vector.utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
from gymnasium.spaces import Tuple
from gymnasium.utils.env_checker import data_equivalence
from tests.experimental.vector.utils.utils import is_rng_equal
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [1, 4], ids=[f"n={n}" for n in [1, 4]])
def test_batch_space_concatenate_iterate_create_empty_array(space: Space, n: int):
"""Test all space_utils functions using them together."""
# Batch the space and create a sample
batched_space = batch_space(space, n)
assert isinstance(batched_space, Space)
batched_sample = batched_space.sample()
assert batched_sample in batched_space
# Check the batched samples are within the original space
iterated_samples = iterate(batched_space, batched_sample)
assert isinstance(iterated_samples, Iterable)
unbatched_samples = list(iterated_samples)
assert len(unbatched_samples) == n
assert all(item in space for item in unbatched_samples)
# Create an empty array and check that space is within the batch space
array = create_empty_array(space, n)
# We do not check that the generated array is within the batched_space.
# assert array in batched_space
unbatched_array = list(iterate(batched_space, array))
assert len(unbatched_array) == n
# assert all(item in space for item in unbatched_array)
# Generate samples from the original space and concatenate using array into a single object
space_samples = [space.sample() for _ in range(n)]
assert all(item in space for item in space_samples)
concatenated_samples_array = concatenate(space, space_samples, array)
# `concatenate` does not necessarily use the out object as the returned object
# assert out is concatenated_samples_array
assert concatenated_samples_array in batched_space
# Iterate over the samples and check that the concatenated samples == original samples
iterated_samples = iterate(batched_space, concatenated_samples_array)
assert isinstance(iterated_samples, Iterable)
unbatched_samples = list(iterated_samples)
assert len(unbatched_samples) == n
for unbatched_sample, original_sample in zip(unbatched_samples, space_samples):
assert data_equivalence(unbatched_sample, original_sample)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_batch_space_deterministic(space: Space, n: int, base_seed: int):
"""Tests the batched spaces are deterministic by using a copied version."""
# Copy the spaces and check that the np_random are not reference equal
space_a = space
space_a.seed(base_seed)
space_b = copy.deepcopy(space_a)
is_rng_equal(space_a.np_random, space_b.np_random)
assert space_a.np_random is not space_b.np_random
# Batch the spaces and check that the np_random are not reference equal
space_a_batched = batch_space(space_a, n)
space_b_batched = batch_space(space_b, n)
is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
assert space_a_batched.np_random is not space_b_batched.np_random
# Create that the batched space is not reference equal to the origin spaces
assert space_a.np_random is not space_a_batched.np_random
# Check that batched space a and b random number generator are not effected by the original space
space_a.sample()
space_a_batched_sample = space_a_batched.sample()
space_b_batched_sample = space_b_batched.sample()
for a_sample, b_sample in zip(
iterate(space_a_batched, space_a_batched_sample),
iterate(space_b_batched, space_b_batched_sample),
):
assert data_equivalence(a_sample, b_sample)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_batch_space_different_samples(space: Space, n: int, base_seed: int):
"""Tests that the rng values produced at each index are different to prevent if the rng is copied for each subspace."""
space.seed(base_seed)
batched_space = batch_space(space, n)
assert space.np_random is not batched_space.np_random
is_rng_equal(space.np_random, batched_space.np_random)
batched_sample = batched_space.sample()
unbatched_samples = list(iterate(batched_space, batched_sample))
assert len(unbatched_samples) == n
assert all(item in space for item in unbatched_samples)
assert not all(
data_equivalence(element, unbatched_samples[0]) for element in unbatched_samples
), unbatched_samples
@pytest.mark.parametrize(
"func, n_args",
[(batch_space, 1), (concatenate, 2), (iterate, 1), (create_empty_array, 2)],
)
def test_non_space(func, n_args):
"""Test spaces for vector utility functions on the error produced with unknown spaces."""
args = [None for _ in range(n_args)]
func_name = func.__name__
with pytest.raises(
TypeError,
match=re.escape(
f"The space provided to `{func_name}` is not a gymnasium Space instance, type: <class 'str'>, space"
),
):
func("space", *args)
def test_custom_space():
"""Test custom spaces with space util functions."""
custom_space = CustomSpace()
batched_space = batch_space(custom_space, n=2)
assert batched_space == Tuple([custom_space, custom_space])
with pytest.raises(
CustomSpaceError,
match=re.escape(
"Space of type `<class 'tests.spaces.utils.CustomSpace'>` doesn't have an registered `iterate` function. Register `<class 'tests.spaces.utils.CustomSpace'>` for `iterate` to support it."
),
):
iterate(custom_space, None)
concatenated_items = concatenate(custom_space, (None, None), out=None)
assert concatenated_items == (None, None)
empty_array = create_empty_array(custom_space)
assert empty_array is None