Add Stack to Sequence space - #217 (#284)

Co-authored-by: StringTheory <mark.m.towers@gmail.com>
This commit is contained in:
Kevlyn Kadamala
2023-01-30 18:27:32 +05:30
committed by GitHub
parent 33b9884ab6
commit 8031da71e6
8 changed files with 111 additions and 32 deletions

View File

@@ -14,6 +14,7 @@ from typing import Any, TypeVar, Union, cast
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium.spaces import (
Box,
Dict,
@@ -236,8 +237,21 @@ def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
@flatten.register(Sequence)
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(flatten(space.feature_space, item) for item in x)
def _flatten_sequence(
space: Sequence, x: tuple[Any, ...] | Any
) -> tuple[Any, ...] | Any:
if space.stack:
samples_iters = gym.vector.utils.iterate(space.batched_feature_space, x)
flattened_samples = [
flatten(space.feature_space, sample) for sample in samples_iters
]
flattened_space = flatten_space(space.feature_space)
out = gym.vector.utils.create_empty_array(
flattened_space, n=len(flattened_samples)
)
return gym.vector.utils.concatenate(flattened_space, flattened_samples, out)
else:
return tuple(flatten(space.feature_space, item) for item in x)
@singledispatch
@@ -363,8 +377,21 @@ def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
@unflatten.register(Sequence)
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(unflatten(space.feature_space, item) for item in x)
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] | Any:
if space.stack:
flattened_space = flatten_space(space.feature_space)
flatten_iters = gym.vector.utils.iterate(flattened_space, x)
unflattened_samples = [
unflatten(space.feature_space, sample) for sample in flatten_iters
]
out = gym.vector.utils.create_empty_array(
space.feature_space, len(unflattened_samples)
)
return gym.vector.utils.concatenate(
space.feature_space, unflattened_samples, out
)
else:
return tuple(unflatten(space.feature_space, item) for item in x)
@singledispatch
@@ -493,4 +520,4 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space))
return Sequence(flatten_space(space.feature_space), stack=space.stack)