Fix VectorizeActionTransform for changing spaces (#1170)

This commit is contained in:
Mark Towers
2024-09-20 13:44:07 +01:00
committed by GitHub
parent 973f924d60
commit a6976e42d4
6 changed files with 95 additions and 32 deletions

View File

@@ -9,6 +9,7 @@
from __future__ import annotations
import typing
from copy import deepcopy
from functools import singledispatch
from typing import Any, Iterable, Iterator
@@ -44,17 +45,17 @@ __all__ = [
@singledispatch
def batch_space(space: Space[Any], n: int = 1) -> Space[Any]:
"""Create a (batched) space, containing multiple copies of a single space.
"""Batch spaces of size `n` optimized for neural networks.
Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
n: Number of environments in the vectorized environment.
space: Space (e.g. the observation space for a single environment in the vectorized environment).
n: Number of spaces to batch by (e.g. the number of environments in a vectorized environment).
Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Batched space of size `n`.
Raises:
ValueError: Cannot batch space does not have a registered function.
ValueError: Cannot batch spaces that does not have a registered function.
Example:
@@ -147,8 +148,21 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
@singledispatch
def batch_differing_spaces(spaces: list[Space]):
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
def batch_differing_spaces(spaces: typing.Sequence[Space]) -> Space:
"""Batch a Sequence of spaces where subspaces to contain minor differences.
Args:
spaces: A sequence of Spaces with minor differences (the same space type but different parameters).
Returns:
A batched space
Example:
>>> from gymnasium.spaces import Discrete
>>> spaces = [Discrete(3), Discrete(5), Discrete(4), Discrete(8)]
>>> batch_differing_spaces(spaces)
MultiDiscrete([3 5 4 8])
"""
assert len(spaces) > 0, "Expects a non-empty list of spaces"
assert all(
isinstance(space, type(spaces[0])) for space in spaces
@@ -257,19 +271,12 @@ def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
@singledispatch
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
def iterate(space: Space[T_cov], items: T_cov) -> Iterator:
"""Iterate over the elements of a (batched) space.
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not an instance of :class:`gymnasium.Space`
space: (batched) space (e.g. `action_space` or `observation_space` from vectorized environment).
items: Batched samples to be iterated over (e.g. sample from the space).
Example:
>>> from gymnasium.spaces import Box, Dict
@@ -353,15 +360,15 @@ def concatenate(
"""Concatenate multiple samples from space into a single object.
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
space: Space of each item (e.g. `single_action_space` from vectorized environment)
items: Samples to be concatenated (e.g. all sample should be an element of the `space`).
out: The output object (e.g. generated from `create_empty_array`)
Returns:
The output object. This object is a (possibly nested) numpy array.
The output object, can be the same object `out`.
Raises:
ValueError: Space
ValueError: Space is not a valid :class:`gymnasium.Space` instance
Example:
>>> from gymnasium.spaces import Box
@@ -423,7 +430,7 @@ def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any,
def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
"""Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
In most cases, the array will be contained within the batched space, however, this is not guaranteed.