Add Sequence space, update flatten functions (#2968)

* Added Sequence space, updated flatten functions to work with Sequence, Graph. WIP.

* Small fixes, added Sequence space to tests

* Replace Optional[Any] by Any

* Added tests for flattening of non-numpy-flattenable spaces

* Return all seeds
This commit is contained in:
Markus Krimmel
2022-08-15 17:11:32 +02:00
committed by GitHub
parent 8b744130bc
commit 63ea5f2517
13 changed files with 613 additions and 58 deletions

View File

@@ -14,6 +14,7 @@ from gym.spaces.discrete import Discrete
from gym.spaces.graph import Graph, GraphInstance
from gym.spaces.multi_binary import MultiBinary
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.sequence import Sequence
from gym.spaces.space import Space
from gym.spaces.text import Text
from gym.spaces.tuple import Tuple
@@ -29,6 +30,7 @@ __all__ = [
"MultiDiscrete",
"MultiBinary",
"Tuple",
"Sequence",
"Dict",
"flatdim",
"flatten_space",

View File

@@ -139,6 +139,11 @@ class Box(Space[np.ndarray]):
"""Has stricter type than gym.Space - never None."""
return self._shape
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def is_bounded(self, manner: str = "both") -> bool:
"""Checks whether the box is bounded in some sense.

View File

@@ -100,6 +100,11 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces.values())
def seed(self, seed: Optional[Union[dict, int]] = None) -> list:
"""Seed the PRNG of this space and all subspaces."""
seeds = []

View File

@@ -40,6 +40,11 @@ class Discrete(Space[int]):
self.start = int(start)
super().__init__((), np.int64, seed)
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: Optional[np.ndarray] = None) -> int:
"""Generates a single random sample from this space.

View File

@@ -18,7 +18,7 @@ class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"]
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes.
(...) must adhere to the shape of the node space.
edges (np.ndarray): an (m x ...) sized array representing the features for m nodes.
edges (np.ndarray): an (m x ...) sized array representing the features for m edges.
(...) must adhere to the shape of the edge space.
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
@@ -68,6 +68,11 @@ class Graph(Space):
super().__init__(None, None, seed)
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return False
def _generate_sample_space(
self, base_space: Union[None, Box, Discrete], num: int
) -> Optional[Union[Box, MultiDiscrete]]:

View File

@@ -51,6 +51,11 @@ class MultiBinary(Space[np.ndarray]):
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
"""Generates a single random sample from this space.

View File

@@ -64,6 +64,11 @@ class MultiDiscrete(Space[np.ndarray]):
"""Has stricter type than :class:`gym.Space` - never None."""
return self._shape # type: ignore
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: Optional[SAMPLE_MASK_TYPE] = None) -> np.ndarray:
"""Generates a single random sample this space.

103
gym/spaces/sequence.py Normal file
View File

@@ -0,0 +1,103 @@
"""Implementation of a space that represents finite-length sequences."""
from collections.abc import Sequence as CollectionSequence
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class Sequence(Space[Tuple]):
r"""This space represent sets of finite-length sequences.
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
to some space that is specified during initialization and the integer :math:`n` is not fixed
Example::
>>> space = Sequence(Box(0, 1))
>>> space.sample()
(array([0.0259352], dtype=float32),)
>>> space.sample()
(array([0.80977976], dtype=float32), array([0.80066574], dtype=float32), array([0.77165383], dtype=float32))
"""
def __init__(
self,
space: Space,
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None,
):
"""Constructor of the :class:`Sequence` space.
Args:
space: Elements in the sequences this space represent must belong to this space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
self.feature_space = space
super().__init__(
None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling
def seed(self, seed: Optional[int] = None) -> list:
"""Seed the PRNG of this space and the feature space."""
seeds = super().seed(seed)
seeds += self.feature_space.seed(seed)
return seeds
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return False
def sample(
self, mask: Optional[Tuple[Optional[np.ndarray], Any]] = None
) -> Tuple[Any]:
"""Generates a single random sample from this space.
Args:
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask`
is either `None` if you do not want to specify any restrictions on the length of the sampled sequence (then, the
length will be randomly drawn from a geometric distribution), or a `np.ndarray` of integers, in which case the length of
the sampled sequence is randomly drawn from this array. The second element of the tuple, `sample` mask
specifies a mask that is applied when sampling elements from the base space.
Returns:
A tuple of random length with random samples of elements from the :attr:`feature_space`.
"""
if mask is not None:
length_mask, feature_mask = mask
else:
length_mask = None
feature_mask = None
if length_mask is not None:
length = self.np_random.choice(length_mask)
else:
length = self.np_random.geometric(0.25)
return tuple(
self.feature_space.sample(mask=feature_mask) for _ in range(length)
)
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
return isinstance(x, CollectionSequence) and all(
self.feature_space.contains(item) for item in x
)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return f"Sequence({self.feature_space})"
def to_jsonable(self, sample_n: list) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as dict-repr of vectors
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
def from_jsonable(self, sample_n: List[List[Any]]) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, Sequence) and self.feature_space == other.feature_space

View File

@@ -82,6 +82,11 @@ class Space(Generic[T_cov]):
"""Return the shape of the space as an immutable property."""
return self._shape
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
raise NotImplementedError
def sample(self, mask: Optional[Any] = None) -> T_cov:
"""Randomly sample an element of this space.

View File

@@ -40,6 +40,11 @@ class Tuple(Space[tuple], Sequence):
), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed) # type: ignore
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
"""Seed the PRNG of this space and all subspaces."""
seeds = []

View File

@@ -6,6 +6,7 @@ These functions mostly take care of flattening and unflattening elements of spac
import operator as op
from collections import OrderedDict
from functools import reduce, singledispatch
from typing import Dict as TypingDict
from typing import TypeVar, Union, cast
import numpy as np
@@ -18,6 +19,7 @@ from gym.spaces import (
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Tuple,
)
@@ -42,7 +44,13 @@ def flatdim(space: Space) -> int:
Raises:
NotImplementedError: if the space is not defined in ``gym.spaces``.
ValueError: if the space cannot be flattened into a :class:`Box`
"""
if not space.is_np_flattenable:
raise ValueError(
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
)
raise NotImplementedError(f"Unknown space: `{space}`")
@@ -64,19 +72,28 @@ def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
@flatdim.register(Tuple)
def _flatdim_tuple(space: Tuple) -> int:
return sum(flatdim(s) for s in space.spaces)
if space.is_np_flattenable:
return sum(flatdim(s) for s in space.spaces)
raise ValueError(
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
)
@flatdim.register(Dict)
def _flatdim_dict(space: Dict) -> int:
return sum(flatdim(s) for s in space.spaces.values())
if space.is_np_flattenable:
return sum(flatdim(s) for s in space.spaces.values())
raise ValueError(
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
)
T = TypeVar("T")
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
@singledispatch
def flatten(space: Space[T], x: T) -> np.ndarray:
def flatten(space: Space[T], x: T) -> FlatType:
"""Flatten a data point from a space.
This is useful when e.g. points from spaces must be passed to a neural
@@ -127,17 +144,23 @@ def _flatten_multidiscrete(space, x) -> np.ndarray:
@flatten.register(Tuple)
def _flatten_tuple(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
if space.is_np_flattenable:
return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
)
return tuple((flatten(s, x_part) for x_part, s in zip(x, space.spaces)))
@flatten.register(Dict)
def _flatten_dict(space, x) -> np.ndarray:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]:
if space.is_np_flattenable:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
@flatten.register(Graph)
def _flatten_graph(space, x) -> np.ndarray:
def _flatten_graph(space, x) -> GraphInstance:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
def _graph_unflatten(space, x):
@@ -156,8 +179,13 @@ def _flatten_graph(space, x) -> np.ndarray:
return GraphInstance(nodes, edges, x.edge_links)
@flatten.register(Sequence)
def _flatten_sequence(space, x) -> tuple:
return tuple(flatten(space.feature_space, item) for item in x)
@singledispatch
def unflatten(space: Space[T], x: np.ndarray) -> T:
def unflatten(space: Space[T], x: FlatType) -> T:
"""Unflatten a data point from a space.
This reverses the transformation applied by :func:`flatten`. You must ensure
@@ -199,24 +227,38 @@ def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
@unflatten.register(Tuple)
def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple:
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1]))
return tuple(
unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)
)
def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
if space.is_np_flattenable:
assert isinstance(
x, np.ndarray
), f"{space} is numpy-flattenable. Thus, you should only unflatten numpy arrays for this space. Got a {type(x)}"
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1]))
return tuple(
unflatten(s, flattened)
for flattened, s in zip(list_flattened, space.spaces)
)
assert isinstance(
x, tuple
), f"{space} is not numpy-flattenable. Thus, you should only unflatten tuples for this space. Got a {type(x)}"
return tuple(unflatten(s, flattened) for flattened, s in zip(x, space.spaces))
@unflatten.register(Dict)
def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1]))
return OrderedDict(
[
(key, unflatten(s, flattened))
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
]
)
def _unflatten_dict(space: Dict, x: Union[np.ndarray, TypingDict]) -> dict:
if space.is_np_flattenable:
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1]))
return OrderedDict(
[
(key, unflatten(s, flattened))
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
]
)
assert isinstance(
x, dict
), f"{space} is not numpy-flattenable. Thus, you should only unflatten dictionary for this space. Got a {type(x)}"
return OrderedDict((key, unflatten(s, x[key])) for key, s in space.spaces.items())
@unflatten.register(Graph)
@@ -242,10 +284,18 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
return GraphInstance(nodes, edges, x.edge_links)
@singledispatch
def flatten_space(space: Space) -> Box:
"""Flatten a space into a single ``Box``.
@unflatten.register(Sequence)
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
return tuple(unflatten(space.feature_space, item) for item in x)
@singledispatch
def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
"""Flatten a space into a space that is as flat as possible.
This function will attempt to flatten `space` into a single :class:`Box` space.
However, this might not be possible when `space` is an instance of :class:`Graph`,
:class:`Sequence` or a compound space that contains a :class:`Graph` or :class:`Sequence`space.
This is equivalent to :func:`flatten`, but operates on the space itself. The
result for non-graph spaces is always a `Box` with flat boundaries. While
the result for graph spaces is always a `Graph` with `node_space` being a `Box`
@@ -314,22 +364,30 @@ def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) ->
@flatten_space.register(Tuple)
def _flatten_space_tuple(space: Tuple) -> Box:
space_list = [flatten_space(s) for s in space.spaces]
return Box(
low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)
def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
if space.is_np_flattenable:
space_list = [flatten_space(s) for s in space.spaces]
return Box(
low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)
return Tuple(spaces=[flatten_space(s) for s in space.spaces])
@flatten_space.register(Dict)
def _flatten_space_dict(space: Dict) -> Box:
space_list = [flatten_space(s) for s in space.spaces.values()]
return Box(
low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
def _flatten_space_dict(space: Dict) -> Union[Box, Dict]:
if space.is_np_flattenable:
space_list = [flatten_space(s) for s in space.spaces.values()]
return Box(
low=np.concatenate([s.low for s in space_list]),
high=np.concatenate([s.high for s in space_list]),
dtype=np.result_type(*[s.dtype for s in space_list]),
)
return Dict(
spaces=OrderedDict(
(key, flatten_space(space)) for key, space in space.spaces.items()
)
)
@@ -341,3 +399,8 @@ def _flatten_space_graph(space: Graph) -> Graph:
if space.edge_space is not None
else None,
)
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space))

View File

@@ -15,6 +15,7 @@ from gym.spaces import (
Graph,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
@@ -55,6 +56,8 @@ from gym.spaces import (
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(5), edge_space=None),
Sequence(Discrete(4)),
Sequence(Dict({"feature": Box(0, 1, (3,))})),
Text(5),
Text(min_length=1, max_length=10, charset=string.digits),
],
@@ -114,6 +117,8 @@ def test_roundtripping(space):
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(5), edge_space=None),
Sequence(Discrete(4)),
Sequence(Dict({"feature": Box(0, 1, (3,))})),
Text(5),
Text(min_length=1, max_length=10, charset=string.digits),
],
@@ -158,6 +163,11 @@ def test_equality(space):
),
Graph(node_space=Discrete(5), edge_space=None),
),
(
Sequence(Discrete(4)),
Sequence(Dict({"feature": Box(0, 1, (3,))})),
),
(Sequence(Discrete(4)), Sequence(Discrete(4, start=-1))),
(
Text(5),
Text(min_length=1, max_length=10, charset=string.digits),
@@ -448,6 +458,14 @@ def test_space_sample_mask(space, mask, n_trials: int = 100):
),
None,
),
(Sequence(Discrete(2)), (None, np.array([0, 1], dtype=np.int8))),
(
Sequence(Discrete(2)),
(np.array([2, 3, 4], dtype=np.int8), np.array([0, 1], dtype=np.int8)),
),
(Sequence(Discrete(2)), (np.array([2, 3, 4], dtype=np.int8), None)),
(Sequence(Discrete(2)), (None, None)),
(Sequence(Discrete(2)), None),
],
)
def test_composite_space_sample_mask(space, mask):
@@ -484,6 +502,7 @@ def test_composite_space_sample_mask(space, mask):
),
Graph(node_space=Discrete(5), edge_space=None),
),
(Sequence(Discrete(4)), Sequence(Discrete(3))),
(
Text(5),
Text(min_length=1, max_length=10, charset=string.digits),
@@ -606,6 +625,8 @@ def test_box_dtype_check():
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
Graph(node_space=Discrete(5), edge_space=None),
Sequence(Discrete(4)),
Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})),
Text(5),
Text(min_length=1, max_length=10, charset=string.digits),
],
@@ -671,6 +692,8 @@ def sample_equal(sample1, sample2):
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
Graph(node_space=Discrete(5), edge_space=None),
Sequence(Discrete(4)),
Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})),
Text(5),
Text(min_length=1, max_length=10, charset=string.digits),
],
@@ -986,6 +1009,8 @@ def test_box_legacy_state_pickling():
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None),
Graph(node_space=Discrete(5), edge_space=None),
Sequence(Discrete(4)),
Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})),
Text(5),
Text(min_length=1, max_length=10, charset=string.digits),
],

View File

@@ -1,3 +1,4 @@
import re
from collections import OrderedDict
import numpy as np
@@ -8,8 +9,10 @@ from gym.spaces import (
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Tuple,
utils,
)
@@ -42,15 +45,57 @@ homogeneous_spaces = [
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
graph_spaces = [
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(5), edge_space=None),
non_homogenous_spaces = [
Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), #
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), #
Graph(node_space=Discrete(5), edge_space=None), #
Sequence(Discrete(4)), #
Sequence(Box(-10, 10, shape=(2, 2))), #
Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), #
Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), #
Dict(
a=Graph(node_space=Discrete(4), edge_space=Discrete(4)),
b=Box(-10, 10, shape=(2, 2)),
), #
Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), #
Tuple(
[
Graph(node_space=Discrete(4), edge_space=Discrete(4)),
Box(-10, 10, shape=(2, 2)),
]
), #
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), #
Dict(
a=Dict(
a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2))
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
), #
Dict(
a=Dict(
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
b=Box(-100, 100, shape=(2, 2)),
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
),
]
@pytest.mark.parametrize("space", non_homogenous_spaces)
def test_non_flattenable(space):
assert space.is_np_flattenable is False
with pytest.raises(
ValueError,
match=re.escape(
"cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
),
):
utils.flatdim(space)
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
def test_flatdim(space, flatdim):
assert space.is_np_flattenable
dim = utils.flatdim(space)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
@@ -64,7 +109,7 @@ def test_flatten_space_boxes(space):
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flat_space_contains_flat_points(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
@@ -83,7 +128,7 @@ def test_flatten_dim(space):
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces)
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
@@ -96,11 +141,14 @@ def test_flatten_roundtripping(space):
assert compare_nested(
original, roundtripped
), f"Expected sample #{i} {original} to equal {roundtripped}"
assert space.contains(roundtripped)
def compare_nested(left, right):
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
return np.allclose(left, right)
if type(left) != type(right):
return False
elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
return left.shape == right.shape and np.allclose(left, right)
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
res = len(left) == len(right)
for ((left_key, left_value), (right_key, right_value)) in zip(
@@ -193,7 +241,7 @@ def compare_sample_types(original_space, original_sample, unflattened_sample):
)
samples = [
homogeneous_samples = [
2,
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32),
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16),
@@ -210,7 +258,7 @@ samples = [
]
expected_flattened_samples = [
expected_flattened_hom_samples = [
np.array([0, 0, 1], dtype=np.int64),
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32),
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16),
@@ -224,23 +272,297 @@ expected_flattened_samples = [
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
]
non_homogenous_samples = [
GraphInstance(
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([0, 1], dtype=int),
np.array([[[1, 2], [3, 4]]], dtype=np.float32),
np.array([[0, 1]], dtype=int),
),
GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)),
(0, 1, 2),
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
(
(np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)),
(np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)),
),
OrderedDict(
[("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[0, 1], [2, 3]], dtype=np.float32)),
]
),
((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)),
(
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
np.array([[0, 1], [2, 3]], dtype=np.float32),
),
(
GraphInstance(
nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32),
edges=np.array([0], dtype=int),
edge_links=np.array([[0, 1]]),
),
GraphInstance(
nodes=np.array(
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32
),
edges=np.array([1], dtype=int),
edge_links=np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=np.float32,
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
]
expected_flattened_non_hom_samples = [
GraphInstance(
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32),
np.array([[1, 0, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
np.array([[1, 2, 3, 4]], dtype=np.float32),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
None,
np.array([[0, 1]], dtype=int),
),
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
OrderedDict(
[
(
"a",
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
(
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32),
np.array([[1, 0, 0, 0]]),
np.array([[0, 1]]),
),
GraphInstance(
np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32),
np.array([[0, 1, 0, 0]]),
np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
]
@pytest.mark.parametrize(
["space", "sample", "expected_flattened_sample"],
zip(homogeneous_spaces, samples, expected_flattened_samples),
zip(
homogeneous_spaces + non_homogenous_spaces,
homogeneous_samples + non_homogenous_samples,
expected_flattened_hom_samples + expected_flattened_non_hom_samples,
),
)
def test_flatten(space, sample, expected_flattened_sample):
assert sample in space
flattened_sample = utils.flatten(space, sample)
assert flattened_sample.shape == expected_flattened_sample.shape
assert flattened_sample.dtype == expected_flattened_sample.dtype
assert np.all(flattened_sample == expected_flattened_sample)
flat_space = utils.flatten_space(space)
assert sample in space
assert flattened_sample in flat_space
if space.is_np_flattenable:
assert isinstance(flattened_sample, np.ndarray)
assert flattened_sample.shape == expected_flattened_sample.shape
assert flattened_sample.dtype == expected_flattened_sample.dtype
assert np.all(flattened_sample == expected_flattened_sample)
else:
assert not isinstance(flattened_sample, np.ndarray)
assert compare_nested(flattened_sample, expected_flattened_sample)
@pytest.mark.parametrize(
["space", "flattened_sample", "expected_sample"],
zip(homogeneous_spaces, expected_flattened_samples, samples),
zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples),
)
def test_unflatten(space, flattened_sample, expected_sample):
sample = utils.unflatten(space, flattened_sample)