Batch action_space in VectorEnv (#2280)

* Batch the action space in VectorEnv and add iterate utility function

* Add tests for iterate

* Add tests for action spaces in SyncVectorEnv and AsyncVectorEnv

* Black formatting

* Use singledispatch for iterate utility function

* Update the ordering of the arguments in the docstring

* Fix ordering in docstring example of iterate

* Check for same action spaces in vectorized environments

* Separate Discrete from other space types in iterate singledispatch
This commit is contained in:
Tristan Deleu
2021-12-08 21:31:41 -05:00
committed by GitHub
parent cdb72ea552
commit fbe3631aa9
9 changed files with 202 additions and 38 deletions

View File

@@ -2,7 +2,7 @@ import pytest
import numpy as np
from multiprocessing import TimeoutError
from gym.spaces import Box, Tuple
from gym.spaces import Box, Tuple, Discrete, MultiDiscrete
from gym.error import AlreadyPendingCallError, NoAsyncCallError, ClosedEnvironmentError
from tests.vector.utils import (
CustomSpace,
@@ -48,6 +48,10 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
assert isinstance(env.single_action_space, Discrete)
assert isinstance(env.action_space, MultiDiscrete)
if use_single_action_space:
actions = [env.single_action_space.sample() for _ in range(8)]
else:
@@ -189,10 +193,10 @@ def test_already_closed_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_check_observations_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
def test_check_spaces_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
env_fns[1] = make_env("MemorizeDigits-v0", 1)
with pytest.raises(RuntimeError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -204,6 +208,10 @@ def test_custom_space_async_vector_env():
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally: