Add testing for Fundamental spaces with full coverage (#3048)

* initial commit

* Fix the multi-binary sample and upgrade multi-discrete sample

* Fix MultiBinary sample tests and pre-commit

* Adds coverage tests and updates test_utils.py to use the utils.py spaces. Fix bug in text.py

* pre-commit
This commit is contained in:
Mark Towers
2022-09-03 22:56:29 +01:00
committed by GitHub
parent 43b42d5280
commit 8e74fe3b62
17 changed files with 933 additions and 1456 deletions

View File

@@ -164,7 +164,9 @@ class Box(Space[np.ndarray]):
elif manner == "above": elif manner == "above":
return above return above
else: else:
raise ValueError("manner is not in {'below', 'above', 'both'}") raise ValueError(
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
)
def sample(self, mask: None = None) -> np.ndarray: def sample(self, mask: None = None) -> np.ndarray:
r"""Generates a single random sample inside the Box. r"""Generates a single random sample inside the Box.
@@ -223,7 +225,10 @@ class Box(Space[np.ndarray]):
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
if not isinstance(x, np.ndarray): if not isinstance(x, np.ndarray):
logger.warn("Casting input x to numpy array.") logger.warn("Casting input x to numpy array.")
x = np.asarray(x, dtype=self.dtype) try:
x = np.asarray(x, dtype=self.dtype)
except (ValueError, TypeError):
return False
return bool( return bool(
np.can_cast(x.dtype, self.dtype) np.can_cast(x.dtype, self.dtype)
@@ -236,7 +241,7 @@ class Box(Space[np.ndarray]):
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist() return np.array(sample_n).tolist()
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> List[np.ndarray]: def from_jsonable(self, sample_n: Sequence[Union[float, int]]) -> List[np.ndarray]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n] return [np.asarray(sample) for sample in sample_n]
@@ -252,10 +257,11 @@ class Box(Space[np.ndarray]):
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})" return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
"""Check whether `other` is equivalent to this instance.""" """Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
return ( return (
isinstance(other, Box) isinstance(other, Box)
and (self.shape == other.shape) and (self.shape == other.shape)
# and (self.dtype == other.dtype)
and np.allclose(self.low, other.low) and np.allclose(self.low, other.low)
and np.allclose(self.high, other.high) and np.allclose(self.high, other.high)
) )

View File

@@ -85,18 +85,19 @@ class Discrete(Space[int]):
if isinstance(x, int): if isinstance(x, int):
as_int = x as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and ( elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == () np.issubdtype(x.dtype, np.integer) and x.shape == ()
): ):
as_int = int(x) # type: ignore as_int = int(x) # type: ignore
else: else:
return False return False
return self.start <= as_int < self.start + self.n return self.start <= as_int < self.start + self.n
def __repr__(self) -> str: def __repr__(self) -> str:
"""Gives a string representation of this space.""" """Gives a string representation of this space."""
if self.start != 0: if self.start != 0:
return "Discrete(%d, start=%d)" % (self.n, self.start) return f"Discrete({self.n}, start={self.start})"
return "Discrete(%d)" % self.n return f"Discrete({self.n})"
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance.""" """Check whether ``other`` is equivalent to this instance."""
@@ -114,8 +115,6 @@ class Discrete(Space[int]):
Args: Args:
state: The new state state: The new state
""" """
super().__setstate__(state)
# Don't mutate the original state # Don't mutate the original state
state = dict(state) state = dict(state)
@@ -124,5 +123,4 @@ class Discrete(Space[int]):
if "start" not in state: if "start" not in state:
state["start"] = 0 state["start"] = 0
# Update our state super().__setstate__(state)
self.__dict__.update(state)

View File

@@ -6,7 +6,7 @@ import numpy as np
from gym.spaces.box import Box from gym.spaces.box import Box
from gym.spaces.discrete import Discrete from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import SAMPLE_MASK_TYPE, MultiDiscrete from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.space import Space from gym.spaces.space import Space
@@ -97,8 +97,8 @@ class Graph(Space):
self, self,
mask: Optional[ mask: Optional[
Tuple[ Tuple[
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]], Optional[Union[np.ndarray, tuple]],
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]], Optional[Union[np.ndarray, tuple]],
] ]
] = None, ] = None,
num_nodes: int = 10, num_nodes: int = 10,

View File

@@ -62,7 +62,8 @@ class MultiBinary(Space[np.ndarray]):
Args: Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``. mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
Where mask == 0 then the samples will be 0. For mask == 0 then the samples will be 0 and mask == 1 then random samples will be generated.
The expected mask shape is the space shape and mask dtype is `np.int8`.
Returns: Returns:
Sampled values from space Sampled values from space
@@ -78,11 +79,13 @@ class MultiBinary(Space[np.ndarray]):
mask.shape == self.shape mask.shape == self.shape
), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}" ), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}"
assert np.all( assert np.all(
np.logical_or(mask == 0, mask == 1) (mask == 0) | (mask == 1) | (mask == 2)
), f"All values of a mask should be 0 or 1, actual values: {mask}" ), f"All values of a mask should be 0, 1 or 2, actual values: {mask}"
return mask * self.np_random.integers( return np.where(
low=0, high=2, size=self.n, dtype=self.dtype mask == 2,
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
mask.astype(self.dtype),
) )
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
@@ -91,9 +94,12 @@ class MultiBinary(Space[np.ndarray]):
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence): if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape:
return False return bool(
return ((x == 0) | (x == 1)).all() isinstance(x, np.ndarray)
and self.shape == x.shape
and np.all((x == 0) | (x == 1))
)
def to_jsonable(self, sample_n) -> list: def to_jsonable(self, sample_n) -> list:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
@@ -101,7 +107,7 @@ class MultiBinary(Space[np.ndarray]):
def from_jsonable(self, sample_n) -> list: def from_jsonable(self, sample_n) -> list:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n] return [np.asarray(sample, self.dtype) for sample in sample_n]
def __repr__(self) -> str: def __repr__(self) -> str:
"""Gives a string representation of this space.""" """Gives a string representation of this space."""

