mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 21:06:59 +00:00
Co-authored-by: StringTheory <mark.m.towers@gmail.com>
This commit is contained in:
@@ -238,9 +238,9 @@ class Box(Space[NDArray[Any]]):
|
||||
and np.all(x <= self.high)
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[NDArray[Any]]:
|
||||
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return np.array(sample_n).tolist()
|
||||
return [sample.tolist() for sample in sample_n]
|
||||
|
||||
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
|
@@ -153,7 +153,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
self, sample_n: list[Sequence[int]]
|
||||
) -> list[npt.NDArray[np.integer[Any]]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return np.array(sample_n)
|
||||
return [np.array(sample) for sample in sample_n]
|
||||
|
||||
def __repr__(self):
|
||||
"""Gives a string representation of this space."""
|
||||
|
@@ -1,17 +1,17 @@
|
||||
"""Implementation of a space that represents finite-length sequences."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
|
||||
class Sequence(Space[typing.Tuple[Any, ...]]):
|
||||
class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
r"""This space represent sets of finite-length sequences.
|
||||
|
||||
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
|
||||
@@ -31,17 +31,24 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
|
||||
self,
|
||||
space: Space[Any],
|
||||
seed: int | np.random.Generator | None = None,
|
||||
stack: bool = False,
|
||||
):
|
||||
"""Constructor of the :class:`Sequence` space.
|
||||
|
||||
Args:
|
||||
space: Elements in the sequences this space represent must belong to this space.
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||
stack: If `True` then the resulting samples would be stacked.
|
||||
"""
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
|
||||
self.feature_space = space
|
||||
self.stack = stack
|
||||
if self.stack:
|
||||
self.batched_feature_space: Space = gym.vector.utils.batch_space(
|
||||
self.feature_space, 1
|
||||
)
|
||||
|
||||
# None for shape and dtype, since it'll require special handling
|
||||
super().__init__(None, None, seed)
|
||||
@@ -114,32 +121,59 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
|
||||
# The choice of 0.25 is arbitrary
|
||||
length = self.np_random.geometric(0.25)
|
||||
|
||||
return tuple(
|
||||
# Generate sample values from feature_space.
|
||||
sampled_values = tuple(
|
||||
self.feature_space.sample(mask=feature_mask) for _ in range(length)
|
||||
)
|
||||
|
||||
if self.stack:
|
||||
# Concatenate values if stacked.
|
||||
out = gym.vector.utils.create_empty_array(
|
||||
self.feature_space, len(sampled_values)
|
||||
)
|
||||
return gym.vector.utils.concatenate(self.feature_space, sampled_values, out)
|
||||
|
||||
return sampled_values
|
||||
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
# by definition, any sequence is an iterable
|
||||
return isinstance(x, collections.abc.Iterable) and all(
|
||||
if self.stack:
|
||||
return all(
|
||||
item in self.feature_space
|
||||
for item in gym.vector.utils.iterate(self.batched_feature_space, x)
|
||||
)
|
||||
else:
|
||||
return isinstance(x, tuple) and all(
|
||||
self.feature_space.contains(item) for item in x
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
return f"Sequence({self.feature_space})"
|
||||
return f"Sequence({self.feature_space}, stack={self.stack})"
|
||||
|
||||
def to_jsonable(
|
||||
self, sample_n: typing.Sequence[tuple[Any, ...]]
|
||||
self, sample_n: typing.Sequence[tuple[Any, ...] | Any]
|
||||
) -> list[list[Any]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# serialize as dict-repr of vectors
|
||||
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
|
||||
if self.stack:
|
||||
return self.batched_feature_space.to_jsonable(sample_n)
|
||||
else:
|
||||
return [self.feature_space.to_jsonable(sample) for sample in sample_n]
|
||||
|
||||
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
||||
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...] | Any]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
|
||||
if self.stack:
|
||||
return self.batched_feature_space.from_jsonable(sample_n)
|
||||
else:
|
||||
return [
|
||||
tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n
|
||||
]
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return isinstance(other, Sequence) and self.feature_space == other.feature_space
|
||||
return (
|
||||
isinstance(other, Sequence)
|
||||
and self.feature_space == other.feature_space
|
||||
and self.stack == other.stack
|
||||
)
|
||||
|
@@ -14,6 +14,7 @@ from typing import Any, TypeVar, Union, cast
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
@@ -236,7 +237,20 @@ def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
|
||||
|
||||
|
||||
@flatten.register(Sequence)
|
||||
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
def _flatten_sequence(
|
||||
space: Sequence, x: tuple[Any, ...] | Any
|
||||
) -> tuple[Any, ...] | Any:
|
||||
if space.stack:
|
||||
samples_iters = gym.vector.utils.iterate(space.batched_feature_space, x)
|
||||
flattened_samples = [
|
||||
flatten(space.feature_space, sample) for sample in samples_iters
|
||||
]
|
||||
flattened_space = flatten_space(space.feature_space)
|
||||
out = gym.vector.utils.create_empty_array(
|
||||
flattened_space, n=len(flattened_samples)
|
||||
)
|
||||
return gym.vector.utils.concatenate(flattened_space, flattened_samples, out)
|
||||
else:
|
||||
return tuple(flatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@@ -363,7 +377,20 @@ def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
|
||||
|
||||
|
||||
@unflatten.register(Sequence)
|
||||
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] | Any:
|
||||
if space.stack:
|
||||
flattened_space = flatten_space(space.feature_space)
|
||||
flatten_iters = gym.vector.utils.iterate(flattened_space, x)
|
||||
unflattened_samples = [
|
||||
unflatten(space.feature_space, sample) for sample in flatten_iters
|
||||
]
|
||||
out = gym.vector.utils.create_empty_array(
|
||||
space.feature_space, len(unflattened_samples)
|
||||
)
|
||||
return gym.vector.utils.concatenate(
|
||||
space.feature_space, unflattened_samples, out
|
||||
)
|
||||
else:
|
||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@@ -493,4 +520,4 @@ def _flatten_space_text(space: Text) -> Box:
|
||||
|
||||
@flatten_space.register(Sequence)
|
||||
def _flatten_space_sequence(space: Sequence) -> Sequence:
|
||||
return Sequence(flatten_space(space.feature_space))
|
||||
return Sequence(flatten_space(space.feature_space), stack=space.stack)
|
||||
|
@@ -3,12 +3,20 @@ from typing import Callable, Iterable, List, Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import Env
|
||||
from gymnasium.vector import utils
|
||||
from gymnasium.vector.async_vector_env import AsyncVectorEnv
|
||||
from gymnasium.vector.sync_vector_env import SyncVectorEnv
|
||||
from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper
|
||||
|
||||
|
||||
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
|
||||
__all__ = [
|
||||
"AsyncVectorEnv",
|
||||
"SyncVectorEnv",
|
||||
"VectorEnv",
|
||||
"VectorEnvWrapper",
|
||||
"make",
|
||||
"utils",
|
||||
]
|
||||
|
||||
|
||||
def make(
|
||||
|
@@ -6,13 +6,19 @@ import pytest
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
def test_stacked_box():
|
||||
"""Tests that sequence with a feature space of Box allows stacked np arrays."""
|
||||
space = gym.spaces.Sequence(gym.spaces.Box(0, 1, shape=(3,)))
|
||||
sample = np.float32(np.random.rand(5, 3))
|
||||
assert space.contains(
|
||||
sample
|
||||
), "Something went wrong, should be able to accept stacked np arrays for Box feature space."
|
||||
def test_stacked_sequence():
|
||||
"""Tests that a stacked sequence with a feature space of Box returns stacked values."""
|
||||
# Box
|
||||
space = gym.spaces.Sequence(gym.spaces.Box(0, 1, shape=(3,)), stack=True)
|
||||
sample = space.sample()
|
||||
# Check if the sample is in 2d format
|
||||
assert len(sample.shape) == 2
|
||||
|
||||
# Discrete
|
||||
space = gym.spaces.Sequence(gym.spaces.Discrete(n=3), stack=True)
|
||||
sample = space.sample()
|
||||
# Check if the sample is a `np.ndarray` as supposed to a tuple
|
||||
assert type(sample) is np.ndarray
|
||||
|
||||
|
||||
def test_sample():
|
||||
|
@@ -5,7 +5,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces import Box, Graph, utils
|
||||
from gymnasium.spaces import Box, Graph, Sequence, utils
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
||||
|
||||
@@ -110,7 +110,9 @@ def test_flatten(space):
|
||||
|
||||
assert single_dim == flatdim
|
||||
else:
|
||||
assert isinstance(flattened_sample, (tuple, dict, Graph))
|
||||
assert isinstance(space, Sequence) or isinstance(
|
||||
flattened_sample, (tuple, dict, Graph)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
|
@@ -97,6 +97,8 @@ TESTING_COMPOSITE_SPACES = [
|
||||
Sequence(Discrete(4)),
|
||||
Sequence(Dict({"feature": Box(0, 1, (3,))})),
|
||||
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
|
||||
Sequence(Box(low=0.0, high=1.0), stack=True),
|
||||
Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True),
|
||||
]
|
||||
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]
|
||||
|
||||
|
Reference in New Issue
Block a user