mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-31 18:12:53 +00:00
Fix VectorizeActionTransform
for changing spaces (#1170)
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any, Iterable, Iterator
|
||||
@@ -44,17 +45,17 @@ __all__ = [
|
||||
|
||||
@singledispatch
|
||||
def batch_space(space: Space[Any], n: int = 1) -> Space[Any]:
|
||||
"""Create a (batched) space, containing multiple copies of a single space.
|
||||
"""Batch spaces of size `n` optimized for neural networks.
|
||||
|
||||
Args:
|
||||
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment.
|
||||
space: Space (e.g. the observation space for a single environment in the vectorized environment).
|
||||
n: Number of spaces to batch by (e.g. the number of environments in a vectorized environment).
|
||||
|
||||
Returns:
|
||||
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
|
||||
Batched space of size `n`.
|
||||
|
||||
Raises:
|
||||
ValueError: Cannot batch space does not have a registered function.
|
||||
ValueError: Cannot batch spaces that does not have a registered function.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -147,8 +148,21 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def batch_differing_spaces(spaces: list[Space]):
|
||||
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
|
||||
def batch_differing_spaces(spaces: typing.Sequence[Space]) -> Space:
|
||||
"""Batch a Sequence of spaces where subspaces to contain minor differences.
|
||||
|
||||
Args:
|
||||
spaces: A sequence of Spaces with minor differences (the same space type but different parameters).
|
||||
|
||||
Returns:
|
||||
A batched space
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Discrete
|
||||
>>> spaces = [Discrete(3), Discrete(5), Discrete(4), Discrete(8)]
|
||||
>>> batch_differing_spaces(spaces)
|
||||
MultiDiscrete([3 5 4 8])
|
||||
"""
|
||||
assert len(spaces) > 0, "Expects a non-empty list of spaces"
|
||||
assert all(
|
||||
isinstance(space, type(spaces[0])) for space in spaces
|
||||
@@ -257,19 +271,12 @@ def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
|
||||
|
||||
|
||||
@singledispatch
|
||||
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
|
||||
def iterate(space: Space[T_cov], items: T_cov) -> Iterator:
|
||||
"""Iterate over the elements of a (batched) space.
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
items: Samples to be concatenated.
|
||||
out: The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Returns:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not an instance of :class:`gymnasium.Space`
|
||||
space: (batched) space (e.g. `action_space` or `observation_space` from vectorized environment).
|
||||
items: Batched samples to be iterated over (e.g. sample from the space).
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
@@ -353,15 +360,15 @@ def concatenate(
|
||||
"""Concatenate multiple samples from space into a single object.
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
items: Samples to be concatenated.
|
||||
out: The output object. This object is a (possibly nested) numpy array.
|
||||
space: Space of each item (e.g. `single_action_space` from vectorized environment)
|
||||
items: Samples to be concatenated (e.g. all sample should be an element of the `space`).
|
||||
out: The output object (e.g. generated from `create_empty_array`)
|
||||
|
||||
Returns:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
The output object, can be the same object `out`.
|
||||
|
||||
Raises:
|
||||
ValueError: Space
|
||||
ValueError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box
|
||||
@@ -423,7 +430,7 @@ def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any,
|
||||
def create_empty_array(
|
||||
space: Space, n: int = 1, fn: callable = np.zeros
|
||||
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
|
||||
"""Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
|
||||
"""Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
|
||||
|
||||
In most cases, the array will be contained within the batched space, however, this is not guaranteed.
|
||||
|
||||
|
Reference in New Issue
Block a user