View File

@@ -7,8 +7,6 @@ from gym import logger
from gym.spaces.discrete import Discrete from gym.spaces.discrete import Discrete
from gym.spaces.space import Space from gym.spaces.space import Space
SAMPLE_MASK_TYPE = Tuple[Union["SAMPLE_MASK_TYPE", np.ndarray], ...]
class MultiDiscrete(Space[np.ndarray]): class MultiDiscrete(Space[np.ndarray]):
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces. """This represents the cartesian product of arbitrary :class:`Discrete` spaces.
@@ -39,7 +37,7 @@ class MultiDiscrete(Space[np.ndarray]):
def __init__( def __init__(
self, self,
nvec: Union[np.ndarray, List[int]], nvec: Union[np.ndarray, list],
dtype=np.int64, dtype=np.int64,
seed: Optional[Union[int, np.random.Generator]] = None, seed: Optional[Union[int, np.random.Generator]] = None,
): ):
@@ -68,7 +66,7 @@ class MultiDiscrete(Space[np.ndarray]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" """Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True return True
def sample(self, mask: Optional[SAMPLE_MASK_TYPE] = None) -> np.ndarray: def sample(self, mask: Optional[tuple] = None) -> np.ndarray:
"""Generates a single random sample this space. """Generates a single random sample this space.
Args: Args:
@@ -82,15 +80,30 @@ class MultiDiscrete(Space[np.ndarray]):
if mask is not None: if mask is not None:
def _apply_mask( def _apply_mask(
sub_mask: SAMPLE_MASK_TYPE, sub_nvec: np.ndarray sub_mask: Union[np.ndarray, tuple],
sub_nvec: Union[np.ndarray, np.integer],
) -> Union[int, List[int]]: ) -> Union[int, List[int]]:
if isinstance(sub_mask, np.ndarray): if isinstance(sub_nvec, np.ndarray):
assert isinstance(
sub_mask, tuple
), f"Expects the mask to be a tuple for sub_nvec ({sub_nvec}), actual type: {type(sub_mask)}"
assert len(sub_mask) == len(
sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
return [
_apply_mask(new_mask, new_nvec)
for new_mask, new_nvec in zip(sub_mask, sub_nvec)
]
else:
assert np.issubdtype( assert np.issubdtype(
type(sub_nvec), np.integer type(sub_nvec), np.integer
), f"Expects the mask to be for an action, actual for {sub_nvec}" ), f"Expects the sub_nvec to be an action, actually: {sub_nvec}, {type(sub_nvec)}"
assert isinstance(
sub_mask, np.ndarray
), f"Expects the sub mask to be np.ndarray, actual type: {type(sub_mask)}"
assert ( assert (
len(sub_mask) == sub_nvec len(sub_mask) == sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {sub_nvec}" ), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, action: {sub_nvec}"
assert ( assert (
sub_mask.dtype == np.int8 sub_mask.dtype == np.int8
), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}" ), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}"
@@ -104,17 +117,6 @@ class MultiDiscrete(Space[np.ndarray]):
return self.np_random.choice(np.where(valid_action_mask)[0]) return self.np_random.choice(np.where(valid_action_mask)[0])
else: else:
return 0 return 0
else:
assert isinstance(
sub_mask, tuple
), f"Expects the mask to be a tuple or np.ndarray, actual type: {type(sub_mask)}"
assert len(sub_mask) == len(
sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
return [
_apply_mask(new_mask, new_nvec)
for new_mask, new_nvec in zip(sub_mask, sub_nvec)
]
return np.array(_apply_mask(mask, self.nvec), dtype=self.dtype) return np.array(_apply_mask(mask, self.nvec), dtype=self.dtype)
@@ -124,9 +126,16 @@ class MultiDiscrete(Space[np.ndarray]):
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence): if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check x = np.array(x) # Promote list to array for contains check
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x # if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
# is within correct bounds for space dtype (even though x does not have to be unsigned) # is within correct bounds for space dtype (even though x does not have to be unsigned)
return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all()) return bool(
isinstance(x, np.ndarray)
and x.shape == self.shape
and x.dtype != object
and np.all(0 <= x)
and np.all(x < self.nvec)
)
def to_jsonable(self, sample_n: Iterable[np.ndarray]): def to_jsonable(self, sample_n: Iterable[np.ndarray]):
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
@@ -147,13 +156,18 @@ class MultiDiscrete(Space[np.ndarray]):
subspace = Discrete(nvec) subspace = Discrete(nvec)
else: else:
subspace = MultiDiscrete(nvec, self.dtype) # type: ignore subspace = MultiDiscrete(nvec, self.dtype) # type: ignore
# you don't need to deepcopy as np random generator call replaces the state not the data
subspace.np_random.bit_generator.state = self.np_random.bit_generator.state subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
return subspace return subspace
def __len__(self): def __len__(self):
"""Gives the ``len`` of samples from this space.""" """Gives the ``len`` of samples from this space."""
if self.nvec.ndim >= 2: if self.nvec.ndim >= 2:
logger.warn("Get length of a multi-dimensional MultiDiscrete space.") logger.warn(
"Getting the length of a multi-dimensional MultiDiscrete space."
)
return len(self.nvec) return len(self.nvec)
def __eq__(self, other): def __eq__(self, other):

View File

@@ -1,5 +1,5 @@
"""Implementation of a space that represents textual strings.""" """Implementation of a space that represents textual strings."""
from typing import Any, FrozenSet, List, Optional, Set, Tuple, Union from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
import numpy as np import numpy as np
@@ -27,7 +27,7 @@ class Text(Space[str]):
self, self,
max_length: int, max_length: int,
*, *,
min_length: int = 0, min_length: int = 1,
charset: Union[Set[str], str] = alphanumeric, charset: Union[Set[str], str] = alphanumeric,
seed: Optional[Union[int, np.random.Generator]] = None, seed: Optional[Union[int, np.random.Generator]] = None,
): ):
@@ -36,9 +36,9 @@ class Text(Space[str]):
Both bounds for text length are inclusive. Both bounds for text length are inclusive.
Args: Args:
min_length (int): Minimum text length (in characters). min_length (int): Minimum text length (in characters). Defaults to 1 to prevent empty strings.
max_length (int): Maximum text length (in characters). max_length (int): Maximum text length (in characters).
charset (Union[set, SupportsIndex]): Character set, defaults to the lower and upper english alphabet plus latin digits. charset (Union[set], str): Character set, defaults to the lower and upper english alphabet plus latin digits.
seed: The seed for sampling from the space. seed: The seed for sampling from the space.
""" """
assert np.issubdtype( assert np.issubdtype(
@@ -56,9 +56,13 @@ class Text(Space[str]):
self.min_length: int = int(min_length) self.min_length: int = int(min_length)
self.max_length: int = int(max_length) self.max_length: int = int(max_length)
self.charset: FrozenSet[str] = frozenset(charset)
self._charlist: List[str] = list(charset) self._char_set: FrozenSet[str] = frozenset(charset)
self._charset_str: str = "".join(sorted(self._charlist)) self._char_list: Tuple[str, ...] = tuple(charset)
self._char_index: Dict[str, np.int32] = {
val: np.int32(i) for i, val in enumerate(tuple(charset))
}
self._char_str: str = "".join(sorted(tuple(charset)))
# As the shape is dynamic (between min_length and max_length) then None # As the shape is dynamic (between min_length and max_length) then None
super().__init__(dtype=str, seed=seed) super().__init__(dtype=str, seed=seed)
@@ -71,20 +75,42 @@ class Text(Space[str]):
Args: Args:
mask: An optional tuples of length and mask for the text. mask: An optional tuples of length and mask for the text.
The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected. The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected.
For the mask, we expect a numpy array of length of the charset passed with dtype == np.int8 For the mask, we expect a numpy array of length of the charset passed with `dtype == np.int8`.
If the charlist mask is all zero then an empty string is returned no matter the `min_length`
Returns: Returns:
A sampled string from the space A sampled string from the space
""" """
if mask is not None: if mask is not None:
assert isinstance(
mask, tuple
), f"Expects the mask type to be a tuple, actual type: {type(mask)}"
assert (
len(mask) == 2
), f"Expects the mask length to be two, actual length: {len(mask)}"
length, charlist_mask = mask length, charlist_mask = mask
if length is not None: if length is not None:
assert self.min_length <= length <= self.max_length assert np.issubdtype(
type(length), np.integer
), f"Expects the Text sample length to be an integer, actual type: {type(length)}"
assert (
self.min_length <= length <= self.max_length
), f"Expects the Text sample length be between {self.min_length} and {self.max_length}, actual length: {length}"
if charlist_mask is not None: if charlist_mask is not None:
assert isinstance(charlist_mask, np.ndarray) assert isinstance(
assert charlist_mask.dtype is np.int8 charlist_mask, np.ndarray
assert charlist_mask.shape == (len(self._charlist),) ), f"Expects the Text sample mask to be an np.ndarray, actual type: {type(charlist_mask)}"
assert (
charlist_mask.dtype == np.int8
), f"Expects the Text sample mask to be an np.ndarray, actual dtype: {charlist_mask.dtype}"
assert charlist_mask.shape == (
len(self.character_set),
), f"expects the Text sample mask to be {(len(self.character_set),)}, actual shape: {charlist_mask.shape}"
assert np.all(
np.logical_or(charlist_mask == 0, charlist_mask == 1)
), f"Expects all masks values to 0 or 1, actual values: {charlist_mask}"
else: else:
length, charlist_mask = None, None length, charlist_mask = None, None
@@ -92,10 +118,23 @@ class Text(Space[str]):
length = self.np_random.integers(self.min_length, self.max_length + 1) length = self.np_random.integers(self.min_length, self.max_length + 1)
if charlist_mask is None: if charlist_mask is None:
string = self.np_random.choice(self._charlist, size=length) string = self.np_random.choice(self.character_list, size=length)
else: else:
masked_charlist = self._charlist[np.where(mask)[0]] valid_mask = charlist_mask == 1
string = self.np_random.choice(masked_charlist, size=length) valid_indexes = np.where(valid_mask)[0]
if len(valid_indexes) == 0:
if self.min_length == 0:
string = ""
else:
# Otherwise the string will not be contained in the space
raise ValueError(
f"Trying to sample with a minimum length > 0 ({self.min_length}) but the character mask is all zero meaning that no character could be sampled."
)
else:
string = "".join(
self.character_list[index]
for index in self.np_random.choice(valid_indexes, size=length)
)
return "".join(string) return "".join(string)
@@ -103,13 +142,13 @@ class Text(Space[str]):
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""
if isinstance(x, str): if isinstance(x, str):
if self.min_length <= len(x) <= self.max_length: if self.min_length <= len(x) <= self.max_length:
return all(c in self.charset for c in x) return all(c in self.character_set for c in x)
return False return False
def __repr__(self) -> str: def __repr__(self) -> str:
"""Gives a string representation of this space.""" """Gives a string representation of this space."""
return ( return (
f"Text({self.min_length}, {self.max_length}, charset={self._charset_str})" f"Text({self.min_length}, {self.max_length}, characters={self.characters})"
) )
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
@@ -118,5 +157,29 @@ class Text(Space[str]):
isinstance(other, Text) isinstance(other, Text)
and self.min_length == other.min_length and self.min_length == other.min_length
and self.max_length == other.max_length and self.max_length == other.max_length
and self.charset == other.charset and self.character_set == other.character_set
) )
@property
def character_set(self) -> FrozenSet[str]:
"""Returns the character set for the space."""
return self._char_set
@property
def character_list(self) -> Tuple[str, ...]:
"""Returns a tuple of characters in the space."""
return self._char_list
def character_index(self, char: str) -> np.int32:
"""Returns a unique index for each character in the space's character set."""
return self._char_index[char]
@property
def characters(self) -> str:
"""Returns a string with all Text characters."""
return self._char_str
@property
def is_np_flattenable(self) -> bool:
"""The flattened version is an integer array for each character, padded to the max character length."""
return True

