mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
* WIP refactor info API sync vector. * Add missing untracked file. * Add info strategy to reset_wait. * Add interface and docstring. * info with strategy pattern on async vector env. * Add default to async vecenv. * episode statistics for asyncvecnev. * Add tests info strategy format. * Add info strategy to reset_wait. * refactor and cleanup. * Code cleanup. Add tests. * Add tests for video recording with new info format. * fix test case. * fix camelcase. * rename enum. * update tests, docstrings, cleanup. * Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests. * fix docstring and logging format. * Set Brax info format as default. Remove classic info format. Update tests. * breaking the wrong loop. * WIP: wrapper. * Add wrapper for brax to classic info. * WIP: wrapper with nested RecordEpisodeStatistic. * Add tests. Refactor docstrings. Cleanup. * cleanup. * patch conflicts. * rebase and conflicts. * new pre-commit conventions. * docstring. * renaming. * incorporate info_processor in vecEnv. * renaming. Create info dict only if needed. * remove all brax references. update docstring. Update duplicate test. * reviews. * pre-commit. * reviews. * docstring. * cleanup blank lines. * add support for numpy dtypes. * docstring fix. * formatting. * naming. * assert correct info from wrappers chaining. Test correct wrappers chaining. naming. * simplify episode_statistics. * change args orer. * update tests. * wip: refactor episode_statistics. * Add test for add_vecore_episode_statistics.
61 lines
2.3 KiB
Python
61 lines
2.3 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
from gym.spaces import Tuple
|
|
from gym.vector.async_vector_env import AsyncVectorEnv
|
|
from gym.vector.sync_vector_env import SyncVectorEnv
|
|
from gym.vector.vector_env import VectorEnv
|
|
from tests.vector.utils import CustomSpace, make_env
|
|
|
|
|
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
|
def test_vector_env_equal(shared_memory):
|
|
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
|
num_steps = 100
|
|
try:
|
|
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
|
sync_env = SyncVectorEnv(env_fns)
|
|
|
|
assert async_env.num_envs == sync_env.num_envs
|
|
assert async_env.observation_space == sync_env.observation_space
|
|
assert async_env.single_observation_space == sync_env.single_observation_space
|
|
assert async_env.action_space == sync_env.action_space
|
|
assert async_env.single_action_space == sync_env.single_action_space
|
|
|
|
async_observations = async_env.reset(seed=0)
|
|
sync_observations = sync_env.reset(seed=0)
|
|
assert np.all(async_observations == sync_observations)
|
|
|
|
for _ in range(num_steps):
|
|
actions = async_env.action_space.sample()
|
|
assert actions in sync_env.action_space
|
|
|
|
# fmt: off
|
|
async_observations, async_rewards, async_dones, async_infos = async_env.step(actions)
|
|
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
|
|
# fmt: on
|
|
|
|
if any(sync_dones):
|
|
assert "terminal_observation" in async_infos
|
|
assert "_terminal_observation" in async_infos
|
|
assert "terminal_observation" in sync_infos
|
|
assert "_terminal_observation" in sync_infos
|
|
|
|
assert np.all(async_observations == sync_observations)
|
|
assert np.all(async_rewards == sync_rewards)
|
|
assert np.all(async_dones == sync_dones)
|
|
|
|
finally:
|
|
async_env.close()
|
|
sync_env.close()
|
|
|
|
|
|
def test_custom_space_vector_env():
|
|
env = VectorEnv(4, CustomSpace(), CustomSpace())
|
|
|
|
assert isinstance(env.single_observation_space, CustomSpace)
|
|
assert isinstance(env.observation_space, Tuple)
|
|
|
|
assert isinstance(env.single_action_space, CustomSpace)
|
|
assert isinstance(env.action_space, Tuple)
|