mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
Add (A)syncVectorEnv support for sub-envs with different observation spaces (#1140)
Co-authored-by: Reggie <72816837+reginald-mclean@users.noreply.github.com> Co-authored-by: Reggie McLean <reginald.mclean@ryerson.ca>
This commit is contained in:
@@ -573,3 +573,102 @@ def _flatten_space_oneof(space: OneOf) -> Box:
|
||||
|
||||
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
|
||||
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool:
|
||||
"""Returns if two spaces share a common dtype and shape (plus any critical variables).
|
||||
|
||||
This function is primarily used to check for compatibility of different spaces in a vector environment.
|
||||
|
||||
Args:
|
||||
space_1: A Gymnasium space
|
||||
space_2: A Gymnasium space
|
||||
|
||||
Returns:
|
||||
If the two spaces share a common dtype and shape (plus any critical variables).
|
||||
"""
|
||||
if isinstance(space_1, Space) and isinstance(space_2, Space):
|
||||
raise NotImplementedError(
|
||||
"`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, "
|
||||
)
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Box)
|
||||
@is_space_dtype_shape_equiv.register(Discrete)
|
||||
@is_space_dtype_shape_equiv.register(MultiDiscrete)
|
||||
@is_space_dtype_shape_equiv.register(MultiBinary)
|
||||
def _is_space_fundamental_dtype_shape_equiv(space_1, space_2):
|
||||
return (
|
||||
# this check is necessary as singledispatch only checks the first variable and there are many options
|
||||
type(space_1) is type(space_2)
|
||||
and space_1.shape == space_2.shape
|
||||
and space_1.dtype == space_2.dtype
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Text)
|
||||
def _is_space_text_dtype_shape_equiv(space_1: Text, space_2):
|
||||
return (
|
||||
isinstance(space_2, Text)
|
||||
and space_1.max_length == space_2.max_length
|
||||
and space_1.character_set == space_2.character_set
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Dict)
|
||||
def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2):
|
||||
return (
|
||||
isinstance(space_2, Dict)
|
||||
and space_1.keys() == space_2.keys()
|
||||
and all(
|
||||
is_space_dtype_shape_equiv(space_1[key], space_2[key])
|
||||
for key in space_1.keys()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Tuple)
|
||||
def _is_space_tuple_dtype_shape_equiv(space_1, space_2):
|
||||
return isinstance(space_2, Tuple) and all(
|
||||
is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1))
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Graph)
|
||||
def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2):
|
||||
return (
|
||||
isinstance(space_2, Graph)
|
||||
and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space)
|
||||
and (
|
||||
(space_1.edge_space is None and space_2.edge_space is None)
|
||||
or (
|
||||
space_1.edge_space is not None
|
||||
and space_2.edge_space is not None
|
||||
and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(OneOf)
|
||||
def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2):
|
||||
return (
|
||||
isinstance(space_2, OneOf)
|
||||
and len(space_1) == len(space_2)
|
||||
and all(
|
||||
is_space_dtype_shape_equiv(space_1[i], space_2[i])
|
||||
for i in range(len(space_1))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_space_dtype_shape_equiv.register(Sequence)
|
||||
def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2):
|
||||
return (
|
||||
isinstance(space_2, Sequence)
|
||||
and space_1.stack is space_2.stack
|
||||
and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space)
|
||||
)
|
||||
|
Reference in New Issue
Block a user