Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)

Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com>
Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
Mark Towers
2024-08-29 16:52:43 +01:00
committed by GitHub
parent ad8734d89b
commit 64fac8e80b
10 changed files with 584 additions and 55 deletions

View File

@@ -573,3 +573,102 @@ def _flatten_space_oneof(space: OneOf) -> Box:
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
@singledispatch
def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool:
"""Returns if two spaces share a common dtype and shape (plus any critical variables).
This function is primarily used to check for compatibility of different spaces in a vector environment.
Args:
space_1: A Gymnasium space
space_2: A Gymnasium space
Returns:
If the two spaces share a common dtype and shape (plus any critical variables).
"""
if isinstance(space_1, Space) and isinstance(space_2, Space):
raise NotImplementedError(
"`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, "
)
else:
raise TypeError()
@is_space_dtype_shape_equiv.register(Box)
@is_space_dtype_shape_equiv.register(Discrete)
@is_space_dtype_shape_equiv.register(MultiDiscrete)
@is_space_dtype_shape_equiv.register(MultiBinary)
def _is_space_fundamental_dtype_shape_equiv(space_1, space_2):
return (
# this check is necessary as singledispatch only checks the first variable and there are many options
type(space_1) is type(space_2)
and space_1.shape == space_2.shape
and space_1.dtype == space_2.dtype
)
@is_space_dtype_shape_equiv.register(Text)
def _is_space_text_dtype_shape_equiv(space_1: Text, space_2):
return (
isinstance(space_2, Text)
and space_1.max_length == space_2.max_length
and space_1.character_set == space_2.character_set
)
@is_space_dtype_shape_equiv.register(Dict)
def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2):
return (
isinstance(space_2, Dict)
and space_1.keys() == space_2.keys()
and all(
is_space_dtype_shape_equiv(space_1[key], space_2[key])
for key in space_1.keys()
)
)
@is_space_dtype_shape_equiv.register(Tuple)
def _is_space_tuple_dtype_shape_equiv(space_1, space_2):
return isinstance(space_2, Tuple) and all(
is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1))
)
@is_space_dtype_shape_equiv.register(Graph)
def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2):
return (
isinstance(space_2, Graph)
and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space)
and (
(space_1.edge_space is None and space_2.edge_space is None)
or (
space_1.edge_space is not None
and space_2.edge_space is not None
and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space)
)
)
)
@is_space_dtype_shape_equiv.register(OneOf)
def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2):
return (
isinstance(space_2, OneOf)
and len(space_1) == len(space_2)
and all(
is_space_dtype_shape_equiv(space_1[i], space_2[i])
for i in range(len(space_1))
)
)
@is_space_dtype_shape_equiv.register(Sequence)
def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2):
return (
isinstance(space_2, Sequence)
and space_1.stack is space_2.stack
and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space)
)