View File

@@ -21,6 +21,7 @@ from gym.spaces import (
MultiDiscrete, MultiDiscrete,
Sequence, Sequence,
Space, Space,
Text,
Tuple, Tuple,
) )
@@ -88,6 +89,11 @@ def _flatdim_dict(space: Dict) -> int:
) )
@flatdim.register(Text)
def _flatdim_text(space: Text) -> int:
return space.max_length
T = TypeVar("T") T = TypeVar("T")
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance] FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
@@ -104,7 +110,9 @@ def flatten(space: Space[T], x: T) -> FlatType:
x: The value to flatten x: The value to flatten
Returns: Returns:
- The flattened ``x``, always returns a 1D array for non-graph spaces. - For ``Box`` and ``MultiBinary``, this is a flattened array
- For ``Discrete`` and ``MultiDiscrete``, this is a flattened one-hot array of the sample
- For ``Tuple`` and ``Dict``, this is a concatenated array the subspaces (does not support graph subspaces)
- For graph spaces, returns `GraphInstance` where: - For graph spaces, returns `GraphInstance` where:
- `nodes` are n x k arrays - `nodes` are n x k arrays
- `edges` are either: - `edges` are either:
@@ -179,6 +187,16 @@ def _flatten_graph(space, x) -> GraphInstance:
return GraphInstance(nodes, edges, x.edge_links) return GraphInstance(nodes, edges, x.edge_links)
@flatten.register(Text)
def _flatten_text(space: Text, x: str) -> np.ndarray:
arr = np.full(
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
)
for i, val in enumerate(x):
arr[i] = space.character_index(val)
return arr
@flatten.register(Sequence) @flatten.register(Sequence)
def _flatten_sequence(space, x) -> tuple: def _flatten_sequence(space, x) -> tuple:
return tuple(flatten(space.feature_space, item) for item in x) return tuple(flatten(space.feature_space, item) for item in x)
@@ -284,6 +302,13 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
return GraphInstance(nodes, edges, x.edge_links) return GraphInstance(nodes, edges, x.edge_links)
@unflatten.register(Text)
def _unflatten_text(space: Text, x: np.ndarray) -> str:
return "".join(
[space.character_list[val] for val in x if val < len(space.character_set)]
)
@unflatten.register(Sequence) @unflatten.register(Sequence)
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple: def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
return tuple(unflatten(space.feature_space, item) for item in x) return tuple(unflatten(space.feature_space, item) for item in x)
@@ -401,6 +426,13 @@ def _flatten_space_graph(space: Graph) -> Graph:
) )
@flatten_space.register(Text)
def _flatten_space_text(space: Text) -> Box:
return Box(
low=0, high=len(space.character_set), shape=(space.max_length,), dtype=np.int32
)
@flatten_space.register(Sequence) @flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence: def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space)) return Sequence(flatten_space(space.feature_space))

