mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-26 08:17:18 +00:00
Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)
Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com> Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
@@ -47,14 +47,14 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return all(space.is_np_flattenable for space in self.spaces)
|
return all(space.is_np_flattenable for space in self.spaces)
|
||||||
|
|
||||||
def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
|
def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]:
|
||||||
"""Seed the PRNG of this space and all subspaces.
|
"""Seed the PRNG of this space and all subspaces.
|
||||||
|
|
||||||
Depending on the type of seed, the subspaces will be seeded differently
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
|
|
||||||
* ``None`` - All the subspaces will use a random initial seed
|
* ``None`` - All the subspaces will use a random initial seed
|
||||||
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
||||||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
* ``List`` / ``Tuple`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||||
|
@@ -573,3 +573,102 @@ def _flatten_space_oneof(space: OneOf) -> Box:
|
|||||||
|
|
||||||
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
|
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
|
||||||
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
|
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool:
|
||||||
|
"""Returns if two spaces share a common dtype and shape (plus any critical variables).
|
||||||
|
|
||||||
|
This function is primarily used to check for compatibility of different spaces in a vector environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space_1: A Gymnasium space
|
||||||
|
space_2: A Gymnasium space
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the two spaces share a common dtype and shape (plus any critical variables).
|
||||||
|
"""
|
||||||
|
if isinstance(space_1, Space) and isinstance(space_2, Space):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, "
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError()
|
||||||
|
|
||||||
|
|
||||||
|
@is_space_dtype_shape_equiv.register(Box)
|
||||||
|
@is_space_dtype_shape_equiv.register(Discrete)
|
||||||
|
@is_space_dtype_shape_equiv.register(MultiDiscrete)
|
||||||
|
@is_space_dtype_shape_equiv.register(MultiBinary)
|
||||||
|
def _is_space_fundamental_dtype_shape_equiv(space_1, space_2):
|
||||||
|
return (
|
||||||
|
# this check is necessary as singledispatch only checks the first variable and there are many options
|
||||||
|
type(space_1) is type(space_2)
|
||||||
|
and space_1.shape == space_2.shape
|
||||||
|
and space_1.dtype == space_2.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_space_dtype_shape_equiv.register(Text)
|
||||||
|
def _is_space_text_dtype_shape_equiv(space_1: Text, space_2):
|
||||||
|
return (
|
||||||
|
isinstance(space_2, Text)
|
||||||
|
and space_1.max_length == space_2.max_length
|
||||||
|
and space_1.character_set == space_2.character_set
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_space_dtype_shape_equiv.register(Dict)
|
||||||
|
def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2):
|
||||||
|
return (
|
||||||
|
isinstance(space_2, Dict)
|
||||||
|
and space_1.keys() == space_2.keys()
|
||||||
|
and all(
|
||||||
|
is_space_dtype_shape_equiv(space_1[key], space_2[key])
|
||||||
|
for key in space_1.keys()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_space_dtype_shape_equiv.register(Tuple)
|
||||||
|
def _is_space_tuple_dtype_shape_equiv(space_1, space_2):
|
||||||
|
return isinstance(space_2, Tuple) and all(
|
||||||
|
is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_space_dtype_shape_equiv.register(Graph)
|
||||||
|
def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2):
|
||||||
|
return (
|
||||||
|
isinstance(space_2, Graph)
|
||||||
|
and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space)
|
||||||
|
and (
|
||||||
|
(space_1.edge_space is None and space_2.edge_space is None)
|
||||||
|
or (
|
||||||
|
space_1.edge_space is not None
|
||||||
|
and space_2.edge_space is not None
|
||||||
|
and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_space_dtype_shape_equiv.register(OneOf)
|
||||||
|
def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2):
|
||||||
|
return (
|
||||||
|
isinstance(space_2, OneOf)
|
||||||
|
and len(space_1) == len(space_2)
|
||||||
|
and all(
|
||||||
|
is_space_dtype_shape_equiv(space_1[i], space_2[i])
|
||||||
|
for i in range(len(space_1))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_space_dtype_shape_equiv.register(Sequence)
|
||||||
|
def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2):
|
||||||
|
return (
|
||||||
|
isinstance(space_2, Sequence)
|
||||||
|
and space_1.stack is space_2.stack
|
||||||
|
and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space)
|
||||||
|
)
|
||||||
|
@@ -14,7 +14,7 @@ from typing import Any, Callable, Sequence
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium import logger
|
from gymnasium import Space, logger
|
||||||
from gymnasium.core import ActType, Env, ObsType, RenderFrame
|
from gymnasium.core import ActType, Env, ObsType, RenderFrame
|
||||||
from gymnasium.error import (
|
from gymnasium.error import (
|
||||||
AlreadyPendingCallError,
|
AlreadyPendingCallError,
|
||||||
@@ -22,8 +22,10 @@ from gymnasium.error import (
|
|||||||
CustomSpaceError,
|
CustomSpaceError,
|
||||||
NoAsyncCallError,
|
NoAsyncCallError,
|
||||||
)
|
)
|
||||||
|
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
|
||||||
from gymnasium.vector.utils import (
|
from gymnasium.vector.utils import (
|
||||||
CloudpickleWrapper,
|
CloudpickleWrapper,
|
||||||
|
batch_differing_spaces,
|
||||||
batch_space,
|
batch_space,
|
||||||
clear_mpi_env_vars,
|
clear_mpi_env_vars,
|
||||||
concatenate,
|
concatenate,
|
||||||
@@ -98,6 +100,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
) = None,
|
) = None,
|
||||||
|
observation_mode: str | Space = "same",
|
||||||
):
|
):
|
||||||
"""Vectorized environment that runs multiple environments in parallel.
|
"""Vectorized environment that runs multiple environments in parallel.
|
||||||
|
|
||||||
@@ -113,11 +116,15 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
so for some environments you may want to have it set to ``False``.
|
so for some environments you may want to have it set to ``False``.
|
||||||
worker: If set, then use that worker in a subprocess instead of a default one.
|
worker: If set, then use that worker in a subprocess instead of a default one.
|
||||||
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
||||||
|
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
|
||||||
|
'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
|
||||||
|
warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
|
||||||
|
``observation_space``, warning, may raise unexpected errors.
|
||||||
|
|
||||||
Warnings:
|
Warnings:
|
||||||
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
||||||
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
||||||
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
from the code for ``_worker`` (or ``_async_worker``) method, and add changes.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||||
@@ -128,6 +135,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
self.shared_memory = shared_memory
|
self.shared_memory = shared_memory
|
||||||
self.copy = copy
|
self.copy = copy
|
||||||
|
self.observation_mode = observation_mode
|
||||||
|
|
||||||
self.num_envs = len(env_fns)
|
self.num_envs = len(env_fns)
|
||||||
|
|
||||||
@@ -139,13 +147,30 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self.metadata = dummy_env.metadata
|
self.metadata = dummy_env.metadata
|
||||||
self.render_mode = dummy_env.render_mode
|
self.render_mode = dummy_env.render_mode
|
||||||
|
|
||||||
self.single_observation_space = dummy_env.observation_space
|
|
||||||
self.single_action_space = dummy_env.action_space
|
self.single_action_space = dummy_env.action_space
|
||||||
|
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||||
|
|
||||||
|
if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
|
||||||
|
assert isinstance(observation_mode[0], Space)
|
||||||
|
assert isinstance(observation_mode[1], Space)
|
||||||
|
self.observation_space, self.single_observation_space = observation_mode
|
||||||
|
else:
|
||||||
|
if observation_mode == "same":
|
||||||
|
self.single_observation_space = dummy_env.observation_space
|
||||||
self.observation_space = batch_space(
|
self.observation_space = batch_space(
|
||||||
self.single_observation_space, self.num_envs
|
self.single_observation_space, self.num_envs
|
||||||
)
|
)
|
||||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
elif observation_mode == "different":
|
||||||
|
# the environment is created and instantly destroy, might cause issues for some environment
|
||||||
|
# but I don't believe there is anything else we can do, for users with issues, pre-compute the spaces and use the custom option.
|
||||||
|
env_spaces = [env().observation_space for env in self.env_fns]
|
||||||
|
|
||||||
|
self.single_observation_space = env_spaces[0]
|
||||||
|
self.observation_space = batch_differing_spaces(env_spaces)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
dummy_env.close()
|
dummy_env.close()
|
||||||
del dummy_env
|
del dummy_env
|
||||||
@@ -162,9 +187,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
)
|
)
|
||||||
except CustomSpaceError as e:
|
except CustomSpaceError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), "
|
"Using `AsyncVector(..., shared_memory=True)` caused an error, you can disable this feature with `shared_memory=False` however this is slower."
|
||||||
"and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. "
|
|
||||||
"Set `shared_memory=False` if you use custom observation spaces."
|
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
_obs_buffer = None
|
_obs_buffer = None
|
||||||
@@ -591,20 +614,33 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def _check_spaces(self):
|
def _check_spaces(self):
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
spaces = (self.single_observation_space, self.single_action_space)
|
|
||||||
|
|
||||||
for pipe in self.parent_pipes:
|
for pipe in self.parent_pipes:
|
||||||
pipe.send(("_check_spaces", spaces))
|
pipe.send(
|
||||||
|
(
|
||||||
|
"_check_spaces",
|
||||||
|
(
|
||||||
|
self.observation_mode,
|
||||||
|
self.single_observation_space,
|
||||||
|
self.single_action_space,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||||
self._raise_if_errors(successes)
|
self._raise_if_errors(successes)
|
||||||
same_observation_spaces, same_action_spaces = zip(*results)
|
same_observation_spaces, same_action_spaces = zip(*results)
|
||||||
|
|
||||||
if not all(same_observation_spaces):
|
if not all(same_observation_spaces):
|
||||||
|
if self.observation_mode == "same":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Some environments have an observation space different from `{self.single_observation_space}`. "
|
"AsyncVectorEnv(..., observation_mode='same') however some of the sub-environments observation spaces are not equivalent. If this is intentional, use `observation_mode='different'` instead."
|
||||||
"In order to batch observations, the observation spaces from all environments must be equal."
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"AsyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environment's observation spaces do not share a common shape and dtype."
|
||||||
|
)
|
||||||
|
|
||||||
if not all(same_action_spaces):
|
if not all(same_action_spaces):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Some environments have an action space different from `{self.single_action_space}`. "
|
f"Some environments have an action space different from `{self.single_action_space}`. "
|
||||||
@@ -714,9 +750,20 @@ def _async_worker(
|
|||||||
env.set_wrapper_attr(name, value)
|
env.set_wrapper_attr(name, value)
|
||||||
pipe.send((None, True))
|
pipe.send((None, True))
|
||||||
elif command == "_check_spaces":
|
elif command == "_check_spaces":
|
||||||
|
obs_mode, single_obs_space, single_action_space = data
|
||||||
|
|
||||||
pipe.send(
|
pipe.send(
|
||||||
(
|
(
|
||||||
(data[0] == observation_space, data[1] == action_space),
|
(
|
||||||
|
(
|
||||||
|
single_obs_space == observation_space
|
||||||
|
if obs_mode == "same"
|
||||||
|
else is_space_dtype_shape_equiv(
|
||||||
|
single_obs_space, observation_space
|
||||||
|
)
|
||||||
|
),
|
||||||
|
single_action_space == action_space,
|
||||||
|
),
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@@ -7,9 +7,16 @@ from typing import Any, Callable, Iterator, Sequence
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium import Env
|
from gymnasium import Env, Space
|
||||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
|
||||||
|
from gymnasium.vector.utils import (
|
||||||
|
batch_differing_spaces,
|
||||||
|
batch_space,
|
||||||
|
concatenate,
|
||||||
|
create_empty_array,
|
||||||
|
iterate,
|
||||||
|
)
|
||||||
from gymnasium.vector.vector_env import ArrayType, VectorEnv
|
from gymnasium.vector.vector_env import ArrayType, VectorEnv
|
||||||
|
|
||||||
|
|
||||||
@@ -57,19 +64,23 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self,
|
self,
|
||||||
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
|
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
|
||||||
copy: bool = True,
|
copy: bool = True,
|
||||||
|
observation_mode: str | Space = "same",
|
||||||
):
|
):
|
||||||
"""Vectorized environment that serially runs multiple environments.
|
"""Vectorized environment that serially runs multiple environments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_fns: iterable of callable functions that create the environments.
|
env_fns: iterable of callable functions that create the environments.
|
||||||
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
||||||
|
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
|
||||||
|
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
|
||||||
|
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||||
(or, by default, the observation space of the first sub-environment).
|
(or, by default, the observation space of the first sub-environment).
|
||||||
"""
|
"""
|
||||||
self.copy = copy
|
self.copy = copy
|
||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
|
self.observation_mode = observation_mode
|
||||||
|
|
||||||
# Initialise all sub-environments
|
# Initialise all sub-environments
|
||||||
self.envs = [env_fn() for env_fn in env_fns]
|
self.envs = [env_fn() for env_fn in env_fns]
|
||||||
@@ -80,16 +91,43 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self.metadata = self.envs[0].metadata
|
self.metadata = self.envs[0].metadata
|
||||||
self.render_mode = self.envs[0].render_mode
|
self.render_mode = self.envs[0].render_mode
|
||||||
|
|
||||||
# Initialises the single spaces from the sub-environments
|
|
||||||
self.single_observation_space = self.envs[0].observation_space
|
|
||||||
self.single_action_space = self.envs[0].action_space
|
self.single_action_space = self.envs[0].action_space
|
||||||
self._check_spaces()
|
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||||
|
|
||||||
# Initialise the obs and action space based on the single versions and num of sub-environments
|
if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
|
||||||
|
assert isinstance(observation_mode[0], Space)
|
||||||
|
assert isinstance(observation_mode[1], Space)
|
||||||
|
self.observation_space, self.single_observation_space = observation_mode
|
||||||
|
else:
|
||||||
|
if observation_mode == "same":
|
||||||
|
self.single_observation_space = self.envs[0].observation_space
|
||||||
self.observation_space = batch_space(
|
self.observation_space = batch_space(
|
||||||
self.single_observation_space, self.num_envs
|
self.single_observation_space, self.num_envs
|
||||||
)
|
)
|
||||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
elif observation_mode == "different":
|
||||||
|
self.single_observation_space = self.envs[0].observation_space
|
||||||
|
self.observation_space = batch_differing_spaces(
|
||||||
|
[env.observation_space for env in self.envs]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check sub-environment obs and action spaces
|
||||||
|
for env in self.envs:
|
||||||
|
if observation_mode == "same":
|
||||||
|
assert (
|
||||||
|
env.observation_space == self.single_observation_space
|
||||||
|
), f"SyncVectorEnv(..., observation_mode='same') however the sub-environments observation spaces are not equivalent. single_observation_space={self.single_observation_space}, sub-environment observation_space={env.observation_space}. If this is intentional, use `observation_mode='different'` instead."
|
||||||
|
else:
|
||||||
|
assert is_space_dtype_shape_equiv(
|
||||||
|
env.observation_space, self.single_observation_space
|
||||||
|
), f"SyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environments observation spaces do not share a common shape and dtype, single_observation_space={self.single_observation_space}, sub-environment observation space={env.observation_space}"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
env.action_space == self.single_action_space
|
||||||
|
), f"Sub-environment action space doesn't make the `single_action_space`, action_space={env.action_space}, single_action_space={self.single_action_space}"
|
||||||
|
|
||||||
# Initialise attributes used in `step` and `reset`
|
# Initialise attributes used in `step` and `reset`
|
||||||
self._observations = create_empty_array(
|
self._observations = create_empty_array(
|
||||||
@@ -265,20 +303,3 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
"""Close the environments."""
|
"""Close the environments."""
|
||||||
if hasattr(self, "envs"):
|
if hasattr(self, "envs"):
|
||||||
[env.close() for env in self.envs]
|
[env.close() for env in self.envs]
|
||||||
|
|
||||||
def _check_spaces(self) -> bool:
|
|
||||||
"""Check that each of the environments obs and action spaces are equivalent to the single obs and action space."""
|
|
||||||
for env in self.envs:
|
|
||||||
if not (env.observation_space == self.single_observation_space):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Some environments have an observation space different from `{self.single_observation_space}`. "
|
|
||||||
"In order to batch observations, the observation spaces from all environments must be equal."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (env.action_space == self.single_action_space):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Some environments have an action space different from `{self.single_action_space}`. "
|
|
||||||
"In order to batch actions, the action spaces from all environments must be equal."
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
@@ -7,6 +7,7 @@ from gymnasium.vector.utils.shared_memory import (
|
|||||||
write_to_shared_memory,
|
write_to_shared_memory,
|
||||||
)
|
)
|
||||||
from gymnasium.vector.utils.space_utils import (
|
from gymnasium.vector.utils.space_utils import (
|
||||||
|
batch_differing_spaces,
|
||||||
batch_space,
|
batch_space,
|
||||||
concatenate,
|
concatenate,
|
||||||
create_empty_array,
|
create_empty_array,
|
||||||
@@ -16,6 +17,7 @@ from gymnasium.vector.utils.space_utils import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"batch_space",
|
"batch_space",
|
||||||
|
"batch_differing_spaces",
|
||||||
"iterate",
|
"iterate",
|
||||||
"concatenate",
|
"concatenate",
|
||||||
"create_empty_array",
|
"create_empty_array",
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
"""Space-based utility functions for vector environments.
|
"""Space-based utility functions for vector environments.
|
||||||
|
|
||||||
- ``batch_space``: Create a (batched) space, containing multiple copies of a single space.
|
- ``batch_space``: Create a (batched) space containing multiple copies of a single space.
|
||||||
|
- ``batch_differing_spaces``: Create a (batched) space containing copies of different compatible spaces (share a common dtype and shape)
|
||||||
- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
|
- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
|
||||||
- ``Iterate``: Iterate over the elements of a (batched) space and items.
|
- ``Iterate``: Iterate over the elements of a (batched) space and items.
|
||||||
- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
|
- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
|
||||||
@@ -32,7 +33,13 @@ from gymnasium.spaces import (
|
|||||||
from gymnasium.spaces.space import T_cov
|
from gymnasium.spaces.space import T_cov
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"]
|
__all__ = [
|
||||||
|
"batch_space",
|
||||||
|
"batch_differing_spaces",
|
||||||
|
"iterate",
|
||||||
|
"concatenate",
|
||||||
|
"create_empty_array",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
@@ -139,6 +146,116 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
|
|||||||
return batched_space
|
return batched_space
|
||||||
|
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def batch_differing_spaces(spaces: list[Space]):
|
||||||
|
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
|
||||||
|
assert len(spaces) > 0, "Expects a non-empty list of spaces"
|
||||||
|
assert all(
|
||||||
|
isinstance(space, type(spaces[0])) for space in spaces
|
||||||
|
), f"Expects all spaces to be the same shape, actual types: {[type(space) for space in spaces]}"
|
||||||
|
assert (
|
||||||
|
type(spaces[0]) in batch_differing_spaces.registry
|
||||||
|
), f"Requires the Space type to have a registered `batch_differing_space`, current list: {batch_differing_spaces.registry}"
|
||||||
|
|
||||||
|
return batch_differing_spaces.dispatch(type(spaces[0]))(spaces)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_differing_spaces.register(Box)
|
||||||
|
def _batch_differing_spaces_box(spaces: list[Box]):
|
||||||
|
assert all(
|
||||||
|
spaces[0].dtype == space.dtype for space in spaces
|
||||||
|
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
|
||||||
|
assert all(
|
||||||
|
spaces[0].low.shape == space.low.shape for space in spaces
|
||||||
|
), f"Expected all Box.low shape to be equal, actually {[space.low.shape for space in spaces]}"
|
||||||
|
assert all(
|
||||||
|
spaces[0].high.shape == space.high.shape for space in spaces
|
||||||
|
), f"Expected all Box.high shape to be equal, actually {[space.high.shape for space in spaces]}"
|
||||||
|
|
||||||
|
return Box(
|
||||||
|
low=np.array([space.low for space in spaces]),
|
||||||
|
high=np.array([space.high for space in spaces]),
|
||||||
|
dtype=spaces[0].dtype,
|
||||||
|
seed=deepcopy(spaces[0].np_random),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_differing_spaces.register(Discrete)
|
||||||
|
def _batch_differing_spaces_discrete(spaces: list[Discrete]):
|
||||||
|
return MultiDiscrete(
|
||||||
|
nvec=np.array([space.n for space in spaces]),
|
||||||
|
start=np.array([space.start for space in spaces]),
|
||||||
|
seed=deepcopy(spaces[0].np_random),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_differing_spaces.register(MultiDiscrete)
|
||||||
|
def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]):
|
||||||
|
assert all(
|
||||||
|
spaces[0].dtype == space.dtype for space in spaces
|
||||||
|
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
|
||||||
|
assert all(
|
||||||
|
spaces[0].nvec.shape == space.nvec.shape for space in spaces
|
||||||
|
), f"Expects all MultiDiscrete.nvec shape, actually {[space.nvec.shape for space in spaces]}"
|
||||||
|
assert all(
|
||||||
|
spaces[0].start.shape == space.start.shape for space in spaces
|
||||||
|
), f"Expects all MultiDiscrete.start shape, actually {[space.start.shape for space in spaces]}"
|
||||||
|
|
||||||
|
return Box(
|
||||||
|
low=np.array([space.start for space in spaces]),
|
||||||
|
high=np.array([space.start + space.nvec for space in spaces]) - 1,
|
||||||
|
dtype=spaces[0].dtype,
|
||||||
|
seed=deepcopy(spaces[0].np_random),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_differing_spaces.register(MultiBinary)
|
||||||
|
def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]):
|
||||||
|
assert all(spaces[0].shape == space.shape for space in spaces)
|
||||||
|
|
||||||
|
return Box(
|
||||||
|
low=0,
|
||||||
|
high=1,
|
||||||
|
shape=(len(spaces),) + spaces[0].shape,
|
||||||
|
dtype=spaces[0].dtype,
|
||||||
|
seed=deepcopy(spaces[0].np_random),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_differing_spaces.register(Tuple)
|
||||||
|
def _batch_differing_spaces_tuple(spaces: list[Tuple]):
|
||||||
|
return Tuple(
|
||||||
|
tuple(
|
||||||
|
batch_differing_spaces(subspaces)
|
||||||
|
for subspaces in zip(*[space.spaces for space in spaces])
|
||||||
|
),
|
||||||
|
seed=deepcopy(spaces[0].np_random),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_differing_spaces.register(Dict)
|
||||||
|
def _batch_differing_spaces_dict(spaces: list[Dict]):
|
||||||
|
assert all(spaces[0].keys() == space.keys() for space in spaces)
|
||||||
|
|
||||||
|
return Dict(
|
||||||
|
{
|
||||||
|
key: batch_differing_spaces([space[key] for space in spaces])
|
||||||
|
for key in spaces[0].keys()
|
||||||
|
},
|
||||||
|
seed=deepcopy(spaces[0].np_random),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_differing_spaces.register(Graph)
|
||||||
|
@batch_differing_spaces.register(Text)
|
||||||
|
@batch_differing_spaces.register(Sequence)
|
||||||
|
@batch_differing_spaces.register(OneOf)
|
||||||
|
def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
|
||||||
|
return Tuple(
|
||||||
|
[deepcopy(space) for space in spaces], seed=deepcopy(spaces[0].np_random)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
|
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
|
||||||
"""Iterate over the elements of a (batched) space.
|
"""Iterate over the elements of a (batched) space.
|
||||||
|
@@ -6,7 +6,13 @@ import pytest
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.spaces import Box, Graph, Sequence, utils
|
from gymnasium.spaces import Box, Graph, Sequence, utils
|
||||||
|
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
|
||||||
from gymnasium.utils.env_checker import data_equivalence
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
|
from gymnasium.vector.utils import (
|
||||||
|
create_shared_memory,
|
||||||
|
read_from_shared_memory,
|
||||||
|
write_to_shared_memory,
|
||||||
|
)
|
||||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
||||||
|
|
||||||
|
|
||||||
@@ -162,3 +168,40 @@ def test_unflatten_multidiscrete_error():
|
|||||||
value = np.array([0, 0])
|
value = np.array([0, 0])
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)
|
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||||
|
def test_is_space_dtype_shape_equiv(space):
|
||||||
|
assert is_space_dtype_shape_equiv(space, space) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space_1", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||||
|
def test_all_space_pairs_for_is_space_dtype_shape_equiv(space_1):
|
||||||
|
"""Practically check that the `is_space_dtype_shape_equiv` works as expected for `shared_memory`."""
|
||||||
|
for space_2 in TESTING_SPACES:
|
||||||
|
compatible = is_space_dtype_shape_equiv(space_1, space_2)
|
||||||
|
|
||||||
|
if compatible:
|
||||||
|
try:
|
||||||
|
shared_memory = create_shared_memory(space_1, n=2)
|
||||||
|
except TypeError as err:
|
||||||
|
assert (
|
||||||
|
"has a dynamic shape so its not possible to make a static shared memory."
|
||||||
|
in str(err)
|
||||||
|
)
|
||||||
|
pytest.skip("Skipping space with dynamic shape")
|
||||||
|
|
||||||
|
space_1.seed(123)
|
||||||
|
space_2.seed(123)
|
||||||
|
sample_1 = space_1.sample()
|
||||||
|
sample_2 = space_2.sample()
|
||||||
|
|
||||||
|
write_to_shared_memory(space_1, 0, sample_1, shared_memory)
|
||||||
|
write_to_shared_memory(space_2, 1, sample_2, shared_memory)
|
||||||
|
|
||||||
|
read_sample_1, read_sample_2 = read_from_shared_memory(
|
||||||
|
space_1, shared_memory, n=2
|
||||||
|
)
|
||||||
|
|
||||||
|
assert data_equivalence(sample_1, read_sample_1)
|
||||||
|
assert data_equivalence(sample_2, read_sample_2)
|
||||||
|
121
tests/vector/test_observation_mode.py
Normal file
121
tests/vector/test_observation_mode.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import re
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gymnasium.spaces import Box, Dict, Discrete
|
||||||
|
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
||||||
|
from gymnasium.vector.utils import batch_differing_spaces
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def create_env(obs_space):
|
||||||
|
return lambda: GenericTestEnv(observation_space=obs_space)
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for both SyncVectorEnv and AsyncVectorEnv
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"vector_env_fn",
|
||||||
|
[SyncVectorEnv, AsyncVectorEnv, partial(AsyncVectorEnv, shared_memory=False)],
|
||||||
|
ids=[
|
||||||
|
"SyncVectorEnv",
|
||||||
|
"AsyncVectorEnv(shared_memory=True)",
|
||||||
|
"AsyncVectorEnv(shared_memory=False)",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class TestVectorEnvObservationModes:
|
||||||
|
|
||||||
|
def test_invalid_observation_mode(self, vector_env_fn):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got invalid"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
vector_env_fn(
|
||||||
|
[create_env(Box(low=0, high=1, shape=(5,))) for _ in range(3)],
|
||||||
|
observation_mode="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_obs_mode_same_different_spaces(self, vector_env_fn):
|
||||||
|
spaces = [Box(low=0, high=i, shape=(2,)) for i in range(1, 4)]
|
||||||
|
with pytest.raises(
|
||||||
|
(AssertionError, RuntimeError),
|
||||||
|
match="the sub-environments observation spaces are not equivalent. .*If this is intentional, use `observation_mode='different'` instead.",
|
||||||
|
):
|
||||||
|
vector_env_fn(
|
||||||
|
[create_env(space) for space in spaces], observation_mode="same"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"observation_mode",
|
||||||
|
[
|
||||||
|
"different",
|
||||||
|
(
|
||||||
|
Box(
|
||||||
|
low=0,
|
||||||
|
high=np.repeat(np.arange(1, 4), 5).reshape((3, 5)),
|
||||||
|
shape=(3, 5),
|
||||||
|
),
|
||||||
|
Box(low=0, high=1, shape=(5,)),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_obs_mode_different_different_spaces(self, vector_env_fn, observation_mode):
|
||||||
|
spaces = [Box(low=0, high=i, shape=(5,)) for i in range(1, 4)]
|
||||||
|
envs = vector_env_fn(
|
||||||
|
[create_env(space) for space in spaces], observation_mode=observation_mode
|
||||||
|
)
|
||||||
|
assert envs.observation_space == batch_differing_spaces(spaces)
|
||||||
|
assert envs.single_observation_space == spaces[0]
|
||||||
|
|
||||||
|
envs.reset()
|
||||||
|
envs.step(envs.action_space.sample())
|
||||||
|
envs.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"observation_mode",
|
||||||
|
[
|
||||||
|
"different",
|
||||||
|
(Box(low=0, high=4, shape=(3, 5)), Box(low=0, high=4, shape=(5,))),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_obs_mode_different_different_shapes(self, vector_env_fn, observation_mode):
|
||||||
|
spaces = [Box(low=0, high=1, shape=(i + 1,)) for i in range(3)]
|
||||||
|
with pytest.raises(
|
||||||
|
(AssertionError, RuntimeError),
|
||||||
|
# match=re.escape(
|
||||||
|
# "Expected all Box.low shape to be equal, actually [(1,), (2,), (3,)]"
|
||||||
|
# ),
|
||||||
|
):
|
||||||
|
vector_env_fn(
|
||||||
|
[create_env(space) for space in spaces],
|
||||||
|
observation_mode=observation_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"observation_mode",
|
||||||
|
[
|
||||||
|
"same",
|
||||||
|
"different",
|
||||||
|
(Box(low=0, high=4, shape=(3, 5)), Box(low=0, high=4, shape=(5,))),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mixed_observation_spaces(self, vector_env_fn, observation_mode):
|
||||||
|
spaces = [
|
||||||
|
Box(low=0, high=1, shape=(3,)),
|
||||||
|
Discrete(5),
|
||||||
|
Dict({"a": Discrete(2), "b": Box(low=0, high=1, shape=(2,))}),
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
(AssertionError, RuntimeError),
|
||||||
|
# match=re.escape(
|
||||||
|
# "Expects all spaces to be the same shape, actual types: [<class 'gymnasium.spaces.box.Box'>, <class 'gymnasium.spaces.discrete.Discrete'>, <class 'gymnasium.spaces.dict.Dict'>]"
|
||||||
|
# ),
|
||||||
|
):
|
||||||
|
vector_env_fn(
|
||||||
|
[create_env(space) for space in spaces],
|
||||||
|
observation_mode=observation_mode,
|
||||||
|
)
|
@@ -1,5 +1,7 @@
|
|||||||
"""Test the `SyncVectorEnv` implementation."""
|
"""Test the `SyncVectorEnv` implementation."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -139,7 +141,12 @@ def test_check_spaces_sync_vector_env():
|
|||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
|
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
|
||||||
env_fns[1] = make_env("FrozenLake-v1", 1)
|
env_fns[1] = make_env("FrozenLake-v1", 1)
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match=re.escape(
|
||||||
|
"SyncVectorEnv(..., observation_mode='same') however the sub-environments observation spaces are not equivalent."
|
||||||
|
),
|
||||||
|
):
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
@@ -4,13 +4,20 @@ import copy
|
|||||||
import re
|
import re
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium import Space
|
from gymnasium import Space
|
||||||
from gymnasium.error import CustomSpaceError
|
from gymnasium.error import CustomSpaceError
|
||||||
from gymnasium.spaces import Tuple
|
from gymnasium.spaces import Box, Tuple
|
||||||
from gymnasium.utils.env_checker import data_equivalence
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
from gymnasium.vector.utils import (
|
||||||
|
batch_differing_spaces,
|
||||||
|
batch_space,
|
||||||
|
concatenate,
|
||||||
|
create_empty_array,
|
||||||
|
iterate,
|
||||||
|
)
|
||||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace
|
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace
|
||||||
from tests.vector.utils.utils import is_rng_equal
|
from tests.vector.utils.utils import is_rng_equal
|
||||||
|
|
||||||
@@ -68,13 +75,13 @@ def test_batch_space_deterministic(space: Space, n: int, base_seed: int):
|
|||||||
space_a = space
|
space_a = space
|
||||||
space_a.seed(base_seed)
|
space_a.seed(base_seed)
|
||||||
space_b = copy.deepcopy(space_a)
|
space_b = copy.deepcopy(space_a)
|
||||||
is_rng_equal(space_a.np_random, space_b.np_random)
|
assert is_rng_equal(space_a.np_random, space_b.np_random)
|
||||||
assert space_a.np_random is not space_b.np_random
|
assert space_a.np_random is not space_b.np_random
|
||||||
|
|
||||||
# Batch the spaces and check that the np_random are not reference equal
|
# Batch the spaces and check that the np_random are not reference equal
|
||||||
space_a_batched = batch_space(space_a, n)
|
space_a_batched = batch_space(space_a, n)
|
||||||
space_b_batched = batch_space(space_b, n)
|
space_b_batched = batch_space(space_b, n)
|
||||||
is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||||
assert space_a_batched.np_random is not space_b_batched.np_random
|
assert space_a_batched.np_random is not space_b_batched.np_random
|
||||||
# Create that the batched space is not reference equal to the origin spaces
|
# Create that the batched space is not reference equal to the origin spaces
|
||||||
assert space_a.np_random is not space_a_batched.np_random
|
assert space_a.np_random is not space_a_batched.np_random
|
||||||
@@ -101,7 +108,7 @@ def test_batch_space_different_samples(space: Space, n: int, base_seed: int):
|
|||||||
|
|
||||||
batched_space = batch_space(space, n)
|
batched_space = batch_space(space, n)
|
||||||
assert space.np_random is not batched_space.np_random
|
assert space.np_random is not batched_space.np_random
|
||||||
is_rng_equal(space.np_random, batched_space.np_random)
|
assert is_rng_equal(space.np_random, batched_space.np_random)
|
||||||
|
|
||||||
batched_sample = batched_space.sample()
|
batched_sample = batched_space.sample()
|
||||||
unbatched_samples = list(iterate(batched_space, batched_sample))
|
unbatched_samples = list(iterate(batched_space, batched_sample))
|
||||||
@@ -149,3 +156,68 @@ def test_custom_space():
|
|||||||
|
|
||||||
empty_array = create_empty_array(custom_space)
|
empty_array = create_empty_array(custom_space)
|
||||||
assert empty_array is None
|
assert empty_array is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spaces,expected_space",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
(
|
||||||
|
Box(low=0, high=1, shape=(2,), dtype=np.float32),
|
||||||
|
Box(low=2, high=np.array([3, 5], dtype=np.float32)),
|
||||||
|
),
|
||||||
|
Box(low=np.array([[0, 0], [2, 2]]), high=np.array([[1, 1], [3, 5]])),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_varying_spaces(spaces: "list[Space]", expected_space):
|
||||||
|
"""Test the batch spaces function."""
|
||||||
|
batched_space = batch_differing_spaces(spaces)
|
||||||
|
assert batched_space == expected_space
|
||||||
|
|
||||||
|
batch_samples = batched_space.sample()
|
||||||
|
for sub_space, sub_sample in zip(spaces, iterate(batched_space, batch_samples)):
|
||||||
|
assert sub_sample in sub_space
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||||
|
@pytest.mark.parametrize("n", [1, 3])
|
||||||
|
def test_batch_differing_space_vs_batch_space(space, n):
|
||||||
|
"""Test the batch_spaces and batch_space functions."""
|
||||||
|
batched_space = batch_space(space, n)
|
||||||
|
batched_spaces = batch_differing_spaces([copy.deepcopy(space) for _ in range(n)])
|
||||||
|
|
||||||
|
assert batched_space == batched_spaces, f"{batched_space=}, {batched_spaces=}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||||
|
@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
|
||||||
|
)
|
||||||
|
def test_batch_differing_spaces_deterministic(space: Space, n: int, base_seed: int):
|
||||||
|
"""Tests the batched spaces are deterministic by using a copied version."""
|
||||||
|
# Copy the spaces and check that the np_random are not reference equal
|
||||||
|
space_a = space
|
||||||
|
space_a.seed(base_seed)
|
||||||
|
space_b = copy.deepcopy(space_a)
|
||||||
|
assert is_rng_equal(space_a.np_random, space_b.np_random)
|
||||||
|
assert space_a.np_random is not space_b.np_random
|
||||||
|
|
||||||
|
# Batch the spaces and check that the np_random are not reference equal
|
||||||
|
space_a_batched = batch_differing_spaces([space_a for _ in range(n)])
|
||||||
|
space_b_batched = batch_differing_spaces([space_b for _ in range(n)])
|
||||||
|
assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
|
||||||
|
assert space_a_batched.np_random is not space_b_batched.np_random
|
||||||
|
# Create that the batched space is not reference equal to the origin spaces
|
||||||
|
assert space_a.np_random is not space_a_batched.np_random
|
||||||
|
|
||||||
|
# Check that batched space a and b random number generator are not effected by the original space
|
||||||
|
space_a.sample()
|
||||||
|
space_a_batched_sample = space_a_batched.sample()
|
||||||
|
space_b_batched_sample = space_b_batched.sample()
|
||||||
|
for a_sample, b_sample in zip(
|
||||||
|
iterate(space_a_batched, space_a_batched_sample),
|
||||||
|
iterate(space_b_batched, space_b_batched_sample),
|
||||||
|
):
|
||||||
|
assert data_equivalence(a_sample, b_sample)
|
||||||
|
Reference in New Issue
Block a user