mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 09:55:39 +00:00
Update Composite spaces with full coverage (#3047)
* Update composite space and tests * Pre-commit * pyright * Fix pyright * retrigger actions * Code review by Arjun * Code review by Arjun * Code review by Omar
This commit is contained in:
@@ -3,7 +3,9 @@ from collections import OrderedDict
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict as TypingDict
|
from typing import Dict as TypingDict
|
||||||
from typing import Optional, Union
|
from typing import List, Optional
|
||||||
|
from typing import Sequence as TypingSequence
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -51,7 +53,12 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spaces: Optional[TypingDict[str, Space]] = None,
|
spaces: Optional[
|
||||||
|
Union[
|
||||||
|
TypingDict[str, Space],
|
||||||
|
TypingSequence[Tuple[str, Space]],
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
seed: Optional[Union[dict, int, np.random.Generator]] = None,
|
seed: Optional[Union[dict, int, np.random.Generator]] = None,
|
||||||
**spaces_kwargs: Space,
|
**spaces_kwargs: Space,
|
||||||
):
|
):
|
||||||
@@ -74,12 +81,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
|
seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
|
||||||
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
|
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
|
||||||
"""
|
"""
|
||||||
assert (spaces is None) or (
|
# Convert the spaces into an OrderedDict
|
||||||
not spaces_kwargs
|
|
||||||
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
|
||||||
|
|
||||||
if spaces is None:
|
|
||||||
spaces = spaces_kwargs
|
|
||||||
if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict):
|
if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict):
|
||||||
try:
|
try:
|
||||||
spaces = OrderedDict(sorted(spaces.items()))
|
spaces = OrderedDict(sorted(spaces.items()))
|
||||||
@@ -87,16 +89,30 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
# Incomparable types (e.g. `int` vs. `str`, or user-defined types) found.
|
# Incomparable types (e.g. `int` vs. `str`, or user-defined types) found.
|
||||||
# The keys remain in the insertion order.
|
# The keys remain in the insertion order.
|
||||||
spaces = OrderedDict(spaces.items())
|
spaces = OrderedDict(spaces.items())
|
||||||
if isinstance(spaces, Sequence):
|
elif isinstance(spaces, Sequence):
|
||||||
spaces = OrderedDict(spaces)
|
spaces = OrderedDict(spaces)
|
||||||
|
elif spaces is None:
|
||||||
|
spaces = OrderedDict()
|
||||||
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
spaces, OrderedDict
|
||||||
|
), f"Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: {type(spaces)}"
|
||||||
|
|
||||||
assert isinstance(spaces, OrderedDict), "spaces must be a dictionary"
|
# Add kwargs to spaces to allow both dictionary and keywords to be used
|
||||||
|
for key, space in spaces_kwargs.items():
|
||||||
|
if key not in spaces:
|
||||||
|
spaces[key] = space
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dict space keyword '{key}' already exists in the spaces dictionary."
|
||||||
|
)
|
||||||
|
|
||||||
self.spaces = spaces
|
self.spaces = spaces
|
||||||
for space in spaces.values():
|
for key, space in self.spaces.items():
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
space, Space
|
space, Space
|
||||||
), "Values of the dict should be instances of gym.Space"
|
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
None, None, seed # type: ignore
|
None, None, seed # type: ignore
|
||||||
) # None for shape and dtype, since it'll require special handling
|
) # None for shape and dtype, since it'll require special handling
|
||||||
@@ -120,27 +136,26 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
seeds = []
|
seeds = []
|
||||||
|
|
||||||
if isinstance(seed, dict):
|
if isinstance(seed, dict):
|
||||||
for key, seed_key in zip(self.spaces, seed):
|
assert (
|
||||||
assert key == seed_key, print(
|
seed.keys() == self.spaces.keys()
|
||||||
"Key value",
|
), f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
|
||||||
seed_key,
|
for key in seed.keys():
|
||||||
"in passed seed dict did not match key value",
|
seeds += self.spaces[key].seed(seed[key])
|
||||||
key,
|
|
||||||
"in spaces Dict.",
|
|
||||||
)
|
|
||||||
seeds += self.spaces[key].seed(seed[seed_key])
|
|
||||||
elif isinstance(seed, int):
|
elif isinstance(seed, int):
|
||||||
seeds = super().seed(seed)
|
seeds = super().seed(seed)
|
||||||
|
# Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
|
||||||
subseeds = self.np_random.integers(
|
subseeds = self.np_random.integers(
|
||||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||||
)
|
)
|
||||||
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
||||||
seeds.append(subspace.seed(int(subseed))[0])
|
seeds += subspace.seed(int(subseed))
|
||||||
elif seed is None:
|
elif seed is None:
|
||||||
for space in self.spaces.values():
|
for space in self.spaces.values():
|
||||||
seeds += space.seed(seed)
|
seeds += space.seed(None)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Passed seed not of an expected type: dict or int or None")
|
raise TypeError(
|
||||||
|
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
|
||||||
|
)
|
||||||
|
|
||||||
return seeds
|
return seeds
|
||||||
|
|
||||||
@@ -170,14 +185,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if not isinstance(x, dict) or len(x) != len(self.spaces):
|
if isinstance(x, dict) and x.keys() == self.spaces.keys():
|
||||||
return False
|
return all(x[key] in self.spaces[key] for key in self.spaces.keys())
|
||||||
for k, space in self.spaces.items():
|
return False
|
||||||
if k not in x:
|
|
||||||
return False
|
|
||||||
if not space.contains(x[k]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> Space:
|
def __getitem__(self, key: str) -> Space:
|
||||||
"""Get the space that is associated to `key`."""
|
"""Get the space that is associated to `key`."""
|
||||||
@@ -185,6 +195,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
def __setitem__(self, key: str, value: Space):
|
def __setitem__(self, key: str, value: Space):
|
||||||
"""Set the space that is associated to `key`."""
|
"""Set the space that is associated to `key`."""
|
||||||
|
assert isinstance(
|
||||||
|
value, Space
|
||||||
|
), f"Trying to set {key} to Dict space with value that is not a gym space, actual type: {type(value)}"
|
||||||
self.spaces[key] = value
|
self.spaces[key] = value
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@@ -217,16 +230,16 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
for key, space in self.spaces.items()
|
for key, space in self.spaces.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: TypingDict[str, list]) -> list:
|
def from_jsonable(self, sample_n: TypingDict[str, list]) -> List[dict]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
dict_of_list: TypingDict[str, list] = {}
|
dict_of_list: TypingDict[str, list] = {
|
||||||
for key, space in self.spaces.items():
|
key: space.from_jsonable(sample_n[key])
|
||||||
dict_of_list[key] = space.from_jsonable(sample_n[key])
|
for key, space in self.spaces.items()
|
||||||
ret = []
|
}
|
||||||
|
|
||||||
n_elements = len(next(iter(dict_of_list.values())))
|
n_elements = len(next(iter(dict_of_list.values())))
|
||||||
for i in range(n_elements):
|
result = [
|
||||||
entry = {}
|
OrderedDict({key: value[n] for key, value in dict_of_list.items()})
|
||||||
for key, value in dict_of_list.items():
|
for n in range(n_elements)
|
||||||
entry[key] = value[i]
|
]
|
||||||
ret.append(entry)
|
return result
|
||||||
return ret
|
|
||||||
|
@@ -1,31 +1,27 @@
|
|||||||
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
|
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
|
||||||
from collections import namedtuple
|
|
||||||
from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from gym.logger import warn
|
||||||
from gym.spaces.box import Box
|
from gym.spaces.box import Box
|
||||||
from gym.spaces.discrete import Discrete
|
from gym.spaces.discrete import Discrete
|
||||||
from gym.spaces.multi_discrete import MultiDiscrete
|
from gym.spaces.multi_discrete import MultiDiscrete
|
||||||
from gym.spaces.space import Space
|
from gym.spaces.space import Space
|
||||||
|
|
||||||
|
|
||||||
class GraphInstance(namedtuple("GraphInstance", ["nodes", "edges", "edge_links"])):
|
class GraphInstance(NamedTuple):
|
||||||
r"""Returns a NamedTuple representing a graph object.
|
"""A Graph space instance.
|
||||||
|
|
||||||
Args:
|
* nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes, (...) must adhere to the shape of the node space.
|
||||||
nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes.
|
* edges (Optional[np.ndarray]): an (m x ...) sized array representing the features for m nodes, (...) must adhere to the shape of the edge space.
|
||||||
(...) must adhere to the shape of the node space.
|
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the two nodes that each edge connects.
|
||||||
|
|
||||||
edges (np.ndarray): an (m x ...) sized array representing the features for m edges.
|
|
||||||
(...) must adhere to the shape of the edge space.
|
|
||||||
|
|
||||||
edge_links (np.ndarray): an (m x 2) sized array of ints representing the two nodes that each edge connects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A NamedTuple representing a graph with `.nodes`, `.edges`, and `.edge_links`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
nodes: np.ndarray
|
||||||
|
edges: Optional[np.ndarray]
|
||||||
|
edge_links: Optional[np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
class Graph(Space):
|
class Graph(Space):
|
||||||
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
||||||
@@ -89,7 +85,7 @@ class Graph(Space):
|
|||||||
elif isinstance(base_space, Discrete):
|
elif isinstance(base_space, Discrete):
|
||||||
return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
|
return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(
|
raise TypeError(
|
||||||
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
|
f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -103,7 +99,7 @@ class Graph(Space):
|
|||||||
] = None,
|
] = None,
|
||||||
num_nodes: int = 10,
|
num_nodes: int = 10,
|
||||||
num_edges: Optional[int] = None,
|
num_edges: Optional[int] = None,
|
||||||
) -> NamedTuple:
|
) -> GraphInstance:
|
||||||
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
|
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -132,12 +128,17 @@ class Graph(Space):
|
|||||||
num_edges = self.np_random.integers(num_nodes * (num_nodes - 1))
|
num_edges = self.np_random.integers(num_nodes * (num_nodes - 1))
|
||||||
else:
|
else:
|
||||||
num_edges = 0
|
num_edges = 0
|
||||||
|
|
||||||
if edge_space_mask is not None:
|
if edge_space_mask is not None:
|
||||||
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
|
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
|
||||||
else:
|
else:
|
||||||
|
if self.edge_space is None:
|
||||||
|
warn(
|
||||||
|
f"The number of edges is set ({num_edges}) but the edge space is None."
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
num_edges >= 0
|
num_edges >= 0
|
||||||
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"
|
), f"Expects the number of edges to be greater than 0, actual value: {num_edges}"
|
||||||
assert num_edges is not None
|
assert num_edges is not None
|
||||||
|
|
||||||
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
|
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
|
||||||
@@ -160,38 +161,31 @@ class Graph(Space):
|
|||||||
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
|
return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
|
||||||
|
|
||||||
def contains(self, x: GraphInstance) -> bool:
|
def contains(self, x: GraphInstance) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space.
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
|
if isinstance(x, GraphInstance):
|
||||||
Returns False when:
|
# Checks the nodes
|
||||||
- any node in nodes is not contained in Graph.node_space
|
if isinstance(x.nodes, np.ndarray):
|
||||||
- edge_links is not of dtype int
|
if all(node in self.node_space for node in x.nodes):
|
||||||
- len(edge_links) != len(edges)
|
# Check the edges and edge links which are optional
|
||||||
- has edges but Graph.edge_space is None
|
if isinstance(x.edges, np.ndarray) and isinstance(
|
||||||
- edge_links has index less than 0
|
x.edge_links, np.ndarray
|
||||||
- edge_links has index more than number of nodes
|
):
|
||||||
- any edge in edges is not contained in Graph.edge_space
|
assert x.edges is not None
|
||||||
"""
|
assert x.edge_links is not None
|
||||||
if not isinstance(x, GraphInstance):
|
if self.edge_space is not None:
|
||||||
return False
|
if all(edge in self.edge_space for edge in x.edges):
|
||||||
if x.edges is not None:
|
if np.issubdtype(x.edge_links.dtype, np.integer):
|
||||||
if not np.issubdtype(x.edge_links.dtype, np.integer):
|
if x.edge_links.shape == (len(x.edges), 2):
|
||||||
return False
|
if np.all(
|
||||||
if x.edge_links.shape[-1] != 2:
|
np.logical_and(
|
||||||
return False
|
x.edge_links >= 0,
|
||||||
if self.edge_space is None:
|
x.edge_links < len(x.nodes),
|
||||||
return False
|
)
|
||||||
if x.edge_links.min() < 0:
|
):
|
||||||
return False
|
return True
|
||||||
if x.edge_links.max() >= len(x.nodes):
|
else:
|
||||||
return False
|
return x.edges is None and x.edge_links is None
|
||||||
if len(x.edges) != len(x.edge_links):
|
return False
|
||||||
return False
|
|
||||||
if any(edge not in self.edge_space for edge in x.edges):
|
|
||||||
return False
|
|
||||||
if any(node not in self.node_space for node in x.nodes):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""A string representation of this space.
|
"""A string representation of this space.
|
||||||
|
@@ -4,8 +4,8 @@ from typing import Any, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import gym
|
||||||
from gym.spaces.space import Space
|
from gym.spaces.space import Space
|
||||||
from gym.utils import seeding
|
|
||||||
|
|
||||||
|
|
||||||
class Sequence(Space[Tuple]):
|
class Sequence(Space[Tuple]):
|
||||||
@@ -25,7 +25,7 @@ class Sequence(Space[Tuple]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
space: Space,
|
space: Space,
|
||||||
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None,
|
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||||
):
|
):
|
||||||
"""Constructor of the :class:`Sequence` space.
|
"""Constructor of the :class:`Sequence` space.
|
||||||
|
|
||||||
@@ -33,6 +33,9 @@ class Sequence(Space[Tuple]):
|
|||||||
space: Elements in the sequences this space represent must belong to this space.
|
space: Elements in the sequences this space represent must belong to this space.
|
||||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(
|
||||||
|
space, gym.Space
|
||||||
|
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
|
||||||
self.feature_space = space
|
self.feature_space = space
|
||||||
super().__init__(
|
super().__init__(
|
||||||
None, None, seed # type: ignore
|
None, None, seed # type: ignore
|
||||||
@@ -50,17 +53,20 @@ class Sequence(Space[Tuple]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self, mask: Optional[Tuple[Optional[np.ndarray], Any]] = None
|
self,
|
||||||
|
mask: Optional[Tuple[Optional[Union[np.ndarray, int]], Optional[Any]]] = None,
|
||||||
) -> Tuple[Any]:
|
) -> Tuple[Any]:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
|
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
|
||||||
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask`
|
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask`
|
||||||
is either `None` if you do not want to specify any restrictions on the length of the sampled sequence (then, the
|
is
|
||||||
length will be randomly drawn from a geometric distribution), or a `np.ndarray` of integers, in which case the length of
|
- `None` The length will be randomly drawn from a geometric distribution
|
||||||
the sampled sequence is randomly drawn from this array. The second element of the tuple, `sample` mask
|
- `np.ndarray` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
|
||||||
specifies a mask that is applied when sampling elements from the base space.
|
- `int` for a fixed length sample
|
||||||
|
The second element of the mask tuple `sample` mask specifies a mask that is applied when
|
||||||
|
sampling elements from the base space. The mask is applied for each feature space sample.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of random length with random samples of elements from the :attr:`feature_space`.
|
A tuple of random length with random samples of elements from the :attr:`feature_space`.
|
||||||
@@ -68,11 +74,28 @@ class Sequence(Space[Tuple]):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
length_mask, feature_mask = mask
|
length_mask, feature_mask = mask
|
||||||
else:
|
else:
|
||||||
length_mask = None
|
length_mask, feature_mask = None, None
|
||||||
feature_mask = None
|
|
||||||
if length_mask is not None:
|
if length_mask is not None:
|
||||||
length = self.np_random.choice(length_mask)
|
if np.issubdtype(type(length_mask), np.integer):
|
||||||
|
assert (
|
||||||
|
0 <= length_mask
|
||||||
|
), f"Expects the length mask to be greater than or equal to zero, actual value: {length_mask}"
|
||||||
|
length = length_mask
|
||||||
|
elif isinstance(length_mask, np.ndarray):
|
||||||
|
assert (
|
||||||
|
len(length_mask.shape) == 1
|
||||||
|
), f"Expects the shape of the length mask to be 1-dimensional, actual shape: {length_mask.shape}"
|
||||||
|
assert np.all(
|
||||||
|
0 <= length_mask
|
||||||
|
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
|
||||||
|
length = self.np_random.choice(length_mask)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expects the type of length_mask to an integer or a np.ndarray, actual type: {type(length_mask)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
# The choice of 0.25 is arbitrary
|
||||||
length = self.np_random.geometric(0.25)
|
length = self.np_random.geometric(0.25)
|
||||||
|
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@@ -1,12 +1,16 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||||
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
from collections.abc import Sequence as CollectionSequence
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
from typing import Sequence as TypingSequence
|
||||||
|
from typing import Tuple as TypingTuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym.spaces.space import Space
|
from gym.spaces.space import Space
|
||||||
|
|
||||||
|
|
||||||
class Tuple(Space[tuple], Sequence):
|
class Tuple(Space[tuple], CollectionSequence):
|
||||||
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
|
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
|
||||||
|
|
||||||
Elements of this space are tuples of elements of the constituent spaces.
|
Elements of this space are tuples of elements of the constituent spaces.
|
||||||
@@ -22,7 +26,7 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spaces: Iterable[Space],
|
spaces: Iterable[Space],
|
||||||
seed: Optional[Union[int, List[int], np.random.Generator]] = None,
|
seed: Optional[Union[int, TypingSequence[int], np.random.Generator]] = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Tuple` space.
|
r"""Constructor of :class:`Tuple` space.
|
||||||
|
|
||||||
@@ -44,7 +48,9 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return all(space.is_np_flattenable for space in self.spaces)
|
return all(space.is_np_flattenable for space in self.spaces)
|
||||||
|
|
||||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
|
def seed(
|
||||||
|
self, seed: Optional[Union[int, TypingSequence[int]]] = None
|
||||||
|
) -> TypingSequence[int]:
|
||||||
"""Seed the PRNG of this space and all subspaces.
|
"""Seed the PRNG of this space and all subspaces.
|
||||||
|
|
||||||
Depending on the type of seed, the subspaces will be seeded differently
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
@@ -57,25 +63,32 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
"""
|
"""
|
||||||
seeds = []
|
seeds = []
|
||||||
|
|
||||||
if isinstance(seed, list):
|
if isinstance(seed, CollectionSequence):
|
||||||
for i, space in enumerate(self.spaces):
|
assert len(seed) == len(
|
||||||
seeds += space.seed(seed[i])
|
self.spaces
|
||||||
|
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
|
||||||
|
for subseed, space in zip(seed, self.spaces):
|
||||||
|
seeds += space.seed(subseed)
|
||||||
elif isinstance(seed, int):
|
elif isinstance(seed, int):
|
||||||
seeds = super().seed(seed)
|
seeds = super().seed(seed)
|
||||||
subseeds = self.np_random.integers(
|
subseeds = self.np_random.integers(
|
||||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||||
)
|
)
|
||||||
for subspace, subseed in zip(self.spaces, subseeds):
|
for subspace, subseed in zip(self.spaces, subseeds):
|
||||||
seeds.append(subspace.seed(int(subseed))[0])
|
seeds += subspace.seed(int(subseed))
|
||||||
elif seed is None:
|
elif seed is None:
|
||||||
for space in self.spaces:
|
for space in self.spaces:
|
||||||
seeds += space.seed(seed)
|
seeds += space.seed(seed)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Passed seed not of an expected type: list or int or None")
|
raise TypeError(
|
||||||
|
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
||||||
|
)
|
||||||
|
|
||||||
return seeds
|
return seeds
|
||||||
|
|
||||||
def sample(self, mask: Optional[Tuple[Optional[np.ndarray]]] = None) -> tuple:
|
def sample(
|
||||||
|
self, mask: Optional[TypingTuple[Optional[np.ndarray], ...]] = None
|
||||||
|
) -> tuple:
|
||||||
"""Generates a single random sample inside this space.
|
"""Generates a single random sample inside this space.
|
||||||
|
|
||||||
This method draws independent samples from the subspaces.
|
This method draws independent samples from the subspaces.
|
||||||
@@ -116,7 +129,7 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
"""Gives a string representation of this space."""
|
"""Gives a string representation of this space."""
|
||||||
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: Sequence) -> list:
|
def to_jsonable(self, sample_n: CollectionSequence) -> list:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# serialize as list-repr of tuple of vectors
|
# serialize as list-repr of tuple of vectors
|
||||||
return [
|
return [
|
||||||
|
@@ -89,6 +89,13 @@ def _flatdim_dict(space: Dict) -> int:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@flatdim.register(Graph)
|
||||||
|
def _flatdim_graph(space: Graph):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot get flattened size as the Graph Space in Gym has a dynamic size."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@flatdim.register(Text)
|
@flatdim.register(Text)
|
||||||
def _flatdim_text(space: Text) -> int:
|
def _flatdim_text(space: Text) -> int:
|
||||||
return space.max_length
|
return space.max_length
|
||||||
@@ -157,11 +164,11 @@ def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
|
|||||||
return np.concatenate(
|
return np.concatenate(
|
||||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
||||||
)
|
)
|
||||||
return tuple((flatten(s, x_part) for x_part, s in zip(x, space.spaces)))
|
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
|
||||||
|
|
||||||
|
|
||||||
@flatten.register(Dict)
|
@flatten.register(Dict)
|
||||||
def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]:
|
def _flatten_dict(space, x) -> Union[dict, np.ndarray]:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||||
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
||||||
@@ -171,14 +178,19 @@ def _flatten_dict(space, x) -> Union[TypingDict, np.ndarray]:
|
|||||||
def _flatten_graph(space, x) -> GraphInstance:
|
def _flatten_graph(space, x) -> GraphInstance:
|
||||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
||||||
|
|
||||||
def _graph_unflatten(space, x):
|
def _graph_unflatten(unflatten_space, unflatten_x):
|
||||||
ret = None
|
ret = None
|
||||||
if space is not None and x is not None:
|
if unflatten_space is not None and unflatten_x is not None:
|
||||||
if isinstance(space, Box):
|
if isinstance(unflatten_space, Box):
|
||||||
ret = x.reshape(x.shape[0], -1)
|
ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
|
||||||
elif isinstance(space, Discrete):
|
elif isinstance(unflatten_space, Discrete):
|
||||||
ret = np.zeros((x.shape[0], space.n - space.start), dtype=space.dtype)
|
ret = np.zeros(
|
||||||
ret[np.arange(x.shape[0]), x - space.start] = 1
|
(unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
|
||||||
|
dtype=unflatten_space.dtype,
|
||||||
|
)
|
||||||
|
ret[
|
||||||
|
np.arange(unflatten_x.shape[0]), unflatten_x - unflatten_space.start
|
||||||
|
] = 1
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||||
|
140
tests/spaces/test_dict.py
Normal file
140
tests/spaces/test_dict.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gym.spaces import Box, Dict, Discrete
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_init():
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match=r"^Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: ",
|
||||||
|
):
|
||||||
|
Dict(Discrete(2))
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Dict space keyword 'a' already exists in the spaces dictionary",
|
||||||
|
):
|
||||||
|
Dict({"a": Discrete(3)}, a=Box(0, 1))
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match="Dict space element is not an instance of Space: key='b', space=Box",
|
||||||
|
):
|
||||||
|
Dict(a=Discrete(2), b="Box")
|
||||||
|
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
|
||||||
|
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
|
||||||
|
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
|
||||||
|
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
|
||||||
|
|
||||||
|
assert a == b == c == d
|
||||||
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
Dict({1: Discrete(2), "a": Discrete(3)})
|
||||||
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
|
DICT_SPACE = Dict(
|
||||||
|
{
|
||||||
|
"a": Box(low=0, high=1, shape=(3, 3)),
|
||||||
|
"b": Dict(
|
||||||
|
{
|
||||||
|
"b_1": Box(low=-100, high=100, shape=(2,)),
|
||||||
|
"b_2": Box(low=-1, high=1, shape=(2,)),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"c": Discrete(5),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_seeding():
|
||||||
|
seeds = DICT_SPACE.seed(
|
||||||
|
{
|
||||||
|
"a": 0,
|
||||||
|
"b": {
|
||||||
|
"b_1": 1,
|
||||||
|
"b_2": 2,
|
||||||
|
},
|
||||||
|
"c": 3,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert all(isinstance(seed, int) for seed in seeds)
|
||||||
|
|
||||||
|
# "Unpack" the dict sub-spaces into individual spaces
|
||||||
|
a = Box(low=0, high=1, shape=(3, 3), seed=0)
|
||||||
|
b_1 = Box(low=-100, high=100, shape=(2,), seed=1)
|
||||||
|
b_2 = Box(low=-1, high=1, shape=(2,), seed=2)
|
||||||
|
c = Discrete(5, seed=3)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
dict_sample = DICT_SPACE.sample()
|
||||||
|
assert np.all(dict_sample["a"] == a.sample())
|
||||||
|
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
|
||||||
|
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
|
||||||
|
assert dict_sample["c"] == c.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def test_int_seeding():
|
||||||
|
seeds = DICT_SPACE.seed(1)
|
||||||
|
assert all(isinstance(seed, int) for seed in seeds)
|
||||||
|
|
||||||
|
# rng, seeds = seeding.np_random(1)
|
||||||
|
# subseeds = rng.choice(np.iinfo(int).max, size=3, replace=False)
|
||||||
|
# b_rng, b_seeds = seeding.np_random(int(subseeds[1]))
|
||||||
|
# b_subseeds = b_rng.choice(np.iinfo(int).max, size=2, replace=False)
|
||||||
|
|
||||||
|
# "Unpack" the dict sub-spaces into individual spaces
|
||||||
|
a = Box(low=0, high=1, shape=(3, 3), seed=seeds[1])
|
||||||
|
b_1 = Box(low=-100, high=100, shape=(2,), seed=seeds[3])
|
||||||
|
b_2 = Box(low=-1, high=1, shape=(2,), seed=seeds[4])
|
||||||
|
c = Discrete(5, seed=seeds[5])
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
dict_sample = DICT_SPACE.sample()
|
||||||
|
assert np.all(dict_sample["a"] == a.sample())
|
||||||
|
assert np.all(dict_sample["b"]["b_1"] == b_1.sample())
|
||||||
|
assert np.all(dict_sample["b"]["b_2"] == b_2.sample())
|
||||||
|
assert dict_sample["c"] == c.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def test_none_seeding():
|
||||||
|
seeds = DICT_SPACE.seed(None)
|
||||||
|
assert len(seeds) == 4 and all(isinstance(seed, int) for seed in seeds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_seed():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
DICT_SPACE.seed("a")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mapping():
|
||||||
|
"""The Gym Dict space inherits from Mapping that allows it to appear like a standard python Dictionary."""
|
||||||
|
assert len(DICT_SPACE) == 3
|
||||||
|
|
||||||
|
a = DICT_SPACE["a"]
|
||||||
|
b = Discrete(5)
|
||||||
|
assert a != b
|
||||||
|
DICT_SPACE["a"] = b
|
||||||
|
assert DICT_SPACE["a"] == b
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match="Trying to set a to Dict space with value that is not a gym space, actual type: <class 'int'>",
|
||||||
|
):
|
||||||
|
DICT_SPACE["a"] = 5
|
||||||
|
|
||||||
|
DICT_SPACE["a"] = a
|
||||||
|
|
||||||
|
|
||||||
|
def test_iterator():
|
||||||
|
"""Tests the Dict `__iter__` function correctly returns keys in the subspaces"""
|
||||||
|
for key in DICT_SPACE:
|
||||||
|
assert key in DICT_SPACE.spaces
|
||||||
|
|
||||||
|
assert {key for key in DICT_SPACE} == DICT_SPACE.spaces.keys()
|
135
tests/spaces/test_graph.py
Normal file
135
tests/spaces/test_graph.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gym.spaces import Discrete, Graph, GraphInstance
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_space_sample():
|
||||||
|
space = Graph(node_space=Discrete(3), edge_space=None)
|
||||||
|
|
||||||
|
sample = space.sample(
|
||||||
|
mask=(tuple(np.array([0, 1, 0], dtype=np.int8) for _ in range(5)), None),
|
||||||
|
num_nodes=5,
|
||||||
|
)
|
||||||
|
assert sample in space
|
||||||
|
assert np.all(sample.nodes == 1)
|
||||||
|
|
||||||
|
sample = space.sample(
|
||||||
|
(
|
||||||
|
(np.array([1, 0, 0], dtype=np.int8), np.array([0, 1, 0], dtype=np.int8)),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
num_nodes=2,
|
||||||
|
)
|
||||||
|
assert sample in space
|
||||||
|
assert np.all(sample.nodes == np.array([0, 1]))
|
||||||
|
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning,
|
||||||
|
match=re.escape("The number of edges is set (5) but the edge space is None."),
|
||||||
|
):
|
||||||
|
sample = space.sample(num_edges=5)
|
||||||
|
assert sample in space
|
||||||
|
|
||||||
|
# Change the node_space or edge_space to a non-Box or discrete space.
|
||||||
|
# This should not happen, test is primarily to increase coverage.
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError,
|
||||||
|
match=re.escape(
|
||||||
|
"Expects base space to be Box and Discrete, actual space: <class 'str'>"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
space.node_space = "abc"
|
||||||
|
space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_space_sample():
|
||||||
|
space = Graph(node_space=Discrete(3), edge_space=Discrete(3))
|
||||||
|
# When num_nodes>1 then num_edges is set to 0
|
||||||
|
assert space.sample(num_nodes=1).edges is None
|
||||||
|
assert 0 <= len(space.sample(num_edges=3).edges) < 6
|
||||||
|
|
||||||
|
sample = space.sample(mask=(None, np.array([0, 1, 0], dtype=np.int8)))
|
||||||
|
assert np.all(sample.edges == 1)
|
||||||
|
|
||||||
|
sample = space.sample(
|
||||||
|
mask=(
|
||||||
|
None,
|
||||||
|
(
|
||||||
|
np.array([1, 0, 0], dtype=np.int8),
|
||||||
|
np.array([0, 1, 0], dtype=np.int8),
|
||||||
|
np.array([0, 0, 1], dtype=np.int8),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
num_edges=3,
|
||||||
|
)
|
||||||
|
assert np.all(sample.edges == np.array([0, 1, 2]))
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match="Expects the number of edges to be greater than 0, actual value: -1",
|
||||||
|
):
|
||||||
|
space.sample(num_edges=-1)
|
||||||
|
|
||||||
|
space = Graph(node_space=Discrete(3), edge_space=None)
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning,
|
||||||
|
match=re.escape(
|
||||||
|
"\x1b[33mWARN: The number of edges is set (5) but the edge space is None.\x1b[0m"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
sample = space.sample(num_edges=5)
|
||||||
|
assert sample.edges is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample",
|
||||||
|
[
|
||||||
|
"abc",
|
||||||
|
GraphInstance(
|
||||||
|
nodes=None, edges=np.array([0, 1]), edge_links=np.array([[0, 1], [1, 0]])
|
||||||
|
),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([10, 1, 0]),
|
||||||
|
edges=np.array([0, 1]),
|
||||||
|
edge_links=np.array([[0, 1], [1, 0]]),
|
||||||
|
),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([0, 1]), edges=None, edge_links=np.array([[0, 1], [1, 0]])
|
||||||
|
),
|
||||||
|
GraphInstance(nodes=np.array([0, 1]), edges=np.array([0, 1]), edge_links=None),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([1, 2]),
|
||||||
|
edges=np.array([10, 1]),
|
||||||
|
edge_links=np.array([[0, 1], [1, 0]]),
|
||||||
|
),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([1, 2]),
|
||||||
|
edges=np.array([0, 1]),
|
||||||
|
edge_links=np.array([[0.5, 1.0], [2.0, 1.0]]),
|
||||||
|
),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([1, 2]), edges=np.array([10, 1]), edge_links=np.array([0, 1])
|
||||||
|
),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([1, 2]),
|
||||||
|
edges=np.array([0, 1]),
|
||||||
|
edge_links=np.array([[[0], [1]], [[0], [0]]]),
|
||||||
|
),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([1, 2]),
|
||||||
|
edges=np.array([0, 1]),
|
||||||
|
edge_links=np.array([[10, 1], [0, 0]]),
|
||||||
|
),
|
||||||
|
GraphInstance(
|
||||||
|
nodes=np.array([1, 2]),
|
||||||
|
edges=np.array([0, 1]),
|
||||||
|
edge_links=np.array([[-10, 1], [0, 0]]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_not_contains(sample):
|
||||||
|
space = Graph(node_space=Discrete(2), edge_space=Discrete(2))
|
||||||
|
assert sample not in space
|
@@ -1,3 +1,19 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gym.spaces import MultiBinary
|
||||||
|
|
||||||
|
|
||||||
def test_sample():
|
def test_sample():
|
||||||
# todo
|
space = MultiBinary(4)
|
||||||
pass
|
|
||||||
|
sample = space.sample(mask=np.array([0, 0, 1, 1], dtype=np.int8))
|
||||||
|
assert np.all(sample == [0, 0, 1, 1])
|
||||||
|
|
||||||
|
sample = space.sample(mask=np.array([0, 1, 2, 2], dtype=np.int8))
|
||||||
|
assert sample[0] == 0 and sample[1] == 1
|
||||||
|
assert sample[2] == 0 or sample[2] == 1
|
||||||
|
assert sample[3] == 0 or sample[3] == 1
|
||||||
|
|
||||||
|
space = MultiBinary(np.array([2, 3]))
|
||||||
|
sample = space.sample(mask=np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int8))
|
||||||
|
assert np.all(sample == [[0, 0, 0], [1, 1, 1]]), sample
|
||||||
|
59
tests/spaces/test_sequence.py
Normal file
59
tests/spaces/test_sequence.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gym.spaces
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample():
|
||||||
|
"""Tests the sequence sampling works as expects and the errors are correctly raised."""
|
||||||
|
space = gym.spaces.Sequence(gym.spaces.Box(0, 1))
|
||||||
|
|
||||||
|
# Test integer mask length
|
||||||
|
for length in range(4):
|
||||||
|
sample = space.sample(mask=(length, None))
|
||||||
|
assert sample in space
|
||||||
|
assert len(sample) == length
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match=re.escape(
|
||||||
|
"Expects the length mask to be greater than or equal to zero, actual value: -1"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
space.sample(mask=(-1, None))
|
||||||
|
|
||||||
|
# Test np.array mask length
|
||||||
|
sample = space.sample(mask=(np.array([5]), None))
|
||||||
|
assert sample in space
|
||||||
|
assert len(sample) == 5
|
||||||
|
|
||||||
|
sample = space.sample(mask=(np.array([3, 4, 5]), None))
|
||||||
|
assert sample in space
|
||||||
|
assert len(sample) in [3, 4, 5]
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match=re.escape(
|
||||||
|
"Expects the shape of the length mask to be 1-dimensional, actual shape: (2, 2)"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
space.sample(mask=(np.array([[2, 2], [2, 2]]), None))
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match=re.escape(
|
||||||
|
"Expects all values in the length_mask to be greater than or equal to zero, actual values: [ 1 2 -1]"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
space.sample(mask=(np.array([1, 2, -1]), None))
|
||||||
|
|
||||||
|
# Test with an invalid length
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError,
|
||||||
|
match=re.escape(
|
||||||
|
"Expects the type of length_mask to an integer or a np.ndarray, actual type: <class 'str'>"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
space.sample(mask=("abc", None))
|
@@ -409,10 +409,10 @@ SPACE_KWARGS = [
|
|||||||
{"nvec": [3, 2]}, # MultiDiscrete
|
{"nvec": [3, 2]}, # MultiDiscrete
|
||||||
{"n": 2}, # MultiBinary
|
{"n": 2}, # MultiBinary
|
||||||
{"max_length": 5}, # Text
|
{"max_length": 5}, # Text
|
||||||
# {"spaces": (Discrete(3), Discrete(2))}, # Tuple
|
{"spaces": (Discrete(3), Discrete(2))}, # Tuple
|
||||||
# {"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
|
{"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
|
||||||
# {"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
|
{"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
|
||||||
# {"space": Discrete(4)}, # Sequence
|
{"space": Discrete(4)}, # Sequence
|
||||||
]
|
]
|
||||||
assert len(SPACE_CLS) == len(SPACE_KWARGS)
|
assert len(SPACE_CLS) == len(SPACE_KWARGS)
|
||||||
|
|
||||||
|
109
tests/spaces/test_tuple.py
Normal file
109
tests/spaces/test_tuple.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gym.spaces
|
||||||
|
from gym.spaces import Box, Dict, Discrete, MultiBinary, Tuple
|
||||||
|
from gym.utils.env_checker import data_equivalence
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequence_inheritance():
|
||||||
|
"""The gym Tuple space inherits from abc.Sequences, this test checks all functions work"""
|
||||||
|
spaces = [Discrete(5), Discrete(10), Discrete(5)]
|
||||||
|
tuple_space = Tuple(spaces)
|
||||||
|
|
||||||
|
assert len(tuple_space) == len(spaces)
|
||||||
|
# Test indexing
|
||||||
|
for i in range(len(tuple_space)):
|
||||||
|
assert tuple_space[i] == spaces[i]
|
||||||
|
|
||||||
|
# Test iterable
|
||||||
|
for space in tuple_space:
|
||||||
|
assert space in spaces
|
||||||
|
|
||||||
|
# Test count
|
||||||
|
assert tuple_space.count(Discrete(5)) == 2
|
||||||
|
assert tuple_space.count(Discrete(6)) == 0
|
||||||
|
assert tuple_space.count(MultiBinary(2)) == 0
|
||||||
|
|
||||||
|
# Test index
|
||||||
|
assert tuple_space.index(Discrete(5)) == 0
|
||||||
|
assert tuple_space.index(Discrete(5), 1) == 2
|
||||||
|
|
||||||
|
# Test errors
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tuple_space.index(Discrete(10), 0, 1)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
assert tuple_space[4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"space, seed, expected_len",
|
||||||
|
[
|
||||||
|
(Tuple([Discrete(5), Discrete(4)]), None, 2),
|
||||||
|
(Tuple([Discrete(5), Discrete(4)]), 123, 3),
|
||||||
|
(Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
|
||||||
|
(
|
||||||
|
Tuple(
|
||||||
|
(Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
|
||||||
|
),
|
||||||
|
(123, (456, 789)),
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Tuple(
|
||||||
|
(
|
||||||
|
Discrete(3),
|
||||||
|
Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(123, {"position": 456, "velocity": 789}),
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_seeds(space, seed, expected_len):
|
||||||
|
seeds = space.seed(seed)
|
||||||
|
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
|
||||||
|
assert len(seeds) == expected_len
|
||||||
|
|
||||||
|
sample1 = space.sample()
|
||||||
|
|
||||||
|
seeds2 = space.seed(seed)
|
||||||
|
sample2 = space.sample()
|
||||||
|
|
||||||
|
data_equivalence(seeds, seeds2)
|
||||||
|
data_equivalence(sample1, sample2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"space_fn",
|
||||||
|
[
|
||||||
|
lambda: Tuple(["abc"]),
|
||||||
|
lambda: Tuple([gym.spaces.Box(0, 1), "abc"]),
|
||||||
|
lambda: Tuple("abc"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bad_space_calls(space_fn):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
space_fn()
|
||||||
|
|
||||||
|
|
||||||
|
def test_contains_promotion():
|
||||||
|
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (2,))))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
np.array([0.0], dtype=np.float32),
|
||||||
|
np.array([0.0, 0.0], dtype=np.float32),
|
||||||
|
) in space
|
||||||
|
|
||||||
|
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (1,))))
|
||||||
|
assert np.array([[0.0], [0.0]], dtype=np.float32) in space
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_seed():
|
||||||
|
space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(0, 1)))
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError,
|
||||||
|
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
|
||||||
|
):
|
||||||
|
space.seed(0.0)
|
@@ -29,25 +29,25 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
|
|||||||
6,
|
6,
|
||||||
6,
|
6,
|
||||||
6,
|
6,
|
||||||
# # Tuple
|
# Tuple
|
||||||
# 9,
|
9,
|
||||||
# 7,
|
7,
|
||||||
# 10,
|
10,
|
||||||
# 6,
|
6,
|
||||||
# None,
|
None,
|
||||||
# # Dict
|
# Dict
|
||||||
# 7,
|
7,
|
||||||
# 8,
|
8,
|
||||||
# 17,
|
17,
|
||||||
# None,
|
None,
|
||||||
# # Graph
|
# Graph
|
||||||
# None,
|
None,
|
||||||
# None,
|
None,
|
||||||
# None,
|
None,
|
||||||
# # Sequence
|
# Sequence
|
||||||
# None,
|
None,
|
||||||
# None,
|
None,
|
||||||
# None,
|
None,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@@ -2,7 +2,18 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
|
from gym.spaces import (
|
||||||
|
Box,
|
||||||
|
Dict,
|
||||||
|
Discrete,
|
||||||
|
Graph,
|
||||||
|
MultiBinary,
|
||||||
|
MultiDiscrete,
|
||||||
|
Sequence,
|
||||||
|
Space,
|
||||||
|
Text,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
TESTING_FUNDAMENTAL_SPACES = [
|
TESTING_FUNDAMENTAL_SPACES = [
|
||||||
Discrete(3),
|
Discrete(3),
|
||||||
@@ -23,5 +34,70 @@ TESTING_FUNDAMENTAL_SPACES = [
|
|||||||
TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
|
TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
|
||||||
|
|
||||||
|
|
||||||
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES # + TESTING_COMPOSITE_SPACES
|
TESTING_COMPOSITE_SPACES = [
|
||||||
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS # + TESTING_COMPOSITE_SPACES_IDS
|
# Tuple spaces
|
||||||
|
Tuple([Discrete(5), Discrete(4)]),
|
||||||
|
Tuple(
|
||||||
|
(
|
||||||
|
Discrete(5),
|
||||||
|
Box(
|
||||||
|
low=np.array([0.0, 0.0]),
|
||||||
|
high=np.array([1.0, 5.0]),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
Tuple((Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))),
|
||||||
|
Tuple((Discrete(3), Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)))),
|
||||||
|
Tuple((Graph(node_space=Box(-1, 1, shape=(2, 1)), edge_space=None), Discrete(2))),
|
||||||
|
# Dict spaces
|
||||||
|
Dict(
|
||||||
|
{
|
||||||
|
"position": Discrete(5),
|
||||||
|
"velocity": Box(
|
||||||
|
low=np.array([0.0, 0.0]),
|
||||||
|
high=np.array([1.0, 5.0]),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Dict(
|
||||||
|
position=Discrete(6),
|
||||||
|
velocity=Box(
|
||||||
|
low=np.array([0.0, 0.0]),
|
||||||
|
high=np.array([1.0, 5.0]),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Dict(
|
||||||
|
{
|
||||||
|
"a": Box(low=0, high=1, shape=(3, 3)),
|
||||||
|
"b": Dict(
|
||||||
|
{
|
||||||
|
"b_1": Box(low=-100, high=100, shape=(2,)),
|
||||||
|
"b_2": Box(low=-1, high=1, shape=(2,)),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"c": Discrete(4),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Dict(
|
||||||
|
a=Dict(
|
||||||
|
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
|
||||||
|
b=Box(-100, 100, shape=(2, 2)),
|
||||||
|
),
|
||||||
|
b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
|
||||||
|
),
|
||||||
|
# Graph spaces
|
||||||
|
Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
|
||||||
|
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
|
||||||
|
Graph(node_space=Discrete(3), edge_space=Discrete(4)),
|
||||||
|
# Sequence spaces
|
||||||
|
Sequence(Discrete(4)),
|
||||||
|
Sequence(Dict({"feature": Box(0, 1, (3,))})),
|
||||||
|
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
|
||||||
|
]
|
||||||
|
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]
|
||||||
|
|
||||||
|
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES + TESTING_COMPOSITE_SPACES
|
||||||
|
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS + TESTING_COMPOSITE_SPACES_IDS
|
||||||
|
@@ -23,7 +23,7 @@ def test_transform_reward(env_id):
|
|||||||
_, wrapped_reward, _, _, _ = wrapped_env.step(action)
|
_, wrapped_reward, _, _, _ = wrapped_env.step(action)
|
||||||
|
|
||||||
assert wrapped_reward == scale * reward
|
assert wrapped_reward == scale * reward
|
||||||
del env, wrapped_env
|
del env, wrapped_env
|
||||||
|
|
||||||
# use case #2: clip
|
# use case #2: clip
|
||||||
min_r = -0.0005
|
min_r = -0.0005
|
||||||
|
Reference in New Issue
Block a user