Add testing for Fundamental spaces with full coverage (#3048)

* initial commit

* Fix the multi-binary sample and upgrade multi-discrete sample

* Fix MultiBinary sample tests and pre-commit

* Adds coverage tests and updates test_utils.py to use the utils.py spaces. Fix bug in text.py

* pre-commit
This commit is contained in:
Mark Towers
2022-09-03 22:56:29 +01:00
committed by GitHub
parent 43b42d5280
commit 8e74fe3b62
17 changed files with 933 additions and 1456 deletions

View File

@@ -164,7 +164,9 @@ class Box(Space[np.ndarray]):
elif manner == "above":
return above
else:
raise ValueError("manner is not in {'below', 'above', 'both'}")
raise ValueError(
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
)
def sample(self, mask: None = None) -> np.ndarray:
r"""Generates a single random sample inside the Box.
@@ -223,7 +225,10 @@ class Box(Space[np.ndarray]):
"""Return boolean specifying if x is a valid member of this space."""
if not isinstance(x, np.ndarray):
logger.warn("Casting input x to numpy array.")
try:
x = np.asarray(x, dtype=self.dtype)
except (ValueError, TypeError):
return False
return bool(
np.can_cast(x.dtype, self.dtype)
@@ -236,7 +241,7 @@ class Box(Space[np.ndarray]):
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> List[np.ndarray]:
def from_jsonable(self, sample_n: Sequence[Union[float, int]]) -> List[np.ndarray]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n]
@@ -252,10 +257,11 @@ class Box(Space[np.ndarray]):
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
def __eq__(self, other) -> bool:
"""Check whether `other` is equivalent to this instance."""
"""Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
return (
isinstance(other, Box)
and (self.shape == other.shape)
# and (self.dtype == other.dtype)
and np.allclose(self.low, other.low)
and np.allclose(self.high, other.high)
)

View File

@@ -85,18 +85,19 @@ class Discrete(Space[int]):
if isinstance(x, int):
as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
np.issubdtype(x.dtype, np.integer) and x.shape == ()
):
as_int = int(x) # type: ignore
else:
return False
return self.start <= as_int < self.start + self.n
def __repr__(self) -> str:
"""Gives a string representation of this space."""
if self.start != 0:
return "Discrete(%d, start=%d)" % (self.n, self.start)
return "Discrete(%d)" % self.n
return f"Discrete({self.n}, start={self.start})"
return f"Discrete({self.n})"
def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
@@ -114,8 +115,6 @@ class Discrete(Space[int]):
Args:
state: The new state
"""
super().__setstate__(state)
# Don't mutate the original state
state = dict(state)
@@ -124,5 +123,4 @@ class Discrete(Space[int]):
if "start" not in state:
state["start"] = 0
# Update our state
self.__dict__.update(state)
super().__setstate__(state)

View File

@@ -6,7 +6,7 @@ import numpy as np
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import SAMPLE_MASK_TYPE, MultiDiscrete
from gym.spaces.multi_discrete import MultiDiscrete
from gym.spaces.space import Space
@@ -97,8 +97,8 @@ class Graph(Space):
self,
mask: Optional[
Tuple[
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
Optional[Union[np.ndarray, SAMPLE_MASK_TYPE]],
Optional[Union[np.ndarray, tuple]],
Optional[Union[np.ndarray, tuple]],
]
] = None,
num_nodes: int = 10,

View File

@@ -62,7 +62,8 @@ class MultiBinary(Space[np.ndarray]):
Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
Where mask == 0 then the samples will be 0.
For mask == 0 then the samples will be 0 and mask == 1 then random samples will be generated.
The expected mask shape is the space shape and mask dtype is `np.int8`.
Returns:
Sampled values from space
@@ -78,11 +79,13 @@ class MultiBinary(Space[np.ndarray]):
mask.shape == self.shape
), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}"
assert np.all(
np.logical_or(mask == 0, mask == 1)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
(mask == 0) | (mask == 1) | (mask == 2)
), f"All values of a mask should be 0, 1 or 2, actual values: {mask}"
return mask * self.np_random.integers(
low=0, high=2, size=self.n, dtype=self.dtype
return np.where(
mask == 2,
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
mask.astype(self.dtype),
)
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
@@ -91,9 +94,12 @@ class MultiBinary(Space[np.ndarray]):
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape:
return False
return ((x == 0) | (x == 1)).all()
return bool(
isinstance(x, np.ndarray)
and self.shape == x.shape
and np.all((x == 0) | (x == 1))
)
def to_jsonable(self, sample_n) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
@@ -101,7 +107,7 @@ class MultiBinary(Space[np.ndarray]):
def from_jsonable(self, sample_n) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n]
return [np.asarray(sample, self.dtype) for sample in sample_n]
def __repr__(self) -> str:
"""Gives a string representation of this space."""

View File

@@ -7,8 +7,6 @@ from gym import logger
from gym.spaces.discrete import Discrete
from gym.spaces.space import Space
SAMPLE_MASK_TYPE = Tuple[Union["SAMPLE_MASK_TYPE", np.ndarray], ...]
class MultiDiscrete(Space[np.ndarray]):
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
@@ -39,7 +37,7 @@ class MultiDiscrete(Space[np.ndarray]):
def __init__(
self,
nvec: Union[np.ndarray, List[int]],
nvec: Union[np.ndarray, list],
dtype=np.int64,
seed: Optional[Union[int, np.random.Generator]] = None,
):
@@ -68,7 +66,7 @@ class MultiDiscrete(Space[np.ndarray]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: Optional[SAMPLE_MASK_TYPE] = None) -> np.ndarray:
def sample(self, mask: Optional[tuple] = None) -> np.ndarray:
"""Generates a single random sample this space.
Args:
@@ -82,15 +80,30 @@ class MultiDiscrete(Space[np.ndarray]):
if mask is not None:
def _apply_mask(
sub_mask: SAMPLE_MASK_TYPE, sub_nvec: np.ndarray
sub_mask: Union[np.ndarray, tuple],
sub_nvec: Union[np.ndarray, np.integer],
) -> Union[int, List[int]]:
if isinstance(sub_mask, np.ndarray):
if isinstance(sub_nvec, np.ndarray):
assert isinstance(
sub_mask, tuple
), f"Expects the mask to be a tuple for sub_nvec ({sub_nvec}), actual type: {type(sub_mask)}"
assert len(sub_mask) == len(
sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
return [
_apply_mask(new_mask, new_nvec)
for new_mask, new_nvec in zip(sub_mask, sub_nvec)
]
else:
assert np.issubdtype(
type(sub_nvec), np.integer
), f"Expects the mask to be for an action, actual for {sub_nvec}"
), f"Expects the sub_nvec to be an action, actually: {sub_nvec}, {type(sub_nvec)}"
assert isinstance(
sub_mask, np.ndarray
), f"Expects the sub mask to be np.ndarray, actual type: {type(sub_mask)}"
assert (
len(sub_mask) == sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {sub_nvec}"
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, action: {sub_nvec}"
assert (
sub_mask.dtype == np.int8
), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}"
@@ -104,17 +117,6 @@ class MultiDiscrete(Space[np.ndarray]):
return self.np_random.choice(np.where(valid_action_mask)[0])
else:
return 0
else:
assert isinstance(
sub_mask, tuple
), f"Expects the mask to be a tuple or np.ndarray, actual type: {type(sub_mask)}"
assert len(sub_mask) == len(
sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
return [
_apply_mask(new_mask, new_nvec)
for new_mask, new_nvec in zip(sub_mask, sub_nvec)
]
return np.array(_apply_mask(mask, self.nvec), dtype=self.dtype)
@@ -124,9 +126,16 @@ class MultiDiscrete(Space[np.ndarray]):
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
# is within correct bounds for space dtype (even though x does not have to be unsigned)
return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all())
return bool(
isinstance(x, np.ndarray)
and x.shape == self.shape
and x.dtype != object
and np.all(0 <= x)
and np.all(x < self.nvec)
)
def to_jsonable(self, sample_n: Iterable[np.ndarray]):
"""Convert a batch of samples from this space to a JSONable data type."""
@@ -147,13 +156,18 @@ class MultiDiscrete(Space[np.ndarray]):
subspace = Discrete(nvec)
else:
subspace = MultiDiscrete(nvec, self.dtype) # type: ignore
# you don't need to deepcopy as np random generator call replaces the state not the data
subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
return subspace
def __len__(self):
"""Gives the ``len`` of samples from this space."""
if self.nvec.ndim >= 2:
logger.warn("Get length of a multi-dimensional MultiDiscrete space.")
logger.warn(
"Getting the length of a multi-dimensional MultiDiscrete space."
)
return len(self.nvec)
def __eq__(self, other):

