Batch action_space in VectorEnv (#2280)

* Batch the action space in VectorEnv and add iterate utility function

* Add tests for iterate

* Add tests for action spaces in SyncVectorEnv and AsyncVectorEnv

* Black formatting

* Use singledispatch for iterate utility function

* Update the ordering of the arguments in the docstring

* Fix ordering in docstring example of iterate

* Check for same action spaces in vectorized environments

* Separate Discrete from other space types in iterate singledispatch
This commit is contained in:
Tristan Deleu
2021-12-08 21:31:41 -05:00
committed by GitHub
parent cdb72ea552
commit fbe3631aa9
9 changed files with 202 additions and 38 deletions

View File

@@ -22,6 +22,7 @@ from gym.vector.utils import (
write_to_shared_memory,
read_from_shared_memory,
concatenate,
iterate,
CloudpickleWrapper,
clear_mpi_env_vars,
)
@@ -188,7 +189,7 @@ class AsyncVectorEnv(VectorEnv):
child_pipe.close()
self._state = AsyncState.DEFAULT
self._check_observation_spaces()
self._check_spaces()
def seed(self, seed=None):
super().seed(seed=seed)
@@ -318,6 +319,7 @@ class AsyncVectorEnv(VectorEnv):
self._state.value,
)
actions = iterate(self.action_space, actions)
for pipe, action in zip(self.parent_pipes, actions):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
@@ -450,18 +452,25 @@ class AsyncVectorEnv(VectorEnv):
return False
return True
def _check_observation_spaces(self):
def _check_spaces(self):
self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space)
for pipe in self.parent_pipes:
pipe.send(("_check_observation_space", self.single_observation_space))
same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
pipe.send(("_check_spaces", spaces))
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
if not all(same_spaces):
same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces):
raise RuntimeError(
"Some environments have an observation space "
"different from `{}`. In order to batch observations, the "
"observation spaces from all environments must be "
"equal.".format(self.single_observation_space)
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not all(same_action_spaces):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)
def _assert_is_running(self):
@@ -515,13 +524,18 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
elif command == "close":
pipe.send((None, True))
break
elif command == "_check_observation_space":
pipe.send((data == env.observation_space, True))
elif command == "_check_spaces":
pipe.send(
(
(data[0] == env.observation_space, data[1] == env.action_space),
True,
)
)
else:
raise RuntimeError(
"Received unknown command `{0}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, "
"`_check_observation_space`}.".format(command)
"`_check_spaces`}.".format(command)
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
@@ -559,13 +573,15 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
elif command == "close":
pipe.send((None, True))
break
elif command == "_check_observation_space":
pipe.send((data == observation_space, True))
elif command == "_check_spaces":
pipe.send(
((data[0] == observation_space, data[1] == env.action_space), True)
)
else:
raise RuntimeError(
"Received unknown command `{0}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, "
"`_check_observation_space`}.".format(command)
"`_check_spaces`}.".format(command)
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])

View File

@@ -6,7 +6,7 @@ from copy import deepcopy
from gym import logger
from gym.logger import warn
from gym.vector.vector_env import VectorEnv
from gym.vector.utils import concatenate, create_empty_array
from gym.vector.utils import concatenate, iterate, create_empty_array
__all__ = ["SyncVectorEnv"]
@@ -67,7 +67,7 @@ class SyncVectorEnv(VectorEnv):
action_space=action_space,
)
self._check_observation_spaces()
self._check_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
@@ -105,7 +105,7 @@ class SyncVectorEnv(VectorEnv):
return deepcopy(self.observations) if self.copy else self.observations
def step_async(self, actions):
self._actions = actions
self._actions = iterate(self.action_space, actions)
def step_wait(self):
observations, infos = [], []
@@ -131,15 +131,21 @@ class SyncVectorEnv(VectorEnv):
"""Close the environments."""
[env.close() for env in self.envs]
def _check_observation_spaces(self):
def _check_spaces(self):
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
break
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not (env.action_space == self.single_action_space):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)
else:
return True
raise RuntimeError(
"Some environments have an observation space "
"different from `{}`. In order to batch observations, the "
"observation spaces from all environments must be "
"equal.".format(self.single_observation_space)
)

View File

