Files
Gymnasium/tests/vector/test_async_vector_env.py

299 lines
10 KiB
Python
Raw Normal View History

from multiprocessing import TimeoutError
import numpy as np
import pytest
from gym.error import AlreadyPendingCallError, ClosedEnvironmentError, NoAsyncCallError
from gym.spaces import Box, Discrete, MultiDiscrete, Tuple
from gym.vector.async_vector_env import AsyncVectorEnv
from tests.vector.utils import (
2021-07-29 02:26:34 +02:00
CustomSpace,
make_custom_space_env,
2021-07-29 02:26:34 +02:00
make_env,
make_slow_env,
)
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_create_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
finally:
env.close()
assert env.num_envs == 8
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations, infos = env.reset(return_info=True)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
New `info` API for vectorized environments #2657 (#2773) * WIP refactor info API sync vector. * Add missing untracked file. * Add info strategy to reset_wait. * Add interface and docstring. * info with strategy pattern on async vector env. * Add default to async vecenv. * episode statistics for asyncvecnev. * Add tests info strategy format. * Add info strategy to reset_wait. * refactor and cleanup. * Code cleanup. Add tests. * Add tests for video recording with new info format. * fix test case. * fix camelcase. * rename enum. * update tests, docstrings, cleanup. * Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests. * fix docstring and logging format. * Set Brax info format as default. Remove classic info format. Update tests. * breaking the wrong loop. * WIP: wrapper. * Add wrapper for brax to classic info. * WIP: wrapper with nested RecordEpisodeStatistic. * Add tests. Refactor docstrings. Cleanup. * cleanup. * patch conflicts. * rebase and conflicts. * new pre-commit conventions. * docstring. * renaming. * incorporate info_processor in vecEnv. * renaming. Create info dict only if needed. * remove all brax references. update docstring. Update duplicate test. * reviews. * pre-commit. * reviews. * docstring. * cleanup blank lines. * add support for numpy dtypes. * docstring fix. * formatting. * naming. * assert correct info from wrappers chaining. Test correct wrappers chaining. naming. * simplify episode_statistics. * change args orer. * update tests. * wip: refactor episode_statistics. * Add test for add_vecore_episode_statistics.
2022-05-24 16:36:35 +02:00
assert isinstance(infos, dict)
assert all([isinstance(info, dict) for info in infos])
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_async_vector_env(shared_memory, use_single_action_space):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
assert isinstance(env.single_action_space, Discrete)
assert isinstance(env.action_space, MultiDiscrete)
if use_single_action_space:
actions = [env.single_action_space.sample() for _ in range(8)]
else:
actions = env.action_space.sample()
observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(rewards, np.ndarray)
assert isinstance(rewards[0], (float, np.floating))
assert rewards.ndim == 1
assert rewards.size == 8
assert isinstance(dones, np.ndarray)
assert dones.dtype == np.bool_
assert dones.ndim == 1
assert dones.size == 8
@pytest.mark.parametrize("shared_memory", [True, False])
def test_call_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
_ = env.reset()
images = env.call("render", mode="rgb_array")
gravity = env.call("gravity")
finally:
env.close()
assert isinstance(images, tuple)
assert len(images) == 4
for i in range(4):
assert isinstance(images[i], np.ndarray)
assert isinstance(gravity, tuple)
assert len(gravity) == 4
for i in range(4):
assert isinstance(gravity[i], float)
assert gravity[i] == 9.8
@pytest.mark.parametrize("shared_memory", [True, False])
def test_set_attr_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_copy_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
2021-07-29 02:26:34 +02:00
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
observations = env.reset()
2022-01-10 23:42:26 -05:00
observations[0] = 0
finally:
env.close()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_no_copy_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
2021-07-29 02:26:34 +02:00
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
observations = env.reset()
2022-01-10 23:42:26 -05:00
observations[0] = 0
finally:
env.close()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_timeout_async_vector_env(shared_memory):
env_fns = [make_slow_env(0.3, i) for i in range(4)]
with pytest.raises(TimeoutError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset_async()
env.reset_wait(timeout=0.1)
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_timeout_async_vector_env(shared_memory):
2021-07-29 02:26:34 +02:00
env_fns = [make_slow_env(0.0, i) for i in range(4)]
with pytest.raises(TimeoutError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset()
env.step_async([0.1, 0.1, 0.3, 0.1])
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_out_of_order_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
with pytest.raises(NoAsyncCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset_wait()
except NoAsyncCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "reset"
raise
finally:
env.close(terminate=True)
with pytest.raises(AlreadyPendingCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
env.reset()
env.step_async(actions)
env.reset_async()
except NoAsyncCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "step"
raise
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_out_of_order_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
with pytest.raises(NoAsyncCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
observations = env.reset()
observations, rewards, dones, infos = env.step_wait()
except AlreadyPendingCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "step"
raise
finally:
env.close(terminate=True)
with pytest.raises(AlreadyPendingCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
env.reset_async()
env.step_async(actions)
except AlreadyPendingCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "reset"
raise
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_already_closed_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
with pytest.raises(ClosedEnvironmentError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.close()
env.reset()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_check_spaces_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
# CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2)
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
env_fns[1] = make_env("FrozenLake-v1", 1)
with pytest.raises(RuntimeError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.close(terminate=True)
def test_custom_space_async_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)
2021-07-29 02:26:34 +02:00
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple)
assert isinstance(reset_observations, tuple)
2021-07-29 02:26:34 +02:00
assert reset_observations == ("reset", "reset", "reset", "reset")
assert isinstance(step_observations, tuple)
2021-07-29 02:26:34 +02:00
assert step_observations == (
"step(action-2)",
"step(action-3)",
"step(action-5)",
"step(action-7)",
)
def test_custom_space_async_vector_env_shared_memory():
env_fns = [make_custom_space_env(i) for i in range(4)]
with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True)
env.close(terminate=True)