View File

@@ -50,7 +50,9 @@ def data_equivalence(data_1, data_2) -> bool:
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2) data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
) )
elif isinstance(data_1, np.ndarray): elif isinstance(data_1, np.ndarray):
return np.all(data_1 == data_2) return data_1.shape == data_2.shape and np.allclose(
data_1, data_2, atol=0.00001
)
else: else:
return data_1 == data_2 return data_1 == data_2
else: else:

View File

@@ -4,26 +4,26 @@ import warnings
import numpy as np import numpy as np
import pytest import pytest
import gym.error
from gym.spaces import Box from gym.spaces import Box
from gym.spaces.box import get_inf
# Todo, move Box unique tests from test_spaces.py to test_box.py
@pytest.mark.parametrize( @pytest.mark.parametrize(
"box,expected_shape", "box,expected_shape",
[ [
( ( # Test with same 1-dim low and high shape
Box(low=np.zeros(2), high=np.ones(2)), Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32),
(2,), (2,),
), # Test with same 1-dim low and high shape ),
( ( # Test with same multi-dim low and high shape
Box(low=np.zeros((2, 1)), high=np.ones((2, 1))), Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32),
(2, 1), (2, 1),
), # Test with same multi-dim low and high shape ),
( ( # Test with scalar low high and different shape
Box(low=0, high=1, shape=(5, 2)), Box(low=0, high=1, shape=(5, 2)),
(5, 2), (5, 2),
), # Test with scalar low high and different shape ),
(Box(low=0, high=1), (1,)), # Test with int and int (Box(low=0, high=1), (1,)), # Test with int and int
(Box(low=0.0, high=1.0), (1,)), # Test with float and float (Box(low=0.0, high=1.0), (1,)), # Test with float and float
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)), (Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
@@ -33,7 +33,8 @@ from gym.spaces import Box
(Box(low=np.zeros(3), high=1.0), (3,)), # Test with array and scalar (Box(low=np.zeros(3), high=1.0), (3,)), # Test with array and scalar
], ],
) )
def test_box_shape_inference(box, expected_shape): def test_shape_inference(box, expected_shape):
"""Test that the shape inference is as expected."""
assert box.shape == expected_shape assert box.shape == expected_shape
assert box.sample().shape == expected_shape assert box.sample().shape == expected_shape
@@ -48,7 +49,7 @@ def test_box_shape_inference(box, expected_shape):
(np.zeros(2, dtype=np.float32), True), (np.zeros(2, dtype=np.float32), True),
(np.zeros((2, 2), dtype=np.float32), True), (np.zeros((2, 2), dtype=np.float32), True),
(np.inf, True), (np.inf, True),
(np.nan, True), # This is a weird side (np.nan, True), # This is a weird case that we allow
(True, False), (True, False),
(np.bool8(True), False), (np.bool8(True), False),
(1 + 1j, False), (1 + 1j, False),
@@ -56,7 +57,8 @@ def test_box_shape_inference(box, expected_shape):
("string", False), ("string", False),
], ],
) )
def test_box_values(value, valid): def test_low_high_values(value, valid: bool):
"""Test what `low` and `high` values are valid for `Box` space."""
if valid: if valid:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
Box(low=value, high=value) Box(low=value, high=value)
@@ -66,7 +68,9 @@ def test_box_values(value, valid):
else: else:
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r"expect their types to be np\.ndarray, an integer or a float", match=re.escape(
"expect their types to be np.ndarray, an integer or a float"
),
): ):
Box(low=value, high=value) Box(low=value, high=value)
@@ -135,6 +139,178 @@ def test_box_values(value, valid):
), ),
], ],
) )
def test_box_errors(low, high, kwargs, error, message): def test_init_errors(low, high, kwargs, error, message):
"""Test all constructor errors."""
with pytest.raises(error, match=f"^{re.escape(message)}$"): with pytest.raises(error, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, **kwargs) Box(low=low, high=high, **kwargs)
def test_dtype_check():
"""Tests the Box contains function with different dtypes."""
# Related Issues:
# https://github.com/openai/gym/issues/2357
# https://github.com/openai/gym/issues/2298
space = Box(0, 1, (), dtype=np.float32)
# casting will match the correct type
assert np.array(0.5, dtype=np.float32) in space
# float16 is in float32 space
assert np.array(0.5, dtype=np.float16) in space
# float64 is not in float32 space
assert np.array(0.5, dtype=np.float64) not in space
@pytest.mark.parametrize(
"space",
[
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
],
)
def test_infinite_space(space):
"""
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
are both modified within the init, we check for infinite when we know it's not 0
"""
assert np.all(
space.low < space.high
), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})"
space.seed(0)
sample = space.sample()
# check if space contains sample
assert (
sample in space
), f"Sample ({sample}) not inside space according to `space.contains()`"
# manually check that the sign of the sample is within the bounds
assert np.all(
np.sign(sample) <= np.sign(space.high)
), f"Sign of sample ({sample}) is less than space upper bound ({space.high})"
assert np.all(
np.sign(space.low) <= np.sign(sample)
), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"
# check that int bounds are bounded for everything
# but floats are unbounded for infinite
if np.any(space.high != 0):
assert (
space.is_bounded("above") is False
), "inf upper bound supposed to be unbounded"
else:
assert (
space.is_bounded("above") is True
), "non-inf upper bound supposed to be bounded"
if np.any(space.low != 0):
assert (
space.is_bounded("below") is False
), "inf lower bound supposed to be unbounded"
else:
assert (
space.is_bounded("below") is True
), "non-inf lower bound supposed to be bounded"
if np.any(space.low != 0) or np.any(space.high != 0):
assert space.is_bounded("both") is False
else:
assert space.is_bounded("both") is True
# check for dtype
assert (
space.high.dtype == space.dtype
), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'"
assert (
space.low.dtype == space.dtype
), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
with pytest.raises(
ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:"
):
space.is_bounded("test")
def test_legacy_state_pickling():
legacy_state = {
"dtype": np.dtype("float32"),
"_shape": (5,),
"low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
"high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
"bounded_below": np.array([True, True, True, True, True]),
"bounded_above": np.array([True, True, True, True, True]),
"_np_random": None,
}
b = Box(-1, 1, ())
assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
del b.__dict__["low_repr"]
del b.__dict__["high_repr"]
assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__
b.__setstate__(legacy_state)
assert b.low_repr == "0.0"
assert b.high_repr == "1.0"
def test_get_inf():
"""Tests that get inf function works as expected, primarily for coverage."""
assert get_inf(np.float32, "+") == np.inf
assert get_inf(np.float16, "-") == -np.inf
with pytest.raises(
TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
):
get_inf(np.float32, "*")
assert get_inf(np.int16, "+") == 32765
assert get_inf(np.int8, "-") == -126
with pytest.raises(
TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
):
get_inf(np.int32, "*")
with pytest.raises(
ValueError,
match=re.escape("Unknown dtype <class 'numpy.complex128'> for infinite bounds"),
):
get_inf(np.complex_, "+")
def test_sample_mask():
"""Box cannot have a mask applied."""
space = Box(0, 1)
with pytest.raises(
gym.error.Error,
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
):
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))

