Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)

Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com>
Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
Mark Towers
2024-08-29 16:52:43 +01:00
committed by GitHub
parent ad8734d89b
commit 64fac8e80b
10 changed files with 584 additions and 55 deletions

View File

@@ -1,6 +1,7 @@
"""Space-based utility functions for vector environments.
- ``batch_space``: Create a (batched) space, containing multiple copies of a single space.
- ``batch_space``: Create a (batched) space containing multiple copies of a single space.
- ``batch_differing_spaces``: Create a (batched) space containing copies of different compatible spaces (share a common dtype and shape)
- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
- ``Iterate``: Iterate over the elements of a (batched) space and items.
- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
@@ -32,7 +33,13 @@ from gymnasium.spaces import (
from gymnasium.spaces.space import T_cov
__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"]
__all__ = [
"batch_space",
"batch_differing_spaces",
"iterate",
"concatenate",
"create_empty_array",
]
@singledispatch
@@ -139,6 +146,116 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
return batched_space
@singledispatch
def batch_differing_spaces(spaces: list[Space]):
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
assert len(spaces) > 0, "Expects a non-empty list of spaces"
assert all(
isinstance(space, type(spaces[0])) for space in spaces
), f"Expects all spaces to be the same shape, actual types: {[type(space) for space in spaces]}"
assert (
type(spaces[0]) in batch_differing_spaces.registry
), f"Requires the Space type to have a registered `batch_differing_space`, current list: {batch_differing_spaces.registry}"
return batch_differing_spaces.dispatch(type(spaces[0]))(spaces)
@batch_differing_spaces.register(Box)
def _batch_differing_spaces_box(spaces: list[Box]):
assert all(
spaces[0].dtype == space.dtype for space in spaces
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
assert all(
spaces[0].low.shape == space.low.shape for space in spaces
), f"Expected all Box.low shape to be equal, actually {[space.low.shape for space in spaces]}"
assert all(
spaces[0].high.shape == space.high.shape for space in spaces
), f"Expected all Box.high shape to be equal, actually {[space.high.shape for space in spaces]}"
return Box(
low=np.array([space.low for space in spaces]),
high=np.array([space.high for space in spaces]),
dtype=spaces[0].dtype,
seed=deepcopy(spaces[0].np_random),
)
@batch_differing_spaces.register(Discrete)
def _batch_differing_spaces_discrete(spaces: list[Discrete]):
return MultiDiscrete(
nvec=np.array([space.n for space in spaces]),
start=np.array([space.start for space in spaces]),
seed=deepcopy(spaces[0].np_random),
)
@batch_differing_spaces.register(MultiDiscrete)
def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]):
assert all(
spaces[0].dtype == space.dtype for space in spaces
), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
assert all(
spaces[0].nvec.shape == space.nvec.shape for space in spaces
), f"Expects all MultiDiscrete.nvec shape, actually {[space.nvec.shape for space in spaces]}"
assert all(
spaces[0].start.shape == space.start.shape for space in spaces
), f"Expects all MultiDiscrete.start shape, actually {[space.start.shape for space in spaces]}"
return Box(
low=np.array([space.start for space in spaces]),
high=np.array([space.start + space.nvec for space in spaces]) - 1,
dtype=spaces[0].dtype,
seed=deepcopy(spaces[0].np_random),
)
@batch_differing_spaces.register(MultiBinary)
def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]):
assert all(spaces[0].shape == space.shape for space in spaces)
return Box(
low=0,
high=1,
shape=(len(spaces),) + spaces[0].shape,
dtype=spaces[0].dtype,
seed=deepcopy(spaces[0].np_random),
)
@batch_differing_spaces.register(Tuple)
def _batch_differing_spaces_tuple(spaces: list[Tuple]):
return Tuple(
tuple(
batch_differing_spaces(subspaces)
for subspaces in zip(*[space.spaces for space in spaces])
),
seed=deepcopy(spaces[0].np_random),
)
@batch_differing_spaces.register(Dict)
def _batch_differing_spaces_dict(spaces: list[Dict]):
assert all(spaces[0].keys() == space.keys() for space in spaces)
return Dict(
{
key: batch_differing_spaces([space[key] for space in spaces])
for key in spaces[0].keys()
},
seed=deepcopy(spaces[0].np_random),
)
@batch_differing_spaces.register(Graph)
@batch_differing_spaces.register(Text)
@batch_differing_spaces.register(Sequence)
@batch_differing_spaces.register(OneOf)
def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
return Tuple(
[deepcopy(space) for space in spaces], seed=deepcopy(spaces[0].np_random)
)
@singledispatch
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
"""Iterate over the elements of a (batched) space.