Add Stack to Sequence space - #217 (#284)

Co-authored-by: StringTheory <mark.m.towers@gmail.com>
This commit is contained in:
Kevlyn Kadamala
2023-01-30 18:27:32 +05:30
committed by GitHub
parent 33b9884ab6
commit 8031da71e6
8 changed files with 111 additions and 32 deletions

View File

@@ -238,9 +238,9 @@ class Box(Space[NDArray[Any]]):
and np.all(x <= self.high) and np.all(x <= self.high)
) )
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[NDArray[Any]]: def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist() return [sample.tolist() for sample in sample_n]
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]: def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""

View File

@@ -153,7 +153,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
self, sample_n: list[Sequence[int]] self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.integer[Any]]]: ) -> list[npt.NDArray[np.integer[Any]]]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
return np.array(sample_n) return [np.array(sample) for sample in sample_n]
def __repr__(self): def __repr__(self):
"""Gives a string representation of this space.""" """Gives a string representation of this space."""

View File

@@ -1,17 +1,17 @@
"""Implementation of a space that represents finite-length sequences.""" """Implementation of a space that represents finite-length sequences."""
from __future__ import annotations from __future__ import annotations
import collections.abc
import typing import typing
from typing import Any from typing import Any, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import gymnasium as gym
from gymnasium.spaces.space import Space from gymnasium.spaces.space import Space
class Sequence(Space[typing.Tuple[Any, ...]]): class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
r"""This space represent sets of finite-length sequences. r"""This space represent sets of finite-length sequences.
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
@@ -31,17 +31,24 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
self, self,
space: Space[Any], space: Space[Any],
seed: int | np.random.Generator | None = None, seed: int | np.random.Generator | None = None,
stack: bool = False,
): ):
"""Constructor of the :class:`Sequence` space. """Constructor of the :class:`Sequence` space.
Args: Args:
space: Elements in the sequences this space represent must belong to this space. space: Elements in the sequences this space represent must belong to this space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
stack: If `True` then the resulting samples would be stacked.
""" """
assert isinstance( assert isinstance(
space, Space space, Space
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}" ), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
self.feature_space = space self.feature_space = space
self.stack = stack
if self.stack:
self.batched_feature_space: Space = gym.vector.utils.batch_space(
self.feature_space, 1
)
# None for shape and dtype, since it'll require special handling # None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed) super().__init__(None, None, seed)
@@ -114,32 +121,59 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
# The choice of 0.25 is arbitrary # The choice of 0.25 is arbitrary
length = self.np_random.geometric(0.25) length = self.np_random.geometric(0.25)
return tuple( # Generate sample values from feature_space.
sampled_values = tuple(
self.feature_space.sample(mask=feature_mask) for _ in range(length) self.feature_space.sample(mask=feature_mask) for _ in range(length)
) )
if self.stack:
# Concatenate values if stacked.
out = gym.vector.utils.create_empty_array(
self.feature_space, len(sampled_values)
)
return gym.vector.utils.concatenate(self.feature_space, sampled_values, out)
return sampled_values
def contains(self, x: Any) -> bool: def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
# by definition, any sequence is an iterable # by definition, any sequence is an iterable
return isinstance(x, collections.abc.Iterable) and all( if self.stack:
self.feature_space.contains(item) for item in x return all(
) item in self.feature_space
for item in gym.vector.utils.iterate(self.batched_feature_space, x)
)
else:
return isinstance(x, tuple) and all(
self.feature_space.contains(item) for item in x
)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Gives a string representation of this space.""" """Gives a string representation of this space."""
return f"Sequence({self.feature_space})" return f"Sequence({self.feature_space}, stack={self.stack})"
def to_jsonable( def to_jsonable(
self, sample_n: typing.Sequence[tuple[Any, ...]] self, sample_n: typing.Sequence[tuple[Any, ...] | Any]
) -> list[list[Any]]: ) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
# serialize as dict-repr of vectors if self.stack:
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n] return self.batched_feature_space.to_jsonable(sample_n)
else:
return [self.feature_space.to_jsonable(sample) for sample in sample_n]
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]: def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...] | Any]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n] if self.stack:
return self.batched_feature_space.from_jsonable(sample_n)
else:
return [
tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n
]
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance.""" """Check whether ``other`` is equivalent to this instance."""
return isinstance(other, Sequence) and self.feature_space == other.feature_space return (
isinstance(other, Sequence)
and self.feature_space == other.feature_space
and self.stack == other.stack
)

View File

@@ -14,6 +14,7 @@ from typing import Any, TypeVar, Union, cast
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
import gymnasium as gym
from gymnasium.spaces import ( from gymnasium.spaces import (
Box, Box,
Dict, Dict,
@@ -236,8 +237,21 @@ def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
@flatten.register(Sequence) @flatten.register(Sequence)
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]: def _flatten_sequence(
return tuple(flatten(space.feature_space, item) for item in x) space: Sequence, x: tuple[Any, ...] | Any
) -> tuple[Any, ...] | Any:
if space.stack:
samples_iters = gym.vector.utils.iterate(space.batched_feature_space, x)
flattened_samples = [
flatten(space.feature_space, sample) for sample in samples_iters
]
flattened_space = flatten_space(space.feature_space)
out = gym.vector.utils.create_empty_array(
flattened_space, n=len(flattened_samples)
)
return gym.vector.utils.concatenate(flattened_space, flattened_samples, out)
else:
return tuple(flatten(space.feature_space, item) for item in x)
@singledispatch @singledispatch
@@ -363,8 +377,21 @@ def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
@unflatten.register(Sequence) @unflatten.register(Sequence)
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]: def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] | Any:
return tuple(unflatten(space.feature_space, item) for item in x) if space.stack:
flattened_space = flatten_space(space.feature_space)
flatten_iters = gym.vector.utils.iterate(flattened_space, x)
unflattened_samples = [
unflatten(space.feature_space, sample) for sample in flatten_iters
]
out = gym.vector.utils.create_empty_array(
space.feature_space, len(unflattened_samples)
)
return gym.vector.utils.concatenate(
space.feature_space, unflattened_samples, out
)
else:
return tuple(unflatten(space.feature_space, item) for item in x)
@singledispatch @singledispatch
@@ -493,4 +520,4 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence) @flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence: def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space)) return Sequence(flatten_space(space.feature_space), stack=space.stack)

