mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 01:27:29 +00:00
Batch action_space in VectorEnv (#2280)
* Batch the action space in VectorEnv and add iterate utility function * Add tests for iterate * Add tests for action spaces in SyncVectorEnv and AsyncVectorEnv * Black formatting * Use singledispatch for iterate utility function * Update the ordering of the arguments in the docstring * Fix ordering in docstring example of iterate * Check for same action spaces in vectorized environments * Separate Discrete from other space types in iterate singledispatch
This commit is contained in:
@@ -22,6 +22,7 @@ from gym.vector.utils import (
|
||||
write_to_shared_memory,
|
||||
read_from_shared_memory,
|
||||
concatenate,
|
||||
iterate,
|
||||
CloudpickleWrapper,
|
||||
clear_mpi_env_vars,
|
||||
)
|
||||
@@ -188,7 +189,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
child_pipe.close()
|
||||
|
||||
self._state = AsyncState.DEFAULT
|
||||
self._check_observation_spaces()
|
||||
self._check_spaces()
|
||||
|
||||
def seed(self, seed=None):
|
||||
super().seed(seed=seed)
|
||||
@@ -318,6 +319,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
actions = iterate(self.action_space, actions)
|
||||
for pipe, action in zip(self.parent_pipes, actions):
|
||||
pipe.send(("step", action))
|
||||
self._state = AsyncState.WAITING_STEP
|
||||
@@ -450,18 +452,25 @@ class AsyncVectorEnv(VectorEnv):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_observation_spaces(self):
|
||||
def _check_spaces(self):
|
||||
self._assert_is_running()
|
||||
spaces = (self.single_observation_space, self.single_action_space)
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("_check_observation_space", self.single_observation_space))
|
||||
same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
pipe.send(("_check_spaces", spaces))
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
if not all(same_spaces):
|
||||
same_observation_spaces, same_action_spaces = zip(*results)
|
||||
if not all(same_observation_spaces):
|
||||
raise RuntimeError(
|
||||
"Some environments have an observation space "
|
||||
"different from `{}`. In order to batch observations, the "
|
||||
"observation spaces from all environments must be "
|
||||
"equal.".format(self.single_observation_space)
|
||||
"Some environments have an observation space different from "
|
||||
f"`{self.single_observation_space}`. In order to batch observations, "
|
||||
"the observation spaces from all environments must be equal."
|
||||
)
|
||||
if not all(same_action_spaces):
|
||||
raise RuntimeError(
|
||||
"Some environments have an action space different from "
|
||||
f"`{self.single_action_space}`. In order to batch actions, the "
|
||||
"action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
def _assert_is_running(self):
|
||||
@@ -515,13 +524,18 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
elif command == "close":
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == "_check_observation_space":
|
||||
pipe.send((data == env.observation_space, True))
|
||||
elif command == "_check_spaces":
|
||||
pipe.send(
|
||||
(
|
||||
(data[0] == env.observation_space, data[1] == env.action_space),
|
||||
True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Received unknown command `{0}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, "
|
||||
"`_check_observation_space`}.".format(command)
|
||||
"`_check_spaces`}.".format(command)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
@@ -559,13 +573,15 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
elif command == "close":
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == "_check_observation_space":
|
||||
pipe.send((data == observation_space, True))
|
||||
elif command == "_check_spaces":
|
||||
pipe.send(
|
||||
((data[0] == observation_space, data[1] == env.action_space), True)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Received unknown command `{0}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, "
|
||||
"`_check_observation_space`}.".format(command)
|
||||
"`_check_spaces`}.".format(command)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
|
@@ -6,7 +6,7 @@ from copy import deepcopy
|
||||
from gym import logger
|
||||
from gym.logger import warn
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
from gym.vector.utils import concatenate, create_empty_array
|
||||
from gym.vector.utils import concatenate, iterate, create_empty_array
|
||||
|
||||
__all__ = ["SyncVectorEnv"]
|
||||
|
||||
@@ -67,7 +67,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
action_space=action_space,
|
||||
)
|
||||
|
||||
self._check_observation_spaces()
|
||||
self._check_spaces()
|
||||
self.observations = create_empty_array(
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
||||
)
|
||||
@@ -105,7 +105,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
|
||||
def step_async(self, actions):
|
||||
self._actions = actions
|
||||
self._actions = iterate(self.action_space, actions)
|
||||
|
||||
def step_wait(self):
|
||||
observations, infos = [], []
|
||||
@@ -131,15 +131,21 @@ class SyncVectorEnv(VectorEnv):
|
||||
"""Close the environments."""
|
||||
[env.close() for env in self.envs]
|
||||
|
||||
def _check_observation_spaces(self):
|
||||
def _check_spaces(self):
|
||||
for env in self.envs:
|
||||
if not (env.observation_space == self.single_observation_space):
|
||||
break
|
||||
raise RuntimeError(
|
||||
"Some environments have an observation space different from "
|
||||
f"`{self.single_observation_space}`. In order to batch observations, "
|
||||
"the observation spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
if not (env.action_space == self.single_action_space):
|
||||
raise RuntimeError(
|
||||
"Some environments have an action space different from "
|
||||
f"`{self.single_action_space}`. In order to batch actions, the "
|
||||
"action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
else:
|
||||
return True
|
||||
raise RuntimeError(
|
||||
"Some environments have an observation space "
|
||||
"different from `{}`. In order to batch observations, the "
|
||||
"observation spaces from all environments must be "
|
||||
"equal.".format(self.single_observation_space)
|
||||
)
|
||||
|
@@ -5,7 +5,7 @@ from gym.vector.utils.shared_memory import (
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space, iterate
|
||||
|
||||
__all__ = [
|
||||
"CloudpickleWrapper",
|
||||
@@ -17,4 +17,5 @@ __all__ = [
|
||||
"write_to_shared_memory",
|
||||
"_BaseGymSpaces",
|
||||
"batch_space",
|
||||
"iterate",
|
||||
]
|
||||
|
@@ -1,10 +1,12 @@
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from functools import singledispatch
|
||||
|
||||
from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
||||
from gym.error import CustomSpaceError
|
||||
|
||||
_BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
|
||||
__all__ = ["_BaseGymSpaces", "batch_space"]
|
||||
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
|
||||
|
||||
|
||||
def batch_space(space, n=1):
|
||||
@@ -86,3 +88,94 @@ def batch_space_dict(space, n=1):
|
||||
|
||||
def batch_space_custom(space, n=1):
|
||||
return Tuple(tuple(space for _ in range(n)))
|
||||
|
||||
|
||||
@singledispatch
|
||||
def iterate(space, items):
|
||||
"""Iterate over the elements of a (batched) space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Space to which `items` belong to.
|
||||
|
||||
items : samples of `space`
|
||||
Items to be iterated over.
|
||||
|
||||
Returns
|
||||
-------
|
||||
iterator : `Iterable` instance
|
||||
Iterator over the elements in `items`.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
|
||||
>>> items = space.sample()
|
||||
>>> it = iterate(space, items)
|
||||
>>> next(it)
|
||||
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
|
||||
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
|
||||
>>> next(it)
|
||||
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
|
||||
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
|
||||
>>> next(it)
|
||||
StopIteration
|
||||
"""
|
||||
raise ValueError(
|
||||
"Space of type `{0}` is not a valid `gym.Space` "
|
||||
"instance.".format(type(space))
|
||||
)
|
||||
|
||||
|
||||
@iterate.register(Discrete)
|
||||
def iterate_discrete(space, items):
|
||||
raise TypeError("Unable to iterate over a space of type `Discrete`.")
|
||||
|
||||
|
||||
@iterate.register(Box)
|
||||
@iterate.register(MultiDiscrete)
|
||||
@iterate.register(MultiBinary)
|
||||
def iterate_base(space, items):
|
||||
try:
|
||||
return iter(items)
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to iterate over the following elements: {items}")
|
||||
|
||||
|
||||
@iterate.register(Tuple)
|
||||
def iterate_tuple(space, items):
|
||||
# If this is a tuple of custom subspaces only, then simply iterate over items
|
||||
if all(
|
||||
isinstance(subspace, Space)
|
||||
and (not isinstance(subspace, _BaseGymSpaces + (Tuple, Dict)))
|
||||
for subspace in space.spaces
|
||||
):
|
||||
return iter(items)
|
||||
|
||||
return zip(
|
||||
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
|
||||
)
|
||||
|
||||
|
||||
@iterate.register(Dict)
|
||||
def iterate_dict(space, items):
|
||||
keys, values = zip(
|
||||
*[
|
||||
(key, iterate(subspace, items[key]))
|
||||
for key, subspace in space.spaces.items()
|
||||
]
|
||||
)
|
||||
for item in zip(*values):
|
||||
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
|
||||
|
||||
|
||||
@iterate.register(Space)
|
||||
def iterate_custom(space, items):
|
||||
raise CustomSpaceError(
|
||||
f"Unable to iterate over {items}, since {space} "
|
||||
"is a custom `gym.Space` instance (i.e. not one of "
|
||||
"`Box`, `Dict`, etc...)."
|
||||
)
|
||||
|
@@ -36,7 +36,7 @@ class VectorEnv(gym.Env):
|
||||
self.num_envs = num_envs
|
||||
self.is_vector_env = True
|
||||
self.observation_space = batch_space(observation_space, n=num_envs)
|
||||
self.action_space = Tuple((action_space,) * num_envs)
|
||||
self.action_space = batch_space(action_space, n=num_envs)
|
||||
|
||||
self.closed = False
|
||||
self.viewer = None
|
||||
|
@@ -2,7 +2,7 @@ import pytest
|
||||
import numpy as np
|
||||
|
||||
from multiprocessing import TimeoutError
|
||||
from gym.spaces import Box, Tuple
|
||||
from gym.spaces import Box, Tuple, Discrete, MultiDiscrete
|
||||
from gym.error import AlreadyPendingCallError, NoAsyncCallError, ClosedEnvironmentError
|
||||
from tests.vector.utils import (
|
||||
CustomSpace,
|
||||
@@ -48,6 +48,10 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset()
|
||||
|
||||
assert isinstance(env.single_action_space, Discrete)
|
||||
assert isinstance(env.action_space, MultiDiscrete)
|
||||
|
||||
if use_single_action_space:
|
||||
actions = [env.single_action_space.sample() for _ in range(8)]
|
||||
else:
|
||||
@@ -189,10 +193,10 @@ def test_already_closed_async_vector_env(shared_memory):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_check_observations_async_vector_env(shared_memory):
|
||||
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
|
||||
def test_check_spaces_async_vector_env(shared_memory):
|
||||
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
|
||||
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
|
||||
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
|
||||
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
|
||||
env_fns[1] = make_env("MemorizeDigits-v0", 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
@@ -204,6 +208,10 @@ def test_custom_space_async_vector_env():
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=False)
|
||||
reset_observations = env.reset()
|
||||
|
||||
assert isinstance(env.single_action_space, CustomSpace)
|
||||
assert isinstance(env.action_space, Tuple)
|
||||
|
||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||
step_observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
|
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from gym.spaces import Box, MultiDiscrete, Tuple, Dict
|
||||
from tests.vector.utils import spaces, custom_spaces, CustomSpace
|
||||
|
||||
from gym.vector.utils.spaces import batch_space
|
||||
from gym.vector.utils.spaces import batch_space, iterate
|
||||
|
||||
expected_batch_spaces_4 = [
|
||||
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
|
||||
@@ -103,3 +103,29 @@ def test_batch_space(space, expected_batch_space_4):
|
||||
def test_batch_space_custom_space(space, expected_batch_space_4):
|
||||
batch_space_4 = batch_space(space, n=4)
|
||||
assert batch_space_4 == expected_batch_space_4
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space,batch_space",
|
||||
list(zip(spaces, expected_batch_spaces_4)),
|
||||
ids=[space.__class__.__name__ for space in spaces],
|
||||
)
|
||||
def test_iterate(space, batch_space):
|
||||
items = batch_space.sample()
|
||||
iterator = iterate(batch_space, items)
|
||||
for i, item in enumerate(iterator):
|
||||
assert item in space
|
||||
assert i == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space,batch_space",
|
||||
list(zip(custom_spaces, expected_custom_batch_spaces_4)),
|
||||
ids=[space.__class__.__name__ for space in custom_spaces],
|
||||
)
|
||||
def test_iterate_custom_space(space, batch_space):
|
||||
items = batch_space.sample()
|
||||
iterator = iterate(batch_space, items)
|
||||
for i, item in enumerate(iterator):
|
||||
assert item in space
|
||||
assert i == 3
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Box, Tuple
|
||||
from gym.spaces import Box, Tuple, Discrete, MultiDiscrete
|
||||
from tests.vector.utils import CustomSpace, make_env, make_custom_space_env
|
||||
|
||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||
@@ -38,6 +38,10 @@ def test_step_sync_vector_env(use_single_action_space):
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset()
|
||||
|
||||
assert isinstance(env.single_action_space, Discrete)
|
||||
assert isinstance(env.action_space, MultiDiscrete)
|
||||
|
||||
if use_single_action_space:
|
||||
actions = [env.single_action_space.sample() for _ in range(8)]
|
||||
else:
|
||||
@@ -63,10 +67,10 @@ def test_step_sync_vector_env(use_single_action_space):
|
||||
assert dones.size == 8
|
||||
|
||||
|
||||
def test_check_observations_sync_vector_env():
|
||||
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
|
||||
def test_check_spaces_sync_vector_env():
|
||||
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
|
||||
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
|
||||
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
|
||||
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
|
||||
env_fns[1] = make_env("MemorizeDigits-v0", 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
env = SyncVectorEnv(env_fns)
|
||||
@@ -78,6 +82,10 @@ def test_custom_space_sync_vector_env():
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
reset_observations = env.reset()
|
||||
|
||||
assert isinstance(env.single_action_space, CustomSpace)
|
||||
assert isinstance(env.action_space, Tuple)
|
||||
|
||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||
step_observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
|
@@ -73,6 +73,12 @@ class UnittestSlowEnv(gym.Env):
|
||||
class CustomSpace(gym.Space):
|
||||
"""Minimal custom observation space."""
|
||||
|
||||
def sample(self):
|
||||
return "sample"
|
||||
|
||||
def contains(self, x):
|
||||
return isinstance(x, str)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, CustomSpace)
|
||||
|
||||
|
Reference in New Issue
Block a user