View File

@@ -1,5 +1,5 @@
"""Implementation of a space that represents textual strings."""
from typing import Any, FrozenSet, List, Optional, Set, Tuple, Union
from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
import numpy as np
@@ -27,7 +27,7 @@ class Text(Space[str]):
self,
max_length: int,
*,
min_length: int = 0,
min_length: int = 1,
charset: Union[Set[str], str] = alphanumeric,
seed: Optional[Union[int, np.random.Generator]] = None,
):
@@ -36,9 +36,9 @@ class Text(Space[str]):
Both bounds for text length are inclusive.
Args:
min_length (int): Minimum text length (in characters).
min_length (int): Minimum text length (in characters). Defaults to 1 to prevent empty strings.
max_length (int): Maximum text length (in characters).
charset (Union[set, SupportsIndex]): Character set, defaults to the lower and upper english alphabet plus latin digits.
charset (Union[set], str): Character set, defaults to the lower and upper english alphabet plus latin digits.
seed: The seed for sampling from the space.
"""
assert np.issubdtype(
@@ -56,9 +56,13 @@ class Text(Space[str]):
self.min_length: int = int(min_length)
self.max_length: int = int(max_length)
self.charset: FrozenSet[str] = frozenset(charset)
self._charlist: List[str] = list(charset)
self._charset_str: str = "".join(sorted(self._charlist))
self._char_set: FrozenSet[str] = frozenset(charset)
self._char_list: Tuple[str, ...] = tuple(charset)
self._char_index: Dict[str, np.int32] = {
val: np.int32(i) for i, val in enumerate(tuple(charset))
}
self._char_str: str = "".join(sorted(tuple(charset)))
# As the shape is dynamic (between min_length and max_length) then None
super().__init__(dtype=str, seed=seed)
@@ -71,20 +75,42 @@ class Text(Space[str]):
Args:
mask: An optional tuples of length and mask for the text.
The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected.
For the mask, we expect a numpy array of length of the charset passed with dtype == np.int8
For the mask, we expect a numpy array of length of the charset passed with `dtype == np.int8`.
If the charlist mask is all zero then an empty string is returned no matter the `min_length`
Returns:
A sampled string from the space
"""
if mask is not None:
assert isinstance(
mask, tuple
), f"Expects the mask type to be a tuple, actual type: {type(mask)}"
assert (
len(mask) == 2
), f"Expects the mask length to be two, actual length: {len(mask)}"
length, charlist_mask = mask
if length is not None:
assert self.min_length <= length <= self.max_length
assert np.issubdtype(
type(length), np.integer
), f"Expects the Text sample length to be an integer, actual type: {type(length)}"
assert (
self.min_length <= length <= self.max_length
), f"Expects the Text sample length be between {self.min_length} and {self.max_length}, actual length: {length}"
if charlist_mask is not None:
assert isinstance(charlist_mask, np.ndarray)
assert charlist_mask.dtype is np.int8
assert charlist_mask.shape == (len(self._charlist),)
assert isinstance(
charlist_mask, np.ndarray
), f"Expects the Text sample mask to be an np.ndarray, actual type: {type(charlist_mask)}"
assert (
charlist_mask.dtype == np.int8
), f"Expects the Text sample mask to be an np.ndarray, actual dtype: {charlist_mask.dtype}"
assert charlist_mask.shape == (
len(self.character_set),
), f"expects the Text sample mask to be {(len(self.character_set),)}, actual shape: {charlist_mask.shape}"
assert np.all(
np.logical_or(charlist_mask == 0, charlist_mask == 1)
), f"Expects all masks values to 0 or 1, actual values: {charlist_mask}"
else:
length, charlist_mask = None, None
@@ -92,10 +118,23 @@ class Text(Space[str]):
length = self.np_random.integers(self.min_length, self.max_length + 1)
if charlist_mask is None:
string = self.np_random.choice(self._charlist, size=length)
string = self.np_random.choice(self.character_list, size=length)
else:
masked_charlist = self._charlist[np.where(mask)[0]]
string = self.np_random.choice(masked_charlist, size=length)
valid_mask = charlist_mask == 1
valid_indexes = np.where(valid_mask)[0]
if len(valid_indexes) == 0:
if self.min_length == 0:
string = ""
else:
# Otherwise the string will not be contained in the space
raise ValueError(
f"Trying to sample with a minimum length > 0 ({self.min_length}) but the character mask is all zero meaning that no character could be sampled."
)
else:
string = "".join(
self.character_list[index]
for index in self.np_random.choice(valid_indexes, size=length)
)
return "".join(string)
@@ -103,13 +142,13 @@ class Text(Space[str]):
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, str):
if self.min_length <= len(x) <= self.max_length:
return all(c in self.charset for c in x)
return all(c in self.character_set for c in x)
return False
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return (
f"Text({self.min_length}, {self.max_length}, charset={self._charset_str})"
f"Text({self.min_length}, {self.max_length}, characters={self.characters})"
)
def __eq__(self, other) -> bool:
@@ -118,5 +157,29 @@ class Text(Space[str]):
isinstance(other, Text)
and self.min_length == other.min_length
and self.max_length == other.max_length
and self.charset == other.charset
and self.character_set == other.character_set
)
@property
def character_set(self) -> FrozenSet[str]:
"""Returns the character set for the space."""
return self._char_set
@property
def character_list(self) -> Tuple[str, ...]:
"""Returns a tuple of characters in the space."""
return self._char_list
def character_index(self, char: str) -> np.int32:
"""Returns a unique index for each character in the space's character set."""
return self._char_index[char]
@property
def characters(self) -> str:
"""Returns a string with all Text characters."""
return self._char_str
@property
def is_np_flattenable(self) -> bool:
"""The flattened version is an integer array for each character, padded to the max character length."""
return True

