Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)

Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com>
Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
Mark Towers
2024-08-29 16:52:43 +01:00
committed by GitHub
parent ad8734d89b
commit 64fac8e80b
10 changed files with 584 additions and 55 deletions

View File

@@ -4,13 +4,20 @@ import copy
import re
from typing import Iterable
import numpy as np
import pytest
from gymnasium import Space
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import Tuple
from gymnasium.spaces import Box, Tuple
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
from gymnasium.vector.utils import (
batch_differing_spaces,
batch_space,
concatenate,
create_empty_array,
iterate,
)
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace
from tests.vector.utils.utils import is_rng_equal
@@ -68,13 +75,13 @@ def test_batch_space_deterministic(space: Space, n: int, base_seed: int):
space_a = space
space_a.seed(base_seed)
space_b = copy.deepcopy(space_a)
is_rng_equal(space_a.np_random, space_b.np_random)
assert is_rng_equal(space_a.np_random, space_b.np_random)
assert space_a.np_random is not space_b.np_random
# Batch the spaces and check that the np_random are not reference equal
space_a_batched = batch_space(space_a, n)
space_b_batched = batch_space(space_b, n)
is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
assert space_a_batched.np_random is not space_b_batched.np_random
# Create that the batched space is not reference equal to the origin spaces
assert space_a.np_random is not space_a_batched.np_random
@@ -101,7 +108,7 @@ def test_batch_space_different_samples(space: Space, n: int, base_seed: int):
batched_space = batch_space(space, n)
assert space.np_random is not batched_space.np_random
is_rng_equal(space.np_random, batched_space.np_random)
assert is_rng_equal(space.np_random, batched_space.np_random)
batched_sample = batched_space.sample()
unbatched_samples = list(iterate(batched_space, batched_sample))
@@ -149,3 +156,68 @@ def test_custom_space():
empty_array = create_empty_array(custom_space)
assert empty_array is None
@pytest.mark.parametrize(
"spaces,expected_space",
[
(
(
Box(low=0, high=1, shape=(2,), dtype=np.float32),
Box(low=2, high=np.array([3, 5], dtype=np.float32)),
),
Box(low=np.array([[0, 0], [2, 2]]), high=np.array([[1, 1], [3, 5]])),
),
],
)
def test_varying_spaces(spaces: "list[Space]", expected_space):
"""Test the batch spaces function."""
batched_space = batch_differing_spaces(spaces)
assert batched_space == expected_space
batch_samples = batched_space.sample()
for sub_space, sub_sample in zip(spaces, iterate(batched_space, batch_samples)):
assert sub_sample in sub_space
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [1, 3])
def test_batch_differing_space_vs_batch_space(space, n):
"""Test the batch_spaces and batch_space functions."""
batched_space = batch_space(space, n)
batched_spaces = batch_differing_spaces([copy.deepcopy(space) for _ in range(n)])
assert batched_space == batched_spaces, f"{batched_space=}, {batched_spaces=}"
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_batch_differing_spaces_deterministic(space: Space, n: int, base_seed: int):
"""Tests the batched spaces are deterministic by using a copied version."""
# Copy the spaces and check that the np_random are not reference equal
space_a = space
space_a.seed(base_seed)
space_b = copy.deepcopy(space_a)
assert is_rng_equal(space_a.np_random, space_b.np_random)
assert space_a.np_random is not space_b.np_random
# Batch the spaces and check that the np_random are not reference equal
space_a_batched = batch_differing_spaces([space_a for _ in range(n)])
space_b_batched = batch_differing_spaces([space_b for _ in range(n)])
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
assert space_a_batched.np_random is not space_b_batched.np_random
# Create that the batched space is not reference equal to the origin spaces
assert space_a.np_random is not space_a_batched.np_random
# Check that batched space a and b random number generator are not effected by the original space
space_a.sample()
space_a_batched_sample = space_a_batched.sample()
space_b_batched_sample = space_b_batched.sample()
for a_sample, b_sample in zip(
iterate(space_a_batched, space_a_batched_sample),
iterate(space_b_batched, space_b_batched_sample),
):
assert data_equivalence(a_sample, b_sample)