mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-15 19:31:27 +00:00
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
![]() |
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
import gym
|
||
|
from gym.vector.sync_vector_env import SyncVectorEnv
|
||
|
from tests.vector.utils import make_env
|
||
|
|
||
|
ENV_ID = "CartPole-v1"
|
||
|
NUM_ENVS = 3
|
||
|
ENV_STEPS = 50
|
||
|
SEED = 42
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("asynchronous", [True, False])
|
||
|
def test_vector_env_info(asynchronous):
|
||
|
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, asynchronous=asynchronous)
|
||
|
env.reset(seed=SEED)
|
||
|
for _ in range(ENV_STEPS):
|
||
|
env.action_space.seed(SEED)
|
||
|
action = env.action_space.sample()
|
||
|
_, _, dones, infos = env.step(action)
|
||
|
if any(dones):
|
||
|
assert len(infos["terminal_observation"]) == NUM_ENVS
|
||
|
assert len(infos["_terminal_observation"]) == NUM_ENVS
|
||
|
|
||
|
assert isinstance(infos["terminal_observation"], np.ndarray)
|
||
|
assert isinstance(infos["_terminal_observation"], np.ndarray)
|
||
|
|
||
|
for i, done in enumerate(dones):
|
||
|
if done:
|
||
|
assert infos["_terminal_observation"][i]
|
||
|
else:
|
||
|
assert not infos["_terminal_observation"][i]
|
||
|
assert infos["terminal_observation"][i] is None
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
|
||
|
def test_vector_env_info_concurrent_termination(concurrent_ends):
|
||
|
# envs that need to terminate together will have the same action
|
||
|
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
|
||
|
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
|
||
|
envs = SyncVectorEnv(envs)
|
||
|
|
||
|
for _ in range(ENV_STEPS):
|
||
|
_, _, dones, infos = envs.step(actions)
|
||
|
if any(dones):
|
||
|
for i, done in enumerate(dones):
|
||
|
if i < concurrent_ends:
|
||
|
assert done
|
||
|
assert infos["_terminal_observation"][i]
|
||
|
else:
|
||
|
assert not infos["_terminal_observation"][i]
|
||
|
assert infos["terminal_observation"][i] is None
|
||
|
return
|