View File

@@ -21,6 +21,7 @@ from gym.spaces import (
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
)
@@ -88,6 +89,11 @@ def _flatdim_dict(space: Dict) -> int:
)
@flatdim.register(Text)
def _flatdim_text(space: Text) -> int:
return space.max_length
T = TypeVar("T")
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
@@ -104,7 +110,9 @@ def flatten(space: Space[T], x: T) -> FlatType:
x: The value to flatten
Returns:
- The flattened ``x``, always returns a 1D array for non-graph spaces.
- For ``Box`` and ``MultiBinary``, this is a flattened array
- For ``Discrete`` and ``MultiDiscrete``, this is a flattened one-hot array of the sample
- For ``Tuple`` and ``Dict``, this is a concatenated array the subspaces (does not support graph subspaces)
- For graph spaces, returns `GraphInstance` where:
- `nodes` are n x k arrays
- `edges` are either:
@@ -179,6 +187,16 @@ def _flatten_graph(space, x) -> GraphInstance:
return GraphInstance(nodes, edges, x.edge_links)
@flatten.register(Text)
def _flatten_text(space: Text, x: str) -> np.ndarray:
arr = np.full(
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
)
for i, val in enumerate(x):
arr[i] = space.character_index(val)
return arr
@flatten.register(Sequence)
def _flatten_sequence(space, x) -> tuple:
return tuple(flatten(space.feature_space, item) for item in x)
@@ -284,6 +302,13 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
return GraphInstance(nodes, edges, x.edge_links)
@unflatten.register(Text)
def _unflatten_text(space: Text, x: np.ndarray) -> str:
return "".join(
[space.character_list[val] for val in x if val < len(space.character_set)]
)
@unflatten.register(Sequence)
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
return tuple(unflatten(space.feature_space, item) for item in x)
@@ -401,6 +426,13 @@ def _flatten_space_graph(space: Graph) -> Graph:
)
@flatten_space.register(Text)
def _flatten_space_text(space: Text) -> Box:
return Box(
low=0, high=len(space.character_set), shape=(space.max_length,), dtype=np.int32
)
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space))

