mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Merge v1.0.0 (#682)
Co-authored-by: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com> Co-authored-by: Jet <38184875+jjshoots@users.noreply.github.com> Co-authored-by: Omar Younis <42100908+younik@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""Test the `SyncVectorEnv` implementation."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
|
||||
from gymnasium.vector.sync_vector_env import SyncVectorEnv
|
||||
from gymnasium.vector import SyncVectorEnv
|
||||
from tests.envs.utils import all_testing_env_specs
|
||||
from tests.vector.utils import (
|
||||
from tests.vector.testing_utils import (
|
||||
CustomSpace,
|
||||
assert_rng_equal,
|
||||
make_custom_space_env,
|
||||
@@ -14,6 +16,7 @@ from tests.vector.utils import (
|
||||
|
||||
|
||||
def test_create_sync_vector_env():
|
||||
"""Tests creating the sync vector environment."""
|
||||
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
||||
env = SyncVectorEnv(env_fns)
|
||||
env.close()
|
||||
@@ -22,6 +25,7 @@ def test_create_sync_vector_env():
|
||||
|
||||
|
||||
def test_reset_sync_vector_env():
|
||||
"""Tests sync vector `reset` function."""
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations, infos = env.reset()
|
||||
@@ -29,7 +33,6 @@ def test_reset_sync_vector_env():
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert isinstance(infos, dict)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
@@ -39,10 +42,9 @@ def test_reset_sync_vector_env():
|
||||
|
||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||
def test_step_sync_vector_env(use_single_action_space):
|
||||
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset()
|
||||
"""Test sync vector `steps` function."""
|
||||
env = SyncVectorEnv([make_env("FrozenLake-v1", i) for i in range(8)])
|
||||
env.reset()
|
||||
|
||||
assert isinstance(env.single_action_space, Discrete)
|
||||
assert isinstance(env.action_space, MultiDiscrete)
|
||||
@@ -51,7 +53,7 @@ def test_step_sync_vector_env(use_single_action_space):
|
||||
actions = [env.single_action_space.sample() for _ in range(8)]
|
||||
else:
|
||||
actions = env.action_space.sample()
|
||||
observations, rewards, terminateds, truncateds, _ = env.step(actions)
|
||||
observations, rewards, terminations, truncations, _ = env.step(actions)
|
||||
|
||||
env.close()
|
||||
|
||||
@@ -66,18 +68,35 @@ def test_step_sync_vector_env(use_single_action_space):
|
||||
assert rewards.ndim == 1
|
||||
assert rewards.size == 8
|
||||
|
||||
assert isinstance(terminateds, np.ndarray)
|
||||
assert terminateds.dtype == np.bool_
|
||||
assert terminateds.ndim == 1
|
||||
assert terminateds.size == 8
|
||||
assert isinstance(terminations, np.ndarray)
|
||||
assert terminations.dtype == np.bool_
|
||||
assert terminations.ndim == 1
|
||||
assert terminations.size == 8
|
||||
|
||||
assert isinstance(truncateds, np.ndarray)
|
||||
assert truncateds.dtype == np.bool_
|
||||
assert truncateds.ndim == 1
|
||||
assert truncateds.size == 8
|
||||
assert isinstance(truncations, np.ndarray)
|
||||
assert truncations.dtype == np.bool_
|
||||
assert truncations.ndim == 1
|
||||
assert truncations.size == 8
|
||||
|
||||
|
||||
def test_render_sync_vector():
|
||||
envs = SyncVectorEnv(
|
||||
[make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(3)]
|
||||
)
|
||||
assert envs.render_mode == "rgb_array"
|
||||
|
||||
envs.reset()
|
||||
rendered_frames = envs.render()
|
||||
assert isinstance(rendered_frames, tuple)
|
||||
assert len(rendered_frames) == envs.num_envs
|
||||
assert all(isinstance(frame, np.ndarray) for frame in rendered_frames)
|
||||
|
||||
envs = SyncVectorEnv([make_env("CartPole-v1", i) for i in range(3)])
|
||||
assert envs.render_mode is None
|
||||
|
||||
|
||||
def test_call_sync_vector_env():
|
||||
"""Test sync vector `call` on sub-environments."""
|
||||
env_fns = [
|
||||
make_env("CartPole-v1", i, render_mode="rgb_array_list") for i in range(4)
|
||||
]
|
||||
@@ -103,6 +122,7 @@ def test_call_sync_vector_env():
|
||||
|
||||
|
||||
def test_set_attr_sync_vector_env():
|
||||
"""Test sync vector `set_attr` function."""
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
@@ -114,6 +134,7 @@ def test_set_attr_sync_vector_env():
|
||||
|
||||
|
||||
def test_check_spaces_sync_vector_env():
|
||||
"""Tests the sync vector `check_spaces` function."""
|
||||
# CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2)
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
|
||||
@@ -124,6 +145,7 @@ def test_check_spaces_sync_vector_env():
|
||||
|
||||
|
||||
def test_custom_space_sync_vector_env():
|
||||
"""Test the use of custom spaces with sync vector environment."""
|
||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
@@ -131,18 +153,15 @@ def test_custom_space_sync_vector_env():
|
||||
|
||||
assert isinstance(env.single_action_space, CustomSpace)
|
||||
assert isinstance(env.action_space, Tuple)
|
||||
assert isinstance(infos, dict)
|
||||
|
||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||
step_observations, rewards, terminateds, truncateds, infos = env.step(actions)
|
||||
step_observations, _, _, _, _ = env.step(actions)
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.single_observation_space, CustomSpace)
|
||||
assert isinstance(env.observation_space, Tuple)
|
||||
|
||||
assert isinstance(infos, dict)
|
||||
|
||||
assert isinstance(reset_observations, tuple)
|
||||
assert reset_observations == ("reset", "reset", "reset", "reset")
|
||||
|
||||
@@ -156,6 +175,7 @@ def test_custom_space_sync_vector_env():
|
||||
|
||||
|
||||
def test_sync_vector_env_seed():
|
||||
"""Test seeding for sync vector environments."""
|
||||
env = make_env("BipedalWalker-v3", seed=123)()
|
||||
sync_vector_env = SyncVectorEnv([make_env("BipedalWalker-v3", seed=123)])
|
||||
|
||||
@@ -165,12 +185,14 @@ def test_sync_vector_env_seed():
|
||||
vector_action = sync_vector_env.action_space.sample()
|
||||
assert np.all(env_action == vector_action)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
||||
)
|
||||
def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3):
|
||||
"""Check that for all environments, the sync vector envs produce the same action samples using the same seeds"""
|
||||
"""Check that for all environments, the sync vector envs produce the same action samples using the same seeds."""
|
||||
env_1 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)])
|
||||
env_2 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)])
|
||||
assert_rng_equal(env_1.action_space.np_random, env_2.action_space.np_random)
|
||||
@@ -179,3 +201,6 @@ def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3):
|
||||
env_1_samples = env_1.action_space.sample()
|
||||
env_2_samples = env_2.action_space.sample()
|
||||
assert np.all(env_1_samples == env_2_samples)
|
||||
|
||||
env_1.close()
|
||||
env_2.close()
|
||||
|
Reference in New Issue
Block a user