@@ -5,7 +5,7 @@ from gym.vector.utils.shared_memory import (
read_from_shared_memory,
write_to_shared_memory,
)
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space, iterate
__all__ = [
"CloudpickleWrapper",
@@ -17,4 +17,5 @@ __all__ = [
"write_to_shared_memory",
"_BaseGymSpaces",
"batch_space",
"iterate",
]

View File

@@ -1,10 +1,12 @@
import numpy as np
from collections import OrderedDict
from functools import singledispatch
from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
from gym.error import CustomSpaceError
_BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
__all__ = ["_BaseGymSpaces", "batch_space"]
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
def batch_space(space, n=1):
@@ -86,3 +88,94 @@ def batch_space_dict(space, n=1):
def batch_space_custom(space, n=1):
return Tuple(tuple(space for _ in range(n)))
@singledispatch
def iterate(space, items):
"""Iterate over the elements of a (batched) space.
Parameters
----------
space : `gym.spaces.Space` instance
Space to which `items` belong to.
items : samples of `space`
Items to be iterated over.
Returns
-------
iterator : `Iterable` instance
Iterator over the elements in `items`.
Example
-------
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
>>> next(it)
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
>>> next(it)
StopIteration
"""
raise ValueError(
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)
@iterate.register(Discrete)
def iterate_discrete(space, items):
raise TypeError("Unable to iterate over a space of type `Discrete`.")
@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def iterate_base(space, items):
try:
return iter(items)
except TypeError:
raise TypeError(f"Unable to iterate over the following elements: {items}")
@iterate.register(Tuple)
def iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
and (not isinstance(subspace, _BaseGymSpaces + (Tuple, Dict)))
for subspace in space.spaces
):
return iter(items)
return zip(
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
)
@iterate.register(Dict)
def iterate_dict(space, items):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
for key, subspace in space.spaces.items()
]
)
for item in zip(*values):
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
@iterate.register(Space)
def iterate_custom(space, items):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custom `gym.Space` instance (i.e. not one of "
"`Box`, `Dict`, etc...)."
)

View File

@@ -36,7 +36,7 @@ class VectorEnv(gym.Env):
self.num_envs = num_envs
self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs)
self.action_space = Tuple((action_space,) * num_envs)
self.action_space = batch_space(action_space, n=num_envs)
self.closed = False
self.viewer = None

View File

@@ -2,7 +2,7 @@ import pytest
import numpy as np
from multiprocessing import TimeoutError
from gym.spaces import Box, Tuple
from gym.spaces import Box, Tuple, Discrete, MultiDiscrete
from gym.error import AlreadyPendingCallError, NoAsyncCallError, ClosedEnvironmentError
from tests.vector.utils import (
CustomSpace,
@@ -48,6 +48,10 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
assert isinstance(env.single_action_space, Discrete)
assert isinstance(env.action_space, MultiDiscrete)
if use_single_action_space:
actions = [env.single_action_space.sample() for _ in range(8)]
else:
@@ -189,10 +193,10 @@ def test_already_closed_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_check_observations_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
def test_check_spaces_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
env_fns[1] = make_env("MemorizeDigits-v0", 1)
with pytest.raises(RuntimeError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -204,6 +208,10 @@ def test_custom_space_async_vector_env():
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally:

View File

@@ -4,7 +4,7 @@ import numpy as np
from gym.spaces import Box, MultiDiscrete, Tuple, Dict
from tests.vector.utils import spaces, custom_spaces, CustomSpace
from gym.vector.utils.spaces import batch_space
from gym.vector.utils.spaces import batch_space, iterate
expected_batch_spaces_4 = [
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
@@ -103,3 +103,29 @@ def test_batch_space(space, expected_batch_space_4):
def test_batch_space_custom_space(space, expected_batch_space_4):
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4
@pytest.mark.parametrize(
"space,batch_space",
list(zip(spaces, expected_batch_spaces_4)),
ids=[space.__class__.__name__ for space in spaces],
)
def test_iterate(space, batch_space):
items = batch_space.sample()
iterator = iterate(batch_space, items)
for i, item in enumerate(iterator):
assert item in space
assert i == 3
@pytest.mark.parametrize(
"space,batch_space",
list(zip(custom_spaces, expected_custom_batch_spaces_4)),
ids=[space.__class__.__name__ for space in custom_spaces],
)
def test_iterate_custom_space(space, batch_space):
items = batch_space.sample()
iterator = iterate(batch_space, items)
for i, item in enumerate(iterator):
assert item in space
assert i == 3

View File

@@ -1,7 +1,7 @@
import pytest
import numpy as np
from gym.spaces import Box, Tuple
from gym.spaces import Box, Tuple, Discrete, MultiDiscrete
from tests.vector.utils import CustomSpace, make_env, make_custom_space_env
from gym.vector.sync_vector_env import SyncVectorEnv
@@ -38,6 +38,10 @@ def test_step_sync_vector_env(use_single_action_space):
try:
env = SyncVectorEnv(env_fns)
observations = env.reset()
assert isinstance(env.single_action_space, Discrete)
assert isinstance(env.action_space, MultiDiscrete)
if use_single_action_space:
actions = [env.single_action_space.sample() for _ in range(8)]
else:
@@ -63,10 +67,10 @@ def test_step_sync_vector_env(use_single_action_space):
assert dones.size == 8
def test_check_observations_sync_vector_env():
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
def test_check_spaces_sync_vector_env():
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
env_fns[1] = make_env("MemorizeDigits-v0", 1)
with pytest.raises(RuntimeError):
env = SyncVectorEnv(env_fns)
@@ -78,6 +82,10 @@ def test_custom_space_sync_vector_env():
try:
env = SyncVectorEnv(env_fns)
reset_observations = env.reset()
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally:

View File

@@ -73,6 +73,12 @@ class UnittestSlowEnv(gym.Env):
class CustomSpace(gym.Space):
"""Minimal custom observation space."""
def sample(self):
return "sample"
def contains(self, x):
return isinstance(x, str)
def __eq__(self, other):
return isinstance(other, CustomSpace)