View File

@@ -50,7 +50,9 @@ def data_equivalence(data_1, data_2) -> bool:
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
elif isinstance(data_1, np.ndarray):
return np.all(data_1 == data_2)
return data_1.shape == data_2.shape and np.allclose(
data_1, data_2, atol=0.00001
)
else:
return data_1 == data_2
else:

View File

@@ -4,26 +4,26 @@ import warnings
import numpy as np
import pytest
import gym.error
from gym.spaces import Box
# Todo, move Box unique tests from test_spaces.py to test_box.py
from gym.spaces.box import get_inf
@pytest.mark.parametrize(
"box,expected_shape",
[
(
Box(low=np.zeros(2), high=np.ones(2)),
( # Test with same 1-dim low and high shape
Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32),
(2,),
), # Test with same 1-dim low and high shape
(
Box(low=np.zeros((2, 1)), high=np.ones((2, 1))),
),
( # Test with same multi-dim low and high shape
Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32),
(2, 1),
), # Test with same multi-dim low and high shape
(
),
( # Test with scalar low high and different shape
Box(low=0, high=1, shape=(5, 2)),
(5, 2),
), # Test with scalar low high and different shape
),
(Box(low=0, high=1), (1,)), # Test with int and int
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
@@ -33,7 +33,8 @@ from gym.spaces import Box
(Box(low=np.zeros(3), high=1.0), (3,)), # Test with array and scalar
],
)
def test_box_shape_inference(box, expected_shape):
def test_shape_inference(box, expected_shape):
"""Test that the shape inference is as expected."""
assert box.shape == expected_shape
assert box.sample().shape == expected_shape
@@ -48,7 +49,7 @@ def test_box_shape_inference(box, expected_shape):
(np.zeros(2, dtype=np.float32), True),
(np.zeros((2, 2), dtype=np.float32), True),
(np.inf, True),
(np.nan, True), # This is a weird side
(np.nan, True), # This is a weird case that we allow
(True, False),
(np.bool8(True), False),
(1 + 1j, False),
@@ -56,7 +57,8 @@ def test_box_shape_inference(box, expected_shape):
("string", False),
],
)
def test_box_values(value, valid):
def test_low_high_values(value, valid: bool):
"""Test what `low` and `high` values are valid for `Box` space."""
if valid:
with warnings.catch_warnings(record=True) as caught_warnings:
Box(low=value, high=value)
@@ -66,7 +68,9 @@ def test_box_values(value, valid):
else:
with pytest.raises(
ValueError,
match=r"expect their types to be np\.ndarray, an integer or a float",
match=re.escape(
"expect their types to be np.ndarray, an integer or a float"
),
):
Box(low=value, high=value)
@@ -135,6 +139,178 @@ def test_box_values(value, valid):
),
],
)
def test_box_errors(low, high, kwargs, error, message):
def test_init_errors(low, high, kwargs, error, message):
"""Test all constructor errors."""
with pytest.raises(error, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, **kwargs)
def test_dtype_check():
"""Tests the Box contains function with different dtypes."""
# Related Issues:
# https://github.com/openai/gym/issues/2357
# https://github.com/openai/gym/issues/2298
space = Box(0, 1, (), dtype=np.float32)
# casting will match the correct type
assert np.array(0.5, dtype=np.float32) in space
# float16 is in float32 space
assert np.array(0.5, dtype=np.float16) in space
# float64 is not in float32 space
assert np.array(0.5, dtype=np.float64) not in space
@pytest.mark.parametrize(
"space",
[
Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
],
)
def test_infinite_space(space):
"""
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
are both modified within the init, we check for infinite when we know it's not 0
"""
assert np.all(
space.low < space.high
), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})"
space.seed(0)
sample = space.sample()
# check if space contains sample
assert (
sample in space
), f"Sample ({sample}) not inside space according to `space.contains()`"
# manually check that the sign of the sample is within the bounds
assert np.all(
np.sign(sample) <= np.sign(space.high)
), f"Sign of sample ({sample}) is less than space upper bound ({space.high})"
assert np.all(
np.sign(space.low) <= np.sign(sample)
), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"
# check that int bounds are bounded for everything
# but floats are unbounded for infinite
if np.any(space.high != 0):
assert (
space.is_bounded("above") is False
), "inf upper bound supposed to be unbounded"
else:
assert (
space.is_bounded("above") is True
), "non-inf upper bound supposed to be bounded"
if np.any(space.low != 0):
assert (
space.is_bounded("below") is False
), "inf lower bound supposed to be unbounded"
else:
assert (
space.is_bounded("below") is True
), "non-inf lower bound supposed to be bounded"
if np.any(space.low != 0) or np.any(space.high != 0):
assert space.is_bounded("both") is False
else:
assert space.is_bounded("both") is True
# check for dtype
assert (
space.high.dtype == space.dtype
), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'"
assert (
space.low.dtype == space.dtype
), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'"
with pytest.raises(
ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:"
):
space.is_bounded("test")
def test_legacy_state_pickling():
legacy_state = {
"dtype": np.dtype("float32"),
"_shape": (5,),
"low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
"high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
"bounded_below": np.array([True, True, True, True, True]),
"bounded_above": np.array([True, True, True, True, True]),
"_np_random": None,
}
b = Box(-1, 1, ())
assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
del b.__dict__["low_repr"]
del b.__dict__["high_repr"]
assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__
b.__setstate__(legacy_state)
assert b.low_repr == "0.0"
assert b.high_repr == "1.0"
def test_get_inf():
"""Tests that get inf function works as expected, primarily for coverage."""
assert get_inf(np.float32, "+") == np.inf
assert get_inf(np.float16, "-") == -np.inf
with pytest.raises(
TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
):
get_inf(np.float32, "*")
assert get_inf(np.int16, "+") == 32765
assert get_inf(np.int8, "-") == -126
with pytest.raises(
TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
):
get_inf(np.int32, "*")
with pytest.raises(
ValueError,
match=re.escape("Unknown dtype <class 'numpy.complex128'> for infinite bounds"),
):
get_inf(np.complex_, "+")
def test_sample_mask():
"""Box cannot have a mask applied."""
space = Box(0, 1)
with pytest.raises(
gym.error.Error,
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
):
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))