View File

@@ -0,0 +1,40 @@
import numpy as np
from gym.spaces import Discrete
def test_space_legacy_pickling():
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
legacy_state = {
"shape": (
1,
2,
3,
),
"dtype": np.int64,
"np_random": np.random.default_rng(),
"n": 3,
}
space = Discrete(1)
space.__setstate__(legacy_state)
assert space.shape == legacy_state["shape"]
assert space.np_random == legacy_state["np_random"]
assert space.n == 3
assert space.dtype == legacy_state["dtype"]
# Test that start is missing
assert "start" in space.__dict__
del space.__dict__["start"] # legacy did not include start param
assert "start" not in space.__dict__
space.__setstate__(legacy_state)
assert space.start == 0
def test_sample_mask():
space = Discrete(4, start=2)
assert 2 <= space.sample() < 6
assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]

View File

@@ -0,0 +1,3 @@
def test_sample():
# todo
pass

View File

@@ -0,0 +1,66 @@
import pytest
from gym.spaces import Discrete, MultiDiscrete
from gym.utils.env_checker import data_equivalence
def test_multidiscrete_as_tuple():
# 1D multi-discrete
space = MultiDiscrete([3, 4, 5])
assert space.shape == (3,)
assert space[0] == Discrete(3)
assert space[0:1] == MultiDiscrete([3])
assert space[0:2] == MultiDiscrete([3, 4])
assert space[:] == space and space[:] is not space
# 2D multi-discrete
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
assert space.shape == (2, 3)
assert space[0, 1] == Discrete(4)
assert space[0] == MultiDiscrete([3, 4, 5])
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
assert space[:] == space and space[:] is not space
assert space[:, :] == space and space[:, :] is not space
def test_multidiscrete_subspace_reproducibility():
# 1D multi-discrete
space = MultiDiscrete([100, 200, 300])
space.seed()
assert data_equivalence(space[0].sample(), space[0].sample())
assert data_equivalence(space[0:1].sample(), space[0:1].sample())
assert data_equivalence(space[0:2].sample(), space[0:2].sample())
assert data_equivalence(space[:].sample(), space[:].sample())
assert data_equivalence(space[:].sample(), space.sample())
# 2D multi-discrete
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
space.seed()
assert data_equivalence(space[0, 1].sample(), space[0, 1].sample())
assert data_equivalence(space[0].sample(), space[0].sample())
assert data_equivalence(space[0:1].sample(), space[0:1].sample())
assert data_equivalence(space[0:2, :].sample(), space[0:2, :].sample())
assert data_equivalence(space[:, 0:1].sample(), space[:, 0:1].sample())
assert data_equivalence(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
assert data_equivalence(space[:].sample(), space[:].sample())
assert data_equivalence(space[:, :].sample(), space[:, :].sample())
assert data_equivalence(space[:, :].sample(), space.sample())
def test_multidiscrete_length():
space = MultiDiscrete(nvec=[3, 2, 4])
assert len(space) == 3
space = MultiDiscrete(nvec=[[2, 3], [3, 2]])
with pytest.warns(
UserWarning,
match="Getting the length of a multi-dimensional MultiDiscrete space.",
):
assert len(space) == 2

View File

@@ -0,0 +1,24 @@
from functools import partial
import pytest
from gym import Space
from gym.spaces import utils
TESTING_SPACE = Space()
@pytest.mark.parametrize(
"func",
[
TESTING_SPACE.sample,
partial(TESTING_SPACE.contains, None),
partial(utils.flatdim, TESTING_SPACE),
partial(utils.flatten, TESTING_SPACE, None),
partial(utils.flatten_space, TESTING_SPACE),
partial(utils.unflatten, TESTING_SPACE, None),
],
)
def test_not_implemented_errors(func):
with pytest.raises(NotImplementedError):
func()

File diff suppressed because it is too large Load Diff

41
tests/spaces/test_text.py Normal file
View File

@@ -0,0 +1,41 @@
import re
import numpy as np
import pytest
from gym.spaces import Text
def test_sample_mask():
space = Text(min_length=1, max_length=5)
# Test the sample length
sample = space.sample(mask=(3, None))
assert sample in space
assert len(sample) == 3
sample = space.sample(mask=None)
assert sample in space
assert 1 <= len(sample) <= 5
with pytest.raises(
ValueError,
match=re.escape(
"Trying to sample with a minimum length > 0 (1) but the character mask is all zero meaning that no character could be sampled."
),
):
space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8)))
space = Text(min_length=0, max_length=5)
sample = space.sample(
mask=(None, np.zeros(len(space.character_set), dtype=np.int8))
)
assert sample in space
assert sample == ""
# Test the sample characters
space = Text(max_length=5, charset="abcd")
sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8)))
assert sample in space
assert sample == "bbb"

