mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-06 11:52:18 +00:00
Add Sequence
space, update flatten
functions (#2968)
* Added Sequence space, updated flatten functions to work with Sequence, Graph. WIP. * Small fixes, added Sequence space to tests * Replace Optional[Any] by Any * Added tests for flattening of non-numpy-flattenable spaces * Return all seeds
This commit is contained in:
@@ -14,6 +14,7 @@ from gym.spaces.discrete import Discrete
|
||||
from gym.spaces.graph import Graph, GraphInstance
|
||||
from gym.spaces.multi_binary import MultiBinary
|
||||
from gym.spaces.multi_discrete import MultiDiscrete
|
||||
from gym.spaces.sequence import Sequence
|
||||
from gym.spaces.space import Space
|
||||
from gym.spaces.text import Text
|
||||
from gym.spaces.tuple import Tuple
|
||||
@@ -29,6 +30,7 @@ __all__ = [
|
||||
"MultiDiscrete",
|
||||
"MultiBinary",
|
||||
"Tuple",
|
||||
"Sequence",
|
||||
"Dict",
|
||||
"flatdim",
|
||||
"flatten_space",
|
||||
|
@@ -139,6 +139,11 @@ class Box(Space[np.ndarray]):
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def is_bounded(self, manner: str = "both") -> bool:
|
||||
"""Checks whether the box is bounded in some sense.
|
||||
|
||||
|
@@ -100,6 +100,11 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
None, None, seed # type: ignore
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces.values())
|
||||
|
||||
def seed(self, seed: Optional[Union[dict, int]] = None) -> list:
|
||||
"""Seed the PRNG of this space and all subspaces."""
|
||||
seeds = []
|
||||
|
@@ -40,6 +40,11 @@ class Discrete(Space[int]):
|
||||
self.start = int(start)
|
||||
super().__init__((), np.int64, seed)
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: Optional[np.ndarray] = None) -> int:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
|
@@ -18,7 +18,7 @@ class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"]
|
||||
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes.
|
||||
(...) must adhere to the shape of the node space.
|
||||
|
||||
edges (np.ndarray): an (m x ...) sized array representing the features for m nodes.
|
||||
edges (np.ndarray): an (m x ...) sized array representing the features for m edges.
|
||||
(...) must adhere to the shape of the edge space.
|
||||
|
||||
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
|
||||
@@ -68,6 +68,11 @@ class Graph(Space):
|
||||
|
||||
super().__init__(None, None, seed)
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return False
|
||||
|
||||
def _generate_sample_space(
|
||||
self, base_space: Union[None, Box, Discrete], num: int
|
||||
) -> Optional[Union[Box, MultiDiscrete]]:
|
||||
|
@@ -51,6 +51,11 @@ class MultiBinary(Space[np.ndarray]):
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
return self._shape # type: ignore
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
|
@@ -64,6 +64,11 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
"""Has stricter type than :class:`gym.Space` - never None."""
|
||||
return self._shape # type: ignore
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: Optional[SAMPLE_MASK_TYPE] = None) -> np.ndarray:
|
||||
"""Generates a single random sample this space.
|
||||
|
||||
|
103
gym/spaces/sequence.py
Normal file
103
gym/spaces/sequence.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Implementation of a space that represents finite-length sequences."""
|
||||
from collections.abc import Sequence as CollectionSequence
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces.space import Space
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
class Sequence(Space[Tuple]):
|
||||
r"""This space represent sets of finite-length sequences.
|
||||
|
||||
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
|
||||
to some space that is specified during initialization and the integer :math:`n` is not fixed
|
||||
|
||||
Example::
|
||||
>>> space = Sequence(Box(0, 1))
|
||||
>>> space.sample()
|
||||
(array([0.0259352], dtype=float32),)
|
||||
>>> space.sample()
|
||||
(array([0.80977976], dtype=float32), array([0.80066574], dtype=float32), array([0.77165383], dtype=float32))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space: Space,
|
||||
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None,
|
||||
):
|
||||
"""Constructor of the :class:`Sequence` space.
|
||||
|
||||
Args:
|
||||
space: Elements in the sequences this space represent must belong to this space.
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||
"""
|
||||
self.feature_space = space
|
||||
super().__init__(
|
||||
None, None, seed # type: ignore
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> list:
|
||||
"""Seed the PRNG of this space and the feature space."""
|
||||
seeds = super().seed(seed)
|
||||
seeds += self.feature_space.seed(seed)
|
||||
return seeds
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return False
|
||||
|
||||
def sample(
|
||||
self, mask: Optional[Tuple[Optional[np.ndarray], Any]] = None
|
||||
) -> Tuple[Any]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
Args:
|
||||
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
|
||||
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask`
|
||||
is either `None` if you do not want to specify any restrictions on the length of the sampled sequence (then, the
|
||||
length will be randomly drawn from a geometric distribution), or a `np.ndarray` of integers, in which case the length of
|
||||
the sampled sequence is randomly drawn from this array. The second element of the tuple, `sample` mask
|
||||
specifies a mask that is applied when sampling elements from the base space.
|
||||
|
||||
Returns:
|
||||
A tuple of random length with random samples of elements from the :attr:`feature_space`.
|
||||
"""
|
||||
if mask is not None:
|
||||
length_mask, feature_mask = mask
|
||||
else:
|
||||
length_mask = None
|
||||
feature_mask = None
|
||||
if length_mask is not None:
|
||||
length = self.np_random.choice(length_mask)
|
||||
else:
|
||||
length = self.np_random.geometric(0.25)
|
||||
|
||||
return tuple(
|
||||
self.feature_space.sample(mask=feature_mask) for _ in range(length)
|
||||
)
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
return isinstance(x, CollectionSequence) and all(
|
||||
self.feature_space.contains(item) for item in x
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
return f"Sequence({self.feature_space})"
|
||||
|
||||
def to_jsonable(self, sample_n: list) -> list:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# serialize as dict-repr of vectors
|
||||
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
|
||||
|
||||
def from_jsonable(self, sample_n: List[List[Any]]) -> list:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return isinstance(other, Sequence) and self.feature_space == other.feature_space
|
@@ -82,6 +82,11 @@ class Space(Generic[T_cov]):
|
||||
"""Return the shape of the space as an immutable property."""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, mask: Optional[Any] = None) -> T_cov:
|
||||
"""Randomly sample an element of this space.
|
||||
|
||||
|
@@ -40,6 +40,11 @@ class Tuple(Space[tuple], Sequence):
|
||||
), "Elements of the tuple must be instances of gym.Space"
|
||||
super().__init__(None, None, seed) # type: ignore
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces)
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
|
||||
"""Seed the PRNG of this space and all subspaces."""
|
||||
seeds = []
|
||||
|
@@ -6,6 +6,7 @@ These functions mostly take care of flattening and unflattening elements of spac
|
||||
import operator as op
|
||||
from collections import OrderedDict
|
||||
from functools import reduce, singledispatch
|
||||
from typing import Dict as TypingDict
|
||||
from typing import TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -18,6 +19,7 @@ from gym.spaces import (
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Space,
|
||||
Tuple,
|
||||
)
|
||||
@@ -42,7 +44,13 @@ def flatdim(space: Space) -> int:
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||
ValueError: if the space cannot be flattened into a :class:`Box`
|
||||
"""
|
||||
if not space.is_np_flattenable:
|
||||
raise ValueError(
|
||||
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||
|
||||
|
||||
@@ -64,19 +72,28 @@ def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
|
||||
|
||||
@flatdim.register(Tuple)
|
||||
def _flatdim_tuple(space: Tuple) -> int:
|
||||
if space.is_np_flattenable:
|
||||
return sum(flatdim(s) for s in space.spaces)
|
||||
raise ValueError(
|
||||
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||
)
|
||||
|
||||
|
||||
@flatdim.register(Dict)
|
||||
def _flatdim_dict(space: Dict) -> int:
|
||||
if space.is_np_flattenable:
|
||||
return sum(flatdim(s) for s in space.spaces.values())
|
||||
raise ValueError(
|
||||
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten(space: Space[T], x: T) -> np.ndarray:
|
||||
def flatten(space: Space[T], x: T) -> FlatType:
|
||||
"""Flatten a data point from a space.
|
||||
|
||||
This is useful when e.g. points from spaces must be passed to a neural
|
||||
@@ -127,17 +144,23 @@ def _flatten_multidiscrete(space, x) -> np.ndarray:
|
||||
|
||||
|
||||
@flatten.register(Tuple)
|
||||
def _flatten_tuple(space, x) -> np.ndarray:
|
||||
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
||||
def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
|
||||
if space.is_np_flattenable:
|
||||
return np.concatenate(
|
||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
||||
)
|
||||
return tuple((flatten(s, x_part) for x_part, s in zip(x, space.spaces)))
|
||||
|
||||
|
||||
@flatten.register(Dict)
|
||||
def _flatten_dict(space, x) -> np.ndarray:
|
||||
def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]:
|
||||
if space.is_np_flattenable:
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
||||
|
||||
|
||||
@flatten.register(Graph)
|
||||
def _flatten_graph(space, x) -> np.ndarray:
|
||||
def _flatten_graph(space, x) -> GraphInstance:
|
||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
||||
|
||||
def _graph_unflatten(space, x):
|
||||
@@ -156,8 +179,13 @@ def _flatten_graph(space, x) -> np.ndarray:
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@flatten.register(Sequence)
|
||||
def _flatten_sequence(space, x) -> tuple:
|
||||
return tuple(flatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def unflatten(space: Space[T], x: np.ndarray) -> T:
|
||||
def unflatten(space: Space[T], x: FlatType) -> T:
|
||||
"""Unflatten a data point from a space.
|
||||
|
||||
This reverses the transformation applied by :func:`flatten`. You must ensure
|
||||
@@ -199,16 +227,26 @@ def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
@unflatten.register(Tuple)
|
||||
def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple:
|
||||
def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(
|
||||
x, np.ndarray
|
||||
), f"{space} is numpy-flattenable. Thus, you should only unflatten numpy arrays for this space. Got a {type(x)}"
|
||||
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
|
||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||
return tuple(
|
||||
unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)
|
||||
unflatten(s, flattened)
|
||||
for flattened, s in zip(list_flattened, space.spaces)
|
||||
)
|
||||
assert isinstance(
|
||||
x, tuple
|
||||
), f"{space} is not numpy-flattenable. Thus, you should only unflatten tuples for this space. Got a {type(x)}"
|
||||
return tuple(unflatten(s, flattened) for flattened, s in zip(x, space.spaces))
|
||||
|
||||
|
||||
@unflatten.register(Dict)
|
||||
def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
|
||||
def _unflatten_dict(space: Dict, x: Union[np.ndarray, TypingDict]) -> dict:
|
||||
if space.is_np_flattenable:
|
||||
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
|
||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||
return OrderedDict(
|
||||
@@ -217,6 +255,10 @@ def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
|
||||
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
|
||||
]
|
||||
)
|
||||
assert isinstance(
|
||||
x, dict
|
||||
), f"{space} is not numpy-flattenable. Thus, you should only unflatten dictionary for this space. Got a {type(x)}"
|
||||
return OrderedDict((key, unflatten(s, x[key])) for key, s in space.spaces.items())
|
||||
|
||||
|
||||
@unflatten.register(Graph)
|
||||
@@ -242,10 +284,18 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space: Space) -> Box:
|
||||
"""Flatten a space into a single ``Box``.
|
||||
@unflatten.register(Sequence)
|
||||
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
|
||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
||||
"""Flatten a space into a space that is as flat as possible.
|
||||
|
||||
This function will attempt to flatten `space` into a single :class:`Box` space.
|
||||
However, this might not be possible when `space` is an instance of :class:`Graph`,
|
||||
:class:`Sequence` or a compound space that contains a :class:`Graph` or :class:`Sequence`space.
|
||||
This is equivalent to :func:`flatten`, but operates on the space itself. The
|
||||
result for non-graph spaces is always a `Box` with flat boundaries. While
|
||||
the result for graph spaces is always a `Graph` with `node_space` being a `Box`
|
||||
@@ -314,23 +364,31 @@ def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) ->
|
||||
|
||||
|
||||
@flatten_space.register(Tuple)
|
||||
def _flatten_space_tuple(space: Tuple) -> Box:
|
||||
def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
|
||||
if space.is_np_flattenable:
|
||||
space_list = [flatten_space(s) for s in space.spaces]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space_list]),
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
return Tuple(spaces=[flatten_space(s) for s in space.spaces])
|
||||
|
||||
|
||||
@flatten_space.register(Dict)
|
||||
def _flatten_space_dict(space: Dict) -> Box:
|
||||
def _flatten_space_dict(space: Dict) -> Union[Box, Dict]:
|
||||
if space.is_np_flattenable:
|
||||
space_list = [flatten_space(s) for s in space.spaces.values()]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space_list]),
|
||||
high=np.concatenate([s.high for s in space_list]),
|
||||
dtype=np.result_type(*[s.dtype for s in space_list]),
|
||||
)
|
||||
return Dict(
|
||||
spaces=OrderedDict(
|
||||
(key, flatten_space(space)) for key, space in space.spaces.items()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Graph)
|
||||
@@ -341,3 +399,8 @@ def _flatten_space_graph(space: Graph) -> Graph:
|
||||
if space.edge_space is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Sequence)
|
||||
def _flatten_space_sequence(space: Sequence) -> Sequence:
|
||||
return Sequence(flatten_space(space.feature_space))
|
||||
|
@@ -15,6 +15,7 @@ from gym.spaces import (
|
||||
Graph,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
Tuple,
|
||||
@@ -55,6 +56,8 @@ from gym.spaces import (
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
Sequence(Discrete(4)),
|
||||
Sequence(Dict({"feature": Box(0, 1, (3,))})),
|
||||
Text(5),
|
||||
Text(min_length=1, max_length=10, charset=string.digits),
|
||||
],
|
||||
@@ -114,6 +117,8 @@ def test_roundtripping(space):
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
Sequence(Discrete(4)),
|
||||
Sequence(Dict({"feature": Box(0, 1, (3,))})),
|
||||
Text(5),
|
||||
Text(min_length=1, max_length=10, charset=string.digits),
|
||||
],
|
||||
@@ -158,6 +163,11 @@ def test_equality(space):
|
||||
),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
),
|
||||
(
|
||||
Sequence(Discrete(4)),
|
||||
Sequence(Dict({"feature": Box(0, 1, (3,))})),
|
||||
),
|
||||
(Sequence(Discrete(4)), Sequence(Discrete(4, start=-1))),
|
||||
(
|
||||
Text(5),
|
||||
Text(min_length=1, max_length=10, charset=string.digits),
|
||||
@@ -448,6 +458,14 @@ def test_space_sample_mask(space, mask, n_trials: int = 100):
|
||||
),
|
||||
None,
|
||||
),
|
||||
(Sequence(Discrete(2)), (None, np.array([0, 1], dtype=np.int8))),
|
||||
(
|
||||
Sequence(Discrete(2)),
|
||||
(np.array([2, 3, 4], dtype=np.int8), np.array([0, 1], dtype=np.int8)),
|
||||
),
|
||||
(Sequence(Discrete(2)), (np.array([2, 3, 4], dtype=np.int8), None)),
|
||||
(Sequence(Discrete(2)), (None, None)),
|
||||
(Sequence(Discrete(2)), None),
|
||||
],
|
||||
)
|
||||
def test_composite_space_sample_mask(space, mask):
|
||||
@@ -484,6 +502,7 @@ def test_composite_space_sample_mask(space, mask):
|
||||
),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
),
|
||||
(Sequence(Discrete(4)), Sequence(Discrete(3))),
|
||||
(
|
||||
Text(5),
|
||||
Text(min_length=1, max_length=10, charset=string.digits),
|
||||
@@ -606,6 +625,8 @@ def test_box_dtype_check():
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
Sequence(Discrete(4)),
|
||||
Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})),
|
||||
Text(5),
|
||||
Text(min_length=1, max_length=10, charset=string.digits),
|
||||
],
|
||||
@@ -671,6 +692,8 @@ def sample_equal(sample1, sample2):
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
Sequence(Discrete(4)),
|
||||
Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})),
|
||||
Text(5),
|
||||
Text(min_length=1, max_length=10, charset=string.digits),
|
||||
],
|
||||
@@ -986,6 +1009,8 @@ def test_box_legacy_state_pickling():
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
Sequence(Discrete(4)),
|
||||
Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})),
|
||||
Text(5),
|
||||
Text(min_length=1, max_length=10, charset=string.digits),
|
||||
],
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
@@ -8,8 +9,10 @@ from gym.spaces import (
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Tuple,
|
||||
utils,
|
||||
)
|
||||
@@ -42,15 +45,57 @@ homogeneous_spaces = [
|
||||
|
||||
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
|
||||
|
||||
graph_spaces = [
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||
Graph(node_space=Discrete(5), edge_space=None),
|
||||
non_homogenous_spaces = [
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), #
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), #
|
||||
Graph(node_space=Discrete(5), edge_space=None), #
|
||||
Sequence(Discrete(4)), #
|
||||
Sequence(Box(-10, 10, shape=(2, 2))), #
|
||||
Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), #
|
||||
Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), #
|
||||
Dict(
|
||||
a=Graph(node_space=Discrete(4), edge_space=Discrete(4)),
|
||||
b=Box(-10, 10, shape=(2, 2)),
|
||||
), #
|
||||
Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), #
|
||||
Tuple(
|
||||
[
|
||||
Graph(node_space=Discrete(4), edge_space=Discrete(4)),
|
||||
Box(-10, 10, shape=(2, 2)),
|
||||
]
|
||||
), #
|
||||
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), #
|
||||
Dict(
|
||||
a=Dict(
|
||||
a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2))
|
||||
),
|
||||
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
|
||||
), #
|
||||
Dict(
|
||||
a=Dict(
|
||||
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
|
||||
b=Box(-100, 100, shape=(2, 2)),
|
||||
),
|
||||
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", non_homogenous_spaces)
|
||||
def test_non_flattenable(space):
|
||||
assert space.is_np_flattenable is False
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||
),
|
||||
):
|
||||
utils.flatdim(space)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
|
||||
def test_flatdim(space, flatdim):
|
||||
assert space.is_np_flattenable
|
||||
dim = utils.flatdim(space)
|
||||
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
|
||||
|
||||
@@ -64,7 +109,7 @@ def test_flatten_space_boxes(space):
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
|
||||
def test_flat_space_contains_flat_points(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
@@ -83,7 +128,7 @@ def test_flatten_dim(space):
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
|
||||
def test_flatten_roundtripping(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
@@ -96,11 +141,14 @@ def test_flatten_roundtripping(space):
|
||||
assert compare_nested(
|
||||
original, roundtripped
|
||||
), f"Expected sample #{i} {original} to equal {roundtripped}"
|
||||
assert space.contains(roundtripped)
|
||||
|
||||
|
||||
def compare_nested(left, right):
|
||||
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
|
||||
return np.allclose(left, right)
|
||||
if type(left) != type(right):
|
||||
return False
|
||||
elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
|
||||
return left.shape == right.shape and np.allclose(left, right)
|
||||
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
||||
res = len(left) == len(right)
|
||||
for ((left_key, left_value), (right_key, right_value)) in zip(
|
||||
@@ -193,7 +241,7 @@ def compare_sample_types(original_space, original_sample, unflattened_sample):
|
||||
)
|
||||
|
||||
|
||||
samples = [
|
||||
homogeneous_samples = [
|
||||
2,
|
||||
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32),
|
||||
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16),
|
||||
@@ -210,7 +258,7 @@ samples = [
|
||||
]
|
||||
|
||||
|
||||
expected_flattened_samples = [
|
||||
expected_flattened_hom_samples = [
|
||||
np.array([0, 0, 1], dtype=np.int64),
|
||||
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32),
|
||||
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16),
|
||||
@@ -224,23 +272,297 @@ expected_flattened_samples = [
|
||||
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
|
||||
]
|
||||
|
||||
non_homogenous_samples = [
|
||||
GraphInstance(
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([0, 1], dtype=int),
|
||||
np.array([[[1, 2], [3, 4]]], dtype=np.float32),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)),
|
||||
(0, 1, 2),
|
||||
(
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
np.array([[4, 5], [6, 7]], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
(np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)),
|
||||
(np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)),
|
||||
),
|
||||
OrderedDict(
|
||||
[("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array([1, 2], dtype=np.int),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([[0, 1], [2, 3]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([1, 2], dtype=np.int),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32),
|
||||
edges=np.array([0], dtype=int),
|
||||
edge_links=np.array([[0, 1]]),
|
||||
),
|
||||
GraphInstance(
|
||||
nodes=np.array(
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32
|
||||
),
|
||||
edges=np.array([1], dtype=int),
|
||||
edge_links=np.array([[0, 1]]),
|
||||
),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
np.array([[4, 5], [6, 7]], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
"b",
|
||||
(
|
||||
np.array([12, 13], dtype=np.float32),
|
||||
np.array([14, 15], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array(
|
||||
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
|
||||
dtype=np.float32,
|
||||
),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
"b",
|
||||
(
|
||||
np.array([12, 13], dtype=np.float32),
|
||||
np.array([14, 15], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
expected_flattened_non_hom_samples = [
|
||||
GraphInstance(
|
||||
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
|
||||
np.array([[1, 2, 3, 4]], dtype=np.float32),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int),
|
||||
np.array([[1, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
(
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0]]),
|
||||
np.array([[0, 1]]),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32),
|
||||
np.array([[0, 1, 0, 0]]),
|
||||
np.array([[0, 1]]),
|
||||
),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array(
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32
|
||||
),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "sample", "expected_flattened_sample"],
|
||||
zip(homogeneous_spaces, samples, expected_flattened_samples),
|
||||
zip(
|
||||
homogeneous_spaces + non_homogenous_spaces,
|
||||
homogeneous_samples + non_homogenous_samples,
|
||||
expected_flattened_hom_samples + expected_flattened_non_hom_samples,
|
||||
),
|
||||
)
|
||||
def test_flatten(space, sample, expected_flattened_sample):
|
||||
assert sample in space
|
||||
|
||||
flattened_sample = utils.flatten(space, sample)
|
||||
flat_space = utils.flatten_space(space)
|
||||
|
||||
assert sample in space
|
||||
assert flattened_sample in flat_space
|
||||
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(flattened_sample, np.ndarray)
|
||||
assert flattened_sample.shape == expected_flattened_sample.shape
|
||||
assert flattened_sample.dtype == expected_flattened_sample.dtype
|
||||
assert np.all(flattened_sample == expected_flattened_sample)
|
||||
else:
|
||||
assert not isinstance(flattened_sample, np.ndarray)
|
||||
assert compare_nested(flattened_sample, expected_flattened_sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "flattened_sample", "expected_sample"],
|
||||
zip(homogeneous_spaces, expected_flattened_samples, samples),
|
||||
zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples),
|
||||
)
|
||||
def test_unflatten(space, flattened_sample, expected_sample):
|
||||
sample = utils.unflatten(space, flattened_sample)
|
||||
|
Reference in New Issue
Block a user