View File

@@ -0,0 +1,40 @@
import numpy as np
from gym.spaces import Discrete
def test_space_legacy_pickling():
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
legacy_state = {
"shape": (
1,
2,
3,
),
"dtype": np.int64,
"np_random": np.random.default_rng(),
"n": 3,
}
space = Discrete(1)
space.__setstate__(legacy_state)
assert space.shape == legacy_state["shape"]
assert space.np_random == legacy_state["np_random"]
assert space.n == 3
assert space.dtype == legacy_state["dtype"]
# Test that start is missing
assert "start" in space.__dict__
del space.__dict__["start"] # legacy did not include start param
assert "start" not in space.__dict__
space.__setstate__(legacy_state)
assert space.start == 0
def test_sample_mask():
space = Discrete(4, start=2)
assert 2 <= space.sample() < 6
assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]

View File

@@ -0,0 +1,3 @@
def test_sample():
# todo
pass

View File

@@ -0,0 +1,66 @@
import pytest
from gym.spaces import Discrete, MultiDiscrete
from gym.utils.env_checker import data_equivalence
def test_multidiscrete_as_tuple():
# 1D multi-discrete
space = MultiDiscrete([3, 4, 5])
assert space.shape == (3,)
assert space[0] == Discrete(3)
assert space[0:1] == MultiDiscrete([3])
assert space[0:2] == MultiDiscrete([3, 4])
assert space[:] == space and space[:] is not space
# 2D multi-discrete
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
assert space.shape == (2, 3)
assert space[0, 1] == Discrete(4)
assert space[0] == MultiDiscrete([3, 4, 5])
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
assert space[:] == space and space[:] is not space
assert space[:, :] == space and space[:, :] is not space
def test_multidiscrete_subspace_reproducibility():
# 1D multi-discrete
space = MultiDiscrete([100, 200, 300])
space.seed()
assert data_equivalence(space[0].sample(), space[0].sample())
assert data_equivalence(space[0:1].sample(), space[0:1].sample())
assert data_equivalence(space[0:2].sample(), space[0:2].sample())
assert data_equivalence(space[:].sample(), space[:].sample())
assert data_equivalence(space[:].sample(), space.sample())
# 2D multi-discrete
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
space.seed()
assert data_equivalence(space[0, 1].sample(), space[0, 1].sample())
assert data_equivalence(space[0].sample(), space[0].sample())
assert data_equivalence(space[0:1].sample(), space[0:1].sample())
assert data_equivalence(space[0:2, :].sample(), space[0:2, :].sample())
assert data_equivalence(space[:, 0:1].sample(), space[:, 0:1].sample())
assert data_equivalence(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
assert data_equivalence(space[:].sample(), space[:].sample())
assert data_equivalence(space[:, :].sample(), space[:, :].sample())
assert data_equivalence(space[:, :].sample(), space.sample())
def test_multidiscrete_length():
space = MultiDiscrete(nvec=[3, 2, 4])
assert len(space) == 3
space = MultiDiscrete(nvec=[[2, 3], [3, 2]])
with pytest.warns(
UserWarning,
match="Getting the length of a multi-dimensional MultiDiscrete space.",
):
assert len(space) == 2

View File

@@ -0,0 +1,24 @@
from functools import partial
import pytest
from gym import Space
from gym.spaces import utils
TESTING_SPACE = Space()
@pytest.mark.parametrize(
"func",
[
TESTING_SPACE.sample,
partial(TESTING_SPACE.contains, None),
partial(utils.flatdim, TESTING_SPACE),
partial(utils.flatten, TESTING_SPACE, None),
partial(utils.flatten_space, TESTING_SPACE),
partial(utils.unflatten, TESTING_SPACE, None),
],
)
def test_not_implemented_errors(func):
with pytest.raises(NotImplementedError):
func()

File diff suppressed because it is too large Load Diff

41
tests/spaces/test_text.py Normal file
View File

@@ -0,0 +1,41 @@
import re
import numpy as np
import pytest
from gym.spaces import Text
def test_sample_mask():
space = Text(min_length=1, max_length=5)
# Test the sample length
sample = space.sample(mask=(3, None))
assert sample in space
assert len(sample) == 3
sample = space.sample(mask=None)
assert sample in space
assert 1 <= len(sample) <= 5
with pytest.raises(
ValueError,
match=re.escape(
"Trying to sample with a minimum length > 0 (1) but the character mask is all zero meaning that no character could be sampled."
),
):
space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8)))
space = Text(min_length=0, max_length=5)
sample = space.sample(
mask=(None, np.zeros(len(space.character_set), dtype=np.int8))
)
assert sample in space
assert sample == ""
# Test the sample characters
space = Text(max_length=5, charset="abcd")
sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8)))
assert sample in space
assert sample == "bbb"

