mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 15:11:51 +00:00
Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)
Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com> Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
@@ -47,14 +47,14 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces)
|
||||
|
||||
def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
|
||||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]:
|
||||
"""Seed the PRNG of this space and all subspaces.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
|
||||
* ``None`` - All the subspaces will use a random initial seed
|
||||
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
||||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||
* ``List`` / ``Tuple`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
|
@@ -573,3 +573,102 @@ def _flatten_space_oneof(space: OneOf) -> Box:
|
||||
|
||||
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
|
||||
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool:
|
||||
"""Returns if two spaces share a common dtype and shape (plus any critical variables).
|
||||
|
||||
This function is primarily used to check for compatibility of different spaces in a vector environment.
|
||||
|
||||
Args:
|
||||
space_1: A Gymnasium space
|
||||
space_2: A Gymnasium space
|
||||
|
||||
Returns:
|
||||
If the two spaces share a common dtype and shape (plus any critical variables).
|
||||
"""
|
||||
if isinstance(space_1, Space) and isinstance(space_2, Space):
|
||||
raise NotImplementedError(
|
||||
"`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, "
|
||||
)
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Box)
|
||||
@is_space_dtype_shape_equiv.register(Discrete)
|
||||
@is_space_dtype_shape_equiv.register(MultiDiscrete)
|
||||
@is_space_dtype_shape_equiv.register(MultiBinary)
|
||||
def _is_space_fundamental_dtype_shape_equiv(space_1, space_2):
|
||||
return (
|
||||
# this check is necessary as singledispatch only checks the first variable and there are many options
|
||||
type(space_1) is type(space_2)
|
||||
and space_1.shape == space_2.shape
|
||||
and space_1.dtype == space_2.dtype
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Text)
|
||||
def _is_space_text_dtype_shape_equiv(space_1: Text, space_2):
|
||||
return (
|
||||
isinstance(space_2, Text)
|
||||
and space_1.max_length == space_2.max_length
|
||||
and space_1.character_set == space_2.character_set
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Dict)
|
||||
def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2):
|
||||
return (
|
||||
isinstance(space_2, Dict)
|
||||
and space_1.keys() == space_2.keys()
|
||||
and all(
|
||||
is_space_dtype_shape_equiv(space_1[key], space_2[key])
|
||||
for key in space_1.keys()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Tuple)
|
||||
def _is_space_tuple_dtype_shape_equiv(space_1, space_2):
|
||||
return isinstance(space_2, Tuple) and all(
|
||||
is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1))
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Graph)
|
||||
def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2):
|
||||
return (
|
||||
isinstance(space_2, Graph)
|
||||
and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space)
|
||||
and (
|
||||
(space_1.edge_space is None and space_2.edge_space is None)
|
||||
or (
|
||||
space_1.edge_space is not None
|
||||
and space_2.edge_space is not None
|
||||
and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(OneOf)
|
||||
def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2):
|
||||
return (
|
||||
isinstance(space_2, OneOf)
|
||||
and len(space_1) == len(space_2)
|
||||
and all(
|
||||
is_space_dtype_shape_equiv(space_1[i], space_2[i])
|
||||
for i in range(len(space_1))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Sequence)
|
||||
def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2):
|
||||
return (
|
||||
isinstance(space_2, Sequence)
|
||||
and space_1.stack is space_2.stack
|
||||
and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space)
|
||||
)
|
||||
|
@@ -14,7 +14,7 @@ from typing import Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import logger
|
||||
from gymnasium import Space, logger
|
||||
from gymnasium.core import ActType, Env, ObsType, RenderFrame
|
||||
from gymnasium.error import (
|
||||
AlreadyPendingCallError,
|
||||
@@ -22,8 +22,10 @@ from gymnasium.error import (
|
||||
CustomSpaceError,
|
||||
NoAsyncCallError,
|
||||
)
|
||||
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
|
||||
from gymnasium.vector.utils import (
|
||||
CloudpickleWrapper,
|
||||
batch_differing_spaces,
|
||||
batch_space,
|
||||
clear_mpi_env_vars,
|
||||
concatenate,
|
||||
@@ -98,6 +100,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
observation_mode: str | Space = "same",
|
||||
):
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
@@ -113,11 +116,15 @@ class AsyncVectorEnv(VectorEnv):
|
||||
so for some environments you may want to have it set to ``False``.
|
||||
worker: If set, then use that worker in a subprocess instead of a default one.
|
||||
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
||||
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
|
||||
'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
|
||||
warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
|
||||
``observation_space``, warning, may raise unexpected errors.
|
||||
|
||||
Warnings:
|
||||
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
||||
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
||||
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
||||
from the code for ``_worker`` (or ``_async_worker``) method, and add changes.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||
@@ -128,6 +135,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.env_fns = env_fns
|
||||
self.shared_memory = shared_memory
|
||||
self.copy = copy
|
||||
self.observation_mode = observation_mode
|
||||
|
||||
self.num_envs = len(env_fns)
|
||||
|
||||
@@ -139,13 +147,30 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.metadata = dummy_env.metadata
|
||||
self.render_mode = dummy_env.render_mode
|
||||
|
||||
self.single_observation_space = dummy_env.observation_space
|
||||
self.single_action_space = dummy_env.action_space
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
|
||||
assert isinstance(observation_mode[0], Space)
|
||||
assert isinstance(observation_mode[1], Space)
|
||||
self.observation_space, self.single_observation_space = observation_mode
|
||||
else:
|
||||
if observation_mode == "same":
|
||||
self.single_observation_space = dummy_env.observation_space
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
elif observation_mode == "different":
|
||||
# the environment is created and instantly destroy, might cause issues for some environment
|
||||
# but I don't believe there is anything else we can do, for users with issues, pre-compute the spaces and use the custom option.
|
||||
env_spaces = [env().observation_space for env in self.env_fns]
|
||||
|
||||
self.single_observation_space = env_spaces[0]
|
||||
self.observation_space = batch_differing_spaces(env_spaces)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
|
||||
)
|
||||
|
||||
dummy_env.close()
|
||||
del dummy_env
|
||||
@@ -162,9 +187,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
)
|
||||
except CustomSpaceError as e:
|
||||
raise ValueError(
|
||||
"Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), "
|
||||
"and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. "
|
||||
"Set `shared_memory=False` if you use custom observation spaces."
|
||||
"Using `AsyncVector(..., shared_memory=True)` caused an error, you can disable this feature with `shared_memory=False` however this is slower."
|
||||
) from e
|
||||
else:
|
||||
_obs_buffer = None
|
||||
@@ -591,20 +614,33 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
def _check_spaces(self):
|
||||
self._assert_is_running()
|
||||
spaces = (self.single_observation_space, self.single_action_space)
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("_check_spaces", spaces))
|
||||
pipe.send(
|
||||
(
|
||||
"_check_spaces",
|
||||
(
|
||||
self.observation_mode,
|
||||
self.single_observation_space,
|
||||
self.single_action_space,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
same_observation_spaces, same_action_spaces = zip(*results)
|
||||
|
||||
if not all(same_observation_spaces):
|
||||
if self.observation_mode == "same":
|
||||
raise RuntimeError(
|
||||
f"Some environments have an observation space different from `{self.single_observation_space}`. "
|
||||
"In order to batch observations, the observation spaces from all environments must be equal."
|
||||
"AsyncVectorEnv(..., observation_mode='same') however some of the sub-environments observation spaces are not equivalent. If this is intentional, use `observation_mode='different'` instead."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"AsyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environment's observation spaces do not share a common shape and dtype."
|
||||
)
|
||||
|
||||
if not all(same_action_spaces):
|
||||
raise RuntimeError(
|
||||
f"Some environments have an action space different from `{self.single_action_space}`. "
|
||||
@@ -714,9 +750,20 @@ def _async_worker(
|
||||
env.set_wrapper_attr(name, value)
|
||||
pipe.send((None, True))
|
||||
elif command == "_check_spaces":
|
||||
obs_mode, single_obs_space, single_action_space = data
|
||||
|
||||
pipe.send(
|
||||
(
|
||||
(data[0] == observation_space, data[1] == action_space),
|
||||
(
|
||||
(
|
||||
single_obs_space == observation_space
|
||||
if obs_mode == "same"
|
||||
else is_space_dtype_shape_equiv(
|
||||
single_obs_space, observation_space
|
||||
)
|
||||
),
|
||||
single_action_space == action_space,
|
||||
),
|
||||
True,
|
||||
)
|
||||
)
|
||||
|
@@ -7,9 +7,16 @@ from typing import Any, Callable, Iterator, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env
|
||||
from gymnasium import Env, Space
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
|
||||
from gymnasium.vector.utils import (
|
||||
batch_differing_spaces,
|
||||
batch_space,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
iterate,
|
||||
)
|
||||
from gymnasium.vector.vector_env import ArrayType, VectorEnv
|
||||
|
||||
|
||||
@@ -57,19 +64,23 @@ class SyncVectorEnv(VectorEnv):
|
||||
self,
|
||||
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
|
||||
copy: bool = True,
|
||||
observation_mode: str | Space = "same",
|
||||
):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Args:
|
||||
env_fns: iterable of callable functions that create the environments.
|
||||
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
||||
|
||||
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
|
||||
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
|
||||
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
|
||||
Raises:
|
||||
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||
(or, by default, the observation space of the first sub-environment).
|
||||
"""
|
||||
self.copy = copy
|
||||
self.env_fns = env_fns
|
||||
self.observation_mode = observation_mode
|
||||
|
||||
# Initialise all sub-environments
|
||||
self.envs = [env_fn() for env_fn in env_fns]
|
||||
@@ -80,16 +91,43 @@ class SyncVectorEnv(VectorEnv):
|
||||
self.metadata = self.envs[0].metadata
|
||||
self.render_mode = self.envs[0].render_mode
|
||||
|
||||
# Initialises the single spaces from the sub-environments
|
||||
self.single_observation_space = self.envs[0].observation_space
|
||||
self.single_action_space = self.envs[0].action_space
|
||||
self._check_spaces()
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
# Initialise the obs and action space based on the single versions and num of sub-environments
|
||||
if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
|
||||
assert isinstance(observation_mode[0], Space)
|
||||
assert isinstance(observation_mode[1], Space)
|
||||
self.observation_space, self.single_observation_space = observation_mode
|
||||
else:
|
||||
if observation_mode == "same":
|
||||
self.single_observation_space = self.envs[0].observation_space
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
elif observation_mode == "different":
|
||||
self.single_observation_space = self.envs[0].observation_space
|
||||
self.observation_space = batch_differing_spaces(
|
||||
[env.observation_space for env in self.envs]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
|
||||
)
|
||||
|
||||
# check sub-environment obs and action spaces
|
||||
for env in self.envs:
|
||||
if observation_mode == "same":
|
||||
assert (
|
||||
env.observation_space == self.single_observation_space
|
||||
), f"SyncVectorEnv(..., observation_mode='same') however the sub-environments observation spaces are not equivalent. single_observation_space={self.single_observation_space}, sub-environment observation_space={env.observation_space}. If this is intentional, use `observation_mode='different'` instead."
|
||||
else:
|
||||
assert is_space_dtype_shape_equiv(
|
||||
env.observation_space, self.single_observation_space
|
||||
), f"SyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environments observation spaces do not share a common shape and dtype, single_observation_space={self.single_observation_space}, sub-environment observation space={env.observation_space}"
|
||||
|
||||
assert (
|
||||
env.action_space == self.single_action_space
|
||||
), f"Sub-environment action space doesn't make the `single_action_space`, action_space={env.action_space}, single_action_space={self.single_action_space}"
|
||||
|
||||
# Initialise attributes used in `step` and `reset`
|
||||
self._observations = create_empty_array(
|
||||
@@ -265,20 +303,3 @@ class SyncVectorEnv(VectorEnv):
|
||||
"""Close the environments."""
|
||||
if hasattr(self, "envs"):
|
||||
[env.close() for env in self.envs]
|
||||
|
||||
def _check_spaces(self) -> bool:
|
||||
"""Check that each of the environments obs and action spaces are equivalent to the single obs and action space."""
|
||||
for env in self.envs:
|
||||
if not (env.observation_space == self.single_observation_space):
|
||||
raise RuntimeError(
|
||||
f"Some environments have an observation space different from `{self.single_observation_space}`. "
|
||||
"In order to batch observations, the observation spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
if not (env.action_space == self.single_action_space):
|
||||
raise RuntimeError(
|
||||
f"Some environments have an action space different from `{self.single_action_space}`. "
|
||||
"In order to batch actions, the action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
return True
|
||||
|
@@ -7,6 +7,7 @@ from gymnasium.vector.utils.shared_memory import (
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gymnasium.vector.utils.space_utils import (
|
||||
batch_differing_spaces,
|
||||
batch_space,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
@@ -16,6 +17,7 @@ from gymnasium.vector.utils.space_utils import (
|
||||
|
||||
__all__ = [
|
||||
"batch_space",
|
||||
"batch_differing_spaces",
|
||||
"iterate",
|
||||
"concatenate",
|
||||
"create_empty_array",
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Space-based utility functions for vector environments.
|
||||
|
||||
- ``batch_space``: Create a (batched) space, containing multiple copies of a single space.
|
||||
- ``batch_space``: Create a (batched) space containing multiple copies of a single space.
|
||||
- ``batch_differing_spaces``: Create a (batched) space containing copies of different compatible spaces (share a common dtype and shape)
|
||||
- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
|
||||
- ``Iterate``: Iterate over the elements of a (batched) space and items.
|
||||
- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
|
||||
@@ -32,7 +33,13 @@ from gymnasium.spaces import (
|
||||
from gymnasium.spaces.space import T_cov
|
||||
|
||||
|
||||
__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"]
|
||||
__all__ = [
|
||||
"batch_space",
|
||||
"batch_differing_spaces",
|
||||
"iterate",
|
||||
"concatenate",
|
||||
"create_empty_array",
|
||||
]
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -139,6 +146,116 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
|
||||
return batched_space
|
||||
|
||||
|
||||
@singledispatch
|
||||
def batch_differing_spaces(spaces: list[Space]):
|
||||
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
|
||||
assert len(spaces) > 0, "Expects a non-empty list of spaces"
|
||||
assert all(
|
||||
isinstance(space, type(spaces[0])) for space in spaces
|
||||
), f"Expects all spaces to be the same shape, actual types: {[type(space) for space in spaces]}"
|
||||
assert (
|
||||
type(spaces[0]) in batch_differing_spaces.registry
|
||||
), f"Requires the Space type to have a registered `batch_differing_space`, current list: {batch_differing_spaces.registry}"
|
||||
|
||||
return batch_differing_spaces.dispatch(type(spaces[0]))(spaces)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Box)
|
||||
def _batch_differing_spaces_box(spaces: list[Box]):
|
||||
assert all(
|
||||
spaces[0].dtype == space.dtype for space in spaces
|
||||
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].low.shape == space.low.shape for space in spaces
|
||||
), f"Expected all Box.low shape to be equal, actually {[space.low.shape for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].high.shape == space.high.shape for space in spaces
|
||||
), f"Expected all Box.high shape to be equal, actually {[space.high.shape for space in spaces]}"
|
||||
|
||||
return Box(
|
||||
low=np.array([space.low for space in spaces]),
|
||||
high=np.array([space.high for space in spaces]),
|
||||
dtype=spaces[0].dtype,
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Discrete)
|
||||
def _batch_differing_spaces_discrete(spaces: list[Discrete]):
|
||||
return MultiDiscrete(
|
||||
nvec=np.array([space.n for space in spaces]),
|
||||
start=np.array([space.start for space in spaces]),
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(MultiDiscrete)
|
||||
def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]):
|
||||
assert all(
|
||||
spaces[0].dtype == space.dtype for space in spaces
|
||||
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].nvec.shape == space.nvec.shape for space in spaces
|
||||
), f"Expects all MultiDiscrete.nvec shape, actually {[space.nvec.shape for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].start.shape == space.start.shape for space in spaces
|
||||
), f"Expects all MultiDiscrete.start shape, actually {[space.start.shape for space in spaces]}"
|
||||
|
||||
return Box(
|
||||
low=np.array([space.start for space in spaces]),
|
||||
high=np.array([space.start + space.nvec for space in spaces]) - 1,
|
||||
dtype=spaces[0].dtype,
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(MultiBinary)
|
||||
def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]):
|
||||
assert all(spaces[0].shape == space.shape for space in spaces)
|
||||
|
||||
return Box(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=(len(spaces),) + spaces[0].shape,
|
||||
dtype=spaces[0].dtype,
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Tuple)
|
||||
def _batch_differing_spaces_tuple(spaces: list[Tuple]):
|
||||
return Tuple(
|
||||
tuple(
|
||||
batch_differing_spaces(subspaces)
|
||||
for subspaces in zip(*[space.spaces for space in spaces])
|
||||
),
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Dict)
|
||||
def _batch_differing_spaces_dict(spaces: list[Dict]):
|
||||
assert all(spaces[0].keys() == space.keys() for space in spaces)
|
||||
|
||||
return Dict(
|
||||
{
|
||||
key: batch_differing_spaces([space[key] for space in spaces])
|
||||
for key in spaces[0].keys()
|
||||
},
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Graph)
|
||||
@batch_differing_spaces.register(Text)
|
||||
@batch_differing_spaces.register(Sequence)
|
||||
@batch_differing_spaces.register(OneOf)
|
||||
def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
|
||||
return Tuple(
|
||||
[deepcopy(space) for space in spaces], seed=deepcopy(spaces[0].np_random)
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
|
||||
"""Iterate over the elements of a (batched) space.
|
||||
|
@@ -6,7 +6,13 @@ import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces import Box, Graph, Sequence, utils
|
||||
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from gymnasium.vector.utils import (
|
||||
create_shared_memory,
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
||||
|
||||
|
||||
@@ -162,3 +168,40 @@ def test_unflatten_multidiscrete_error():
|
||||
value = np.array([0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_is_space_dtype_shape_equiv(space):
|
||||
assert is_space_dtype_shape_equiv(space, space) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space_1", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_all_space_pairs_for_is_space_dtype_shape_equiv(space_1):
|
||||
"""Practically check that the `is_space_dtype_shape_equiv` works as expected for `shared_memory`."""
|
||||
for space_2 in TESTING_SPACES:
|
||||
compatible = is_space_dtype_shape_equiv(space_1, space_2)
|
||||
|
||||
if compatible:
|
||||
try:
|
||||
shared_memory = create_shared_memory(space_1, n=2)
|
||||
except TypeError as err:
|
||||
assert (
|
||||
"has a dynamic shape so its not possible to make a static shared memory."
|
||||
in str(err)
|
||||
)
|
||||
pytest.skip("Skipping space with dynamic shape")
|
||||
|
||||
space_1.seed(123)
|
||||
space_2.seed(123)
|
||||
sample_1 = space_1.sample()
|
||||
sample_2 = space_2.sample()
|
||||
|
||||
write_to_shared_memory(space_1, 0, sample_1, shared_memory)
|
||||
write_to_shared_memory(space_2, 1, sample_2, shared_memory)
|
||||
|
||||
read_sample_1, read_sample_2 = read_from_shared_memory(
|
||||
space_1, shared_memory, n=2
|
||||
)
|
||||
|
||||
assert data_equivalence(sample_1, read_sample_1)
|
||||
assert data_equivalence(sample_2, read_sample_2)
|
||||
|
121
tests/vector/test_observation_mode.py
Normal file
121
tests/vector/test_observation_mode.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.spaces import Box, Dict, Discrete
|
||||
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
||||
from gymnasium.vector.utils import batch_differing_spaces
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
def create_env(obs_space):
|
||||
return lambda: GenericTestEnv(observation_space=obs_space)
|
||||
|
||||
|
||||
# Test cases for both SyncVectorEnv and AsyncVectorEnv
|
||||
@pytest.mark.parametrize(
|
||||
"vector_env_fn",
|
||||
[SyncVectorEnv, AsyncVectorEnv, partial(AsyncVectorEnv, shared_memory=False)],
|
||||
ids=[
|
||||
"SyncVectorEnv",
|
||||
"AsyncVectorEnv(shared_memory=True)",
|
||||
"AsyncVectorEnv(shared_memory=False)",
|
||||
],
|
||||
)
|
||||
class TestVectorEnvObservationModes:
|
||||
|
||||
def test_invalid_observation_mode(self, vector_env_fn):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got invalid"
|
||||
),
|
||||
):
|
||||
vector_env_fn(
|
||||
[create_env(Box(low=0, high=1, shape=(5,))) for _ in range(3)],
|
||||
observation_mode="invalid",
|
||||
)
|
||||
|
||||
def test_obs_mode_same_different_spaces(self, vector_env_fn):
|
||||
spaces = [Box(low=0, high=i, shape=(2,)) for i in range(1, 4)]
|
||||
with pytest.raises(
|
||||
(AssertionError, RuntimeError),
|
||||
match="the sub-environments observation spaces are not equivalent. .*If this is intentional, use `observation_mode='different'` instead.",
|
||||
):
|
||||
vector_env_fn(
|
||||
[create_env(space) for space in spaces], observation_mode="same"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"observation_mode",
|
||||
[
|
||||
"different",
|
||||
(
|
||||
Box(
|
||||
low=0,
|
||||
high=np.repeat(np.arange(1, 4), 5).reshape((3, 5)),
|
||||
shape=(3, 5),
|
||||
),
|
||||
Box(low=0, high=1, shape=(5,)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_obs_mode_different_different_spaces(self, vector_env_fn, observation_mode):
|
||||
spaces = [Box(low=0, high=i, shape=(5,)) for i in range(1, 4)]
|
||||
envs = vector_env_fn(
|
||||
[create_env(space) for space in spaces], observation_mode=observation_mode
|
||||
)
|
||||
assert envs.observation_space == batch_differing_spaces(spaces)
|
||||
assert envs.single_observation_space == spaces[0]
|
||||
|
||||
envs.reset()
|
||||
envs.step(envs.action_space.sample())
|
||||
envs.close()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"observation_mode",
|
||||
[
|
||||
"different",
|
||||
(Box(low=0, high=4, shape=(3, 5)), Box(low=0, high=4, shape=(5,))),
|
||||
],
|
||||
)
|
||||
def test_obs_mode_different_different_shapes(self, vector_env_fn, observation_mode):
|
||||
spaces = [Box(low=0, high=1, shape=(i + 1,)) for i in range(3)]
|
||||
with pytest.raises(
|
||||
(AssertionError, RuntimeError),
|
||||
# match=re.escape(
|
||||
# "Expected all Box.low shape to be equal, actually [(1,), (2,), (3,)]"
|
||||
# ),
|
||||
):
|
||||
vector_env_fn(
|
||||
[create_env(space) for space in spaces],
|
||||
observation_mode=observation_mode,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"observation_mode",
|
||||
[
|
||||
"same",
|
||||
"different",
|
||||
(Box(low=0, high=4, shape=(3, 5)), Box(low=0, high=4, shape=(5,))),
|
||||
],
|
||||
)
|
||||
def test_mixed_observation_spaces(self, vector_env_fn, observation_mode):
|
||||
spaces = [
|
||||
Box(low=0, high=1, shape=(3,)),
|
||||
Discrete(5),
|
||||
Dict({"a": Discrete(2), "b": Box(low=0, high=1, shape=(2,))}),
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
(AssertionError, RuntimeError),
|
||||
# match=re.escape(
|
||||
# "Expects all spaces to be the same shape, actual types: [<class 'gymnasium.spaces.box.Box'>, <class 'gymnasium.spaces.discrete.Discrete'>, <class 'gymnasium.spaces.dict.Dict'>]"
|
||||
# ),
|
||||
):
|
||||
vector_env_fn(
|
||||
[create_env(space) for space in spaces],
|
||||
observation_mode=observation_mode,
|
||||
)
|
@@ -1,5 +1,7 @@
|
||||
"""Test the `SyncVectorEnv` implementation."""
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -139,7 +141,12 @@ def test_check_spaces_sync_vector_env():
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
|
||||
env_fns[1] = make_env("FrozenLake-v1", 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=re.escape(
|
||||
"SyncVectorEnv(..., observation_mode='same') however the sub-environments observation spaces are not equivalent."
|
||||
),
|
||||
):
|
||||
env = SyncVectorEnv(env_fns)
|
||||
env.close()
|
||||
|
||||
|
@@ -4,13 +4,20 @@ import copy
|
||||
import re
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium import Space
|
||||
from gymnasium.error import CustomSpaceError
|
||||
from gymnasium.spaces import Tuple
|
||||
from gymnasium.spaces import Box, Tuple
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||
from gymnasium.vector.utils import (
|
||||
batch_differing_spaces,
|
||||
batch_space,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
iterate,
|
||||
)
|
||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace
|
||||
from tests.vector.utils.utils import is_rng_equal
|
||||
|
||||
@@ -68,13 +75,13 @@ def test_batch_space_deterministic(space: Space, n: int, base_seed: int):
|
||||
space_a = space
|
||||
space_a.seed(base_seed)
|
||||
space_b = copy.deepcopy(space_a)
|
||||
is_rng_equal(space_a.np_random, space_b.np_random)
|
||||
assert is_rng_equal(space_a.np_random, space_b.np_random)
|
||||
assert space_a.np_random is not space_b.np_random
|
||||
|
||||
# Batch the spaces and check that the np_random are not reference equal
|
||||
space_a_batched = batch_space(space_a, n)
|
||||
space_b_batched = batch_space(space_b, n)
|
||||
is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||
assert space_a_batched.np_random is not space_b_batched.np_random
|
||||
# Create that the batched space is not reference equal to the origin spaces
|
||||
assert space_a.np_random is not space_a_batched.np_random
|
||||
@@ -101,7 +108,7 @@ def test_batch_space_different_samples(space: Space, n: int, base_seed: int):
|
||||
|
||||
batched_space = batch_space(space, n)
|
||||
assert space.np_random is not batched_space.np_random
|
||||
is_rng_equal(space.np_random, batched_space.np_random)
|
||||
assert is_rng_equal(space.np_random, batched_space.np_random)
|
||||
|
||||
batched_sample = batched_space.sample()
|
||||
unbatched_samples = list(iterate(batched_space, batched_sample))
|
||||
@@ -149,3 +156,68 @@ def test_custom_space():
|
||||
|
||||
empty_array = create_empty_array(custom_space)
|
||||
assert empty_array is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spaces,expected_space",
|
||||
[
|
||||
(
|
||||
(
|
||||
Box(low=0, high=1, shape=(2,), dtype=np.float32),
|
||||
Box(low=2, high=np.array([3, 5], dtype=np.float32)),
|
||||
),
|
||||
Box(low=np.array([[0, 0], [2, 2]]), high=np.array([[1, 1], [3, 5]])),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_varying_spaces(spaces: "list[Space]", expected_space):
|
||||
"""Test the batch spaces function."""
|
||||
batched_space = batch_differing_spaces(spaces)
|
||||
assert batched_space == expected_space
|
||||
|
||||
batch_samples = batched_space.sample()
|
||||
for sub_space, sub_sample in zip(spaces, iterate(batched_space, batch_samples)):
|
||||
assert sub_sample in sub_space
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
@pytest.mark.parametrize("n", [1, 3])
|
||||
def test_batch_differing_space_vs_batch_space(space, n):
|
||||
"""Test the batch_spaces and batch_space functions."""
|
||||
batched_space = batch_space(space, n)
|
||||
batched_spaces = batch_differing_spaces([copy.deepcopy(space) for _ in range(n)])
|
||||
|
||||
assert batched_space == batched_spaces, f"{batched_space=}, {batched_spaces=}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]])
|
||||
@pytest.mark.parametrize(
|
||||
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
|
||||
)
|
||||
def test_batch_differing_spaces_deterministic(space: Space, n: int, base_seed: int):
|
||||
"""Tests the batched spaces are deterministic by using a copied version."""
|
||||
# Copy the spaces and check that the np_random are not reference equal
|
||||
space_a = space
|
||||
space_a.seed(base_seed)
|
||||
space_b = copy.deepcopy(space_a)
|
||||
assert is_rng_equal(space_a.np_random, space_b.np_random)
|
||||
assert space_a.np_random is not space_b.np_random
|
||||
|
||||
# Batch the spaces and check that the np_random are not reference equal
|
||||
space_a_batched = batch_differing_spaces([space_a for _ in range(n)])
|
||||
space_b_batched = batch_differing_spaces([space_b for _ in range(n)])
|
||||
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||
assert space_a_batched.np_random is not space_b_batched.np_random
|
||||
# Create that the batched space is not reference equal to the origin spaces
|
||||
assert space_a.np_random is not space_a_batched.np_random
|
||||
|
||||
# Check that batched space a and b random number generator are not effected by the original space
|
||||
space_a.sample()
|
||||
space_a_batched_sample = space_a_batched.sample()
|
||||
space_b_batched_sample = space_b_batched.sample()
|
||||
for a_sample, b_sample in zip(
|
||||
iterate(space_a_batched, space_a_batched_sample),
|
||||
iterate(space_b_batched, space_b_batched_sample),
|
||||
):
|
||||
assert data_equivalence(a_sample, b_sample)
|
||||
|
Reference in New Issue
Block a user