mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
Add testing for Fundamental spaces with full coverage (#3048)
* initial commit * Fix the multi-binary sample and upgrade multi-discrete sample * Fix MultiBinary sample tests and pre-commit * Adds coverage tests and updates test_utils.py to use the utils.py spaces. Fix bug in text.py * pre-commit
This commit is contained in:
@@ -1,601 +1,135 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from itertools import zip_longest
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Tuple,
|
||||
utils,
|
||||
)
|
||||
import gym
|
||||
from gym.spaces import Box, Graph, utils
|
||||
from gym.utils.env_checker import data_equivalence
|
||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
||||
|
||||
homogeneous_spaces = [
|
||||
Discrete(3),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2)),
|
||||
Box(low=0.0, high=np.inf, shape=(2, 2), dtype=np.float16),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple(
|
||||
[
|
||||
Discrete(5),
|
||||
Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64),
|
||||
]
|
||||
),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 10]),
|
||||
MultiBinary(10),
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64
|
||||
),
|
||||
}
|
||||
),
|
||||
Discrete(3, start=2),
|
||||
Discrete(8, start=-5),
|
||||
]
|
||||
|
||||
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
|
||||
|
||||
non_homogenous_spaces = [
|
||||
Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), #
|
||||
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), #
|
||||
Graph(node_space=Discrete(5), edge_space=None), #
|
||||
Sequence(Discrete(4)), #
|
||||
Sequence(Box(-10, 10, shape=(2, 2))), #
|
||||
Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), #
|
||||
Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), #
|
||||
Dict(
|
||||
a=Graph(node_space=Discrete(4), edge_space=Discrete(4)),
|
||||
b=Box(-10, 10, shape=(2, 2)),
|
||||
), #
|
||||
Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), #
|
||||
Tuple(
|
||||
[
|
||||
Graph(node_space=Discrete(4), edge_space=Discrete(4)),
|
||||
Box(-10, 10, shape=(2, 2)),
|
||||
]
|
||||
), #
|
||||
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), #
|
||||
Dict(
|
||||
a=Dict(
|
||||
a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2))
|
||||
),
|
||||
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
|
||||
), #
|
||||
Dict(
|
||||
a=Dict(
|
||||
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
|
||||
b=Box(-100, 100, shape=(2, 2)),
|
||||
),
|
||||
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", non_homogenous_spaces)
|
||||
def test_non_flattenable(space):
|
||||
assert space.is_np_flattenable is False
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||
),
|
||||
):
|
||||
utils.flatdim(space)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
|
||||
def test_flatdim(space, flatdim):
|
||||
assert space.is_np_flattenable
|
||||
dim = utils.flatdim(space)
|
||||
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces)
|
||||
def test_flatten_space_boxes(space):
|
||||
flat_space = utils.flatten_space(space)
|
||||
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
|
||||
flatdim = utils.flatdim(space)
|
||||
(single_dim,) = flat_space.shape
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
|
||||
def test_flat_space_contains_flat_points(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
flat_space = utils.flatten_space(space)
|
||||
for i, flat_sample in enumerate(flattened_samples):
|
||||
assert flat_space.contains(
|
||||
flat_sample
|
||||
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces)
|
||||
def test_flatten_dim(space):
|
||||
sample = utils.flatten(space, space.sample())
|
||||
(single_dim,) = sample.shape
|
||||
flatdim = utils.flatdim(space)
|
||||
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
|
||||
def test_flatten_roundtripping(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
roundtripped_samples = [
|
||||
utils.unflatten(space, sample) for sample in flattened_samples
|
||||
]
|
||||
for i, (original, roundtripped) in enumerate(
|
||||
zip(some_samples, roundtripped_samples)
|
||||
):
|
||||
assert compare_nested(
|
||||
original, roundtripped
|
||||
), f"Expected sample #{i} {original} to equal {roundtripped}"
|
||||
assert space.contains(roundtripped)
|
||||
|
||||
|
||||
def compare_nested(left, right):
|
||||
if type(left) != type(right):
|
||||
return False
|
||||
elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
|
||||
return left.shape == right.shape and np.allclose(left, right)
|
||||
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
||||
res = len(left) == len(right)
|
||||
for ((left_key, left_value), (right_key, right_value)) in zip(
|
||||
left.items(), right.items()
|
||||
):
|
||||
if not res:
|
||||
return False
|
||||
res = left_key == right_key and compare_nested(left_value, right_value)
|
||||
return res
|
||||
elif isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
|
||||
res = len(left) == len(right)
|
||||
for (x, y) in zip(left, right):
|
||||
if not res:
|
||||
return False
|
||||
res = compare_nested(x, y)
|
||||
return res
|
||||
else:
|
||||
return left == right
|
||||
|
||||
|
||||
"""
|
||||
Expecteded flattened types are based off:
|
||||
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
|
||||
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
|
||||
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
|
||||
internally by numpy when np.concatenate is called.
|
||||
"""
|
||||
|
||||
expected_flattened_dtypes = [
|
||||
np.int64,
|
||||
np.float32,
|
||||
np.float16,
|
||||
np.int64,
|
||||
np.float64,
|
||||
np.int64,
|
||||
np.int64,
|
||||
np.int8,
|
||||
np.float64,
|
||||
np.int64,
|
||||
np.int64,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["original_space", "expected_flattened_dtype"],
|
||||
zip(homogeneous_spaces, expected_flattened_dtypes),
|
||||
)
|
||||
def test_dtypes(original_space, expected_flattened_dtype):
|
||||
flattened_space = utils.flatten_space(original_space)
|
||||
|
||||
original_sample = original_space.sample()
|
||||
flattened_sample = utils.flatten(original_space, original_sample)
|
||||
unflattened_sample = utils.unflatten(original_space, flattened_sample)
|
||||
|
||||
assert flattened_space.contains(
|
||||
flattened_sample
|
||||
), "Expected flattened_space to contain flattened_sample"
|
||||
assert (
|
||||
flattened_space.dtype == expected_flattened_dtype
|
||||
), f"Expected flattened_space's dtype to equal {expected_flattened_dtype}"
|
||||
|
||||
assert flattened_sample.dtype == flattened_space.dtype, (
|
||||
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "
|
||||
)
|
||||
|
||||
compare_sample_types(original_space, original_sample, unflattened_sample)
|
||||
|
||||
|
||||
def compare_sample_types(original_space, original_sample, unflattened_sample):
|
||||
if isinstance(original_space, Discrete):
|
||||
assert isinstance(unflattened_sample, int), (
|
||||
"Expected unflattened_sample to be an int. unflattened_sample: "
|
||||
"{} original_sample: {}".format(unflattened_sample, original_sample)
|
||||
)
|
||||
elif isinstance(original_space, Tuple):
|
||||
for index in range(len(original_space)):
|
||||
compare_sample_types(
|
||||
original_space.spaces[index],
|
||||
original_sample[index],
|
||||
unflattened_sample[index],
|
||||
)
|
||||
elif isinstance(original_space, Dict):
|
||||
for key, space in original_space.spaces.items():
|
||||
compare_sample_types(space, original_sample[key], unflattened_sample[key])
|
||||
else:
|
||||
assert unflattened_sample.dtype == original_sample.dtype, (
|
||||
"Expected unflattened_sample's dtype to equal "
|
||||
"original_sample's dtype. unflattened_sample: "
|
||||
"{} original_sample: {}".format(unflattened_sample, original_sample)
|
||||
)
|
||||
|
||||
|
||||
homogeneous_samples = [
|
||||
2,
|
||||
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32),
|
||||
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16),
|
||||
(3, 7),
|
||||
(2, np.array([0.5, 3.5], dtype=np.float32)),
|
||||
(3, 0, 1),
|
||||
np.array([0, 1, 7], dtype=np.int64),
|
||||
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
|
||||
OrderedDict(
|
||||
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
|
||||
),
|
||||
TESTING_SPACES_EXPECTED_FLATDIMS = [
|
||||
# Discrete
|
||||
3,
|
||||
-2,
|
||||
]
|
||||
|
||||
|
||||
expected_flattened_hom_samples = [
|
||||
np.array([0, 0, 1], dtype=np.int64),
|
||||
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32),
|
||||
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16),
|
||||
np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
|
||||
np.array([0, 0, 1, 0, 0, 0.5, 3.5], dtype=np.float64),
|
||||
np.array([0, 0, 0, 1, 0, 1, 0, 0, 1], dtype=np.int64),
|
||||
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
|
||||
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
|
||||
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
|
||||
np.array([0, 1, 0], dtype=np.int64),
|
||||
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
|
||||
]
|
||||
|
||||
non_homogenous_samples = [
|
||||
GraphInstance(
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([0, 1], dtype=int),
|
||||
np.array([[[1, 2], [3, 4]]], dtype=np.float32),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)),
|
||||
(0, 1, 2),
|
||||
(
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
np.array([[4, 5], [6, 7]], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
(np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)),
|
||||
(np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)),
|
||||
),
|
||||
OrderedDict(
|
||||
[("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array([1, 2], dtype=np.int),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([[0, 1], [2, 3]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([1, 2], dtype=np.int),
|
||||
np.array(
|
||||
[
|
||||
0,
|
||||
],
|
||||
dtype=int,
|
||||
),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32),
|
||||
edges=np.array([0], dtype=int),
|
||||
edge_links=np.array([[0, 1]]),
|
||||
),
|
||||
GraphInstance(
|
||||
nodes=np.array(
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32
|
||||
),
|
||||
edges=np.array([1], dtype=int),
|
||||
edge_links=np.array([[0, 1]]),
|
||||
),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
np.array([[4, 5], [6, 7]], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
"b",
|
||||
(
|
||||
np.array([12, 13], dtype=np.float32),
|
||||
np.array([14, 15], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array(
|
||||
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
|
||||
dtype=np.float32,
|
||||
),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
"b",
|
||||
(
|
||||
np.array([12, 13], dtype=np.float32),
|
||||
np.array([14, 15], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
expected_flattened_non_hom_samples = [
|
||||
GraphInstance(
|
||||
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
|
||||
np.array([[1, 2, 3, 4]], dtype=np.float32),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int),
|
||||
np.array([[1, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
(
|
||||
(
|
||||
np.array([1, 0, 0, 0], dtype=int),
|
||||
np.array([0, 1, 0, 0], dtype=int),
|
||||
np.array([0, 0, 1, 0], dtype=int),
|
||||
),
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0]], dtype=int),
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
GraphInstance(
|
||||
np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32),
|
||||
np.array([[1, 0, 0, 0]]),
|
||||
np.array([[0, 1]]),
|
||||
),
|
||||
GraphInstance(
|
||||
np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32),
|
||||
np.array([[0, 1, 0, 0]]),
|
||||
np.array([[0, 1]]),
|
||||
),
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
(
|
||||
np.array([0, 1, 2, 3], dtype=np.float32),
|
||||
np.array([4, 5, 6, 7], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
|
||||
]
|
||||
),
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
"a",
|
||||
GraphInstance(
|
||||
np.array(
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32
|
||||
),
|
||||
None,
|
||||
np.array([[0, 1]], dtype=int),
|
||||
),
|
||||
),
|
||||
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
|
||||
]
|
||||
),
|
||||
),
|
||||
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
|
||||
]
|
||||
),
|
||||
3,
|
||||
# Box
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
# Multi-discrete
|
||||
4,
|
||||
10,
|
||||
# Multi-binary
|
||||
8,
|
||||
6,
|
||||
# Text
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
# # Tuple
|
||||
# 9,
|
||||
# 7,
|
||||
# 10,
|
||||
# 6,
|
||||
# None,
|
||||
# # Dict
|
||||
# 7,
|
||||
# 8,
|
||||
# 17,
|
||||
# None,
|
||||
# # Graph
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# # Sequence
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "sample", "expected_flattened_sample"],
|
||||
zip(
|
||||
homogeneous_spaces + non_homogenous_spaces,
|
||||
homogeneous_samples + non_homogenous_samples,
|
||||
expected_flattened_hom_samples + expected_flattened_non_hom_samples,
|
||||
),
|
||||
["space", "flatdim"],
|
||||
zip_longest(TESTING_SPACES, TESTING_SPACES_EXPECTED_FLATDIMS),
|
||||
ids=TESTING_SPACES_IDS,
|
||||
)
|
||||
def test_flatten(space, sample, expected_flattened_sample):
|
||||
flattened_sample = utils.flatten(space, sample)
|
||||
def test_flatdim(space: gym.spaces.Space, flatdim: Optional[int]):
|
||||
"""Checks that the flattened dims of the space is equal to an expected value."""
|
||||
if space.is_np_flattenable:
|
||||
dim = utils.flatdim(space)
|
||||
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
|
||||
else:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
):
|
||||
utils.flatdim(space)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flatten_space(space):
|
||||
"""Test that the flattened spaces are a box and have the `flatdim` shape."""
|
||||
flat_space = utils.flatten_space(space)
|
||||
|
||||
assert sample in space
|
||||
assert flattened_sample in flat_space
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(flat_space, Box)
|
||||
(single_dim,) = flat_space.shape
|
||||
flatdim = utils.flatdim(space)
|
||||
|
||||
assert single_dim == flatdim
|
||||
elif isinstance(flat_space, Graph):
|
||||
assert isinstance(space, Graph)
|
||||
|
||||
(node_single_dim,) = flat_space.node_space.shape
|
||||
node_flatdim = utils.flatdim(space.node_space)
|
||||
assert node_single_dim == node_flatdim
|
||||
|
||||
if flat_space.edge_space is not None:
|
||||
(edge_single_dim,) = flat_space.edge_space.shape
|
||||
edge_flatdim = utils.flatdim(space.edge_space)
|
||||
assert edge_single_dim == edge_flatdim
|
||||
else:
|
||||
assert isinstance(
|
||||
space, (gym.spaces.Tuple, gym.spaces.Dict, gym.spaces.Sequence)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flatten(space):
|
||||
"""Test that a flattened sample have the `flatdim` shape."""
|
||||
flattened_sample = utils.flatten(space, space.sample())
|
||||
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(flattened_sample, np.ndarray)
|
||||
assert flattened_sample.shape == expected_flattened_sample.shape
|
||||
assert flattened_sample.dtype == expected_flattened_sample.dtype
|
||||
assert np.all(flattened_sample == expected_flattened_sample)
|
||||
(single_dim,) = flattened_sample.shape
|
||||
flatdim = utils.flatdim(space)
|
||||
|
||||
assert single_dim == flatdim
|
||||
else:
|
||||
assert not isinstance(flattened_sample, np.ndarray)
|
||||
assert compare_nested(flattened_sample, expected_flattened_sample)
|
||||
assert isinstance(flattened_sample, (tuple, dict, Graph))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "flattened_sample", "expected_sample"],
|
||||
zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples),
|
||||
)
|
||||
def test_unflatten(space, flattened_sample, expected_sample):
|
||||
sample = utils.unflatten(space, flattened_sample)
|
||||
assert compare_nested(sample, expected_sample)
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flat_space_contains_flat_points(space):
|
||||
"""Test that the flattened samples are contained within the flattened space."""
|
||||
flattened_samples = [utils.flatten(space, space.sample()) for _ in range(10)]
|
||||
flat_space = utils.flatten_space(space)
|
||||
|
||||
for flat_sample in flattened_samples:
|
||||
assert flat_sample in flat_space
|
||||
|
||||
|
||||
expected_flattened_spaces = [
|
||||
Box(low=0, high=1, shape=(3,), dtype=np.int64),
|
||||
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float32),
|
||||
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float16),
|
||||
Box(low=0, high=1, shape=(15,), dtype=np.int64),
|
||||
Box(
|
||||
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
|
||||
dtype=np.float64,
|
||||
),
|
||||
Box(low=0, high=1, shape=(9,), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(14,), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(10,), dtype=np.int8),
|
||||
Box(
|
||||
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
|
||||
dtype=np.float64,
|
||||
),
|
||||
Box(low=0, high=1, shape=(3,), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(8,), dtype=np.int64),
|
||||
]
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flatten_roundtripping(space):
|
||||
"""Tests roundtripping with flattening and unflattening are equal to the original sample."""
|
||||
samples = [space.sample() for _ in range(10)]
|
||||
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in samples]
|
||||
unflattened_samples = [
|
||||
utils.unflatten(space, sample) for sample in flattened_samples
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["space", "expected_flattened_space"],
|
||||
zip(homogeneous_spaces, expected_flattened_spaces),
|
||||
)
|
||||
def test_flatten_space(space, expected_flattened_space):
|
||||
flattened_space = utils.flatten_space(space)
|
||||
assert flattened_space == expected_flattened_space
|
||||
for original, roundtripped in zip(samples, unflattened_samples):
|
||||
assert data_equivalence(original, roundtripped)
|
||||
|
Reference in New Issue
Block a user