Update Composite spaces with full coverage (#3047)

* Update composite space and tests

* Pre-commit

* pyright

* Fix pyright

* retrigger actions

* Code review by Arjun

* Code review by Arjun

* Code review by Omar
This commit is contained in:
Mark Towers
2022-09-03 23:39:23 +01:00
committed by GitHub
parent 8e74fe3b62
commit f39747d6a2
14 changed files with 741 additions and 151 deletions

View File

@@ -3,7 +3,9 @@ from collections import OrderedDict
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Dict as TypingDict
from typing import Optional, Union
from typing import List, Optional
from typing import Sequence as TypingSequence
from typing import Tuple, Union
import numpy as np
@@ -51,7 +53,12 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def __init__(
self,
spaces: Optional[TypingDict[str, Space]] = None,
spaces: Optional[
Union[
TypingDict[str, Space],
TypingSequence[Tuple[str, Space]],
]
] = None,
seed: Optional[Union[dict, int, np.random.Generator]] = None,
**spaces_kwargs: Space,
):
@@ -74,12 +81,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
"""
assert (spaces is None) or (
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
if spaces is None:
spaces = spaces_kwargs
# Convert the spaces into an OrderedDict
if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict):
try:
spaces = OrderedDict(sorted(spaces.items()))
@@ -87,16 +89,30 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
# Incomparable types (e.g. `int` vs. `str`, or user-defined types) found.
# The keys remain in the insertion order.
spaces = OrderedDict(spaces.items())
if isinstance(spaces, Sequence):
elif isinstance(spaces, Sequence):
spaces = OrderedDict(spaces)
elif spaces is None:
spaces = OrderedDict()
else:
assert isinstance(
spaces, OrderedDict
), f"Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: {type(spaces)}"
assert isinstance(spaces, OrderedDict), "spaces must be a dictionary"
# Add kwargs to spaces to allow both dictionary and keywords to be used
for key, space in spaces_kwargs.items():
if key not in spaces:
spaces[key] = space
else:
raise ValueError(
f"Dict space keyword '{key}' already exists in the spaces dictionary."
)
self.spaces = spaces
for space in spaces.values():
for key, space in self.spaces.items():
assert isinstance(
space, Space
), "Values of the dict should be instances of gym.Space"
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
super().__init__(
None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling
@@ -120,27 +136,26 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
seeds = []
if isinstance(seed, dict):
for key, seed_key in zip(self.spaces, seed):
assert key == seed_key, print(
"Key value",
seed_key,
"in passed seed dict did not match key value",
key,
"in spaces Dict.",
)
seeds += self.spaces[key].seed(seed[seed_key])
assert (
seed.keys() == self.spaces.keys()
), f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
for key in seed.keys():
seeds += self.spaces[key].seed(seed[key])
elif isinstance(seed, int):
seeds = super().seed(seed)
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces.values(), subseeds):
seeds.append(subspace.seed(int(subseed))[0])
seeds += subspace.seed(int(subseed))
elif seed is None:
for space in self.spaces.values():
seeds += space.seed(seed)
seeds += space.seed(None)
else:
raise TypeError("Passed seed not of an expected type: dict or int or None")
raise TypeError(
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
)
return seeds
@@ -170,14 +185,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if not isinstance(x, dict) or len(x) != len(self.spaces):
if isinstance(x, dict) and x.keys() == self.spaces.keys():
return all(x[key] in self.spaces[key] for key in self.spaces.keys())
return False
for k, space in self.spaces.items():
if k not in x:
return False
if not space.contains(x[k]):
return False
return True
def __getitem__(self, key: str) -> Space:
"""Get the space that is associated to `key`."""
@@ -185,6 +195,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def __setitem__(self, key: str, value: Space):
"""Set the space that is associated to `key`."""
assert isinstance(
value, Space
), f"Trying to set {key} to Dict space with value that is not a gym space, actual type: {type(value)}"
self.spaces[key] = value
def __iter__(self):
@@ -217,16 +230,16 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
for key, space in self.spaces.items()
}
def from_jsonable(self, sample_n: TypingDict[str, list]) -> list:
def from_jsonable(self, sample_n: TypingDict[str, list]) -> List[dict]:
"""Convert a JSONable data type to a batch of samples from this space."""
dict_of_list: TypingDict[str, list] = {}
for key, space in self.spaces.items():
dict_of_list[key] = space.from_jsonable(sample_n[key])
ret = []
dict_of_list: TypingDict[str, list] = {
key: space.from_jsonable(sample_n[key])
for key, space in self.spaces.items()
}
n_elements = len(next(iter(dict_of_list.values())))
for i in range(n_elements):
entry = {}
for key, value in dict_of_list.items():
entry[key] = value[i]
ret.append(entry)
return ret
result = [
OrderedDict({key: value[n] for key, value in dict_of_list.items()})
for n in range(n_elements)
]
return result

View File

@@ -1,31 +1,27 @@
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from collections import namedtuple
from typing import NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from gym.logger import warn
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.space import Space
class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"])):
r"""Returns a NamedTuple representing a graph object.
class GraphInstance(NamedTuple):
"""A Graph space instance.
Args:
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes.
(...) must adhere to the shape of the node space.
edges (np.ndarray): an (m x ...) sized array representing the features for m edges.
(...) must adhere to the shape of the edge space.
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
Returns:
A NamedTuple representing a graph with `.nodes`, `.edges`, and `.edge_links`.
* nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes, (...) must adhere to the shape of the node space.
* edges (Optional[np.ndarray]): an (m x ...) sized array representing the features for m nodes, (...) must adhere to the shape of the edge space.
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the two nodes that each edge connects.
"""
nodes: np.ndarray
edges: Optional[np.ndarray]
edge_links: Optional[np.ndarray]
class Graph(Space):
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
@@ -89,7 +85,7 @@ class Graph(Space):
elif isinstance(base_space, Discrete):
return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
else:
raise AssertionError(
raise TypeError(
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
)
@@ -103,7 +99,7 @@ class Graph(Space):
] = None,
num_nodes: int = 10,
num_edges: Optional[int] = None,
) -> NamedTuple:
) -> GraphInstance:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
Args:
@@ -132,12 +128,17 @@ class Graph(Space):
num_edges = self.np_random.integers(num_nodes * (num_nodes - 1))
else:
num_edges = 0
if edge_space_mask is not None:
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
else:
if self.edge_space is None:
warn(
f"The number of edges is set ({num_edges}) but the edge space is None."
)
assert (
num_edges >= 0
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"
), f"Expects the number of edges to be greater than 0, actual value: {num_edges}"
assert num_edges is not None
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
@@ -160,38 +161,31 @@ class Graph(Space):
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
def contains(self, x: GraphInstance) -> bool:
"""Return boolean specifying if x is a valid member of this space.
Returns False when:
- any node in nodes is not contained in Graph.node_space
- edge_links is not of dtype int
- len(edge_links) != len(edges)
- has edges but Graph.edge_space is None
- edge_links has index less than 0
- edge_links has index more than number of nodes
- any edge in edges is not contained in Graph.edge_space
"""
if not isinstance(x, GraphInstance):
return False
if x.edges is not None:
if not np.issubdtype(x.edge_links.dtype, np.integer):
return False
if x.edge_links.shape[-1] != 2:
return False
if self.edge_space is None:
return False
if x.edge_links.min() < 0:
return False
if x.edge_links.max() >= len(x.nodes):
return False
if len(x.edges) != len(x.edge_links):
return False
if any(edge not in self.edge_space for edge in x.edges):
return False
if any(node not in self.node_space for node in x.nodes):
return False
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, GraphInstance):
# Checks the nodes
if isinstance(x.nodes, np.ndarray):
if all(node in self.node_space for node in x.nodes):
# Check the edges and edge links which are optional
if isinstance(x.edges, np.ndarray) and isinstance(
x.edge_links, np.ndarray
):
assert x.edges is not None
assert x.edge_links is not None
if self.edge_space is not None:
if all(edge in self.edge_space for edge in x.edges):
if np.issubdtype(x.edge_links.dtype, np.integer):
if x.edge_links.shape == (len(x.edges), 2):
if np.all(
np.logical_and(
x.edge_links >= 0,
x.edge_links < len(x.nodes),
)
):
return True
else:
return x.edges is None and x.edge_links is None
return False
def __repr__(self) -> str:
"""A string representation of this space.

View File

@@ -4,8 +4,8 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import gym
from gym.spaces.space import Space
from gym.utils import seeding
class Sequence(Space[Tuple]):
@@ -25,7 +25,7 @@ class Sequence(Space[Tuple]):
def __init__(
self,
space: Space,
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
):
"""Constructor of the :class:`Sequence` space.
@@ -33,6 +33,9 @@ class Sequence(Space[Tuple]):
space: Elements in the sequences this space represent must belong to this space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
assert isinstance(
space, gym.Space
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
self.feature_space = space
super().__init__(
None, None, seed # type: ignore
@@ -50,17 +53,20 @@ class Sequence(Space[Tuple]):
return False
def sample(
self, mask: Optional[Tuple[Optional[np.ndarray], Any]] = None
self,
mask: Optional[Tuple[Optional[Union[np.ndarray, int]], Optional[Any]]] = None,
) -> Tuple[Any]:
"""Generates a single random sample from this space.
Args:
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask`
is either `None` if you do not want to specify any restrictions on the length of the sampled sequence (then, the
length will be randomly drawn from a geometric distribution), or a `np.ndarray` of integers, in which case the length of
the sampled sequence is randomly drawn from this array. The second element of the tuple, `sample` mask
specifies a mask that is applied when sampling elements from the base space.
is
- `None` The length will be randomly drawn from a geometric distribution
- `np.ndarray` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
- `int` for a fixed length sample
The second element of the mask tuple `sample` mask specifies a mask that is applied when
sampling elements from the base space. The mask is applied for each feature space sample.
Returns:
A tuple of random length with random samples of elements from the :attr:`feature_space`.
@@ -68,11 +74,28 @@ class Sequence(Space[Tuple]):
if mask is not None:
length_mask, feature_mask = mask
else:
length_mask = None
feature_mask = None
length_mask, feature_mask = None, None
if length_mask is not None:
if np.issubdtype(type(length_mask), np.integer):
assert (
0 <= length_mask
), f"Expects the length mask to be greater than or equal to zero, actual value: {length_mask}"
length = length_mask
elif isinstance(length_mask, np.ndarray):
assert (
len(length_mask.shape) == 1
), f"Expects the shape of the length mask to be 1-dimensional, actual shape: {length_mask.shape}"
assert np.all(
0 <= length_mask
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
length = self.np_random.choice(length_mask)
else:
raise TypeError(
f"Expects the type of length_mask to an integer or a np.ndarray, actual type: {type(length_mask)}"
)
else:
# The choice of 0.25 is arbitrary
length = self.np_random.geometric(0.25)
return tuple(

View File

@@ -1,12 +1,16 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from typing import Iterable, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence as CollectionSequence
from typing import Iterable, Optional
from typing import Sequence as TypingSequence
from typing import Tuple as TypingTuple
from typing import Union
import numpy as np
from gym.spaces.space import Space
class Tuple(Space[tuple], Sequence):
class Tuple(Space[tuple], CollectionSequence):
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
Elements of this space are tuples of elements of the constituent spaces.
@@ -22,7 +26,7 @@ class Tuple(Space[tuple], Sequence):
def __init__(
self,
spaces: Iterable[Space],
seed: Optional[Union[int, List[int], np.random.Generator]] = None,
seed: Optional[Union[int, TypingSequence[int], np.random.Generator]] = None,
):
r"""Constructor of :class:`Tuple` space.
@@ -44,7 +48,9 @@ class Tuple(Space[tuple], Sequence):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
def seed(
self, seed: Optional[Union[int, TypingSequence[int]]] = None
) -> TypingSequence[int]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
@@ -57,25 +63,32 @@ class Tuple(Space[tuple], Sequence):
"""
seeds = []
if isinstance(seed, list):
for i, space in enumerate(self.spaces):
seeds += space.seed(seed[i])
if isinstance(seed, CollectionSequence):
assert len(seed) == len(
self.spaces
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int):
seeds = super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds.append(subspace.seed(int(subseed))[0])
seeds += subspace.seed(int(subseed))
elif seed is None:
for space in self.spaces:
seeds += space.seed(seed)
else:
raise TypeError("Passed seed not of an expected type: list or int or None")
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)
return seeds
def sample(self, mask: Optional[Tuple[Optional[np.ndarray]]] = None) -> tuple:
def sample(
self, mask: Optional[TypingTuple[Optional[np.ndarray], ...]] = None
) -> tuple:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
@@ -116,7 +129,7 @@ class Tuple(Space[tuple], Sequence):
"""Gives a string representation of this space."""
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n: Sequence) -> list:
def to_jsonable(self, sample_n: CollectionSequence) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list-repr of tuple of vectors
return [

View File

@@ -89,6 +89,13 @@ def _flatdim_dict(space: Dict) -> int:
)
@flatdim.register(Graph)
def _flatdim_graph(space: Graph):
raise ValueError(
"Cannot get flattened size as the Graph Space in Gym has a dynamic size."
)
@flatdim.register(Text)
def _flatdim_text(space: Text) -> int:
return space.max_length
@@ -157,11 +164,11 @@ def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
)
return tuple((flatten(s, x_part) for x_part, s in zip(x, space.spaces)))
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
@flatten.register(Dict)
def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]:
def _flatten_dict(space, x) -> Union[dict, np.ndarray]:
if space.is_np_flattenable:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
@@ -171,14 +178,19 @@ def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]:
def _flatten_graph(space, x) -> GraphInstance:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
def _graph_unflatten(space, x):
def _graph_unflatten(unflatten_space, unflatten_x):
ret = None
if space is not None and x is not None:
if isinstance(space, Box):
ret = x.reshape(x.shape[0], -1)
elif isinstance(space, Discrete):
ret = np.zeros((x.shape[0], space.n - space.start), dtype=space.dtype)
ret[np.arange(x.shape[0]), x - space.start] = 1
if unflatten_space is not None and unflatten_x is not None:
if isinstance(unflatten_space, Box):
ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
elif isinstance(unflatten_space, Discrete):
ret = np.zeros(
(unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
dtype=unflatten_space.dtype,
)
ret[
np.arange(unflatten_x.shape[0]), unflatten_x - unflatten_space.start
] = 1
return ret
nodes = _graph_unflatten(space.node_space, x.nodes)

140
tests/spaces/test_dict.py Normal file
View File

@@ -0,0 +1,140 @@
from collections import OrderedDict
import numpy as np
import pytest
from gym.spaces import Box, Dict, Discrete
def test_dict_init():
with pytest.raises(
AssertionError,
match=r"^Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: ",
):
Dict(Discrete(2))
with pytest.raises(
ValueError,
match="Dict space keyword 'a' already exists in the spaces dictionary",
):
Dict({"a": Discrete(3)}, a=Box(0, 1))
with pytest.raises(
AssertionError,
match="Dict space element is not an instance of Space: key='b', space=Box",
):
Dict(a=Discrete(2), b="Box")
with pytest.warns(None) as warnings:
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
assert a == b == c == d
assert len(warnings) == 0
with pytest.warns(None) as warnings:
Dict({1: Discrete(2), "a": Discrete(3)})
assert len(warnings) == 0
DICT_SPACE = Dict(
{
"a": Box(low=0, high=1, shape=(3, 3)),
"b": Dict(
{
"b_1": Box(low=-100, high=100, shape=(2,)),
"b_2": Box(low=-1, high=1, shape=(2,)),
}
),
"c": Discrete(5),
}
)
def test_dict_seeding():
seeds = DICT_SPACE.seed(
{
"a": 0,
"b": {
"b_1": 1,
"b_2": 2,
},
"c": 3,
}
)
assert all(isinstance(seed, int) for seed in seeds)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=0)
b_1 = Box(low=-100, high=100, shape=(2,), seed=1)
b_2 = Box(low=-1, high=1, shape=(2,), seed=2)
c = Discrete(5, seed=3)
for i in range(10):
dict_sample = DICT_SPACE.sample()
assert np.all(dict_sample["a"] == a.sample())
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
assert dict_sample["c"] == c.sample()
def test_int_seeding():
seeds = DICT_SPACE.seed(1)
assert all(isinstance(seed, int) for seed in seeds)
# rng, seeds = seeding.np_random(1)
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
# b_rng, b_seeds = seeding.np_random(int(subseeds[1]))
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3), seed=seeds[1])
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds[3])
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds[4])
c = Discrete(5, seed=seeds[5])
for i in range(10):
dict_sample = DICT_SPACE.sample()
assert np.all(dict_sample["a"] == a.sample())
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
assert dict_sample["c"] == c.sample()
def test_none_seeding():
seeds = DICT_SPACE.seed(None)
assert len(seeds) == 4 and all(isinstance(seed, int) for seed in seeds)
def test_bad_seed():
with pytest.raises(TypeError):
DICT_SPACE.seed("a")
def test_mapping():
"""The Gym Dict space inherits from Mapping that allows it to appear like a standard python Dictionary."""
assert len(DICT_SPACE) == 3
a = DICT_SPACE["a"]
b = Discrete(5)
assert a != b
DICT_SPACE["a"] = b
assert DICT_SPACE["a"] == b
with pytest.raises(
AssertionError,
match="Trying to set a to Dict space with value that is not a gym space, actual type: <class 'int'>",
):
DICT_SPACE["a"] = 5
DICT_SPACE["a"] = a
def test_iterator():
"""Tests the Dict `__iter__` function correctly returns keys in the subspaces"""
for key in DICT_SPACE:
assert key in DICT_SPACE.spaces
assert {key for key in DICT_SPACE} == DICT_SPACE.spaces.keys()

135
tests/spaces/test_graph.py Normal file
View File

@@ -0,0 +1,135 @@
import re
import numpy as np
import pytest
from gym.spaces import Discrete, Graph, GraphInstance
def test_node_space_sample():
space = Graph(node_space=Discrete(3), edge_space=None)
sample = space.sample(
mask=(tuple(np.array([0, 1, 0], dtype=np.int8) for _ in range(5)), None),
num_nodes=5,
)
assert sample in space
assert np.all(sample.nodes == 1)
sample = space.sample(
(
(np.array([1, 0, 0], dtype=np.int8), np.array([0, 1, 0], dtype=np.int8)),
None,
),
num_nodes=2,
)
assert sample in space
assert np.all(sample.nodes == np.array([0, 1]))
with pytest.warns(
UserWarning,
match=re.escape("The number of edges is set (5) but the edge space is None."),
):
sample = space.sample(num_edges=5)
assert sample in space
# Change the node_space or edge_space to a non-Box or discrete space.
# This should not happen, test is primarily to increase coverage.
with pytest.raises(
TypeError,
match=re.escape(
"Expects base space to be Box and Discrete, actual space: <class 'str'>"
),
):
space.node_space = "abc"
space.sample()
def test_edge_space_sample():
space = Graph(node_space=Discrete(3), edge_space=Discrete(3))
# When num_nodes>1 then num_edges is set to 0
assert space.sample(num_nodes=1).edges is None
assert 0 <= len(space.sample(num_edges=3).edges) < 6
sample = space.sample(mask=(None, np.array([0, 1, 0], dtype=np.int8)))
assert np.all(sample.edges == 1)
sample = space.sample(
mask=(
None,
(
np.array([1, 0, 0], dtype=np.int8),
np.array([0, 1, 0], dtype=np.int8),
np.array([0, 0, 1], dtype=np.int8),
),
),
num_edges=3,
)
assert np.all(sample.edges == np.array([0, 1, 2]))
with pytest.raises(
AssertionError,
match="Expects the number of edges to be greater than 0, actual value: -1",
):
space.sample(num_edges=-1)
space = Graph(node_space=Discrete(3), edge_space=None)
with pytest.warns(
UserWarning,
match=re.escape(
"\x1b[33mWARN: The number of edges is set (5) but the edge space is None.\x1b[0m"
),
):
sample = space.sample(num_edges=5)
assert sample.edges is None
@pytest.mark.parametrize(
"sample",
[
"abc",
GraphInstance(
nodes=None, edges=np.array([0, 1]), edge_links=np.array([[0, 1], [1, 0]])
),
GraphInstance(
nodes=np.array([10, 1, 0]),
edges=np.array([0, 1]),
edge_links=np.array([[0, 1], [1, 0]]),
),
GraphInstance(
nodes=np.array([0, 1]), edges=None, edge_links=np.array([[0, 1], [1, 0]])
),
GraphInstance(nodes=np.array([0, 1]), edges=np.array([0, 1]), edge_links=None),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([10, 1]),
edge_links=np.array([[0, 1], [1, 0]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[0.5, 1.0], [2.0, 1.0]]),
),
GraphInstance(
nodes=np.array([1, 2]), edges=np.array([10, 1]), edge_links=np.array([0, 1])
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[[0], [1]], [[0], [0]]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[10, 1], [0, 0]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[-10, 1], [0, 0]]),
),
],
)
def test_not_contains(sample):
space = Graph(node_space=Discrete(2), edge_space=Discrete(2))
assert sample not in space

View File

@@ -1,3 +1,19 @@
import numpy as np
from gym.spaces import MultiBinary
def test_sample():
# todo
pass
space = MultiBinary(4)
sample = space.sample(mask=np.array([0, 0, 1, 1], dtype=np.int8))
assert np.all(sample == [0, 0, 1, 1])
sample = space.sample(mask=np.array([0, 1, 2, 2], dtype=np.int8))
assert sample[0] == 0 and sample[1] == 1
assert sample[2] == 0 or sample[2] == 1
assert sample[3] == 0 or sample[3] == 1
space = MultiBinary(np.array([2, 3]))
sample = space.sample(mask=np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int8))
assert np.all(sample == [[0, 0, 0], [1, 1, 1]]), sample

View File

@@ -0,0 +1,59 @@
import re
import numpy as np
import pytest
import gym.spaces
def test_sample():
"""Tests the sequence sampling works as expects and the errors are correctly raised."""
space = gym.spaces.Sequence(gym.spaces.Box(0, 1))
# Test integer mask length
for length in range(4):
sample = space.sample(mask=(length, None))
assert sample in space
assert len(sample) == length
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the length mask to be greater than or equal to zero, actual value: -1"
),
):
space.sample(mask=(-1, None))
# Test np.array mask length
sample = space.sample(mask=(np.array([5]), None))
assert sample in space
assert len(sample) == 5
sample = space.sample(mask=(np.array([3, 4, 5]), None))
assert sample in space
assert len(sample) in [3, 4, 5]
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the shape of the length mask to be 1-dimensional, actual shape: (2, 2)"
),
):
space.sample(mask=(np.array([[2, 2], [2, 2]]), None))
with pytest.raises(
AssertionError,
match=re.escape(
"Expects all values in the length_mask to be greater than or equal to zero, actual values: [ 1 2 -1]"
),
):
space.sample(mask=(np.array([1, 2, -1]), None))
# Test with an invalid length
with pytest.raises(
TypeError,
match=re.escape(
"Expects the type of length_mask to an integer or a np.ndarray, actual type: <class 'str'>"
),
):
space.sample(mask=("abc", None))

View File

@@ -409,10 +409,10 @@ SPACE_KWARGS = [
{"nvec": [3, 2]}, # MultiDiscrete
{"n": 2}, # MultiBinary
{"max_length": 5}, # Text
# {"spaces": (Discrete(3), Discrete(2))}, # Tuple
# {"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
# {"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
# {"space": Discrete(4)}, # Sequence
{"spaces": (Discrete(3), Discrete(2))}, # Tuple
{"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
{"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
{"space": Discrete(4)}, # Sequence
]
assert len(SPACE_CLS) == len(SPACE_KWARGS)

109
tests/spaces/test_tuple.py Normal file
View File

@@ -0,0 +1,109 @@
import numpy as np
import pytest
import gym.spaces
from gym.spaces import Box, Dict, Discrete, MultiBinary, Tuple
from gym.utils.env_checker import data_equivalence
def test_sequence_inheritance():
"""The gym Tuple space inherits from abc.Sequences, this test checks all functions work"""
spaces = [Discrete(5), Discrete(10), Discrete(5)]
tuple_space = Tuple(spaces)
assert len(tuple_space) == len(spaces)
# Test indexing
for i in range(len(tuple_space)):
assert tuple_space[i] == spaces[i]
# Test iterable
for space in tuple_space:
assert space in spaces
# Test count
assert tuple_space.count(Discrete(5)) == 2
assert tuple_space.count(Discrete(6)) == 0
assert tuple_space.count(MultiBinary(2)) == 0
# Test index
assert tuple_space.index(Discrete(5)) == 0
assert tuple_space.index(Discrete(5), 1) == 2
# Test errors
with pytest.raises(ValueError):
tuple_space.index(Discrete(10), 0, 1)
with pytest.raises(IndexError):
assert tuple_space[4]
@pytest.mark.parametrize(
"space, seed, expected_len",
[
(Tuple([Discrete(5), Discrete(4)]), None, 2),
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
(
Tuple(
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
),
(123, (456, 789)),
3,
),
(
Tuple(
(
Discrete(3),
Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)),
)
),
(123, {"position": 456, "velocity": 789}),
3,
),
],
)
def test_seeds(space, seed, expected_len):
seeds = space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
sample1 = space.sample()
seeds2 = space.seed(seed)
sample2 = space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
@pytest.mark.parametrize(
"space_fn",
[
lambda: Tuple(["abc"]),
lambda: Tuple([gym.spaces.Box(0, 1), "abc"]),
lambda: Tuple("abc"),
],
)
def test_bad_space_calls(space_fn):
with pytest.raises(AssertionError):
space_fn()
def test_contains_promotion():
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (2,))))
assert (
np.array([0.0], dtype=np.float32),
np.array([0.0, 0.0], dtype=np.float32),
) in space
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (1,))))
assert np.array([[0.0], [0.0]], dtype=np.float32) in space
def test_bad_seed():
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(0, 1)))
with pytest.raises(
TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
):
space.seed(0.0)

View File

@@ -29,25 +29,25 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
6,
6,
6,
# # Tuple
# 9,
# 7,
# 10,
# 6,
# None,
# # Dict
# 7,
# 8,
# 17,
# None,
# # Graph
# None,
# None,
# None,
# # Sequence
# None,
# None,
# None,
# Tuple
9,
7,
10,
6,
None,
# Dict
7,
8,
17,
None,
# Graph
None,
None,
None,
# Sequence
None,
None,
None,
]

View File

@@ -2,7 +2,18 @@ from typing import List
import numpy as np
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
from gym.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
)
TESTING_FUNDAMENTAL_SPACES = [
Discrete(3),
@@ -23,5 +34,70 @@ TESTING_FUNDAMENTAL_SPACES = [
TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES # + TESTING_COMPOSITE_SPACES
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS # + TESTING_COMPOSITE_SPACES_IDS
TESTING_COMPOSITE_SPACES = [
# Tuple spaces
Tuple([Discrete(5), Discrete(4)]),
Tuple(
(
Discrete(5),
Box(
low=np.array([0.0, 0.0]),
high=np.array([1.0, 5.0]),
dtype=np.float64,
),
)
),
Tuple((Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))),
Tuple((Discrete(3), Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)))),
Tuple((Graph(node_space=Box(-1, 1, shape=(2, 1)), edge_space=None), Discrete(2))),
# Dict spaces
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0.0, 0.0]),
high=np.array([1.0, 5.0]),
dtype=np.float64,
),
}
),
Dict(
position=Discrete(6),
velocity=Box(
low=np.array([0.0, 0.0]),
high=np.array([1.0, 5.0]),
dtype=np.float64,
),
),
Dict(
{
"a": Box(low=0, high=1, shape=(3, 3)),
"b": Dict(
{
"b_1": Box(low=-100, high=100, shape=(2,)),
"b_2": Box(low=-1, high=1, shape=(2,)),
}
),
"c": Discrete(4),
}
),
Dict(
a=Dict(
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
b=Box(-100, 100, shape=(2, 2)),
),
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
),
# Graph spaces
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
Graph(node_space=Discrete(3), edge_space=Discrete(4)),
# Sequence spaces
Sequence(Discrete(4)),
Sequence(Dict({"feature": Box(0, 1, (3,))})),
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
]
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES + TESTING_COMPOSITE_SPACES
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS + TESTING_COMPOSITE_SPACES_IDS