Update Composite spaces with full coverage (#3047)

* Update composite space and tests

* Pre-commit

* pyright

* Fix pyright

* retrigger actions

* Code review by Arjun

* Code review by Arjun

* Code review by Omar
This commit is contained in:
Mark Towers
2022-09-03 23:39:23 +01:00
committed by GitHub
parent 8e74fe3b62
commit f39747d6a2
14 changed files with 741 additions and 151 deletions

View File

@@ -3,7 +3,9 @@ from collections import OrderedDict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any from typing import Any
from typing import Dict as TypingDict from typing import Dict as TypingDict
from typing import Optional, Union from typing import List, Optional
from typing import Sequence as TypingSequence
from typing import Tuple, Union
import numpy as np import numpy as np
@@ -51,7 +53,12 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def __init__( def __init__(
self, self,
spaces: Optional[TypingDict[str, Space]] = None, spaces: Optional[
Union[
TypingDict[str, Space],
TypingSequence[Tuple[str, Space]],
]
] = None,
seed: Optional[Union[dict, int, np.random.Generator]] = None, seed: Optional[Union[dict, int, np.random.Generator]] = None,
**spaces_kwargs: Space, **spaces_kwargs: Space,
): ):
@@ -74,12 +81,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space. seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above. **spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
""" """
assert (spaces is None) or ( # Convert the spaces into an OrderedDict
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
if spaces is None:
spaces = spaces_kwargs
if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict): if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict):
try: try:
spaces = OrderedDict(sorted(spaces.items())) spaces = OrderedDict(sorted(spaces.items()))
@@ -87,16 +89,30 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
# Incomparable types (e.g. `int` vs. `str`, or user-defined types) found. # Incomparable types (e.g. `int` vs. `str`, or user-defined types) found.
# The keys remain in the insertion order. # The keys remain in the insertion order.
spaces = OrderedDict(spaces.items()) spaces = OrderedDict(spaces.items())
if isinstance(spaces, Sequence): elif isinstance(spaces, Sequence):
spaces = OrderedDict(spaces) spaces = OrderedDict(spaces)
elif spaces is None:
spaces = OrderedDict()
else:
assert isinstance(
spaces, OrderedDict
), f"Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: {type(spaces)}"
assert isinstance(spaces, OrderedDict), "spaces must be a dictionary" # Add kwargs to spaces to allow both dictionary and keywords to be used
for key, space in spaces_kwargs.items():
if key not in spaces:
spaces[key] = space
else:
raise ValueError(
f"Dict space keyword '{key}' already exists in the spaces dictionary."
)
self.spaces = spaces self.spaces = spaces
for space in spaces.values(): for key, space in self.spaces.items():
assert isinstance( assert isinstance(
space, Space space, Space
), "Values of the dict should be instances of gym.Space" ), f"Dict space element is not an instance of Space: key='{key}', space={space}"
super().__init__( super().__init__(
None, None, seed # type: ignore None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling ) # None for shape and dtype, since it'll require special handling
@@ -120,27 +136,26 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
seeds = [] seeds = []
if isinstance(seed, dict): if isinstance(seed, dict):
for key, seed_key in zip(self.spaces, seed): assert (
assert key == seed_key, print( seed.keys() == self.spaces.keys()
"Key value", ), f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
seed_key, for key in seed.keys():
"in passed seed dict did not match key value", seeds += self.spaces[key].seed(seed[key])
key,
"in spaces Dict.",
)
seeds += self.spaces[key].seed(seed[seed_key])
elif isinstance(seed, int): elif isinstance(seed, int):
seeds = super().seed(seed) seeds = super().seed(seed)
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
subseeds = self.np_random.integers( subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces) np.iinfo(np.int32).max, size=len(self.spaces)
) )
for subspace, subseed in zip(self.spaces.values(), subseeds): for subspace, subseed in zip(self.spaces.values(), subseeds):
seeds.append(subspace.seed(int(subseed))[0]) seeds += subspace.seed(int(subseed))
elif seed is None: elif seed is None:
for space in self.spaces.values(): for space in self.spaces.values():
seeds += space.seed(seed) seeds += space.seed(None)
else: else:
raise TypeError("Passed seed not of an expected type: dict or int or None") raise TypeError(
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
)
return seeds return seeds
@@ -170,14 +185,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def contains(self, x) -> bool: def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
if not isinstance(x, dict) or len(x) != len(self.spaces): if isinstance(x, dict) and x.keys() == self.spaces.keys():
return all(x[key] in self.spaces[key] for key in self.spaces.keys())
return False return False
for k, space in self.spaces.items():
if k not in x:
return False
if not space.contains(x[k]):
return False
return True
def __getitem__(self, key: str) -> Space: def __getitem__(self, key: str) -> Space:
"""Get the space that is associated to `key`.""" """Get the space that is associated to `key`."""
@@ -185,6 +195,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def __setitem__(self, key: str, value: Space): def __setitem__(self, key: str, value: Space):
"""Set the space that is associated to `key`.""" """Set the space that is associated to `key`."""
assert isinstance(
value, Space
), f"Trying to set {key} to Dict space with value that is not a gym space, actual type: {type(value)}"
self.spaces[key] = value self.spaces[key] = value
def __iter__(self): def __iter__(self):
@@ -217,16 +230,16 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
for key, space in self.spaces.items() for key, space in self.spaces.items()
} }
def from_jsonable(self, sample_n: TypingDict[str, list]) -> list: def from_jsonable(self, sample_n: TypingDict[str, list]) -> List[dict]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
dict_of_list: TypingDict[str, list] = {} dict_of_list: TypingDict[str, list] = {
for key, space in self.spaces.items(): key: space.from_jsonable(sample_n[key])
dict_of_list[key] = space.from_jsonable(sample_n[key]) for key, space in self.spaces.items()
ret = [] }
n_elements = len(next(iter(dict_of_list.values()))) n_elements = len(next(iter(dict_of_list.values())))
for i in range(n_elements): result = [
entry = {} OrderedDict({key: value[n] for key, value in dict_of_list.items()})
for key, value in dict_of_list.items(): for n in range(n_elements)
entry[key] = value[i] ]
ret.append(entry) return result
return ret

