mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-31 18:12:53 +00:00
Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)
Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com> Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Space-based utility functions for vector environments.
|
||||
|
||||
- ``batch_space``: Create a (batched) space, containing multiple copies of a single space.
|
||||
- ``batch_space``: Create a (batched) space containing multiple copies of a single space.
|
||||
- ``batch_differing_spaces``: Create a (batched) space containing copies of different compatible spaces (share a common dtype and shape)
|
||||
- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
|
||||
- ``Iterate``: Iterate over the elements of a (batched) space and items.
|
||||
- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
|
||||
@@ -32,7 +33,13 @@ from gymnasium.spaces import (
|
||||
from gymnasium.spaces.space import T_cov
|
||||
|
||||
|
||||
__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"]
|
||||
__all__ = [
|
||||
"batch_space",
|
||||
"batch_differing_spaces",
|
||||
"iterate",
|
||||
"concatenate",
|
||||
"create_empty_array",
|
||||
]
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -139,6 +146,116 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
|
||||
return batched_space
|
||||
|
||||
|
||||
@singledispatch
|
||||
def batch_differing_spaces(spaces: list[Space]):
|
||||
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
|
||||
assert len(spaces) > 0, "Expects a non-empty list of spaces"
|
||||
assert all(
|
||||
isinstance(space, type(spaces[0])) for space in spaces
|
||||
), f"Expects all spaces to be the same shape, actual types: {[type(space) for space in spaces]}"
|
||||
assert (
|
||||
type(spaces[0]) in batch_differing_spaces.registry
|
||||
), f"Requires the Space type to have a registered `batch_differing_space`, current list: {batch_differing_spaces.registry}"
|
||||
|
||||
return batch_differing_spaces.dispatch(type(spaces[0]))(spaces)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Box)
|
||||
def _batch_differing_spaces_box(spaces: list[Box]):
|
||||
assert all(
|
||||
spaces[0].dtype == space.dtype for space in spaces
|
||||
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].low.shape == space.low.shape for space in spaces
|
||||
), f"Expected all Box.low shape to be equal, actually {[space.low.shape for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].high.shape == space.high.shape for space in spaces
|
||||
), f"Expected all Box.high shape to be equal, actually {[space.high.shape for space in spaces]}"
|
||||
|
||||
return Box(
|
||||
low=np.array([space.low for space in spaces]),
|
||||
high=np.array([space.high for space in spaces]),
|
||||
dtype=spaces[0].dtype,
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Discrete)
|
||||
def _batch_differing_spaces_discrete(spaces: list[Discrete]):
|
||||
return MultiDiscrete(
|
||||
nvec=np.array([space.n for space in spaces]),
|
||||
start=np.array([space.start for space in spaces]),
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(MultiDiscrete)
|
||||
def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]):
|
||||
assert all(
|
||||
spaces[0].dtype == space.dtype for space in spaces
|
||||
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].nvec.shape == space.nvec.shape for space in spaces
|
||||
), f"Expects all MultiDiscrete.nvec shape, actually {[space.nvec.shape for space in spaces]}"
|
||||
assert all(
|
||||
spaces[0].start.shape == space.start.shape for space in spaces
|
||||
), f"Expects all MultiDiscrete.start shape, actually {[space.start.shape for space in spaces]}"
|
||||
|
||||
return Box(
|
||||
low=np.array([space.start for space in spaces]),
|
||||
high=np.array([space.start + space.nvec for space in spaces]) - 1,
|
||||
dtype=spaces[0].dtype,
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(MultiBinary)
|
||||
def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]):
|
||||
assert all(spaces[0].shape == space.shape for space in spaces)
|
||||
|
||||
return Box(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=(len(spaces),) + spaces[0].shape,
|
||||
dtype=spaces[0].dtype,
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Tuple)
|
||||
def _batch_differing_spaces_tuple(spaces: list[Tuple]):
|
||||
return Tuple(
|
||||
tuple(
|
||||
batch_differing_spaces(subspaces)
|
||||
for subspaces in zip(*[space.spaces for space in spaces])
|
||||
),
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Dict)
|
||||
def _batch_differing_spaces_dict(spaces: list[Dict]):
|
||||
assert all(spaces[0].keys() == space.keys() for space in spaces)
|
||||
|
||||
return Dict(
|
||||
{
|
||||
key: batch_differing_spaces([space[key] for space in spaces])
|
||||
for key in spaces[0].keys()
|
||||
},
|
||||
seed=deepcopy(spaces[0].np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_differing_spaces.register(Graph)
|
||||
@batch_differing_spaces.register(Text)
|
||||
@batch_differing_spaces.register(Sequence)
|
||||
@batch_differing_spaces.register(OneOf)
|
||||
def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
|
||||
return Tuple(
|
||||
[deepcopy(space) for space in spaces], seed=deepcopy(spaces[0].np_random)
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
|
||||
"""Iterate over the elements of a (batched) space.
|
||||
|
Reference in New Issue
Block a user