diff --git a/gym/spaces/__init__.py b/gym/spaces/__init__.py index 833e0818c..0f6f4923c 100644 --- a/gym/spaces/__init__.py +++ b/gym/spaces/__init__.py @@ -14,6 +14,7 @@ from gym.spaces.discrete import Discrete from gym.spaces.graph import Graph, GraphInstance from gym.spaces.multi_binary import MultiBinary from gym.spaces.multi_discrete import MultiDiscrete +from gym.spaces.sequence import Sequence from gym.spaces.space import Space from gym.spaces.text import Text from gym.spaces.tuple import Tuple @@ -29,6 +30,7 @@ __all__ = [ "MultiDiscrete", "MultiBinary", "Tuple", + "Sequence", "Dict", "flatdim", "flatten_space", diff --git a/gym/spaces/box.py b/gym/spaces/box.py index ad2cc8ca0..f73deac5c 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -139,6 +139,11 @@ class Box(Space[np.ndarray]): """Has stricter type than gym.Space - never None.""" return self._shape + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return True + def is_bounded(self, manner: str = "both") -> bool: """Checks whether the box is bounded in some sense. diff --git a/gym/spaces/dict.py b/gym/spaces/dict.py index 50dc7f259..315cad1a6 100644 --- a/gym/spaces/dict.py +++ b/gym/spaces/dict.py @@ -100,6 +100,11 @@ class Dict(Space[TypingDict[str, Space]], Mapping): None, None, seed # type: ignore ) # None for shape and dtype, since it'll require special handling + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return all(space.is_np_flattenable for space in self.spaces.values()) + def seed(self, seed: Optional[Union[dict, int]] = None) -> list: """Seed the PRNG of this space and all subspaces.""" seeds = [] diff --git a/gym/spaces/discrete.py b/gym/spaces/discrete.py index 7f8f17dfa..a0f8656e6 100644 --- a/gym/spaces/discrete.py +++ b/gym/spaces/discrete.py @@ -40,6 +40,11 @@ class Discrete(Space[int]): self.start = int(start) super().__init__((), np.int64, seed) + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return True + def sample(self, mask: Optional[np.ndarray] = None) -> int: """Generates a single random sample from this space. diff --git a/gym/spaces/graph.py b/gym/spaces/graph.py index 40f3cfd87..4d18e03ca 100644 --- a/gym/spaces/graph.py +++ b/gym/spaces/graph.py @@ -18,7 +18,7 @@ class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"] nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes. (...) must adhere to the shape of the node space. - edges (np.ndarray): an (m x ...) sized array representing the features for m nodes. + edges (np.ndarray): an (m x ...) sized array representing the features for m edges. (...) must adhere to the shape of the edge space. edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects. @@ -68,6 +68,11 @@ class Graph(Space): super().__init__(None, None, seed) + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return False + def _generate_sample_space( self, base_space: Union[None, Box, Discrete], num: int ) -> Optional[Union[Box, MultiDiscrete]]: diff --git a/gym/spaces/multi_binary.py b/gym/spaces/multi_binary.py index 6662b9122..45bf55f69 100644 --- a/gym/spaces/multi_binary.py +++ b/gym/spaces/multi_binary.py @@ -51,6 +51,11 @@ class MultiBinary(Space[np.ndarray]): """Has stricter type than gym.Space - never None.""" return self._shape # type: ignore + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return True + def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray: """Generates a single random sample from this space. diff --git a/gym/spaces/multi_discrete.py b/gym/spaces/multi_discrete.py index 71111d4c9..74b04bb7a 100644 --- a/gym/spaces/multi_discrete.py +++ b/gym/spaces/multi_discrete.py @@ -64,6 +64,11 @@ class MultiDiscrete(Space[np.ndarray]): """Has stricter type than :class:`gym.Space` - never None.""" return self._shape # type: ignore + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return True + def sample(self, mask: Optional[SAMPLE_MASK_TYPE] = None) -> np.ndarray: """Generates a single random sample this space. diff --git a/gym/spaces/sequence.py b/gym/spaces/sequence.py new file mode 100644 index 000000000..399b4a969 --- /dev/null +++ b/gym/spaces/sequence.py @@ -0,0 +1,103 @@ +"""Implementation of a space that represents finite-length sequences.""" +from collections.abc import Sequence as CollectionSequence +from typing import Any, List, Optional, Tuple, Union + +import numpy as np + +from gym.spaces.space import Space +from gym.utils import seeding + + +class Sequence(Space[Tuple]): + r"""This space represent sets of finite-length sequences. + + This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong + to some space that is specified during initialization and the integer :math:`n` is not fixed + + Example:: + >>> space = Sequence(Box(0, 1)) + >>> space.sample() + (array([0.0259352], dtype=float32),) + >>> space.sample() + (array([0.80977976], dtype=float32), array([0.80066574], dtype=float32), array([0.77165383], dtype=float32)) + """ + + def __init__( + self, + space: Space, + seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None, + ): + """Constructor of the :class:`Sequence` space. + + Args: + space: Elements in the sequences this space represent must belong to this space. + seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. + """ + self.feature_space = space + super().__init__( + None, None, seed # type: ignore + ) # None for shape and dtype, since it'll require special handling + + def seed(self, seed: Optional[int] = None) -> list: + """Seed the PRNG of this space and the feature space.""" + seeds = super().seed(seed) + seeds += self.feature_space.seed(seed) + return seeds + + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return False + + def sample( + self, mask: Optional[Tuple[Optional[np.ndarray], Any]] = None + ) -> Tuple[Any]: + """Generates a single random sample from this space. + + Args: + mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence. + If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask` + is either `None` if you do not want to specify any restrictions on the length of the sampled sequence (then, the + length will be randomly drawn from a geometric distribution), or a `np.ndarray` of integers, in which case the length of + the sampled sequence is randomly drawn from this array. The second element of the tuple, `sample` mask + specifies a mask that is applied when sampling elements from the base space. + + Returns: + A tuple of random length with random samples of elements from the :attr:`feature_space`. + """ + if mask is not None: + length_mask, feature_mask = mask + else: + length_mask = None + feature_mask = None + if length_mask is not None: + length = self.np_random.choice(length_mask) + else: + length = self.np_random.geometric(0.25) + + return tuple( + self.feature_space.sample(mask=feature_mask) for _ in range(length) + ) + + def contains(self, x) -> bool: + """Return boolean specifying if x is a valid member of this space.""" + return isinstance(x, CollectionSequence) and all( + self.feature_space.contains(item) for item in x + ) + + def __repr__(self) -> str: + """Gives a string representation of this space.""" + return f"Sequence({self.feature_space})" + + def to_jsonable(self, sample_n: list) -> list: + """Convert a batch of samples from this space to a JSONable data type.""" + # serialize as dict-repr of vectors + return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n] + + def from_jsonable(self, sample_n: List[List[Any]]) -> list: + """Convert a JSONable data type to a batch of samples from this space.""" + return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n] + + def __eq__(self, other) -> bool: + """Check whether ``other`` is equivalent to this instance.""" + return isinstance(other, Sequence) and self.feature_space == other.feature_space diff --git a/gym/spaces/space.py b/gym/spaces/space.py index 5d7dea6f9..157cc5446 100644 --- a/gym/spaces/space.py +++ b/gym/spaces/space.py @@ -82,6 +82,11 @@ class Space(Generic[T_cov]): """Return the shape of the space as an immutable property.""" return self._shape + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + raise NotImplementedError + def sample(self, mask: Optional[Any] = None) -> T_cov: """Randomly sample an element of this space. diff --git a/gym/spaces/tuple.py b/gym/spaces/tuple.py index 10b4344ef..8a44a649a 100644 --- a/gym/spaces/tuple.py +++ b/gym/spaces/tuple.py @@ -40,6 +40,11 @@ class Tuple(Space[tuple], Sequence): ), "Elements of the tuple must be instances of gym.Space" super().__init__(None, None, seed) # type: ignore + @property + def is_np_flattenable(self): + """Checks whether this space can be flattened to a :class:`spaces.Box`.""" + return all(space.is_np_flattenable for space in self.spaces) + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list: """Seed the PRNG of this space and all subspaces.""" seeds = [] diff --git a/gym/spaces/utils.py b/gym/spaces/utils.py index 6ddccb2dc..55b2935e1 100644 --- a/gym/spaces/utils.py +++ b/gym/spaces/utils.py @@ -6,6 +6,7 @@ These functions mostly take care of flattening and unflattening elements of spac import operator as op from collections import OrderedDict from functools import reduce, singledispatch +from typing import Dict as TypingDict from typing import TypeVar, Union, cast import numpy as np @@ -18,6 +19,7 @@ from gym.spaces import ( GraphInstance, MultiBinary, MultiDiscrete, + Sequence, Space, Tuple, ) @@ -42,7 +44,13 @@ def flatdim(space: Space) -> int: Raises: NotImplementedError: if the space is not defined in ``gym.spaces``. + ValueError: if the space cannot be flattened into a :class:`Box` """ + if not space.is_np_flattenable: + raise ValueError( + f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace" + ) + raise NotImplementedError(f"Unknown space: `{space}`") @@ -64,19 +72,28 @@ def _flatdim_multidiscrete(space: MultiDiscrete) -> int: @flatdim.register(Tuple) def _flatdim_tuple(space: Tuple) -> int: - return sum(flatdim(s) for s in space.spaces) + if space.is_np_flattenable: + return sum(flatdim(s) for s in space.spaces) + raise ValueError( + f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace" + ) @flatdim.register(Dict) def _flatdim_dict(space: Dict) -> int: - return sum(flatdim(s) for s in space.spaces.values()) + if space.is_np_flattenable: + return sum(flatdim(s) for s in space.spaces.values()) + raise ValueError( + f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace" + ) T = TypeVar("T") +FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance] @singledispatch -def flatten(space: Space[T], x: T) -> np.ndarray: +def flatten(space: Space[T], x: T) -> FlatType: """Flatten a data point from a space. This is useful when e.g. points from spaces must be passed to a neural @@ -127,17 +144,23 @@ def _flatten_multidiscrete(space, x) -> np.ndarray: @flatten.register(Tuple) -def _flatten_tuple(space, x) -> np.ndarray: - return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)]) +def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]: + if space.is_np_flattenable: + return np.concatenate( + [flatten(s, x_part) for x_part, s in zip(x, space.spaces)] + ) + return tuple((flatten(s, x_part) for x_part, s in zip(x, space.spaces))) @flatten.register(Dict) -def _flatten_dict(space, x) -> np.ndarray: - return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) +def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]: + if space.is_np_flattenable: + return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) + return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items()) @flatten.register(Graph) -def _flatten_graph(space, x) -> np.ndarray: +def _flatten_graph(space, x) -> GraphInstance: """We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring.""" def _graph_unflatten(space, x): @@ -156,8 +179,13 @@ def _flatten_graph(space, x) -> np.ndarray: return GraphInstance(nodes, edges, x.edge_links) +@flatten.register(Sequence) +def _flatten_sequence(space, x) -> tuple: + return tuple(flatten(space.feature_space, item) for item in x) + + @singledispatch -def unflatten(space: Space[T], x: np.ndarray) -> T: +def unflatten(space: Space[T], x: FlatType) -> T: """Unflatten a data point from a space. This reverses the transformation applied by :func:`flatten`. You must ensure @@ -199,24 +227,38 @@ def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray: @unflatten.register(Tuple) -def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple: - dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_) - list_flattened = np.split(x, np.cumsum(dims[:-1])) - return tuple( - unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces) - ) +def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple: + if space.is_np_flattenable: + assert isinstance( + x, np.ndarray + ), f"{space} is numpy-flattenable. Thus, you should only unflatten numpy arrays for this space. Got a {type(x)}" + dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_) + list_flattened = np.split(x, np.cumsum(dims[:-1])) + return tuple( + unflatten(s, flattened) + for flattened, s in zip(list_flattened, space.spaces) + ) + assert isinstance( + x, tuple + ), f"{space} is not numpy-flattenable. Thus, you should only unflatten tuples for this space. Got a {type(x)}" + return tuple(unflatten(s, flattened) for flattened, s in zip(x, space.spaces)) @unflatten.register(Dict) -def _unflatten_dict(space: Dict, x: np.ndarray) -> dict: - dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_) - list_flattened = np.split(x, np.cumsum(dims[:-1])) - return OrderedDict( - [ - (key, unflatten(s, flattened)) - for flattened, (key, s) in zip(list_flattened, space.spaces.items()) - ] - ) +def _unflatten_dict(space: Dict, x: Union[np.ndarray, TypingDict]) -> dict: + if space.is_np_flattenable: + dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_) + list_flattened = np.split(x, np.cumsum(dims[:-1])) + return OrderedDict( + [ + (key, unflatten(s, flattened)) + for flattened, (key, s) in zip(list_flattened, space.spaces.items()) + ] + ) + assert isinstance( + x, dict + ), f"{space} is not numpy-flattenable. Thus, you should only unflatten dictionary for this space. Got a {type(x)}" + return OrderedDict((key, unflatten(s, x[key])) for key, s in space.spaces.items()) @unflatten.register(Graph) @@ -242,10 +284,18 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance: return GraphInstance(nodes, edges, x.edge_links) -@singledispatch -def flatten_space(space: Space) -> Box: - """Flatten a space into a single ``Box``. +@unflatten.register(Sequence) +def _unflatten_sequence(space: Sequence, x: tuple) -> tuple: + return tuple(unflatten(space.feature_space, item) for item in x) + +@singledispatch +def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]: + """Flatten a space into a space that is as flat as possible. + + This function will attempt to flatten `space` into a single :class:`Box` space. + However, this might not be possible when `space` is an instance of :class:`Graph`, + :class:`Sequence` or a compound space that contains a :class:`Graph` or :class:`Sequence`space. This is equivalent to :func:`flatten`, but operates on the space itself. The result for non-graph spaces is always a `Box` with flat boundaries. While the result for graph spaces is always a `Graph` with `node_space` being a `Box` @@ -314,22 +364,30 @@ def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> @flatten_space.register(Tuple) -def _flatten_space_tuple(space: Tuple) -> Box: - space_list = [flatten_space(s) for s in space.spaces] - return Box( - low=np.concatenate([s.low for s in space_list]), - high=np.concatenate([s.high for s in space_list]), - dtype=np.result_type(*[s.dtype for s in space_list]), - ) +def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]: + if space.is_np_flattenable: + space_list = [flatten_space(s) for s in space.spaces] + return Box( + low=np.concatenate([s.low for s in space_list]), + high=np.concatenate([s.high for s in space_list]), + dtype=np.result_type(*[s.dtype for s in space_list]), + ) + return Tuple(spaces=[flatten_space(s) for s in space.spaces]) @flatten_space.register(Dict) -def _flatten_space_dict(space: Dict) -> Box: - space_list = [flatten_space(s) for s in space.spaces.values()] - return Box( - low=np.concatenate([s.low for s in space_list]), - high=np.concatenate([s.high for s in space_list]), - dtype=np.result_type(*[s.dtype for s in space_list]), +def _flatten_space_dict(space: Dict) -> Union[Box, Dict]: + if space.is_np_flattenable: + space_list = [flatten_space(s) for s in space.spaces.values()] + return Box( + low=np.concatenate([s.low for s in space_list]), + high=np.concatenate([s.high for s in space_list]), + dtype=np.result_type(*[s.dtype for s in space_list]), + ) + return Dict( + spaces=OrderedDict( + (key, flatten_space(space)) for key, space in space.spaces.items() + ) ) @@ -341,3 +399,8 @@ def _flatten_space_graph(space: Graph) -> Graph: if space.edge_space is not None else None, ) + + +@flatten_space.register(Sequence) +def _flatten_space_sequence(space: Sequence) -> Sequence: + return Sequence(flatten_space(space.feature_space)) diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index b58ab6c79..47eb2834b 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -15,6 +15,7 @@ from gym.spaces import ( Graph, MultiBinary, MultiDiscrete, + Sequence, Space, Text, Tuple, @@ -55,6 +56,8 @@ from gym.spaces import ( Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)), Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), Graph(node_space=Discrete(5), edge_space=None), + Sequence(Discrete(4)), + Sequence(Dict({"feature": Box(0, 1, (3,))})), Text(5), Text(min_length=1, max_length=10, charset=string.digits), ], @@ -114,6 +117,8 @@ def test_roundtripping(space): Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)), Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), Graph(node_space=Discrete(5), edge_space=None), + Sequence(Discrete(4)), + Sequence(Dict({"feature": Box(0, 1, (3,))})), Text(5), Text(min_length=1, max_length=10, charset=string.digits), ], @@ -158,6 +163,11 @@ def test_equality(space): ), Graph(node_space=Discrete(5), edge_space=None), ), + ( + Sequence(Discrete(4)), + Sequence(Dict({"feature": Box(0, 1, (3,))})), + ), + (Sequence(Discrete(4)), Sequence(Discrete(4, start=-1))), ( Text(5), Text(min_length=1, max_length=10, charset=string.digits), @@ -448,6 +458,14 @@ def test_space_sample_mask(space, mask, n_trials: int = 100): ), None, ), + (Sequence(Discrete(2)), (None, np.array([0, 1], dtype=np.int8))), + ( + Sequence(Discrete(2)), + (np.array([2, 3, 4], dtype=np.int8), np.array([0, 1], dtype=np.int8)), + ), + (Sequence(Discrete(2)), (np.array([2, 3, 4], dtype=np.int8), None)), + (Sequence(Discrete(2)), (None, None)), + (Sequence(Discrete(2)), None), ], ) def test_composite_space_sample_mask(space, mask): @@ -484,6 +502,7 @@ def test_composite_space_sample_mask(space, mask): ), Graph(node_space=Discrete(5), edge_space=None), ), + (Sequence(Discrete(4)), Sequence(Discrete(3))), ( Text(5), Text(min_length=1, max_length=10, charset=string.digits), @@ -606,6 +625,8 @@ def test_box_dtype_check(): Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None), Graph(node_space=Discrete(5), edge_space=None), + Sequence(Discrete(4)), + Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})), Text(5), Text(min_length=1, max_length=10, charset=string.digits), ], @@ -671,6 +692,8 @@ def sample_equal(sample1, sample2): Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None), Graph(node_space=Discrete(5), edge_space=None), + Sequence(Discrete(4)), + Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})), Text(5), Text(min_length=1, max_length=10, charset=string.digits), ], @@ -986,6 +1009,8 @@ def test_box_legacy_state_pickling(): Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=None), Graph(node_space=Discrete(5), edge_space=None), + Sequence(Discrete(4)), + Sequence(Dict({"a": Box(0, 1), "b": Discrete(3)})), Text(5), Text(min_length=1, max_length=10, charset=string.digits), ], diff --git a/tests/spaces/test_utils.py b/tests/spaces/test_utils.py index 9443e711a..3def3a257 100644 --- a/tests/spaces/test_utils.py +++ b/tests/spaces/test_utils.py @@ -1,3 +1,4 @@ +import re from collections import OrderedDict import numpy as np @@ -8,8 +9,10 @@ from gym.spaces import ( Dict, Discrete, Graph, + GraphInstance, MultiBinary, MultiDiscrete, + Sequence, Tuple, utils, ) @@ -42,15 +45,57 @@ homogeneous_spaces = [ flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8] -graph_spaces = [ - Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)), - Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), - Graph(node_space=Discrete(5), edge_space=None), +non_homogenous_spaces = [ + Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), # + Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), # + Graph(node_space=Discrete(5), edge_space=None), # + Sequence(Discrete(4)), # + Sequence(Box(-10, 10, shape=(2, 2))), # + Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), # + Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), # + Dict( + a=Graph(node_space=Discrete(4), edge_space=Discrete(4)), + b=Box(-10, 10, shape=(2, 2)), + ), # + Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), # + Tuple( + [ + Graph(node_space=Discrete(4), edge_space=Discrete(4)), + Box(-10, 10, shape=(2, 2)), + ] + ), # + Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), # + Dict( + a=Dict( + a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2)) + ), + b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]), + ), # + Dict( + a=Dict( + a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None), + b=Box(-100, 100, shape=(2, 2)), + ), + b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]), + ), ] +@pytest.mark.parametrize("space", non_homogenous_spaces) +def test_non_flattenable(space): + assert space.is_np_flattenable is False + with pytest.raises( + ValueError, + match=re.escape( + "cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace" + ), + ): + utils.flatdim(space) + + @pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims)) def test_flatdim(space, flatdim): + assert space.is_np_flattenable dim = utils.flatdim(space) assert dim == flatdim, f"Expected {dim} to equal {flatdim}" @@ -64,7 +109,7 @@ def test_flatten_space_boxes(space): assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}" -@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces) +@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces) def test_flat_space_contains_flat_points(space): some_samples = [space.sample() for _ in range(10)] flattened_samples = [utils.flatten(space, sample) for sample in some_samples] @@ -83,7 +128,7 @@ def test_flatten_dim(space): assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}" -@pytest.mark.parametrize("space", homogeneous_spaces + graph_spaces) +@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces) def test_flatten_roundtripping(space): some_samples = [space.sample() for _ in range(10)] flattened_samples = [utils.flatten(space, sample) for sample in some_samples] @@ -96,11 +141,14 @@ def test_flatten_roundtripping(space): assert compare_nested( original, roundtripped ), f"Expected sample #{i} {original} to equal {roundtripped}" + assert space.contains(roundtripped) def compare_nested(left, right): - if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): - return np.allclose(left, right) + if type(left) != type(right): + return False + elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + return left.shape == right.shape and np.allclose(left, right) elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict): res = len(left) == len(right) for ((left_key, left_value), (right_key, right_value)) in zip( @@ -193,7 +241,7 @@ def compare_sample_types(original_space, original_sample, unflattened_sample): ) -samples = [ +homogeneous_samples = [ 2, np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32), np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16), @@ -210,7 +258,7 @@ samples = [ ] -expected_flattened_samples = [ +expected_flattened_hom_samples = [ np.array([0, 0, 1], dtype=np.int64), np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32), np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16), @@ -224,23 +272,297 @@ expected_flattened_samples = [ np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64), ] +non_homogenous_samples = [ + GraphInstance( + np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32), + np.array( + [ + 0, + ], + dtype=int, + ), + np.array([[0, 1]], dtype=int), + ), + GraphInstance( + np.array([0, 1], dtype=int), + np.array([[[1, 2], [3, 4]]], dtype=np.float32), + np.array([[0, 1]], dtype=int), + ), + GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)), + (0, 1, 2), + ( + np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[4, 5], [6, 7]], dtype=np.float32), + ), + ( + (np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)), + (np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)), + ), + OrderedDict( + [("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))] + ), + OrderedDict( + [ + ( + "a", + GraphInstance( + np.array([1, 2], dtype=np.int), + np.array( + [ + 0, + ], + dtype=int, + ), + np.array([[0, 1]], dtype=int), + ), + ), + ("b", np.array([[0, 1], [2, 3]], dtype=np.float32)), + ] + ), + ((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)), + ( + GraphInstance( + np.array([1, 2], dtype=np.int), + np.array( + [ + 0, + ], + dtype=int, + ), + np.array([[0, 1]], dtype=int), + ), + np.array([[0, 1], [2, 3]], dtype=np.float32), + ), + ( + GraphInstance( + nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32), + edges=np.array([0], dtype=int), + edge_links=np.array([[0, 1]]), + ), + GraphInstance( + nodes=np.array( + [[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32 + ), + edges=np.array([1], dtype=int), + edge_links=np.array([[0, 1]]), + ), + ), + OrderedDict( + [ + ( + "a", + OrderedDict( + [ + ( + "a", + ( + np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[4, 5], [6, 7]], dtype=np.float32), + ), + ), + ("b", np.array([[8, 9], [10, 11]], dtype=np.float32)), + ] + ), + ), + ( + "b", + ( + np.array([12, 13], dtype=np.float32), + np.array([14, 15], dtype=np.float32), + ), + ), + ] + ), + OrderedDict( + [ + ( + "a", + OrderedDict( + [ + ( + "a", + GraphInstance( + np.array( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + dtype=np.float32, + ), + None, + np.array([[0, 1]], dtype=int), + ), + ), + ("b", np.array([[8, 9], [10, 11]], dtype=np.float32)), + ] + ), + ), + ( + "b", + ( + np.array([12, 13], dtype=np.float32), + np.array([14, 15], dtype=np.float32), + ), + ), + ] + ), +] + + +expected_flattened_non_hom_samples = [ + GraphInstance( + np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32), + np.array([[1, 0, 0, 0, 0]], dtype=int), + np.array([[0, 1]], dtype=int), + ), + GraphInstance( + np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int), + np.array([[1, 2, 3, 4]], dtype=np.float32), + np.array([[0, 1]], dtype=int), + ), + GraphInstance( + np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int), + None, + np.array([[0, 1]], dtype=int), + ), + ( + np.array([1, 0, 0, 0], dtype=int), + np.array([0, 1, 0, 0], dtype=int), + np.array([0, 0, 1, 0], dtype=int), + ), + ( + np.array([0, 1, 2, 3], dtype=np.float32), + np.array([4, 5, 6, 7], dtype=np.float32), + ), + ( + np.array([0, 1, 2, 3], dtype=np.float32), + np.array([4, 5, 6, 7], dtype=np.float32), + ), + OrderedDict( + [ + ( + "a", + ( + np.array([1, 0, 0, 0], dtype=int), + np.array([0, 1, 0, 0], dtype=int), + np.array([0, 0, 1, 0], dtype=int), + ), + ), + ("b", np.array([0, 1, 2, 3], dtype=np.float32)), + ] + ), + OrderedDict( + [ + ( + "a", + GraphInstance( + np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int), + np.array([[1, 0, 0, 0]], dtype=int), + np.array([[0, 1]], dtype=int), + ), + ), + ("b", np.array([0, 1, 2, 3], dtype=np.float32)), + ] + ), + ( + ( + np.array([1, 0, 0, 0], dtype=int), + np.array([0, 1, 0, 0], dtype=int), + np.array([0, 0, 1, 0], dtype=int), + ), + np.array([0, 1, 2, 3], dtype=np.float32), + ), + ( + GraphInstance( + np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32), + np.array([[1, 0, 0, 0]], dtype=int), + np.array([[0, 1]], dtype=int), + ), + np.array([0, 1, 2, 3], dtype=np.float32), + ), + ( + GraphInstance( + np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32), + np.array([[1, 0, 0, 0]]), + np.array([[0, 1]]), + ), + GraphInstance( + np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32), + np.array([[0, 1, 0, 0]]), + np.array([[0, 1]]), + ), + ), + OrderedDict( + [ + ( + "a", + OrderedDict( + [ + ( + "a", + ( + np.array([0, 1, 2, 3], dtype=np.float32), + np.array([4, 5, 6, 7], dtype=np.float32), + ), + ), + ("b", np.array([8, 9, 10, 11], dtype=np.float32)), + ] + ), + ), + ("b", (np.array([12, 13, 14, 15], dtype=np.float32))), + ] + ), + OrderedDict( + [ + ( + "a", + OrderedDict( + [ + ( + "a", + GraphInstance( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32 + ), + None, + np.array([[0, 1]], dtype=int), + ), + ), + ("b", np.array([8, 9, 10, 11], dtype=np.float32)), + ] + ), + ), + ("b", (np.array([12, 13, 14, 15], dtype=np.float32))), + ] + ), +] + @pytest.mark.parametrize( ["space", "sample", "expected_flattened_sample"], - zip(homogeneous_spaces, samples, expected_flattened_samples), + zip( + homogeneous_spaces + non_homogenous_spaces, + homogeneous_samples + non_homogenous_samples, + expected_flattened_hom_samples + expected_flattened_non_hom_samples, + ), ) def test_flatten(space, sample, expected_flattened_sample): - assert sample in space - flattened_sample = utils.flatten(space, sample) - assert flattened_sample.shape == expected_flattened_sample.shape - assert flattened_sample.dtype == expected_flattened_sample.dtype - assert np.all(flattened_sample == expected_flattened_sample) + flat_space = utils.flatten_space(space) + + assert sample in space + assert flattened_sample in flat_space + + if space.is_np_flattenable: + assert isinstance(flattened_sample, np.ndarray) + assert flattened_sample.shape == expected_flattened_sample.shape + assert flattened_sample.dtype == expected_flattened_sample.dtype + assert np.all(flattened_sample == expected_flattened_sample) + else: + assert not isinstance(flattened_sample, np.ndarray) + assert compare_nested(flattened_sample, expected_flattened_sample) @pytest.mark.parametrize( ["space", "flattened_sample", "expected_sample"], - zip(homogeneous_spaces, expected_flattened_samples, samples), + zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples), ) def test_unflatten(space, flattened_sample, expected_sample): sample = utils.unflatten(space, flattened_sample)