View File

@@ -1,31 +1,27 @@
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space.""" """Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from collections import namedtuple
from typing import NamedTuple, Optional, Sequence, Tuple, Union from typing import NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from gym.logger import warn
from gym.spaces.box import Box from gym.spaces.box import Box
from gym.spaces.discrete import Discrete from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import MultiDiscrete from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.space import Space from gym.spaces.space import Space
class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"])): class GraphInstance(NamedTuple):
r"""Returns a NamedTuple representing a graph object. """A Graph space instance.
Args: * nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes, (...) must adhere to the shape of the node space.
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes. * edges (Optional[np.ndarray]): an (m x ...) sized array representing the features for m nodes, (...) must adhere to the shape of the edge space.
(...) must adhere to the shape of the node space. * edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the two nodes that each edge connects.
edges (np.ndarray): an (m x ...) sized array representing the features for m edges.
(...) must adhere to the shape of the edge space.
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
Returns:
A NamedTuple representing a graph with `.nodes`, `.edges`, and `.edge_links`.
""" """
nodes: np.ndarray
edges: Optional[np.ndarray]
edge_links: Optional[np.ndarray]
class Graph(Space): class Graph(Space):
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`. r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
@@ -89,7 +85,7 @@ class Graph(Space):
elif isinstance(base_space, Discrete): elif isinstance(base_space, Discrete):
return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random) return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
else: else:
raise AssertionError( raise TypeError(
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}." f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
) )
@@ -103,7 +99,7 @@ class Graph(Space):
] = None, ] = None,
num_nodes: int = 10, num_nodes: int = 10,
num_edges: Optional[int] = None, num_edges: Optional[int] = None,
) -> NamedTuple: ) -> GraphInstance:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph. """Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
Args: Args:
@@ -132,12 +128,17 @@ class Graph(Space):
num_edges = self.np_random.integers(num_nodes * (num_nodes - 1)) num_edges = self.np_random.integers(num_nodes * (num_nodes - 1))
else: else:
num_edges = 0 num_edges = 0
if edge_space_mask is not None: if edge_space_mask is not None:
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges)) edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
else: else:
if self.edge_space is None:
warn(
f"The number of edges is set ({num_edges}) but the edge space is None."
)
assert ( assert (
num_edges >= 0 num_edges >= 0
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}" ), f"Expects the number of edges to be greater than 0, actual value: {num_edges}"
assert num_edges is not None assert num_edges is not None
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes) sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
@@ -160,38 +161,31 @@ class Graph(Space):
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links) return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
def contains(self, x: GraphInstance) -> bool: def contains(self, x: GraphInstance) -> bool:
"""Return boolean specifying if x is a valid member of this space. """Return boolean specifying if x is a valid member of this space."""
if isinstance(x, GraphInstance):
Returns False when: # Checks the nodes
- any node in nodes is not contained in Graph.node_space if isinstance(x.nodes, np.ndarray):
- edge_links is not of dtype int if all(node in self.node_space for node in x.nodes):
- len(edge_links) != len(edges) # Check the edges and edge links which are optional
- has edges but Graph.edge_space is None if isinstance(x.edges, np.ndarray) and isinstance(
- edge_links has index less than 0 x.edge_links, np.ndarray
- edge_links has index more than number of nodes ):
- any edge in edges is not contained in Graph.edge_space assert x.edges is not None
""" assert x.edge_links is not None
if not isinstance(x, GraphInstance): if self.edge_space is not None:
return False if all(edge in self.edge_space for edge in x.edges):
if x.edges is not None: if np.issubdtype(x.edge_links.dtype, np.integer):
if not np.issubdtype(x.edge_links.dtype, np.integer): if x.edge_links.shape == (len(x.edges), 2):
return False if np.all(
if x.edge_links.shape[-1] != 2: np.logical_and(
return False x.edge_links >= 0,
if self.edge_space is None: x.edge_links < len(x.nodes),
return False )
if x.edge_links.min() < 0: ):
return False
if x.edge_links.max() >= len(x.nodes):
return False
if len(x.edges) != len(x.edge_links):
return False
if any(edge not in self.edge_space for edge in x.edges):
return False
if any(node not in self.node_space for node in x.nodes):
return False
return True return True
else:
return x.edges is None and x.edge_links is None
return False
def __repr__(self) -> str: def __repr__(self) -> str:
"""A string representation of this space. """A string representation of this space.

View File

@@ -4,8 +4,8 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import gym
from gym.spaces.space import Space from gym.spaces.space import Space
from gym.utils import seeding
class Sequence(Space[Tuple]): class Sequence(Space[Tuple]):
@@ -25,7 +25,7 @@ class Sequence(Space[Tuple]):
def __init__( def __init__(
self, self,
space: Space, space: Space,
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None, seed: Optional[Union[int, np.random.Generator]] = None,
): ):
"""Constructor of the :class:`Sequence` space. """Constructor of the :class:`Sequence` space.
@@ -33,6 +33,9 @@ class Sequence(Space[Tuple]):
space: Elements in the sequences this space represent must belong to this space. space: Elements in the sequences this space represent must belong to this space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
""" """
assert isinstance(
space, gym.Space
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
self.feature_space = space self.feature_space = space
super().__init__( super().__init__(
None, None, seed # type: ignore None, None, seed # type: ignore
@@ -50,17 +53,20 @@ class Sequence(Space[Tuple]):
return False return False
def sample( def sample(
self, mask: Optional[Tuple[Optional[np.ndarray], Any]] = None self,
mask: Optional[Tuple[Optional[Union[np.ndarray, int]], Optional[Any]]] = None,
) -> Tuple[Any]: ) -> Tuple[Any]:
"""Generates a single random sample from this space. """Generates a single random sample from this space.
Args: Args:
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence. mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask` If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask`
is either `None` if you do not want to specify any restrictions on the length of the sampled sequence (then, the is
length will be randomly drawn from a geometric distribution), or a `np.ndarray` of integers, in which case the length of - `None` The length will be randomly drawn from a geometric distribution
the sampled sequence is randomly drawn from this array. The second element of the tuple, `sample` mask - `np.ndarray` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
specifies a mask that is applied when sampling elements from the base space. - `int` for a fixed length sample
The second element of the mask tuple `sample` mask specifies a mask that is applied when
sampling elements from the base space. The mask is applied for each feature space sample.
Returns: Returns:
A tuple of random length with random samples of elements from the :attr:`feature_space`. A tuple of random length with random samples of elements from the :attr:`feature_space`.
@@ -68,11 +74,28 @@ class Sequence(Space[Tuple]):
if mask is not None: if mask is not None:
length_mask, feature_mask = mask length_mask, feature_mask = mask
else: else:
length_mask = None length_mask, feature_mask = None, None
feature_mask = None
if length_mask is not None: if length_mask is not None:
if np.issubdtype(type(length_mask), np.integer):
assert (
0 <= length_mask
), f"Expects the length mask to be greater than or equal to zero, actual value: {length_mask}"
length = length_mask
elif isinstance(length_mask, np.ndarray):
assert (
len(length_mask.shape) == 1
), f"Expects the shape of the length mask to be 1-dimensional, actual shape: {length_mask.shape}"
assert np.all(
0 <= length_mask
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
length = self.np_random.choice(length_mask) length = self.np_random.choice(length_mask)
else: else:
raise TypeError(
f"Expects the type of length_mask to an integer or a np.ndarray, actual type: {type(length_mask)}"
)
else:
# The choice of 0.25 is arbitrary
length = self.np_random.geometric(0.25) length = self.np_random.geometric(0.25)
return tuple( return tuple(

View File

@@ -1,12 +1,16 @@
"""Implementation of a space that represents the cartesian product of other spaces.""" """Implementation of a space that represents the cartesian product of other spaces."""
from typing import Iterable, List, Optional, Sequence, Tuple, Union from collections.abc import Sequence as CollectionSequence
from typing import Iterable, Optional
from typing import Sequence as TypingSequence
from typing import Tuple as TypingTuple
from typing import Union
import numpy as np import numpy as np
from gym.spaces.space import Space from gym.spaces.space import Space
class Tuple(Space[tuple], Sequence): class Tuple(Space[tuple], CollectionSequence):
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances. """A tuple (more precisely: the cartesian product) of :class:`Space` instances.
Elements of this space are tuples of elements of the constituent spaces. Elements of this space are tuples of elements of the constituent spaces.
@@ -22,7 +26,7 @@ class Tuple(Space[tuple], Sequence):
def __init__( def __init__(
self, self,
spaces: Iterable[Space], spaces: Iterable[Space],
seed: Optional[Union[int, List[int], np.random.Generator]] = None, seed: Optional[Union[int, TypingSequence[int], np.random.Generator]] = None,
): ):
r"""Constructor of :class:`Tuple` space. r"""Constructor of :class:`Tuple` space.
@@ -44,7 +48,9 @@ class Tuple(Space[tuple], Sequence):
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" """Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces) return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list: def seed(
self, seed: Optional[Union[int, TypingSequence[int]]] = None
) -> TypingSequence[int]:
"""Seed the PRNG of this space and all subspaces. """Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently Depending on the type of seed, the subspaces will be seeded differently
@@ -57,25 +63,32 @@ class Tuple(Space[tuple], Sequence):
""" """
seeds = [] seeds = []
if isinstance(seed, list): if isinstance(seed, CollectionSequence):
for i, space in enumerate(self.spaces): assert len(seed) == len(
seeds += space.seed(seed[i]) self.spaces
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int): elif isinstance(seed, int):
seeds = super().seed(seed) seeds = super().seed(seed)
subseeds = self.np_random.integers( subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces) np.iinfo(np.int32).max, size=len(self.spaces)
) )
for subspace, subseed in zip(self.spaces, subseeds): for subspace, subseed in zip(self.spaces, subseeds):
seeds.append(subspace.seed(int(subseed))[0]) seeds += subspace.seed(int(subseed))
elif seed is None: elif seed is None:
for space in self.spaces: for space in self.spaces:
seeds += space.seed(seed) seeds += space.seed(seed)
else: else:
raise TypeError("Passed seed not of an expected type: list or int or None") raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)
return seeds return seeds
def sample(self, mask: Optional[Tuple[Optional[np.ndarray]]] = None) -> tuple: def sample(
self, mask: Optional[TypingTuple[Optional[np.ndarray], ...]] = None
) -> tuple:
"""Generates a single random sample inside this space. """Generates a single random sample inside this space.
This method draws independent samples from the subspaces. This method draws independent samples from the subspaces.
@@ -116,7 +129,7 @@ class Tuple(Space[tuple], Sequence):
"""Gives a string representation of this space.""" """Gives a string representation of this space."""
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")" return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n: Sequence) -> list: def to_jsonable(self, sample_n: CollectionSequence) -> list:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
# serialize as list-repr of tuple of vectors # serialize as list-repr of tuple of vectors
return [ return [

View File

@@ -89,6 +89,13 @@ def _flatdim_dict(space: Dict) -> int:
) )
@flatdim.register(Graph)
def _flatdim_graph(space: Graph):
raise ValueError(
"Cannot get flattened size as the Graph Space in Gym has a dynamic size."
)
@flatdim.register(Text) @flatdim.register(Text)
def _flatdim_text(space: Text) -> int: def _flatdim_text(space: Text) -> int:
return space.max_length return space.max_length
@@ -157,11 +164,11 @@ def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
return np.concatenate( return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)] [flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
) )
return tuple((flatten(s, x_part) for x_part, s in zip(x, space.spaces))) return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
@flatten.register(Dict) @flatten.register(Dict)
def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]: def _flatten_dict(space, x) -> Union[dict, np.ndarray]:
if space.is_np_flattenable: if space.is_np_flattenable:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items()) return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
@@ -171,14 +178,19 @@ def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]:
def _flatten_graph(space, x) -> GraphInstance: def _flatten_graph(space, x) -> GraphInstance:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring.""" """We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
def _graph_unflatten(space, x): def _graph_unflatten(unflatten_space, unflatten_x):
ret = None ret = None
if space is not None and x is not None: if unflatten_space is not None and unflatten_x is not None:
if isinstance(space, Box): if isinstance(unflatten_space, Box):
ret = x.reshape(x.shape[0], -1) ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
elif isinstance(space, Discrete): elif isinstance(unflatten_space, Discrete):
ret = np.zeros((x.shape[0], space.n - space.start), dtype=space.dtype) ret = np.zeros(
ret[np.arange(x.shape[0]), x - space.start] = 1 (unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
dtype=unflatten_space.dtype,
)
ret[
np.arange(unflatten_x.shape[0]), unflatten_x - unflatten_space.start
] = 1
return ret return ret
nodes = _graph_unflatten(space.node_space, x.nodes) nodes = _graph_unflatten(space.node_space, x.nodes)

140
tests/spaces/test_dict.py Normal file
View File

@@ -0,0 +1,140 @@
from collections import OrderedDict
import numpy as np
import pytest
from gym.spaces import Box, Dict, Discrete
def test_dict_init():
with pytest.raises(
AssertionError,
match=r"^Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: ",
):
Dict(Discrete(2))
with pytest.raises(
ValueError,
match="Dict space keyword 'a' already exists in the spaces dictionary",
):
Dict({"a": Discrete(3)}, a=Box(0, 1))
with pytest.raises(
AssertionError,
match="Dict space element is not an instance of Space: key='b', space=Box",
):
Dict(a=Discrete(2), b="Box")
with pytest.warns(None) as warnings:
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
assert a == b == c == d
assert len(warnings) == 0
with pytest.warns(None) as warnings:
Dict({1: Discrete(2), "a": Discrete(3)})
assert len(warnings) == 0
DICT_SPACE = Dict(
{
"a": Box(low=0, high=1, shape=(3, 3)),
"b": Dict(
{
"b_1": Box(low=-100, high=100, shape=(2,)),
"b_2": Box(low=-1, high=1, shape=(2,)),
}
),
"c": Discrete(5),
}
)
def test_dict_seeding():
seeds = DICT_SPACE.seed(
{
"a": 0,
"b": {
"b_1": 1,
"b_2": 2,
},
"c": 3,
}
)
assert all(isinstance(seed, int) for seed in seeds)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=0)
b_1 = Box(low=-100, high=100, shape=(2,), seed=1)
b_2 = Box(low=-1, high=1, shape=(2,), seed=2)
c = Discrete(5, seed=3)
for i in range(10):
dict_sample = DICT_SPACE.sample()
assert np.all(dict_sample["a"] == a.sample())
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
assert dict_sample["c"] == c.sample()
def test_int_seeding():
seeds = DICT_SPACE.seed(1)
assert all(isinstance(seed, int) for seed in seeds)
# rng, seeds = seeding.np_random(1)
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
# b_rng, b_seeds = seeding.np_random(int(subseeds[1]))
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=seeds[1])
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds[3])
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds[4])
c = Discrete(5, seed=seeds[5])
for i in range(10):
dict_sample = DICT_SPACE.sample()
assert np.all(dict_sample["a"] == a.sample())
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
assert dict_sample["c"] == c.sample()
def test_none_seeding():
seeds = DICT_SPACE.seed(None)
assert len(seeds) == 4 and all(isinstance(seed, int) for seed in seeds)
def test_bad_seed():
with pytest.raises(TypeError):
DICT_SPACE.seed("a")
def test_mapping():
"""The Gym Dict space inherits from Mapping that allows it to appear like a standard python Dictionary."""
assert len(DICT_SPACE) == 3
a = DICT_SPACE["a"]
b = Discrete(5)
assert a != b
DICT_SPACE["a"] = b
assert DICT_SPACE["a"] == b
with pytest.raises(
AssertionError,
match="Trying to set a to Dict space with value that is not a gym space, actual type: <class 'int'>",
):
DICT_SPACE["a"] = 5
DICT_SPACE["a"] = a
def test_iterator():
"""Tests the Dict `__iter__` function correctly returns keys in the subspaces"""
for key in DICT_SPACE:
assert key in DICT_SPACE.spaces
assert {key for key in DICT_SPACE} == DICT_SPACE.spaces.keys()

135
tests/spaces/test_graph.py Normal file
View File

@@ -0,0 +1,135 @@
import re
import numpy as np
import pytest
from gym.spaces import Discrete, Graph, GraphInstance
def test_node_space_sample():
space = Graph(node_space=Discrete(3), edge_space=None)
sample = space.sample(
mask=(tuple(np.array([0, 1, 0], dtype=np.int8) for _ in range(5)), None),
num_nodes=5,
)
assert sample in space
assert np.all(sample.nodes == 1)
sample = space.sample(
(
(np.array([1, 0, 0], dtype=np.int8), np.array([0, 1, 0], dtype=np.int8)),
None,
),
num_nodes=2,
)
assert sample in space
assert np.all(sample.nodes == np.array([0, 1]))
with pytest.warns(
UserWarning,
match=re.escape("The number of edges is set (5) but the edge space is None."),
):
sample = space.sample(num_edges=5)
assert sample in space
# Change the node_space or edge_space to a non-Box or discrete space.
# This should not happen, test is primarily to increase coverage.
with pytest.raises(
TypeError,
match=re.escape(
"Expects base space to be Box and Discrete, actual space: <class 'str'>"
),
):
space.node_space = "abc"
space.sample()
def test_edge_space_sample():
space = Graph(node_space=Discrete(3), edge_space=Discrete(3))
# When num_nodes>1 then num_edges is set to 0
assert space.sample(num_nodes=1).edges is None
assert 0 <= len(space.sample(num_edges=3).edges) < 6
sample = space.sample(mask=(None, np.array([0, 1, 0], dtype=np.int8)))
assert np.all(sample.edges == 1)
sample = space.sample(
mask=(
None,
(
np.array([1, 0, 0], dtype=np.int8),
np.array([0, 1, 0], dtype=np.int8),
np.array([0, 0, 1], dtype=np.int8),
),
),
num_edges=3,
)
assert np.all(sample.edges == np.array([0, 1, 2]))
with pytest.raises(
AssertionError,
match="Expects the number of edges to be greater than 0, actual value: -1",
):
space.sample(num_edges=-1)
space = Graph(node_space=Discrete(3), edge_space=None)
with pytest.warns(
UserWarning,
match=re.escape(
"\x1b[33mWARN: The number of edges is set (5) but the edge space is None.\x1b[0m"
),
):
sample = space.sample(num_edges=5)
assert sample.edges is None
@pytest.mark.parametrize(
"sample",
[
"abc",
GraphInstance(
nodes=None, edges=np.array([0, 1]), edge_links=np.array([[0, 1], [1, 0]])
),
GraphInstance(
nodes=np.array([10, 1, 0]),
edges=np.array([0, 1]),
edge_links=np.array([[0, 1], [1, 0]]),
),
GraphInstance(
nodes=np.array([0, 1]), edges=None, edge_links=np.array([[0, 1], [1, 0]])
),
GraphInstance(nodes=np.array([0, 1]), edges=np.array([0, 1]), edge_links=None),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([10, 1]),
edge_links=np.array([[0, 1], [1, 0]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[0.5, 1.0], [2.0, 1.0]]),
),
GraphInstance(
nodes=np.array([1, 2]), edges=np.array([10, 1]), edge_links=np.array([0, 1])
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[[0], [1]], [[0], [0]]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[10, 1], [0, 0]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[-10, 1], [0, 0]]),
),
],
)
def test_not_contains(sample):
space = Graph(node_space=Discrete(2), edge_space=Discrete(2))
assert sample not in space

View File

@@ -1,3 +1,19 @@
import numpy as np
from gym.spaces import MultiBinary
def test_sample(): def test_sample():
# todo space = MultiBinary(4)
pass
sample = space.sample(mask=np.array([0, 0, 1, 1], dtype=np.int8))
assert np.all(sample == [0, 0, 1, 1])
sample = space.sample(mask=np.array([0, 1, 2, 2], dtype=np.int8))
assert sample[0] == 0 and sample[1] == 1
assert sample[2] == 0 or sample[2] == 1
assert sample[3] == 0 or sample[3] == 1
space = MultiBinary(np.array([2, 3]))
sample = space.sample(mask=np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int8))
assert np.all(sample == [[0, 0, 0], [1, 1, 1]]), sample

View File

@@ -0,0 +1,59 @@
import re
import numpy as np
import pytest
import gym.spaces
def test_sample():
"""Tests the sequence sampling works as expects and the errors are correctly raised."""
space = gym.spaces.Sequence(gym.spaces.Box(0, 1))
# Test integer mask length
for length in range(4):
sample = space.sample(mask=(length, None))
assert sample in space
assert len(sample) == length
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the length mask to be greater than or equal to zero, actual value: -1"
),
):
space.sample(mask=(-1, None))
# Test np.array mask length
sample = space.sample(mask=(np.array([5]), None))
assert sample in space
assert len(sample) == 5
sample = space.sample(mask=(np.array([3, 4, 5]), None))
assert sample in space
assert len(sample) in [3, 4, 5]
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the shape of the length mask to be 1-dimensional, actual shape: (2, 2)"
),
):
space.sample(mask=(np.array([[2, 2], [2, 2]]), None))
with pytest.raises(
AssertionError,
match=re.escape(
"Expects all values in the length_mask to be greater than or equal to zero, actual values: [ 1 2 -1]"
),
):
space.sample(mask=(np.array([1, 2, -1]), None))
# Test with an invalid length
with pytest.raises(
TypeError,
match=re.escape(
"Expects the type of length_mask to an integer or a np.ndarray, actual type: <class 'str'>"
),
):
space.sample(mask=("abc", None))

View File

@@ -409,10 +409,10 @@ SPACE_KWARGS = [
{"nvec": [3, 2]}, # MultiDiscrete {"nvec": [3, 2]}, # MultiDiscrete
{"n": 2}, # MultiBinary {"n": 2}, # MultiBinary
{"max_length": 5}, # Text {"max_length": 5}, # Text
# {"spaces": (Discrete(3), Discrete(2))}, # Tuple {"spaces": (Discrete(3), Discrete(2))}, # Tuple
# {"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict {"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
# {"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph {"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
# {"space": Discrete(4)}, # Sequence {"space": Discrete(4)}, # Sequence
] ]
assert len(SPACE_CLS) == len(SPACE_KWARGS) assert len(SPACE_CLS) == len(SPACE_KWARGS)

109
tests/spaces/test_tuple.py Normal file
View File

@@ -0,0 +1,109 @@
import numpy as np
import pytest
import gym.spaces
from gym.spaces import Box, Dict, Discrete, MultiBinary, Tuple
from gym.utils.env_checker import data_equivalence
def test_sequence_inheritance():
"""The gym Tuple space inherits from abc.Sequences, this test checks all functions work"""
spaces = [Discrete(5), Discrete(10), Discrete(5)]
tuple_space = Tuple(spaces)
assert len(tuple_space) == len(spaces)
# Test indexing
for i in range(len(tuple_space)):
assert tuple_space[i] == spaces[i]
# Test iterable
for space in tuple_space:
assert space in spaces
# Test count
assert tuple_space.count(Discrete(5)) == 2
assert tuple_space.count(Discrete(6)) == 0
assert tuple_space.count(MultiBinary(2)) == 0
# Test index
assert tuple_space.index(Discrete(5)) == 0
assert tuple_space.index(Discrete(5), 1) == 2
# Test errors
with pytest.raises(ValueError):
tuple_space.index(Discrete(10), 0, 1)
with pytest.raises(IndexError):
assert tuple_space[4]
@pytest.mark.parametrize(
"space, seed, expected_len",
[
(Tuple([Discrete(5), Discrete(4)]), None, 2),
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
(
Tuple(
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
),
(123, (456, 789)),
3,
),
(
Tuple(
(
Discrete(3),
Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)),
)
),
(123, {"position": 456, "velocity": 789}),
3,
),
],
)
def test_seeds(space, seed, expected_len):
seeds = space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
sample1 = space.sample()
seeds2 = space.seed(seed)
sample2 = space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
@pytest.mark.parametrize(
"space_fn",
[
lambda: Tuple(["abc"]),
lambda: Tuple([gym.spaces.Box(0, 1), "abc"]),
lambda: Tuple("abc"),
],
)
def test_bad_space_calls(space_fn):
with pytest.raises(AssertionError):
space_fn()
def test_contains_promotion():
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (2,))))
assert (
np.array([0.0], dtype=np.float32),
np.array([0.0, 0.0], dtype=np.float32),
) in space
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (1,))))
assert np.array([[0.0], [0.0]], dtype=np.float32) in space
def test_bad_seed():
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(0, 1)))
with pytest.raises(
TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
):
space.seed(0.0)

View File

@@ -29,25 +29,25 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
6, 6,
6, 6,
6, 6,
# # Tuple # Tuple
# 9, 9,
# 7, 7,
# 10, 10,
# 6, 6,
# None, None,
# # Dict # Dict
# 7, 7,
# 8, 8,
# 17, 17,
# None, None,
# # Graph # Graph
# None, None,
# None, None,
# None, None,
# # Sequence # Sequence
# None, None,
# None, None,
# None, None,
] ]

View File

@@ -2,7 +2,18 @@ from typing import List
import numpy as np import numpy as np
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text from gym.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
)
TESTING_FUNDAMENTAL_SPACES = [ TESTING_FUNDAMENTAL_SPACES = [
Discrete(3), Discrete(3),
@@ -23,5 +34,70 @@ TESTING_FUNDAMENTAL_SPACES = [
TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES] TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES # + TESTING_COMPOSITE_SPACES TESTING_COMPOSITE_SPACES = [
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS # + TESTING_COMPOSITE_SPACES_IDS # Tuple spaces
Tuple([Discrete(5), Discrete(4)]),
Tuple(
(
Discrete(5),
Box(
low=np.array([0.0, 0.0]),
high=np.array([1.0, 5.0]),
dtype=np.float64,
),
)
),
Tuple((Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))),
Tuple((Discrete(3), Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)))),
Tuple((Graph(node_space=Box(-1, 1, shape=(2, 1)), edge_space=None), Discrete(2))),
# Dict spaces
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0.0, 0.0]),
high=np.array([1.0, 5.0]),
dtype=np.float64,
),
}
),
Dict(
position=Discrete(6),
velocity=Box(
low=np.array([0.0, 0.0]),
high=np.array([1.0, 5.0]),
dtype=np.float64,
),
),
Dict(
{
"a": Box(low=0, high=1, shape=(3, 3)),
"b": Dict(
{
"b_1": Box(low=-100, high=100, shape=(2,)),
"b_2": Box(low=-1, high=1, shape=(2,)),
}
),
"c": Discrete(4),
}
),
Dict(
a=Dict(
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
b=Box(-100, 100, shape=(2, 2)),
),
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
),
# Graph spaces
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(3), edge_space=Discrete(4)),
# Sequence spaces
Sequence(Discrete(4)),
Sequence(Dict({"feature": Box(0, 1, (3,))})),
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
]
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES + TESTING_COMPOSITE_SPACES
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS + TESTING_COMPOSITE_SPACES_IDS