View File

@@ -1,601 +1,135 @@
import re from itertools import zip_longest
from collections import OrderedDict from typing import Optional
import numpy as np import numpy as np
import pytest import pytest
from gym.spaces import ( import gym
Box, from gym.spaces import Box, Graph, utils
Dict, from gym.utils.env_checker import data_equivalence
Discrete, from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Tuple,
utils,
)
homogeneous_spaces = [ TESTING_SPACES_EXPECTED_FLATDIMS = [
Discrete(3), # Discrete
Box(low=0.0, high=np.inf, shape=(2, 2)),
Box(low=0.0, high=np.inf, shape=(2, 2), dtype=np.float16),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
[
Discrete(5),
Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64),
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 10]),
MultiBinary(10),
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64
),
}
),
Discrete(3, start=2),
Discrete(8, start=-5),
]
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
non_homogenous_spaces = [
Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), #
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), #
Graph(node_space=Discrete(5), edge_space=None), #
Sequence(Discrete(4)), #
Sequence(Box(-10, 10, shape=(2, 2))), #
Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), #
Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), #
Dict(
a=Graph(node_space=Discrete(4), edge_space=Discrete(4)),
b=Box(-10, 10, shape=(2, 2)),
), #
Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), #
Tuple(
[
Graph(node_space=Discrete(4), edge_space=Discrete(4)),
Box(-10, 10, shape=(2, 2)),
]
), #
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), #
Dict(
a=Dict(
a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2))
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
), #
Dict(
a=Dict(
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
b=Box(-100, 100, shape=(2, 2)),
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
),
]
@pytest.mark.parametrize("space", non_homogenous_spaces)
def test_non_flattenable(space):
assert space.is_np_flattenable is False
with pytest.raises(
ValueError,
match=re.escape(
"cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
),
):
utils.flatdim(space)
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
def test_flatdim(space, flatdim):
assert space.is_np_flattenable
dim = utils.flatdim(space)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_space_boxes(space):
flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
flatdim = utils.flatdim(space)
(single_dim,) = flat_space.shape
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flat_space_contains_flat_points(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples):
assert flat_space.contains(
flat_sample
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_dim(space):
sample = utils.flatten(space, space.sample())
(single_dim,) = sample.shape
flatdim = utils.flatdim(space)
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
roundtripped_samples = [
utils.unflatten(space, sample) for sample in flattened_samples
]
for i, (original, roundtripped) in enumerate(
zip(some_samples, roundtripped_samples)
):
assert compare_nested(
original, roundtripped
), f"Expected sample #{i} {original} to equal {roundtripped}"
assert space.contains(roundtripped)
def compare_nested(left, right):
if type(left) != type(right):
return False
elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
return left.shape == right.shape and np.allclose(left, right)
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
res = len(left) == len(right)
for ((left_key, left_value), (right_key, right_value)) in zip(
left.items(), right.items()
):
if not res:
return False
res = left_key == right_key and compare_nested(left_value, right_value)
return res
elif isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
res = len(left) == len(right)
for (x, y) in zip(left, right):
if not res:
return False
res = compare_nested(x, y)
return res
else:
return left == right
"""
Expecteded flattened types are based off:
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
internally by numpy when np.concatenate is called.
"""
expected_flattened_dtypes = [
np.int64,
np.float32,
np.float16,
np.int64,
np.float64,
np.int64,
np.int64,
np.int8,
np.float64,
np.int64,
np.int64,
]
@pytest.mark.parametrize(
["original_space", "expected_flattened_dtype"],
zip(homogeneous_spaces, expected_flattened_dtypes),
)
def test_dtypes(original_space, expected_flattened_dtype):
flattened_space = utils.flatten_space(original_space)
original_sample = original_space.sample()
flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample)
assert flattened_space.contains(
flattened_sample
), "Expected flattened_space to contain flattened_sample"
assert (
flattened_space.dtype == expected_flattened_dtype
), f"Expected flattened_space's dtype to equal {expected_flattened_dtype}"
assert flattened_sample.dtype == flattened_space.dtype, (
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "
)
compare_sample_types(original_space, original_sample, unflattened_sample)
def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete):
assert isinstance(unflattened_sample, int), (
"Expected unflattened_sample to be an int. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
)
elif isinstance(original_space, Tuple):
for index in range(len(original_space)):
compare_sample_types(
original_space.spaces[index],
original_sample[index],
unflattened_sample[index],
)
elif isinstance(original_space, Dict):
for key, space in original_space.spaces.items():
compare_sample_types(space, original_sample[key], unflattened_sample[key])
else:
assert unflattened_sample.dtype == original_sample.dtype, (
"Expected unflattened_sample's dtype to equal "
"original_sample's dtype. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
)
homogeneous_samples = [
2,
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32),
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16),
(3, 7),
(2, np.array([0.5, 3.5], dtype=np.float32)),
(3, 0, 1),
np.array([0, 1, 7], dtype=np.int64),
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
OrderedDict(
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
),
3, 3,
-2, 3,
] # Box
1,
4,
expected_flattened_hom_samples = [ 2,
np.array([0, 0, 1], dtype=np.int64), 2,
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32), 2,
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16), # Multi-discrete
np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64), 4,
np.array([0, 0, 1, 0, 0, 0.5, 3.5], dtype=np.float64), 10,
np.array([0, 0, 0, 1, 0, 1, 0, 0, 1], dtype=np.int64), # Multi-binary
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64), 8,
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8), 6,
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64), # Text
np.array([0, 1, 0], dtype=np.int64), 6,
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64), 6,
] 6,
# # Tuple
non_homogenous_samples = [ # 9,
GraphInstance( # 7,
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32), # 10,
np.array( # 6,
[ # None,
0, # # Dict
], # 7,
dtype=int, # 8,
), # 17,
np.array([[0, 1]], dtype=int), # None,
), # # Graph
GraphInstance( # None,
np.array([0, 1], dtype=int), # None,
np.array([[[1, 2], [3, 4]]], dtype=np.float32), # None,
np.array([[0, 1]], dtype=int), # # Sequence
), # None,
GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)), # None,
(0, 1, 2), # None,
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
(
(np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)),
(np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)),
),
OrderedDict(
[("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[0, 1], [2, 3]], dtype=np.float32)),
]
),
((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)),
(
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
np.array([[0, 1], [2, 3]], dtype=np.float32),
),
(
GraphInstance(
nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32),
edges=np.array([0], dtype=int),
edge_links=np.array([[0, 1]]),
),
GraphInstance(
nodes=np.array(
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32
),
edges=np.array([1], dtype=int),
edge_links=np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=np.float32,
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
]
expected_flattened_non_hom_samples = [
GraphInstance(
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32),
np.array([[1, 0, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
np.array([[1, 2, 3, 4]], dtype=np.float32),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
None,
np.array([[0, 1]], dtype=int),
),
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
OrderedDict(
[
(
"a",
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
(
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32),
np.array([[1, 0, 0, 0]]),
np.array([[0, 1]]),
),
GraphInstance(
np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32),
np.array([[0, 1, 0, 0]]),
np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
] ]
@pytest.mark.parametrize( @pytest.mark.parametrize(
["space", "sample", "expected_flattened_sample"], ["space", "flatdim"],
zip( zip_longest(TESTING_SPACES, TESTING_SPACES_EXPECTED_FLATDIMS),
homogeneous_spaces + non_homogenous_spaces, ids=TESTING_SPACES_IDS,
homogeneous_samples + non_homogenous_samples,
expected_flattened_hom_samples + expected_flattened_non_hom_samples,
),
) )
def test_flatten(space, sample, expected_flattened_sample): def test_flatdim(space: gym.spaces.Space, flatdim: Optional[int]):
flattened_sample = utils.flatten(space, sample) """Checks that the flattened dims of the space is equal to an expected value."""
if space.is_np_flattenable:
dim = utils.flatdim(space)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
else:
with pytest.raises(
ValueError,
):
utils.flatdim(space)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flatten_space(space):
"""Test that the flattened spaces are a box and have the `flatdim` shape."""
flat_space = utils.flatten_space(space) flat_space = utils.flatten_space(space)
assert sample in space if space.is_np_flattenable:
assert flattened_sample in flat_space assert isinstance(flat_space, Box)
(single_dim,) = flat_space.shape
flatdim = utils.flatdim(space)
assert single_dim == flatdim
elif isinstance(flat_space, Graph):
assert isinstance(space, Graph)
(node_single_dim,) = flat_space.node_space.shape
node_flatdim = utils.flatdim(space.node_space)
assert node_single_dim == node_flatdim
if flat_space.edge_space is not None:
(edge_single_dim,) = flat_space.edge_space.shape
edge_flatdim = utils.flatdim(space.edge_space)
assert edge_single_dim == edge_flatdim
else:
assert isinstance(
space, (gym.spaces.Tuple, gym.spaces.Dict, gym.spaces.Sequence)
)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flatten(space):
"""Test that a flattened sample have the `flatdim` shape."""
flattened_sample = utils.flatten(space, space.sample())
if space.is_np_flattenable: if space.is_np_flattenable:
assert isinstance(flattened_sample, np.ndarray) assert isinstance(flattened_sample, np.ndarray)
assert flattened_sample.shape == expected_flattened_sample.shape (single_dim,) = flattened_sample.shape
assert flattened_sample.dtype == expected_flattened_sample.dtype flatdim = utils.flatdim(space)
assert np.all(flattened_sample == expected_flattened_sample)
assert single_dim == flatdim
else: else:
assert not isinstance(flattened_sample, np.ndarray) assert isinstance(flattened_sample, (tuple, dict, Graph))
assert compare_nested(flattened_sample, expected_flattened_sample)
@pytest.mark.parametrize( @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
["space", "flattened_sample", "expected_sample"], def test_flat_space_contains_flat_points(space):
zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples), """Test that the flattened samples are contained within the flattened space."""
) flattened_samples = [utils.flatten(space, space.sample()) for _ in range(10)]
def test_unflatten(space, flattened_sample, expected_sample): flat_space = utils.flatten_space(space)
sample = utils.unflatten(space, flattened_sample)
assert compare_nested(sample, expected_sample) for flat_sample in flattened_samples:
assert flat_sample in flat_space
expected_flattened_spaces = [ @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
Box(low=0, high=1, shape=(3,), dtype=np.int64), def test_flatten_roundtripping(space):
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float32), """Tests roundtripping with flattening and unflattening are equal to the original sample."""
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float16), samples = [space.sample() for _ in range(10)]
Box(low=0, high=1, shape=(15,), dtype=np.int64),
Box(
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64,
),
Box(low=0, high=1, shape=(9,), dtype=np.int64),
Box(low=0, high=1, shape=(14,), dtype=np.int64),
Box(low=0, high=1, shape=(10,), dtype=np.int8),
Box(
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64,
),
Box(low=0, high=1, shape=(3,), dtype=np.int64),
Box(low=0, high=1, shape=(8,), dtype=np.int64),
]
flattened_samples = [utils.flatten(space, sample) for sample in samples]
unflattened_samples = [
utils.unflatten(space, sample) for sample in flattened_samples
]
@pytest.mark.parametrize( for original, roundtripped in zip(samples, unflattened_samples):
["space", "expected_flattened_space"], assert data_equivalence(original, roundtripped)
zip(homogeneous_spaces, expected_flattened_spaces),
)
def test_flatten_space(space, expected_flattened_space):
flattened_space = utils.flatten_space(space)
assert flattened_space == expected_flattened_space

27
tests/spaces/utils.py Normal file
View File

@@ -0,0 +1,27 @@
from typing import List
import numpy as np
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
TESTING_FUNDAMENTAL_SPACES = [
Discrete(3),
Discrete(3, start=-1),
Box(low=0.0, high=1.0),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64),
Box(low=-np.inf, high=0.0, shape=(2, 1)),
Box(low=0.0, high=np.inf, shape=(2, 1)),
MultiDiscrete([2, 2]),
MultiDiscrete([[2, 3], [3, 2]]),
MultiBinary(8),
MultiBinary([2, 3]),
Text(6),
Text(min_length=3, max_length=6),
Text(6, charset="abcdef"),
]
TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES # + TESTING_COMPOSITE_SPACES
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS # + TESTING_COMPOSITE_SPACES_IDS