View File

@@ -1,601 +1,135 @@
import re
from collections import OrderedDict
from itertools import zip_longest
from typing import Optional
import numpy as np
import pytest
from gym.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Tuple,
utils,
import gym
from gym.spaces import Box, Graph, utils
from gym.utils.env_checker import data_equivalence
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
TESTING_SPACES_EXPECTED_FLATDIMS = [
# Discrete
3,
3,
# Box
1,
4,
2,
2,
2,
# Multi-discrete
4,
10,
# Multi-binary
8,
6,
# Text
6,
6,
6,
# # Tuple
# 9,
# 7,
# 10,
# 6,
# None,
# # Dict
# 7,
# 8,
# 17,
# None,
# # Graph
# None,
# None,
# None,
# # Sequence
# None,
# None,
# None,
]
@pytest.mark.parametrize(
["space", "flatdim"],
zip_longest(TESTING_SPACES, TESTING_SPACES_EXPECTED_FLATDIMS),
ids=TESTING_SPACES_IDS,
)
homogeneous_spaces = [
Discrete(3),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Box(low=0.0, high=np.inf, shape=(2, 2), dtype=np.float16),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
[
Discrete(5),
Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64),
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 10]),
MultiBinary(10),
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64
),
}
),
Discrete(3, start=2),
Discrete(8, start=-5),
]
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
non_homogenous_spaces = [
Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), #
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), #
Graph(node_space=Discrete(5), edge_space=None), #
Sequence(Discrete(4)), #
Sequence(Box(-10, 10, shape=(2, 2))), #
Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), #
Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), #
Dict(
a=Graph(node_space=Discrete(4), edge_space=Discrete(4)),
b=Box(-10, 10, shape=(2, 2)),
), #
Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), #
Tuple(
[
Graph(node_space=Discrete(4), edge_space=Discrete(4)),
Box(-10, 10, shape=(2, 2)),
]
), #
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), #
Dict(
a=Dict(
a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2))
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
), #
Dict(
a=Dict(
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
b=Box(-100, 100, shape=(2, 2)),
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
),
]
@pytest.mark.parametrize("space", non_homogenous_spaces)
def test_non_flattenable(space):
assert space.is_np_flattenable is False
def test_flatdim(space: gym.spaces.Space, flatdim: Optional[int]):
"""Checks that the flattened dims of the space is equal to an expected value."""
if space.is_np_flattenable:
dim = utils.flatdim(space)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
else:
with pytest.raises(
ValueError,
match=re.escape(
"cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
),
):
utils.flatdim(space)
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
def test_flatdim(space, flatdim):
assert space.is_np_flattenable
dim = utils.flatdim(space)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_space_boxes(space):
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flatten_space(space):
"""Test that the flattened spaces are a box and have the `flatdim` shape."""
flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
flatdim = utils.flatdim(space)
if space.is_np_flattenable:
assert isinstance(flat_space, Box)
(single_dim,) = flat_space.shape
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flat_space_contains_flat_points(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples):
assert flat_space.contains(
flat_sample
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_dim(space):
sample = utils.flatten(space, space.sample())
(single_dim,) = sample.shape
flatdim = utils.flatdim(space)
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
assert single_dim == flatdim
elif isinstance(flat_space, Graph):
assert isinstance(space, Graph)
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
roundtripped_samples = [
utils.unflatten(space, sample) for sample in flattened_samples
]
for i, (original, roundtripped) in enumerate(
zip(some_samples, roundtripped_samples)
):
assert compare_nested(
original, roundtripped
), f"Expected sample #{i} {original} to equal {roundtripped}"
assert space.contains(roundtripped)
(node_single_dim,) = flat_space.node_space.shape
node_flatdim = utils.flatdim(space.node_space)
assert node_single_dim == node_flatdim
def compare_nested(left, right):
if type(left) != type(right):
return False
elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
return left.shape == right.shape and np.allclose(left, right)
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
res = len(left) == len(right)
for ((left_key, left_value), (right_key, right_value)) in zip(
left.items(), right.items()
):
if not res:
return False
res = left_key == right_key and compare_nested(left_value, right_value)
return res
elif isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
res = len(left) == len(right)
for (x, y) in zip(left, right):
if not res:
return False
res = compare_nested(x, y)
return res
if flat_space.edge_space is not None:
(edge_single_dim,) = flat_space.edge_space.shape
edge_flatdim = utils.flatdim(space.edge_space)
assert edge_single_dim == edge_flatdim
else:
return left == right
"""
Expecteded flattened types are based off:
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
internally by numpy when np.concatenate is called.
"""
expected_flattened_dtypes = [
np.int64,
np.float32,
np.float16,
np.int64,
np.float64,
np.int64,
np.int64,
np.int8,
np.float64,
np.int64,
np.int64,
]
@pytest.mark.parametrize(
["original_space", "expected_flattened_dtype"],
zip(homogeneous_spaces, expected_flattened_dtypes),
)
def test_dtypes(original_space, expected_flattened_dtype):
flattened_space = utils.flatten_space(original_space)
original_sample = original_space.sample()
flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample)
assert flattened_space.contains(
flattened_sample
), "Expected flattened_space to contain flattened_sample"
assert (
flattened_space.dtype == expected_flattened_dtype
), f"Expected flattened_space's dtype to equal {expected_flattened_dtype}"
assert flattened_sample.dtype == flattened_space.dtype, (
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "
)
compare_sample_types(original_space, original_sample, unflattened_sample)
def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete):
assert isinstance(unflattened_sample, int), (
"Expected unflattened_sample to be an int. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
)
elif isinstance(original_space, Tuple):
for index in range(len(original_space)):
compare_sample_types(
original_space.spaces[index],
original_sample[index],
unflattened_sample[index],
)
elif isinstance(original_space, Dict):
for key, space in original_space.spaces.items():
compare_sample_types(space, original_sample[key], unflattened_sample[key])
else:
assert unflattened_sample.dtype == original_sample.dtype, (
"Expected unflattened_sample's dtype to equal "
"original_sample's dtype. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
assert isinstance(
space, (gym.spaces.Tuple, gym.spaces.Dict, gym.spaces.Sequence)
)
homogeneous_samples = [
2,
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32),
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16),
(3, 7),
(2, np.array([0.5, 3.5], dtype=np.float32)),
(3, 0, 1),
np.array([0, 1, 7], dtype=np.int64),
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
OrderedDict(
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
),
3,
-2,
]
expected_flattened_hom_samples = [
np.array([0, 0, 1], dtype=np.int64),
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32),
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16),
np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
np.array([0, 0, 1, 0, 0, 0.5, 3.5], dtype=np.float64),
np.array([0, 0, 0, 1, 0, 1, 0, 0, 1], dtype=np.int64),
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
np.array([0, 1, 0], dtype=np.int64),
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
]
non_homogenous_samples = [
GraphInstance(
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([0, 1], dtype=int),
np.array([[[1, 2], [3, 4]]], dtype=np.float32),
np.array([[0, 1]], dtype=int),
),
GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)),
(0, 1, 2),
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
(
(np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)),
(np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)),
),
OrderedDict(
[("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[0, 1], [2, 3]], dtype=np.float32)),
]
),
((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)),
(
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
np.array([[0, 1], [2, 3]], dtype=np.float32),
),
(
GraphInstance(
nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32),
edges=np.array([0], dtype=int),
edge_links=np.array([[0, 1]]),
),
GraphInstance(
nodes=np.array(
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32
),
edges=np.array([1], dtype=int),
edge_links=np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=np.float32,
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
]
expected_flattened_non_hom_samples = [
GraphInstance(
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32),
np.array([[1, 0, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
np.array([[1, 2, 3, 4]], dtype=np.float32),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
None,
np.array([[0, 1]], dtype=int),
),
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
OrderedDict(
[
(
"a",
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
(
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32),
np.array([[1, 0, 0, 0]]),
np.array([[0, 1]]),
),
GraphInstance(
np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32),
np.array([[0, 1, 0, 0]]),
np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
]
@pytest.mark.parametrize(
["space", "sample", "expected_flattened_sample"],
zip(
homogeneous_spaces + non_homogenous_spaces,
homogeneous_samples + non_homogenous_samples,
expected_flattened_hom_samples + expected_flattened_non_hom_samples,
),
)
def test_flatten(space, sample, expected_flattened_sample):
flattened_sample = utils.flatten(space, sample)
flat_space = utils.flatten_space(space)
assert sample in space
assert flattened_sample in flat_space
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flatten(space):
"""Test that a flattened sample have the `flatdim` shape."""
flattened_sample = utils.flatten(space, space.sample())
if space.is_np_flattenable:
assert isinstance(flattened_sample, np.ndarray)
assert flattened_sample.shape == expected_flattened_sample.shape
assert flattened_sample.dtype == expected_flattened_sample.dtype
assert np.all(flattened_sample == expected_flattened_sample)
(single_dim,) = flattened_sample.shape
flatdim = utils.flatdim(space)
assert single_dim == flatdim
else:
assert not isinstance(flattened_sample, np.ndarray)
assert compare_nested(flattened_sample, expected_flattened_sample)
assert isinstance(flattened_sample, (tuple, dict, Graph))
@pytest.mark.parametrize(
["space", "flattened_sample", "expected_sample"],
zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples),
)
def test_unflatten(space, flattened_sample, expected_sample):
sample = utils.unflatten(space, flattened_sample)
assert compare_nested(sample, expected_sample)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flat_space_contains_flat_points(space):
"""Test that the flattened samples are contained within the flattened space."""
flattened_samples = [utils.flatten(space, space.sample()) for _ in range(10)]
flat_space = utils.flatten_space(space)
for flat_sample in flattened_samples:
assert flat_sample in flat_space
expected_flattened_spaces = [
Box(low=0, high=1, shape=(3,), dtype=np.int64),
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float32),
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float16),
Box(low=0, high=1, shape=(15,), dtype=np.int64),
Box(
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64,
),
Box(low=0, high=1, shape=(9,), dtype=np.int64),
Box(low=0, high=1, shape=(14,), dtype=np.int64),
Box(low=0, high=1, shape=(10,), dtype=np.int8),
Box(
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64,
),
Box(low=0, high=1, shape=(3,), dtype=np.int64),
Box(low=0, high=1, shape=(8,), dtype=np.int64),
]
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flatten_roundtripping(space):
"""Tests roundtripping with flattening and unflattening are equal to the original sample."""
samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in samples]
unflattened_samples = [
utils.unflatten(space, sample) for sample in flattened_samples
]
@pytest.mark.parametrize(
["space", "expected_flattened_space"],
zip(homogeneous_spaces, expected_flattened_spaces),
)
def test_flatten_space(space, expected_flattened_space):
flattened_space = utils.flatten_space(space)
assert flattened_space == expected_flattened_space
for original, roundtripped in zip(samples, unflattened_samples):
assert data_equivalence(original, roundtripped)

27
tests/spaces/utils.py Normal file
View File

@@ -0,0 +1,27 @@
from typing import List
import numpy as np
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete, Space, Text
TESTING_FUNDAMENTAL_SPACES = [
Discrete(3),
Discrete(3, start=-1),
Box(low=0.0, high=1.0),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64),
Box(low=-np.inf, high=0.0, shape=(2, 1)),
Box(low=0.0, high=np.inf, shape=(2, 1)),
MultiDiscrete([2, 2]),
MultiDiscrete([[2, 3], [3, 2]]),
MultiBinary(8),
MultiBinary([2, 3]),
Text(6),
Text(min_length=3, max_length=6),
Text(6, charset="abcdef"),
]
TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES # + TESTING_COMPOSITE_SPACES
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS # + TESTING_COMPOSITE_SPACES_IDS