mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 21:06:59 +00:00
Add testing for Fundamental spaces with full coverage (#3048)
* initial commit * Fix the multi-binary sample and upgrade multi-discrete sample * Fix MultiBinary sample tests and pre-commit * Adds coverage tests and updates test_utils.py to use the utils.py spaces. Fix bug in text.py * pre-commit
This commit is contained in:
@@ -164,7 +164,9 @@ class Box(Space[np.ndarray]):
|
||||
elif manner == "above":
|
||||
return above
|
||||
else:
|
||||
raise ValueError("manner is not in {'below', 'above', 'both'}")
|
||||
raise ValueError(
|
||||
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
|
||||
)
|
||||
|
||||
def sample(self, mask: None = None) -> np.ndarray:
|
||||
r"""Generates a single random sample inside the Box.
|
||||
@@ -223,7 +225,10 @@ class Box(Space[np.ndarray]):
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if not isinstance(x, np.ndarray):
|
||||
logger.warn("Casting input x to numpy array.")
|
||||
try:
|
||||
x = np.asarray(x, dtype=self.dtype)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
return bool(
|
||||
np.can_cast(x.dtype, self.dtype)
|
||||
@@ -236,7 +241,7 @@ class Box(Space[np.ndarray]):
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> List[np.ndarray]:
|
||||
def from_jsonable(self, sample_n: Sequence[Union[float, int]]) -> List[np.ndarray]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
|
||||
@@ -252,10 +257,11 @@ class Box(Space[np.ndarray]):
|
||||
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""Check whether `other` is equivalent to this instance."""
|
||||
"""Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
|
||||
return (
|
||||
isinstance(other, Box)
|
||||
and (self.shape == other.shape)
|
||||
# and (self.dtype == other.dtype)
|
||||
and np.allclose(self.low, other.low)
|
||||
and np.allclose(self.high, other.high)
|
||||
)
|
||||
|
@@ -85,18 +85,19 @@ class Discrete(Space[int]):
|
||||
if isinstance(x, int):
|
||||
as_int = x
|
||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
||||
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
|
||||
np.issubdtype(x.dtype, np.integer) and x.shape == ()
|
||||
):
|
||||
as_int = int(x) # type: ignore
|
||||
else:
|
||||
return False
|
||||
|
||||
return self.start <= as_int < self.start + self.n
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
if self.start != 0:
|
||||
return "Discrete(%d, start=%d)" % (self.n, self.start)
|
||||
return "Discrete(%d)" % self.n
|
||||
return f"Discrete({self.n}, start={self.start})"
|
||||
return f"Discrete({self.n})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
@@ -114,8 +115,6 @@ class Discrete(Space[int]):
|
||||
Args:
|
||||
state: The new state
|
||||
"""
|
||||
super().__setstate__(state)
|
||||
|
||||
# Don't mutate the original state
|
||||
state = dict(state)
|
||||
|
||||
@@ -124,5 +123,4 @@ class Discrete(Space[int]):
|
||||
if "start" not in state:
|
||||
state["start"] = 0
|
||||
|
||||
# Update our state
|
||||
self.__dict__.update(state)
|
||||
super().__setstate__(state)
|
||||
|
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from gym.spaces.box import Box
|
||||
from gym.spaces.discrete import Discrete
|
||||
from gym.spaces.multi_discrete import SAMPLE_MASK_TYPE, MultiDiscrete
|
||||
from gym.spaces.multi_discrete import MultiDiscrete
|
||||
from gym.spaces.space import Space
|
||||
|
||||
|
||||
@@ -97,8 +97,8 @@ class Graph(Space):
|
||||
self,
|
||||
mask: Optional[
|
||||
Tuple[
|
||||
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
|
||||
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
|
||||
Optional[Union[np.ndarray, tuple]],
|
||||
Optional[Union[np.ndarray, tuple]],
|
||||
]
|
||||
] = None,
|
||||
num_nodes: int = 10,
|
||||
|
@@ -62,7 +62,8 @@ class MultiBinary(Space[np.ndarray]):
|
||||
|
||||
Args:
|
||||
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
|
||||
Where mask == 0 then the samples will be 0.
|
||||
For mask == 0 then the samples will be 0 and mask == 1 then random samples will be generated.
|
||||
The expected mask shape is the space shape and mask dtype is `np.int8`.
|
||||
|
||||
Returns:
|
||||
Sampled values from space
|
||||
@@ -78,11 +79,13 @@ class MultiBinary(Space[np.ndarray]):
|
||||
mask.shape == self.shape
|
||||
), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}"
|
||||
assert np.all(
|
||||
np.logical_or(mask == 0, mask == 1)
|
||||
), f"All values of a mask should be 0 or 1, actual values: {mask}"
|
||||
(mask == 0) | (mask == 1) | (mask == 2)
|
||||
), f"All values of a mask should be 0, 1 or 2, actual values: {mask}"
|
||||
|
||||
return mask * self.np_random.integers(
|
||||
low=0, high=2, size=self.n, dtype=self.dtype
|
||||
return np.where(
|
||||
mask == 2,
|
||||
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
|
||||
mask.astype(self.dtype),
|
||||
)
|
||||
|
||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
@@ -91,9 +94,12 @@ class MultiBinary(Space[np.ndarray]):
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, Sequence):
|
||||
x = np.array(x) # Promote list to array for contains check
|
||||
if self.shape != x.shape:
|
||||
return False
|
||||
return ((x == 0) | (x == 1)).all()
|
||||
|
||||
return bool(
|
||||
isinstance(x, np.ndarray)
|
||||
and self.shape == x.shape
|
||||
and np.all((x == 0) | (x == 1))
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n) -> list:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
@@ -101,7 +107,7 @@ class MultiBinary(Space[np.ndarray]):
|
||||
|
||||
def from_jsonable(self, sample_n) -> list:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
|
@@ -7,8 +7,6 @@ from gym import logger
|
||||
from gym.spaces.discrete import Discrete
|
||||
from gym.spaces.space import Space
|
||||
|
||||
SAMPLE_MASK_TYPE = Tuple[Union["SAMPLE_MASK_TYPE", np.ndarray], ...]
|
||||
|
||||
|
||||
class MultiDiscrete(Space[np.ndarray]):
|
||||
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
||||
@@ -39,7 +37,7 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nvec: Union[np.ndarray, List[int]],
|
||||
nvec: Union[np.ndarray, list],
|
||||
dtype=np.int64,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
):
|
||||
@@ -68,7 +66,7 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: Optional[SAMPLE_MASK_TYPE] = None) -> np.ndarray:
|
||||
def sample(self, mask: Optional[tuple] = None) -> np.ndarray:
|
||||
"""Generates a single random sample this space.
|
||||
|
||||
Args:
|
||||
@@ -82,15 +80,30 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
if mask is not None:
|
||||
|
||||
def _apply_mask(
|
||||
sub_mask: SAMPLE_MASK_TYPE, sub_nvec: np.ndarray
|
||||
sub_mask: Union[np.ndarray, tuple],
|
||||
sub_nvec: Union[np.ndarray, np.integer],
|
||||
) -> Union[int, List[int]]:
|
||||
if isinstance(sub_mask, np.ndarray):
|
||||
if isinstance(sub_nvec, np.ndarray):
|
||||
assert isinstance(
|
||||
sub_mask, tuple
|
||||
), f"Expects the mask to be a tuple for sub_nvec ({sub_nvec}), actual type: {type(sub_mask)}"
|
||||
assert len(sub_mask) == len(
|
||||
sub_nvec
|
||||
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
|
||||
return [
|
||||
_apply_mask(new_mask, new_nvec)
|
||||
for new_mask, new_nvec in zip(sub_mask, sub_nvec)
|
||||
]
|
||||
else:
|
||||
assert np.issubdtype(
|
||||
type(sub_nvec), np.integer
|
||||
), f"Expects the mask to be for an action, actual for {sub_nvec}"
|
||||
), f"Expects the sub_nvec to be an action, actually: {sub_nvec}, {type(sub_nvec)}"
|
||||
assert isinstance(
|
||||
sub_mask, np.ndarray
|
||||
), f"Expects the sub mask to be np.ndarray, actual type: {type(sub_mask)}"
|
||||
assert (
|
||||
len(sub_mask) == sub_nvec
|
||||
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {sub_nvec}"
|
||||
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, action: {sub_nvec}"
|
||||
assert (
|
||||
sub_mask.dtype == np.int8
|
||||
), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}"
|
||||
@@ -104,17 +117,6 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
return self.np_random.choice(np.where(valid_action_mask)[0])
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
assert isinstance(
|
||||
sub_mask, tuple
|
||||
), f"Expects the mask to be a tuple or np.ndarray, actual type: {type(sub_mask)}"
|
||||
assert len(sub_mask) == len(
|
||||
sub_nvec
|
||||
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
|
||||
return [
|
||||
_apply_mask(new_mask, new_nvec)
|
||||
for new_mask, new_nvec in zip(sub_mask, sub_nvec)
|
||||
]
|
||||
|
||||
return np.array(_apply_mask(mask, self.nvec), dtype=self.dtype)
|
||||
|
||||
@@ -124,9 +126,16 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, Sequence):
|
||||
x = np.array(x) # Promote list to array for contains check
|
||||
|
||||
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
|
||||
# is within correct bounds for space dtype (even though x does not have to be unsigned)
|
||||
return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all())
|
||||
return bool(
|
||||
isinstance(x, np.ndarray)
|
||||
and x.shape == self.shape
|
||||
and x.dtype != object
|
||||
and np.all(0 <= x)
|
||||
and np.all(x < self.nvec)
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n: Iterable[np.ndarray]):
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
@@ -147,13 +156,18 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
subspace = Discrete(nvec)
|
||||
else:
|
||||
subspace = MultiDiscrete(nvec, self.dtype) # type: ignore
|
||||
|
||||
# you don't need to deepcopy as np random generator call replaces the state not the data
|
||||
subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
|
||||
|
||||
return subspace
|
||||
|
||||
def __len__(self):
|
||||
"""Gives the ``len`` of samples from this space."""
|
||||
if self.nvec.ndim >= 2:
|
||||
logger.warn("Get length of a multi-dimensional MultiDiscrete space.")
|
||||
logger.warn(
|
||||
"Getting the length of a multi-dimensional MultiDiscrete space."
|
||||
)
|
||||
return len(self.nvec)
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@@ -1,5 +1,5 @@
|
||||
"""Implementation of a space that represents textual strings."""
|
||||
from typing import Any, FrozenSet, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -27,7 +27,7 @@ class Text(Space[str]):
|
||||
self,
|
||||
max_length: int,
|
||||
*,
|
||||
min_length: int = 0,
|
||||
min_length: int = 1,
|
||||
charset: Union[Set[str], str] = alphanumeric,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
):
|
||||
@@ -36,9 +36,9 @@ class Text(Space[str]):
|
||||
Both bounds for text length are inclusive.
|
||||
|
||||
Args:
|
||||
min_length (int): Minimum text length (in characters).
|
||||
min_length (int): Minimum text length (in characters). Defaults to 1 to prevent empty strings.
|
||||
max_length (int): Maximum text length (in characters).
|
||||
charset (Union[set, SupportsIndex]): Character set, defaults to the lower and upper english alphabet plus latin digits.
|
||||
charset (Union[set], str): Character set, defaults to the lower and upper english alphabet plus latin digits.
|
||||
seed: The seed for sampling from the space.
|
||||
"""
|
||||
assert np.issubdtype(
|
||||
@@ -56,9 +56,13 @@ class Text(Space[str]):
|
||||
|
||||
self.min_length: int = int(min_length)
|
||||
self.max_length: int = int(max_length)
|
||||
self.charset: FrozenSet[str] = frozenset(charset)
|
||||
self._charlist: List[str] = list(charset)
|
||||
self._charset_str: str = "".join(sorted(self._charlist))
|
||||
|
||||
self._char_set: FrozenSet[str] = frozenset(charset)
|
||||
self._char_list: Tuple[str, ...] = tuple(charset)
|
||||
self._char_index: Dict[str, np.int32] = {
|
||||
val: np.int32(i) for i, val in enumerate(tuple(charset))
|
||||
}
|
||||
self._char_str: str = "".join(sorted(tuple(charset)))
|
||||
|
||||
# As the shape is dynamic (between min_length and max_length) then None
|
||||
super().__init__(dtype=str, seed=seed)
|
||||
@@ -71,20 +75,42 @@ class Text(Space[str]):
|
||||
Args:
|
||||
mask: An optional tuples of length and mask for the text.
|
||||
The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected.
|
||||
For the mask, we expect a numpy array of length of the charset passed with dtype == np.int8
|
||||
For the mask, we expect a numpy array of length of the charset passed with `dtype == np.int8`.
|
||||
If the charlist mask is all zero then an empty string is returned no matter the `min_length`
|
||||
|
||||
Returns:
|
||||
A sampled string from the space
|
||||
"""
|
||||
if mask is not None:
|
||||
assert isinstance(
|
||||
mask, tuple
|
||||
), f"Expects the mask type to be a tuple, actual type: {type(mask)}"
|
||||
assert (
|
||||
len(mask) == 2
|
||||
), f"Expects the mask length to be two, actual length: {len(mask)}"
|
||||
length, charlist_mask = mask
|
||||
|
||||
if length is not None:
|
||||
assert self.min_length <= length <= self.max_length
|
||||
assert np.issubdtype(
|
||||
type(length), np.integer
|
||||
), f"Expects the Text sample length to be an integer, actual type: {type(length)}"
|
||||
assert (
|
||||
self.min_length <= length <= self.max_length
|
||||
), f"Expects the Text sample length be between {self.min_length} and {self.max_length}, actual length: {length}"
|
||||
|
||||
if charlist_mask is not None:
|
||||
assert isinstance(charlist_mask, np.ndarray)
|
||||
assert charlist_mask.dtype is np.int8
|
||||
assert charlist_mask.shape == (len(self._charlist),)
|
||||
assert isinstance(
|
||||
charlist_mask, np.ndarray
|
||||
), f"Expects the Text sample mask to be an np.ndarray, actual type: {type(charlist_mask)}"
|
||||
assert (
|
||||
charlist_mask.dtype == np.int8
|
||||
), f"Expects the Text sample mask to be an np.ndarray, actual dtype: {charlist_mask.dtype}"
|
||||
assert charlist_mask.shape == (
|
||||
len(self.character_set),
|
||||
), f"expects the Text sample mask to be {(len(self.character_set),)}, actual shape: {charlist_mask.shape}"
|
||||
assert np.all(
|
||||
np.logical_or(charlist_mask == 0, charlist_mask == 1)
|
||||
), f"Expects all masks values to 0 or 1, actual values: {charlist_mask}"
|
||||
else:
|
||||
length, charlist_mask = None, None
|
||||
|
||||
@@ -92,10 +118,23 @@ class Text(Space[str]):
|
||||
length = self.np_random.integers(self.min_length, self.max_length + 1)
|
||||
|
||||
if charlist_mask is None:
|
||||
string = self.np_random.choice(self._charlist, size=length)
|
||||
string = self.np_random.choice(self.character_list, size=length)
|
||||
else:
|
||||
masked_charlist = self._charlist[np.where(mask)[0]]
|
||||
string = self.np_random.choice(masked_charlist, size=length)
|
||||
valid_mask = charlist_mask == 1
|
||||
valid_indexes = np.where(valid_mask)[0]
|
||||
if len(valid_indexes) == 0:
|
||||
if self.min_length == 0:
|
||||
string = ""
|
||||
else:
|
||||
# Otherwise the string will not be contained in the space
|
||||
raise ValueError(
|
||||
f"Trying to sample with a minimum length > 0 ({self.min_length}) but the character mask is all zero meaning that no character could be sampled."
|
||||
)
|
||||
else:
|
||||
string = "".join(
|
||||
self.character_list[index]
|
||||
for index in self.np_random.choice(valid_indexes, size=length)
|
||||
)
|
||||
|
||||
return "".join(string)
|
||||
|
||||
@@ -103,13 +142,13 @@ class Text(Space[str]):
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, str):
|
||||
if self.min_length <= len(x) <= self.max_length:
|
||||
return all(c in self.charset for c in x)
|
||||
return all(c in self.character_set for c in x)
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
return (
|
||||
f"Text({self.min_length}, {self.max_length}, charset={self._charset_str})"
|
||||
f"Text({self.min_length}, {self.max_length}, characters={self.characters})"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@@ -118,5 +157,29 @@ class Text(Space[str]):
|
||||
isinstance(other, Text)
|
||||
and self.min_length == other.min_length
|
||||
and self.max_length == other.max_length
|
||||
and self.charset == other.charset
|
||||
and self.character_set == other.character_set
|
||||
)
|
||||
|
||||
@property
|
||||
def character_set(self) -> FrozenSet[str]:
|
||||
"""Returns the character set for the space."""
|
||||
return self._char_set
|
||||
|
||||
@property
|
||||
def character_list(self) -> Tuple[str, ...]:
|
||||
"""Returns a tuple of characters in the space."""
|
||||
return self._char_list
|
||||
|
||||
def character_index(self, char: str) -> np.int32:
|
||||
"""Returns a unique index for each character in the space's character set."""
|
||||
return self._char_index[char]
|
||||
|
||||
@property
|
||||
def characters(self) -> str:
|
||||
"""Returns a string with all Text characters."""
|
||||
return self._char_str
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self) -> bool:
|
||||
"""The flattened version is an integer array for each character, padded to the max character length."""
|
||||
return True
|
||||
|
@@ -21,6 +21,7 @@ from gym.spaces import (
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
@@ -88,6 +89,11 @@ def _flatdim_dict(space: Dict) -> int:
|
||||
)
|
||||
|
||||
|
||||
@flatdim.register(Text)
|
||||
def _flatdim_text(space: Text) -> int:
|
||||
return space.max_length
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
|
||||
|
||||
@@ -104,7 +110,9 @@ def flatten(space: Space[T], x: T) -> FlatType:
|
||||
x: The value to flatten
|
||||
|
||||
Returns:
|
||||
- The flattened ``x``, always returns a 1D array for non-graph spaces.
|
||||
- For ``Box`` and ``MultiBinary``, this is a flattened array
|
||||
- For ``Discrete`` and ``MultiDiscrete``, this is a flattened one-hot array of the sample
|
||||
- For ``Tuple`` and ``Dict``, this is a concatenated array the subspaces (does not support graph subspaces)
|
||||
- For graph spaces, returns `GraphInstance` where:
|
||||
- `nodes` are n x k arrays
|
||||
- `edges` are either:
|
||||
@@ -179,6 +187,16 @@ def _flatten_graph(space, x) -> GraphInstance:
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@flatten.register(Text)
|
||||
def _flatten_text(space: Text, x: str) -> np.ndarray:
|
||||
arr = np.full(
|
||||
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
|
||||
)
|
||||
for i, val in enumerate(x):
|
||||
arr[i] = space.character_index(val)
|
||||
return arr
|
||||
|
||||
|
||||
@flatten.register(Sequence)
|
||||
def _flatten_sequence(space, x) -> tuple:
|
||||
return tuple(flatten(space.feature_space, item) for item in x)
|
||||
@@ -284,6 +302,13 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@unflatten.register(Text)
|
||||
def _unflatten_text(space: Text, x: np.ndarray) -> str:
|
||||
return "".join(
|
||||
[space.character_list[val] for val in x if val < len(space.character_set)]
|
||||
)
|
||||
|
||||
|
||||
@unflatten.register(Sequence)
|
||||
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
|
||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||
@@ -401,6 +426,13 @@ def _flatten_space_graph(space: Graph) -> Graph:
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Text)
|
||||
def _flatten_space_text(space: Text) -> Box:
|
||||
return Box(
|
||||
low=0, high=len(space.character_set), shape=(space.max_length,), dtype=np.int32
|
||||
)
|
||||
|
||||
|
||||
@flatten_space.register(Sequence)
|
||||
def _flatten_space_sequence(space: Sequence) -> Sequence:
|
||||
return Sequence(flatten_space(space.feature_space))
|
||||
|
@@ -50,7 +50,9 @@ def data_equivalence(data_1, data_2) -> bool:
|
||||
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||
)
|
||||
elif isinstance(data_1, np.ndarray):
|
||||
return np.all(data_1 == data_2)
|
||||
return data_1.shape == data_2.shape and np.allclose(
|
||||
data_1, data_2, atol=0.00001
|
||||
)
|
||||
else:
|
||||
return data_1 == data_2
|
||||
else:
|
||||
|
@@ -4,26 +4,26 @@ import warnings
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gym.error
|
||||
from gym.spaces import Box
|
||||
|
||||
# Todo, move Box unique tests from test_spaces.py to test_box.py
|
||||
from gym.spaces.box import get_inf
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"box,expected_shape",
|
||||
[
|
||||
(
|
||||
Box(low=np.zeros(2), high=np.ones(2)),
|
||||
( # Test with same 1-dim low and high shape
|
||||
Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32),
|
||||
(2,),
|
||||
), # Test with same 1-dim low and high shape
|
||||
(
|
||||
Box(low=np.zeros((2, 1)), high=np.ones((2, 1))),
|
||||
),
|
||||
( # Test with same multi-dim low and high shape
|
||||
Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32),
|
||||
(2, 1),
|
||||
), # Test with same multi-dim low and high shape
|
||||
(
|
||||
),
|
||||
( # Test with scalar low high and different shape
|
||||
Box(low=0, high=1, shape=(5, 2)),
|
||||
(5, 2),
|
||||
), # Test with scalar low high and different shape
|
||||
),
|
||||
(Box(low=0, high=1), (1,)), # Test with int and int
|
||||
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
|
||||
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
|
||||
@@ -33,7 +33,8 @@ from gym.spaces import Box
|
||||
(Box(low=np.zeros(3), high=1.0), (3,)), # Test with array and scalar
|
||||
],
|
||||
)
|
||||
def test_box_shape_inference(box, expected_shape):
|
||||
def test_shape_inference(box, expected_shape):
|
||||
"""Test that the shape inference is as expected."""
|
||||
assert box.shape == expected_shape
|
||||
assert box.sample().shape == expected_shape
|
||||
|
||||
@@ -48,7 +49,7 @@ def test_box_shape_inference(box, expected_shape):
|
||||
(np.zeros(2, dtype=np.float32), True),
|
||||
(np.zeros((2, 2), dtype=np.float32), True),
|
||||
(np.inf, True),
|
||||
(np.nan, True), # This is a weird side
|
||||
(np.nan, True), # This is a weird case that we allow
|
||||
(True, False),
|
||||
(np.bool8(True), False),
|
||||
(1 + 1j, False),
|
||||
@@ -56,7 +57,8 @@ def test_box_shape_inference(box, expected_shape):
|
||||
("string", False),
|
||||
],
|
||||
)
|
||||
def test_box_values(value, valid):
|
||||
def test_low_high_values(value, valid: bool):
|
||||
"""Test what `low` and `high` values are valid for `Box` space."""
|
||||
if valid:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
Box(low=value, high=value)
|
||||
@@ -66,7 +68,9 @@ def test_box_values(value, valid):
|
||||
else:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"expect their types to be np\.ndarray, an integer or a float",
|
||||
match=re.escape(
|
||||
"expect their types to be np.ndarray, an integer or a float"
|
||||
),
|
||||
):
|
||||
Box(low=value, high=value)
|
||||
|
||||
@@ -135,6 +139,178 @@ def test_box_values(value, valid):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_box_errors(low, high, kwargs, error, message):
|
||||
def test_init_errors(low, high, kwargs, error, message):
|
||||
"""Test all constructor errors."""
|
||||
with pytest.raises(error, match=f"^{re.escape(message)}$"):
|
||||
Box(low=low, high=high, **kwargs)
|
||||
|
||||
|
||||
def test_dtype_check():
|
||||
"""Tests the Box contains function with different dtypes."""
|
||||
# Related Issues:
|
||||
# https://github.com/openai/gym/issues/2357
|
||||
# https://github.com/openai/gym/issues/2298
|
||||
|
||||
space = Box(0, 1, (), dtype=np.float32)
|
||||
|
||||
# casting will match the correct type
|
||||
assert np.array(0.5, dtype=np.float32) in space
|
||||
|
||||
# float16 is in float32 space
|
||||
assert np.array(0.5, dtype=np.float16) in space
|
||||
|
||||
# float64 is not in float32 space
|
||||
assert np.array(0.5, dtype=np.float64) not in space
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space",
|
||||
[
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
|
||||
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
|
||||
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
|
||||
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
|
||||
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
|
||||
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
|
||||
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
|
||||
],
|
||||
)
|
||||
def test_infinite_space(space):
|
||||
"""
|
||||
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
|
||||
are both modified within the init, we check for infinite when we know it's not 0
|
||||
"""
|
||||
|
||||
assert np.all(
|
||||
space.low < space.high
|
||||
), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})"
|
||||
|
||||
space.seed(0)
|
||||
sample = space.sample()
|
||||
|
||||
# check if space contains sample
|
||||
assert (
|
||||
sample in space
|
||||
), f"Sample ({sample}) not inside space according to `space.contains()`"
|
||||
|
||||
# manually check that the sign of the sample is within the bounds
|
||||
assert np.all(
|
||||
np.sign(sample) <= np.sign(space.high)
|
||||
), f"Sign of sample ({sample}) is less than space upper bound ({space.high})"
|
||||
assert np.all(
|
||||
np.sign(space.low) <= np.sign(sample)
|
||||
), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"
|
||||
|
||||
# check that int bounds are bounded for everything
|
||||
# but floats are unbounded for infinite
|
||||
if np.any(space.high != 0):
|
||||
assert (
|
||||
space.is_bounded("above") is False
|
||||
), "inf upper bound supposed to be unbounded"
|
||||
else:
|
||||
assert (
|
||||
space.is_bounded("above") is True
|
||||
), "non-inf upper bound supposed to be bounded"
|
||||
|
||||
if np.any(space.low != 0):
|
||||
assert (
|
||||
space.is_bounded("below") is False
|
||||
), "inf lower bound supposed to be unbounded"
|
||||
else:
|
||||
assert (
|
||||
space.is_bounded("below") is True
|
||||
), "non-inf lower bound supposed to be bounded"
|
||||
|
||||
if np.any(space.low != 0) or np.any(space.high != 0):
|
||||
assert space.is_bounded("both") is False
|
||||
else:
|
||||
assert space.is_bounded("both") is True
|
||||
|
||||
# check for dtype
|
||||
assert (
|
||||
space.high.dtype == space.dtype
|
||||
), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
||||
assert (
|
||||
space.low.dtype == space.dtype
|
||||
), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:"
|
||||
):
|
||||
space.is_bounded("test")
|
||||
|
||||
|
||||
def test_legacy_state_pickling():
|
||||
legacy_state = {
|
||||
"dtype": np.dtype("float32"),
|
||||
"_shape": (5,),
|
||||
"low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
||||
"bounded_below": np.array([True, True, True, True, True]),
|
||||
"bounded_above": np.array([True, True, True, True, True]),
|
||||
"_np_random": None,
|
||||
}
|
||||
|
||||
b = Box(-1, 1, ())
|
||||
assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
|
||||
del b.__dict__["low_repr"]
|
||||
del b.__dict__["high_repr"]
|
||||
assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__
|
||||
|
||||
b.__setstate__(legacy_state)
|
||||
assert b.low_repr == "0.0"
|
||||
assert b.high_repr == "1.0"
|
||||
|
||||
|
||||
def test_get_inf():
|
||||
"""Tests that get inf function works as expected, primarily for coverage."""
|
||||
assert get_inf(np.float32, "+") == np.inf
|
||||
assert get_inf(np.float16, "-") == -np.inf
|
||||
with pytest.raises(
|
||||
TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
|
||||
):
|
||||
get_inf(np.float32, "*")
|
||||
|
||||
assert get_inf(np.int16, "+") == 32765
|
||||
assert get_inf(np.int8, "-") == -126
|
||||
with pytest.raises(
|
||||
TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
|
||||
):
|
||||
get_inf(np.int32, "*")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("Unknown dtype <class 'numpy.complex128'> for infinite bounds"),
|
||||
):
|
||||
get_inf(np.complex_, "+")
|
||||
|
||||
|
||||
def test_sample_mask():
|
||||
"""Box cannot have a mask applied."""
|
||||
space = Box(0, 1)
|
||||
with pytest.raises(
|
||||
gym.error.Error,
|
||||
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
|
||||
):
|
||||
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))
|
||||
|
40
tests/spaces/test_discrete.py
Normal file
40
tests/spaces/test_discrete.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Discrete
|
||||
|
||||
|
||||
def test_space_legacy_pickling():
|
||||
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
|
||||
legacy_state = {
|
||||
"shape": (
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
),
|
||||
"dtype": np.int64,
|
||||
"np_random": np.random.default_rng(),
|
||||
"n": 3,
|
||||
}
|
||||
space = Discrete(1)
|
||||
space.__setstate__(legacy_state)
|
||||
|
||||
assert space.shape == legacy_state["shape"]
|
||||
assert space.np_random == legacy_state["np_random"]
|
||||
assert space.n == 3
|
||||
assert space.dtype == legacy_state["dtype"]
|
||||
|
||||
# Test that start is missing
|
||||
assert "start" in space.__dict__
|
||||
del space.__dict__["start"] # legacy did not include start param
|
||||
assert "start" not in space.__dict__
|
||||
|
||||
space.__setstate__(legacy_state)
|
||||
assert space.start == 0
|
||||
|
||||
|
||||
def test_sample_mask():
|
||||
space = Discrete(4, start=2)
|
||||
assert 2 <= space.sample() < 6
|
||||
assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
|
||||
assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
|
||||
assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]
|
3
tests/spaces/test_multibinary.py
Normal file
3
tests/spaces/test_multibinary.py
Normal file
@@ -0,0 +1,3 @@
|
||||
def test_sample():
|
||||
# todo
|
||||
pass
|
66
tests/spaces/test_multidiscrete.py
Normal file
66
tests/spaces/test_multidiscrete.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
|
||||
from gym.spaces import Discrete, MultiDiscrete
|
||||
from gym.utils.env_checker import data_equivalence
|
||||
|
||||
|
||||
def test_multidiscrete_as_tuple():
|
||||
# 1D multi-discrete
|
||||
space = MultiDiscrete([3, 4, 5])
|
||||
|
||||
assert space.shape == (3,)
|
||||
assert space[0] == Discrete(3)
|
||||
assert space[0:1] == MultiDiscrete([3])
|
||||
assert space[0:2] == MultiDiscrete([3, 4])
|
||||
assert space[:] == space and space[:] is not space
|
||||
|
||||
# 2D multi-discrete
|
||||
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
||||
|
||||
assert space.shape == (2, 3)
|
||||
assert space[0, 1] == Discrete(4)
|
||||
assert space[0] == MultiDiscrete([3, 4, 5])
|
||||
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
|
||||
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
|
||||
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
|
||||
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
|
||||
assert space[:] == space and space[:] is not space
|
||||
assert space[:, :] == space and space[:, :] is not space
|
||||
|
||||
|
||||
def test_multidiscrete_subspace_reproducibility():
|
||||
# 1D multi-discrete
|
||||
space = MultiDiscrete([100, 200, 300])
|
||||
space.seed()
|
||||
|
||||
assert data_equivalence(space[0].sample(), space[0].sample())
|
||||
assert data_equivalence(space[0:1].sample(), space[0:1].sample())
|
||||
assert data_equivalence(space[0:2].sample(), space[0:2].sample())
|
||||
assert data_equivalence(space[:].sample(), space[:].sample())
|
||||
assert data_equivalence(space[:].sample(), space.sample())
|
||||
|
||||
# 2D multi-discrete
|
||||
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
|
||||
space.seed()
|
||||
|
||||
assert data_equivalence(space[0, 1].sample(), space[0, 1].sample())
|
||||
assert data_equivalence(space[0].sample(), space[0].sample())
|
||||
assert data_equivalence(space[0:1].sample(), space[0:1].sample())
|
||||
assert data_equivalence(space[0:2, :].sample(), space[0:2, :].sample())
|
||||
assert data_equivalence(space[:, 0:1].sample(), space[:, 0:1].sample())
|
||||
assert data_equivalence(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
|
||||
assert data_equivalence(space[:].sample(), space[:].sample())
|
||||
assert data_equivalence(space[:, :].sample(), space[:, :].sample())
|
||||
assert data_equivalence(space[:, :].sample(), space.sample())
|
||||
|
||||
|
||||
def test_multidiscrete_length():
|
||||
space = MultiDiscrete(nvec=[3, 2, 4])
|
||||
assert len(space) == 3
|
||||
|
||||
space = MultiDiscrete(nvec=[[2, 3], [3, 2]])
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="Getting the length of a multi-dimensional MultiDiscrete space.",
|
||||
):
|
||||
assert len(space) == 2
|
24
tests/spaces/test_space.py
Normal file
24
tests/spaces/test_space.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
|
||||
from gym import Space
|
||||
from gym.spaces import utils
|
||||
|
||||
TESTING_SPACE = Space()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
TESTING_SPACE.sample,
|
||||
partial(TESTING_SPACE.contains, None),
|
||||
partial(utils.flatdim, TESTING_SPACE),
|
||||
partial(utils.flatten, TESTING_SPACE, None),
|
||||
partial(utils.flatten_space, TESTING_SPACE),
|
||||
partial(utils.unflatten, TESTING_SPACE, None),
|
||||
],
|
||||
)
|
||||
def test_not_implemented_errors(func):
|
||||
with pytest.raises(NotImplementedError):
|
||||
func()
|
File diff suppressed because it is too large
Load Diff
41
tests/spaces/test_text.py
Normal file
41
tests/spaces/test_text.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import Text
|
||||
|
||||
|
||||
def test_sample_mask():
|
||||
space = Text(min_length=1, max_length=5)
|
||||
|
||||
# Test the sample length
|
||||
sample = space.sample(mask=(3, None))
|
||||
assert sample in space
|
||||
assert len(sample) == 3
|
||||
|
||||
sample = space.sample(mask=None)
|
||||
assert sample in space
|
||||
assert 1 <= len(sample) <= 5
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"Trying to sample with a minimum length > 0 (1) but the character mask is all zero meaning that no character could be sampled."
|
||||
),
|
||||
):
|
||||
space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8)))
|
||||
|
||||
space = Text(min_length=0, max_length=5)
|
||||
sample = space.sample(
|
||||
mask=(None, np.zeros(len(space.character_set), dtype=np.int8))
|
||||
)
|
||||
assert sample in space
|
||||
assert sample == ""
|
||||
|
||||
# Test the sample characters
|
||||
space = Text(max_length=5, charset="abcd")
|
||||
|
||||
sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8)))
|
||||
assert sample in space
|
||||
assert sample == "bbb"
|
@@ -1,601 +1,135 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from itertools import zip_longest
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Tuple,
|
||||
utils,
|
||||
import gym
|
||||
from gym.spaces import Box, Graph, utils
|
||||
from gym.utils.env_checker import data_equivalence
|
||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
||||
|
||||
TESTING_SPACES_EXPECTED_FLATDIMS = [
|
||||
# Discrete
|
||||
3,
|
||||
3,
|
||||
# Box
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
# Multi-discrete
|
||||
4,
|
||||
10,
|
||||
# Multi-binary
|
||||
8,
|
||||
6,
|
||||
# Text
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
# # Tuple
|
||||
# 9,
|
||||
# 7,
|
||||
# 10,
|
||||
# 6,
|
||||
# None,
|
||||
# # Dict
|
||||
# 7,
|
||||
# 8,
|
||||
# 17,
|
||||
# None,
|
||||
# # Graph
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# # Sequence
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "flatdim"],
|
||||
zip_longest(TESTING_SPACES, TESTING_SPACES_EXPECTED_FLATDIMS),
|
||||
ids=TESTING_SPACES_IDS,
|
||||
)
|
||||
|
||||
homogeneous_spaces = [
|
||||
Discrete(3),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2), dtype=np.float16),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64),
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 10]),
|
||||
MultiBinary(10),
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64
|
||||
),
|
||||
}
|
||||
),
|
||||
Discrete(3, start=2),
|
||||
Discrete(8, start=-5),
|
||||
]
|
||||
|
||||
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
|
||||
|
||||
non_homogenous_spaces = [
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), #
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), #
|
||||
Graph(node_space=Discrete(5), edge_space=None), #
|
||||
Sequence(Discrete(4)), #
|
||||
Sequence(Box(-10, 10, shape=(2, 2))), #
|
||||
Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), #
|
||||
Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), #
|
||||
Dict(
|
||||
a=Graph(node_space=Discrete(4), edge_space=Discrete(4)),
|
||||
b=Box(-10, 10, shape=(2, 2)),
|
||||
), #
|
||||
Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), #
|
||||
Tuple(
|
||||
[
|
||||
Graph(node_space=Discrete(4), edge_space=Discrete(4)),
|
||||
Box(-10, 10, shape=(2, 2)),
|
||||
]
|
||||
), #
|
||||
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), #
|
||||
Dict(
|
||||
a=Dict(
|
||||
a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2))
|
||||
),
|
||||
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
|
||||
), #
|
||||
Dict(
|
||||
a=Dict(
|
||||
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
|
||||
b=Box(-100, 100, shape=(2, 2)),
|
||||
),
|
||||
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", non_homogenous_spaces)
|
||||
def test_non_flattenable(space):
|
||||
assert space.is_np_flattenable is False
|
||||
def test_flatdim(space: gym.spaces.Space, flatdim: Optional[int]):
|
||||
"""Checks that the flattened dims of the space is equal to an expected value."""
|
||||
if space.is_np_flattenable:
|
||||
dim = utils.flatdim(space)
|
||||
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
|
||||
else:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||
),
|
||||
):
|
||||
utils.flatdim(space)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
|
||||
def test_flatdim(space, flatdim):
|
||||
assert space.is_np_flattenable
|
||||
dim = utils.flatdim(space)
|
||||
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces)
|
||||
def test_flatten_space_boxes(space):
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flatten_space(space):
|
||||
"""Test that the flattened spaces are a box and have the `flatdim` shape."""
|
||||
flat_space = utils.flatten_space(space)
|
||||
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
|
||||
flatdim = utils.flatdim(space)
|
||||
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(flat_space, Box)
|
||||
(single_dim,) = flat_space.shape
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
|
||||
def test_flat_space_contains_flat_points(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
flat_space = utils.flatten_space(space)
|
||||
for i, flat_sample in enumerate(flattened_samples):
|
||||
assert flat_space.contains(
|
||||
flat_sample
|
||||
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces)
|
||||
def test_flatten_dim(space):
|
||||
sample = utils.flatten(space, space.sample())
|
||||
(single_dim,) = sample.shape
|
||||
flatdim = utils.flatdim(space)
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
assert single_dim == flatdim
|
||||
elif isinstance(flat_space, Graph):
|
||||
assert isinstance(space, Graph)
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
|
||||
def test_flatten_roundtripping(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
roundtripped_samples = [
|
||||
utils.unflatten(space, sample) for sample in flattened_samples
|
||||
]
|
||||
for i, (original, roundtripped) in enumerate(
|
||||
zip(some_samples, roundtripped_samples)
|
||||
):
|
||||
assert compare_nested(
|
||||
original, roundtripped
|
||||
), f"Expected sample #{i} {original} to equal {roundtripped}"
|
||||
assert space.contains(roundtripped)
|
||||
(node_single_dim,) = flat_space.node_space.shape
|
||||
node_flatdim = utils.flatdim(space.node_space)
|
||||
assert node_single_dim == node_flatdim
|
||||
|
||||
|
||||
def compare_nested(left, right):
|
||||
if type(left) != type(right):
|
||||
return False
|
||||
elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
|
||||
return left.shape == right.shape and np.allclose(left, right)
|
||||
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
||||
res = len(left) == len(right)
|
||||
for ((left_key, left_value), (right_key, right_value)) in zip(
|
||||
left.items(), right.items()
|
||||
):
|
||||
if not res:
|
||||
return False
|
||||
res = left_key == right_key and compare_nested(left_value, right_value)
|
||||
return res
|
||||
elif isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
|
||||
res = len(left) == len(right)
|
||||
for (x, y) in zip(left, right):
|
||||
if not res:
|
||||
return False
|
||||
res = compare_nested(x, y)
|
||||
return res
|
||||
if flat_space.edge_space is not None:
|
||||
(edge_single_dim,) = flat_space.edge_space.shape
|
||||
edge_flatdim = utils.flatdim(space.edge_space)
|
||||
assert edge_single_dim == edge_flatdim
|
||||
else:
|
||||
return left == right
|
||||
|
||||
|
||||
"""
|
||||
Expecteded flattened types are based off:
|
||||
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
|
||||
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
|
||||
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
|
||||
internally by numpy when np.concatenate is called.
|
||||
"""
|
||||
|
||||
expected_flattened_dtypes = [
|
||||
np.int64,
|
||||
np.float32,
|
||||
np.float16,
|
||||
np.int64,
|
||||
np.float64,
|
||||
np.int64,
|
||||
np.int64,
|
||||
np.int8,
|
||||
np.float64,
|
||||
np.int64,
|
||||
np.int64,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["original_space", "expected_flattened_dtype"],
|
||||
zip(homogeneous_spaces, expected_flattened_dtypes),
|
||||
)
|
||||
def test_dtypes(original_space, expected_flattened_dtype):
|
||||
flattened_space = utils.flatten_space(original_space)
|
||||
|
||||
original_sample = original_space.sample()
|
||||
flattened_sample = utils.flatten(original_space, original_sample)
|
||||
unflattened_sample = utils.unflatten(original_space, flattened_sample)
|
||||
|
||||
assert flattened_space.contains(
|
||||
flattened_sample
|
||||
), "Expected flattened_space to contain flattened_sample"
|
||||
assert (
|
||||
flattened_space.dtype == expected_flattened_dtype
|
||||
), f"Expected flattened_space's dtype to equal {expected_flattened_dtype}"
|
||||
|
||||
assert flattened_sample.dtype == flattened_space.dtype, (
|
||||
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "
|
||||
)
|
||||
|
||||
compare_sample_types(original_space, original_sample, unflattened_sample)
|
||||
|
||||
|
||||
def compare_sample_types(original_space, original_sample, unflattened_sample):
|
||||
if isinstance(original_space, Discrete):
|
||||
assert isinstance(unflattened_sample, int), (
|
||||
"Expected unflattened_sample to be an int. unflattened_sample: "
|
||||
"{} original_sample: {}".format(unflattened_sample, original_sample)
|
||||
)
|
||||
elif isinstance(original_space, Tuple):
|
||||
for index in range(len(original_space)):
|
||||
compare_sample_types(
|
||||
original_space.spaces[index],
|
||||
original_sample[index],
|
||||
unflattened_sample[index],
|
||||
)
|
||||
elif isinstance(original_space, Dict):
|
||||
for key, space in original_space.spaces.items():
|
||||
compare_sample_types(space, original_sample[key], unflattened_sample[key])
|
||||
else:
|
||||
assert unflattened_sample.dtype == original_sample.dtype, (
|
||||
"Expected unflattened_sample's dtype to equal "
|
||||
"original_sample's dtype. unflattened_sample: "
|
||||
"{} original_sample: {}".format(unflattened_sample, original_sample)
|
||||
assert isinstance(
|
||||
space, (gym.spaces.Tuple, gym.spaces.Dict, gym.spaces.Sequence)
|
||||
)
|
||||
|
||||
|
||||
homogeneous_samples = [
|
||||
2,
|
||||
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32),
|
||||
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16),
|
||||
(3, 7),
|
||||
(2, np.array([0.5, 3.5], dtype=np.float32)),
|
||||
(3, 0, 1),
|
||||
np.array([0, 1, 7], dtype=np.int64),
|
||||
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
|
||||
OrderedDict(
|
||||
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
|
||||
),
|
||||
3,
|
||||
-2,
|
||||
]
|
||||
|
||||
|
||||
expected_flattened_hom_samples = [
|
||||
np.array([0, 0, 1], dtype=np.int64),
|
||||
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32),
|
||||
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16),
|
||||
np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
|
||||
np.array([0, 0, 1, 0, 0, 0.5, 3.5], dtype=np.float64),
|
||||
np.array([0, 0, 0, 1, 0, 1, 0, 0, 1], dtype=np.int64),
|
||||
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
|
||||
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
|
||||
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
|
||||
np.array([0, 1, 0], dtype=np.int64),
|
||||
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
|
||||
]
|
||||
|
||||
non_homogenous_samples = [
|
||||
GraphInstance(
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([0, 1], dtype=int),
|
||||
np.array([[[1, 2], [3, 4]]], dtype=np.float32),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)),
|
||||
(0, 1, 2),
|
||||
(
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
np.array([[4, 5], [6, 7]], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
(np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)),
|
||||
(np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)),
|
||||
),
|
||||
OrderedDict(
|
||||
[("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array([1, 2], dtype=np.int),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([[0, 1], [2, 3]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([1, 2], dtype=np.int),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32),
|
||||
edges=np.array([0], dtype=int),
|
||||
edge_links=np.array([[0, 1]]),
|
||||
),
|
||||
GraphInstance(
|
||||
nodes=np.array(
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32
|
||||
),
|
||||
edges=np.array([1], dtype=int),
|
||||
edge_links=np.array([[0, 1]]),
|
||||
),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
np.array([[4, 5], [6, 7]], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
"b",
|
||||
(
|
||||
np.array([12, 13], dtype=np.float32),
|
||||
np.array([14, 15], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array(
|
||||
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
|
||||
dtype=np.float32,
|
||||
),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
"b",
|
||||
(
|
||||
np.array([12, 13], dtype=np.float32),
|
||||
np.array([14, 15], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
expected_flattened_non_hom_samples = [
|
||||
GraphInstance(
|
||||
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
|
||||
np.array([[1, 2, 3, 4]], dtype=np.float32),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int),
|
||||
np.array([[1, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
(
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0]]),
|
||||
np.array([[0, 1]]),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32),
|
||||
np.array([[0, 1, 0, 0]]),
|
||||
np.array([[0, 1]]),
|
||||
),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array(
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32
|
||||
),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "sample", "expected_flattened_sample"],
|
||||
zip(
|
||||
homogeneous_spaces + non_homogenous_spaces,
|
||||
homogeneous_samples + non_homogenous_samples,
|
||||
expected_flattened_hom_samples + expected_flattened_non_hom_samples,
|
||||
),
|
||||
)
|
||||
def test_flatten(space, sample, expected_flattened_sample):
|
||||
flattened_sample = utils.flatten(space, sample)
|
||||
flat_space = utils.flatten_space(space)
|
||||
|
||||
assert sample in space
|
||||
assert flattened_sample in flat_space
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flatten(space):
|
||||
"""Test that a flattened sample have the `flatdim` shape."""
|
||||
flattened_sample = utils.flatten(space, space.sample())
|
||||
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(flattened_sample, np.ndarray)
|
||||
assert flattened_sample.shape == expected_flattened_sample.shape
|
||||
assert flattened_sample.dtype == expected_flattened_sample.dtype
|
||||
assert np.all(flattened_sample == expected_flattened_sample)
|
||||
(single_dim,) = flattened_sample.shape
|
||||
flatdim = utils.flatdim(space)
|
||||
|
||||
assert single_dim == flatdim
|
||||
else:
|
||||
assert not isinstance(flattened_sample, np.ndarray)
|
||||
assert compare_nested(flattened_sample, expected_flattened_sample)
|
||||
assert isinstance(flattened_sample, (tuple, dict, Graph))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "flattened_sample", "expected_sample"],
|
||||
zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples),
|
||||
)
|
||||
def test_unflatten(space, flattened_sample, expected_sample):
|
||||
sample = utils.unflatten(space, flattened_sample)
|
||||
assert compare_nested(sample, expected_sample)
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flat_space_contains_flat_points(space):
|
||||
"""Test that the flattened samples are contained within the flattened space."""
|
||||
flattened_samples = [utils.flatten(space, space.sample()) for _ in range(10)]
|
||||
flat_space = utils.flatten_space(space)
|
||||
|
||||
for flat_sample in flattened_samples:
|
||||
assert flat_sample in flat_space
|
||||
|
||||
|
||||
expected_flattened_spaces = [
|
||||
Box(low=0, high=1, shape=(3,), dtype=np.int64),
|
||||
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float32),
|
||||
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float16),
|
||||
Box(low=0, high=1, shape=(15,), dtype=np.int64),
|
||||
Box(
|
||||
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
|
||||
dtype=np.float64,
|
||||
),
|
||||
Box(low=0, high=1, shape=(9,), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(14,), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(10,), dtype=np.int8),
|
||||
Box(
|
||||
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
|
||||
dtype=np.float64,
|
||||
),
|
||||
Box(low=0, high=1, shape=(3,), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(8,), dtype=np.int64),
|
||||
]
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flatten_roundtripping(space):
|
||||
"""Tests roundtripping with flattening and unflattening are equal to the original sample."""
|
||||
samples = [space.sample() for _ in range(10)]
|
||||
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in samples]
|
||||
unflattened_samples = [
|
||||
utils.unflatten(space, sample) for sample in flattened_samples
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "expected_flattened_space"],
|
||||
zip(homogeneous_spaces, expected_flattened_spaces),
|
||||
)
|
||||
def test_flatten_space(space, expected_flattened_space):
|
||||
flattened_space = utils.flatten_space(space)
|
||||
assert flattened_space == expected_flattened_space
|
||||
for original, roundtripped in zip(samples, unflattened_samples):
|
||||
assert data_equivalence(original, roundtripped)
|
||||
|
27
tests/spaces/utils.py
Normal file
27
tests/spaces/utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
|
||||
|
||||
TESTING_FUNDAMENTAL_SPACES = [
|
||||
Discrete(3),
|
||||
Discrete(3, start=-1),
|
||||
Box(low=0.0, high=1.0),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64),
|
||||
Box(low=-np.inf, high=0.0, shape=(2, 1)),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 1)),
|
||||
MultiDiscrete([2, 2]),
|
||||
MultiDiscrete([[2, 3], [3, 2]]),
|
||||
MultiBinary(8),
|
||||
MultiBinary([2, 3]),
|
||||
Text(6),
|
||||
Text(min_length=3, max_length=6),
|
||||
Text(6, charset="abcdef"),
|
||||
]
|
||||
TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
|
||||
|
||||
|
||||
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES # + TESTING_COMPOSITE_SPACES
|
||||
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS # + TESTING_COMPOSITE_SPACES_IDS
|
Reference in New Issue
Block a user