mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 01:50:19 +00:00
Co-authored-by: StringTheory <mark.m.towers@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from typing import Any, TypeVar, Union, cast
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
@@ -236,8 +237,21 @@ def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
|
||||
|
||||
|
||||
@flatten.register(Sequence)
|
||||
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
return tuple(flatten(space.feature_space, item) for item in x)
|
||||
def _flatten_sequence(
|
||||
space: Sequence, x: tuple[Any, ...] | Any
|
||||
) -> tuple[Any, ...] | Any:
|
||||
if space.stack:
|
||||
samples_iters = gym.vector.utils.iterate(space.batched_feature_space, x)
|
||||
flattened_samples = [
|
||||
flatten(space.feature_space, sample) for sample in samples_iters
|
||||
]
|
||||
flattened_space = flatten_space(space.feature_space)
|
||||
out = gym.vector.utils.create_empty_array(
|
||||
flattened_space, n=len(flattened_samples)
|
||||
)
|
||||
return gym.vector.utils.concatenate(flattened_space, flattened_samples, out)
|
||||
else:
|
||||
return tuple(flatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -363,8 +377,21 @@ def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
|
||||
|
||||
|
||||
@unflatten.register(Sequence)
|
||||
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] | Any:
|
||||
if space.stack:
|
||||
flattened_space = flatten_space(space.feature_space)
|
||||
flatten_iters = gym.vector.utils.iterate(flattened_space, x)
|
||||
unflattened_samples = [
|
||||
unflatten(space.feature_space, sample) for sample in flatten_iters
|
||||
]
|
||||
out = gym.vector.utils.create_empty_array(
|
||||
space.feature_space, len(unflattened_samples)
|
||||
)
|
||||
return gym.vector.utils.concatenate(
|
||||
space.feature_space, unflattened_samples, out
|
||||
)
|
||||
else:
|
||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -493,4 +520,4 @@ def _flatten_space_text(space: Text) -> Box:
|
||||
|
||||
@flatten_space.register(Sequence)
|
||||
def _flatten_space_sequence(space: Sequence) -> Sequence:
|
||||
return Sequence(flatten_space(space.feature_space))
|
||||
return Sequence(flatten_space(space.feature_space), stack=space.stack)
|
||||
|
Reference in New Issue
Block a user