Add Stack to Sequence space - #217 (#284)

Co-authored-by: StringTheory <mark.m.towers@gmail.com>
This commit is contained in:
Kevlyn Kadamala
2023-01-30 18:27:32 +05:30
committed by GitHub
parent 33b9884ab6
commit 8031da71e6
8 changed files with 111 additions and 32 deletions

View File

@@ -238,9 +238,9 @@ class Box(Space[NDArray[Any]]):
and np.all(x <= self.high)
)
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[NDArray[Any]]:
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()
return [sample.tolist() for sample in sample_n]
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
"""Convert a JSONable data type to a batch of samples from this space."""

View File

@@ -153,7 +153,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.integer[Any]]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return np.array(sample_n)
return [np.array(sample) for sample in sample_n]
def __repr__(self):
"""Gives a string representation of this space."""

View File

@@ -1,17 +1,17 @@
"""Implementation of a space that represents finite-length sequences."""
from __future__ import annotations
import collections.abc
import typing
from typing import Any
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import gymnasium as gym
from gymnasium.spaces.space import Space
class Sequence(Space[typing.Tuple[Any, ...]]):
class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
r"""This space represent sets of finite-length sequences.
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
@@ -31,17 +31,24 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
self,
space: Space[Any],
seed: int | np.random.Generator | None = None,
stack: bool = False,
):
"""Constructor of the :class:`Sequence` space.
Args:
space: Elements in the sequences this space represent must belong to this space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
stack: If `True` then the resulting samples would be stacked.
"""
assert isinstance(
space, Space
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
self.feature_space = space
self.stack = stack
if self.stack:
self.batched_feature_space: Space = gym.vector.utils.batch_space(
self.feature_space, 1
)
# None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed)
@@ -114,32 +121,59 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
# The choice of 0.25 is arbitrary
length = self.np_random.geometric(0.25)
return tuple(
# Generate sample values from feature_space.
sampled_values = tuple(
self.feature_space.sample(mask=feature_mask) for _ in range(length)
)
if self.stack:
# Concatenate values if stacked.
out = gym.vector.utils.create_empty_array(
self.feature_space, len(sampled_values)
)
return gym.vector.utils.concatenate(self.feature_space, sampled_values, out)
return sampled_values
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
# by definition, any sequence is an iterable
return isinstance(x, collections.abc.Iterable) and all(
if self.stack:
return all(
item in self.feature_space
for item in gym.vector.utils.iterate(self.batched_feature_space, x)
)
else:
return isinstance(x, tuple) and all(
self.feature_space.contains(item) for item in x
)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return f"Sequence({self.feature_space})"
return f"Sequence({self.feature_space}, stack={self.stack})"
def to_jsonable(
self, sample_n: typing.Sequence[tuple[Any, ...]]
self, sample_n: typing.Sequence[tuple[Any, ...] | Any]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as dict-repr of vectors
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
if self.stack:
return self.batched_feature_space.to_jsonable(sample_n)
else:
return [self.feature_space.to_jsonable(sample) for sample in sample_n]
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...] | Any]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
if self.stack:
return self.batched_feature_space.from_jsonable(sample_n)
else:
return [
tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n
]
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, Sequence) and self.feature_space == other.feature_space
return (
isinstance(other, Sequence)
and self.feature_space == other.feature_space
and self.stack == other.stack
)

View File

@@ -14,6 +14,7 @@ from typing import Any, TypeVar, Union, cast
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium.spaces import (
Box,
Dict,
@@ -236,7 +237,20 @@ def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
@flatten.register(Sequence)
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
def _flatten_sequence(
space: Sequence, x: tuple[Any, ...] | Any
) -> tuple[Any, ...] | Any:
if space.stack:
samples_iters = gym.vector.utils.iterate(space.batched_feature_space, x)
flattened_samples = [
flatten(space.feature_space, sample) for sample in samples_iters
]
flattened_space = flatten_space(space.feature_space)
out = gym.vector.utils.create_empty_array(
flattened_space, n=len(flattened_samples)
)
return gym.vector.utils.concatenate(flattened_space, flattened_samples, out)
else:
return tuple(flatten(space.feature_space, item) for item in x)
@@ -363,7 +377,20 @@ def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
@unflatten.register(Sequence)
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] | Any:
if space.stack:
flattened_space = flatten_space(space.feature_space)
flatten_iters = gym.vector.utils.iterate(flattened_space, x)
unflattened_samples = [
unflatten(space.feature_space, sample) for sample in flatten_iters
]
out = gym.vector.utils.create_empty_array(
space.feature_space, len(unflattened_samples)
)
return gym.vector.utils.concatenate(
space.feature_space, unflattened_samples, out
)
else:
return tuple(unflatten(space.feature_space, item) for item in x)
@@ -493,4 +520,4 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space))
return Sequence(flatten_space(space.feature_space), stack=space.stack)

View File

@@ -3,12 +3,20 @@ from typing import Callable, Iterable, List, Optional, Union
import gymnasium as gym
from gymnasium.core import Env
from gymnasium.vector import utils
from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
__all__ = [
"AsyncVectorEnv",
"SyncVectorEnv",
"VectorEnv",
"VectorEnvWrapper",
"make",
"utils",
]
def make(

View File

@@ -6,13 +6,19 @@ import pytest
import gymnasium as gym
def test_stacked_box():
"""Tests that sequence with a feature space of Box allows stacked np arrays."""
space = gym.spaces.Sequence(gym.spaces.Box(0, 1, shape=(3,)))
sample = np.float32(np.random.rand(5, 3))
assert space.contains(
sample
), "Something went wrong, should be able to accept stacked np arrays for Box feature space."
def test_stacked_sequence():
"""Tests that a stacked sequence with a feature space of Box returns stacked values."""
# Box
space = gym.spaces.Sequence(gym.spaces.Box(0, 1, shape=(3,)), stack=True)
sample = space.sample()
# Check if the sample is in 2d format
assert len(sample.shape) == 2
# Discrete
space = gym.spaces.Sequence(gym.spaces.Discrete(n=3), stack=True)
sample = space.sample()
# Check if the sample is a `np.ndarray` as supposed to a tuple
assert type(sample) is np.ndarray
def test_sample():

View File

@@ -5,7 +5,7 @@ import numpy as np
import pytest
import gymnasium as gym
from gymnasium.spaces import Box, Graph, utils
from gymnasium.spaces import Box, Graph, Sequence, utils
from gymnasium.utils.env_checker import data_equivalence
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
@@ -110,7 +110,9 @@ def test_flatten(space):
assert single_dim == flatdim
else:
assert isinstance(flattened_sample, (tuple, dict, Graph))
assert isinstance(space, Sequence) or isinstance(
flattened_sample, (tuple, dict, Graph)
)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)

View File

@@ -97,6 +97,8 @@ TESTING_COMPOSITE_SPACES = [
Sequence(Discrete(4)),
Sequence(Dict({"feature": Box(0, 1, (3,))})),
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
Sequence(Box(low=0.0, high=1.0), stack=True),
Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True),
]
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]