mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)
Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com> Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
@@ -4,13 +4,20 @@ import copy
|
||||
import re
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium import Space
|
||||
from gymnasium.error import CustomSpaceError
|
||||
from gymnasium.spaces import Tuple
|
||||
from gymnasium.spaces import Box, Tuple
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||
from gymnasium.vector.utils import (
|
||||
batch_differing_spaces,
|
||||
batch_space,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
iterate,
|
||||
)
|
||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace
|
||||
from tests.vector.utils.utils import is_rng_equal
|
||||
|
||||
@@ -68,13 +75,13 @@ def test_batch_space_deterministic(space: Space, n: int, base_seed: int):
|
||||
space_a = space
|
||||
space_a.seed(base_seed)
|
||||
space_b = copy.deepcopy(space_a)
|
||||
is_rng_equal(space_a.np_random, space_b.np_random)
|
||||
assert is_rng_equal(space_a.np_random, space_b.np_random)
|
||||
assert space_a.np_random is not space_b.np_random
|
||||
|
||||
# Batch the spaces and check that the np_random are not reference equal
|
||||
space_a_batched = batch_space(space_a, n)
|
||||
space_b_batched = batch_space(space_b, n)
|
||||
is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||
assert space_a_batched.np_random is not space_b_batched.np_random
|
||||
# Create that the batched space is not reference equal to the origin spaces
|
||||
assert space_a.np_random is not space_a_batched.np_random
|
||||
@@ -101,7 +108,7 @@ def test_batch_space_different_samples(space: Space, n: int, base_seed: int):
|
||||
|
||||
batched_space = batch_space(space, n)
|
||||
assert space.np_random is not batched_space.np_random
|
||||
is_rng_equal(space.np_random, batched_space.np_random)
|
||||
assert is_rng_equal(space.np_random, batched_space.np_random)
|
||||
|
||||
batched_sample = batched_space.sample()
|
||||
unbatched_samples = list(iterate(batched_space, batched_sample))
|
||||
@@ -149,3 +156,68 @@ def test_custom_space():
|
||||
|
||||
empty_array = create_empty_array(custom_space)
|
||||
assert empty_array is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spaces,expected_space",
|
||||
[
|
||||
(
|
||||
(
|
||||
Box(low=0, high=1, shape=(2,), dtype=np.float32),
|
||||
Box(low=2, high=np.array([3, 5], dtype=np.float32)),
|
||||
),
|
||||
Box(low=np.array([[0, 0], [2, 2]]), high=np.array([[1, 1], [3, 5]])),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_varying_spaces(spaces: "list[Space]", expected_space):
|
||||
"""Test the batch spaces function."""
|
||||
batched_space = batch_differing_spaces(spaces)
|
||||
assert batched_space == expected_space
|
||||
|
||||
batch_samples = batched_space.sample()
|
||||
for sub_space, sub_sample in zip(spaces, iterate(batched_space, batch_samples)):
|
||||
assert sub_sample in sub_space
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
@pytest.mark.parametrize("n", [1, 3])
|
||||
def test_batch_differing_space_vs_batch_space(space, n):
|
||||
"""Test the batch_spaces and batch_space functions."""
|
||||
batched_space = batch_space(space, n)
|
||||
batched_spaces = batch_differing_spaces([copy.deepcopy(space) for _ in range(n)])
|
||||
|
||||
assert batched_space == batched_spaces, f"{batched_space=}, {batched_spaces=}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]])
|
||||
@pytest.mark.parametrize(
|
||||
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
|
||||
)
|
||||
def test_batch_differing_spaces_deterministic(space: Space, n: int, base_seed: int):
|
||||
"""Tests the batched spaces are deterministic by using a copied version."""
|
||||
# Copy the spaces and check that the np_random are not reference equal
|
||||
space_a = space
|
||||
space_a.seed(base_seed)
|
||||
space_b = copy.deepcopy(space_a)
|
||||
assert is_rng_equal(space_a.np_random, space_b.np_random)
|
||||
assert space_a.np_random is not space_b.np_random
|
||||
|
||||
# Batch the spaces and check that the np_random are not reference equal
|
||||
space_a_batched = batch_differing_spaces([space_a for _ in range(n)])
|
||||
space_b_batched = batch_differing_spaces([space_b for _ in range(n)])
|
||||
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||
assert space_a_batched.np_random is not space_b_batched.np_random
|
||||
# Create that the batched space is not reference equal to the origin spaces
|
||||
assert space_a.np_random is not space_a_batched.np_random
|
||||
|
||||
# Check that batched space a and b random number generator are not effected by the original space
|
||||
space_a.sample()
|
||||
space_a_batched_sample = space_a_batched.sample()
|
||||
space_b_batched_sample = space_b_batched.sample()
|
||||
for a_sample, b_sample in zip(
|
||||
iterate(space_a_batched, space_a_batched_sample),
|
||||
iterate(space_b_batched, space_b_batched_sample),
|
||||
):
|
||||
assert data_equivalence(a_sample, b_sample)
|
||||
|
Reference in New Issue
Block a user