View File

@@ -3,12 +3,20 @@ from typing import Callable, Iterable, List, Optional, Union
import gymnasium as gym import gymnasium as gym
from gymnasium.core import Env from gymnasium.core import Env
from gymnasium.vector import utils
from gymnasium.vector.async_vector_env import AsyncVectorEnv from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] __all__ = [
"AsyncVectorEnv",
"SyncVectorEnv",
"VectorEnv",
"VectorEnvWrapper",
"make",
"utils",
]
def make( def make(

View File

@@ -6,13 +6,19 @@ import pytest
import gymnasium as gym import gymnasium as gym
def test_stacked_box(): def test_stacked_sequence():
"""Tests that sequence with a feature space of Box allows stacked np arrays.""" """Tests that a stacked sequence with a feature space of Box returns stacked values."""
space = gym.spaces.Sequence(gym.spaces.Box(0, 1, shape=(3,))) # Box
sample = np.float32(np.random.rand(5, 3)) space = gym.spaces.Sequence(gym.spaces.Box(0, 1, shape=(3,)), stack=True)
assert space.contains( sample = space.sample()
sample # Check if the sample is in 2d format
), "Something went wrong, should be able to accept stacked np arrays for Box feature space." assert len(sample.shape) == 2
# Discrete
space = gym.spaces.Sequence(gym.spaces.Discrete(n=3), stack=True)
sample = space.sample()
# Check if the sample is a `np.ndarray` as supposed to a tuple
assert type(sample) is np.ndarray
def test_sample(): def test_sample():

View File

@@ -5,7 +5,7 @@ import numpy as np
import pytest import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium.spaces import Box, Graph, utils from gymnasium.spaces import Box, Graph, Sequence, utils
from gymnasium.utils.env_checker import data_equivalence from gymnasium.utils.env_checker import data_equivalence
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
@@ -110,7 +110,9 @@ def test_flatten(space):
assert single_dim == flatdim assert single_dim == flatdim
else: else:
assert isinstance(flattened_sample, (tuple, dict, Graph)) assert isinstance(space, Sequence) or isinstance(
flattened_sample, (tuple, dict, Graph)
)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)

View File

@@ -97,6 +97,8 @@ TESTING_COMPOSITE_SPACES = [
Sequence(Discrete(4)), Sequence(Discrete(4)),
Sequence(Dict({"feature": Box(0, 1, (3,))})), Sequence(Dict({"feature": Box(0, 1, (3,))})),
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
Sequence(Box(low=0.0, high=1.0), stack=True),
Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True),
] ]
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES] TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]