Files
Gymnasium/tests/vector/test_async_vector_env.py

299 lines
10 KiB
Python
Raw Normal View History

from multiprocessing import TimeoutError
import numpy as np
import pytest
from gym.error import AlreadyPendingCallError, ClosedEnvironmentError, NoAsyncCallError
from gym.spaces import Box, Discrete, MultiDiscrete, Tuple
from gym.vector.async_vector_env import AsyncVectorEnv
from tests.vector.utils import (
2021-07-29 02:26:34 +02:00
CustomSpace,
make_custom_space_env,
2021-07-29 02:26:34 +02:00
make_env,
make_slow_env,
)
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_create_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
finally:
env.close()
assert env.num_envs == 8
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations, infos = env.reset(return_info=True)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, list)
assert all([isinstance(info, dict) for info in infos])
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_async_vector_env(shared_memory, use_single_action_space):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
assert isinstance(env.single_action_space, Discrete)
assert isinstance(env.action_space, MultiDiscrete)
if use_single_action_space:
actions = [env.single_action_space.sample() for _ in range(8)]
else:
actions = env.action_space.sample()
observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(rewards, np.ndarray)
assert isinstance(rewards[0], (float, np.floating))
assert rewards.ndim == 1
assert rewards.size == 8
assert isinstance(dones, np.ndarray)
assert dones.dtype == np.bool_
assert dones.ndim == 1
assert dones.size == 8
@pytest.mark.parametrize("shared_memory", [True, False])
def test_call_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
_ = env.reset()
images = env.call("render", mode="rgb_array")
gravity = env.call("gravity")
finally:
env.close()
assert isinstance(images, tuple)
assert len(images) == 4
for i in range(4):
assert isinstance(images[i], np.ndarray)
assert isinstance(gravity, tuple)
assert len(gravity) == 4
for i in range(4):
assert isinstance(gravity[i], float)
assert gravity[i] == 9.8
@pytest.mark.parametrize("shared_memory", [True, False])
def test_set_attr_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_copy_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
2021-07-29 02:26:34 +02:00
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
observations = env.reset()
2022-01-10 23:42:26 -05:00
observations[0] = 0
finally:
env.close()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_no_copy_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
2021-07-29 02:26:34 +02:00
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
observations = env.reset()
2022-01-10 23:42:26 -05:00
observations[0] = 0
finally:
env.close()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_timeout_async_vector_env(shared_memory):
env_fns = [make_slow_env(0.3, i) for i in range(4)]
with pytest.raises(TimeoutError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset_async()
env.reset_wait(timeout=0.1)
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_timeout_async_vector_env(shared_memory):
2021-07-29 02:26:34 +02:00
env_fns = [make_slow_env(0.0, i) for i in range(4)]
with pytest.raises(TimeoutError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset()
env.step_async([0.1, 0.1, 0.3, 0.1])
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_out_of_order_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
with pytest.raises(NoAsyncCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset_wait()
except NoAsyncCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "reset"
raise
finally:
env.close(terminate=True)
with pytest.raises(AlreadyPendingCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
env.reset()
env.step_async(actions)
env.reset_async()
except NoAsyncCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "step"
raise
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_out_of_order_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
with pytest.raises(NoAsyncCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
observations = env.reset()
observations, rewards, dones, infos = env.step_wait()
except AlreadyPendingCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "step"
raise
finally:
env.close(terminate=True)
with pytest.raises(AlreadyPendingCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
env.reset_async()
env.step_async(actions)
except AlreadyPendingCallError as exception:
2021-07-29 02:26:34 +02:00
assert exception.name == "reset"
raise
finally:
env.close(terminate=True)
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_already_closed_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
with pytest.raises(ClosedEnvironmentError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.close()
env.reset()
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shared_memory", [True, False])
def test_check_spaces_async_vector_env(shared_memory):
2022-01-10 23:42:26 -05:00
# CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2)
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
env_fns[1] = make_env("FrozenLake-v1", 1)
with pytest.raises(RuntimeError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.close(terminate=True)
def test_custom_space_async_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)
2021-07-29 02:26:34 +02:00
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple)
assert isinstance(reset_observations, tuple)
2021-07-29 02:26:34 +02:00
assert reset_observations == ("reset", "reset", "reset", "reset")
assert isinstance(step_observations, tuple)
2021-07-29 02:26:34 +02:00
assert step_observations == (
"step(action-2)",
"step(action-3)",
"step(action-5)",
"step(action-7)",
)
def test_custom_space_async_vector_env_shared_memory():
env_fns = [make_custom_space_env(i) for i in range(4)]
with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True)
env.close(terminate=True)