Files
Gymnasium/tests/spaces/test_utils.py
Markus Krimmel 63ea5f2517 Add Sequence space, update flatten functions (#2968)
* Added Sequence space, updated flatten functions to work with Sequence, Graph. WIP.

* Small fixes, added Sequence space to tests

* Replace Optional[Any] by Any

* Added tests for flattening of non-numpy-flattenable spaces

* Return all seeds
2022-08-15 11:11:32 -04:00

602 lines
19 KiB
Python

import re
from collections import OrderedDict
import numpy as np
import pytest
from gym.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Tuple,
utils,
)
homogeneous_spaces = [
Discrete(3),
Box(low=0.0, high=np.inf, shape=(2, 2)),
Box(low=0.0, high=np.inf, shape=(2, 2), dtype=np.float16),
Tuple([Discrete(5), Discrete(10)]),
Tuple(
[
Discrete(5),
Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64),
]
),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 10]),
MultiBinary(10),
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0.0, 0.0]), high=np.array([1.0, 5.0]), dtype=np.float64
),
}
),
Discrete(3, start=2),
Discrete(8, start=-5),
]
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
non_homogenous_spaces = [
Graph(node_space=Box(low=-100, high=100, shape=(2, 2)), edge_space=Discrete(5)), #
Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(2, 2))), #
Graph(node_space=Discrete(5), edge_space=None), #
Sequence(Discrete(4)), #
Sequence(Box(-10, 10, shape=(2, 2))), #
Sequence(Tuple([Box(-10, 10, shape=(2,)), Box(-10, 10, shape=(2,))])), #
Dict(a=Sequence(Discrete(4)), b=Box(-10, 10, shape=(2, 2))), #
Dict(
a=Graph(node_space=Discrete(4), edge_space=Discrete(4)),
b=Box(-10, 10, shape=(2, 2)),
), #
Tuple([Sequence(Discrete(4)), Box(-10, 10, shape=(2, 2))]), #
Tuple(
[
Graph(node_space=Discrete(4), edge_space=Discrete(4)),
Box(-10, 10, shape=(2, 2)),
]
), #
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), #
Dict(
a=Dict(
a=Sequence(Box(-100, 100, shape=(2, 2))), b=Box(-100, 100, shape=(2, 2))
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
), #
Dict(
a=Dict(
a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
b=Box(-100, 100, shape=(2, 2)),
),
b=Tuple([Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,))]),
),
]
@pytest.mark.parametrize("space", non_homogenous_spaces)
def test_non_flattenable(space):
assert space.is_np_flattenable is False
with pytest.raises(
ValueError,
match=re.escape(
"cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
),
):
utils.flatdim(space)
@pytest.mark.parametrize(["space", "flatdim"], zip(homogeneous_spaces, flatdims))
def test_flatdim(space, flatdim):
assert space.is_np_flattenable
dim = utils.flatdim(space)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_space_boxes(space):
flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
flatdim = utils.flatdim(space)
(single_dim,) = flat_space.shape
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flat_space_contains_flat_points(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples):
assert flat_space.contains(
flat_sample
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
@pytest.mark.parametrize("space", homogeneous_spaces)
def test_flatten_dim(space):
sample = utils.flatten(space, space.sample())
(single_dim,) = sample.shape
flatdim = utils.flatdim(space)
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", homogeneous_spaces + non_homogenous_spaces)
def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
roundtripped_samples = [
utils.unflatten(space, sample) for sample in flattened_samples
]
for i, (original, roundtripped) in enumerate(
zip(some_samples, roundtripped_samples)
):
assert compare_nested(
original, roundtripped
), f"Expected sample #{i} {original} to equal {roundtripped}"
assert space.contains(roundtripped)
def compare_nested(left, right):
if type(left) != type(right):
return False
elif isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
return left.shape == right.shape and np.allclose(left, right)
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
res = len(left) == len(right)
for ((left_key, left_value), (right_key, right_value)) in zip(
left.items(), right.items()
):
if not res:
return False
res = left_key == right_key and compare_nested(left_value, right_value)
return res
elif isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
res = len(left) == len(right)
for (x, y) in zip(left, right):
if not res:
return False
res = compare_nested(x, y)
return res
else:
return left == right
"""
Expecteded flattened types are based off:
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
internally by numpy when np.concatenate is called.
"""
expected_flattened_dtypes = [
np.int64,
np.float32,
np.float16,
np.int64,
np.float64,
np.int64,
np.int64,
np.int8,
np.float64,
np.int64,
np.int64,
]
@pytest.mark.parametrize(
["original_space", "expected_flattened_dtype"],
zip(homogeneous_spaces, expected_flattened_dtypes),
)
def test_dtypes(original_space, expected_flattened_dtype):
flattened_space = utils.flatten_space(original_space)
original_sample = original_space.sample()
flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample)
assert flattened_space.contains(
flattened_sample
), "Expected flattened_space to contain flattened_sample"
assert (
flattened_space.dtype == expected_flattened_dtype
), f"Expected flattened_space's dtype to equal {expected_flattened_dtype}"
assert flattened_sample.dtype == flattened_space.dtype, (
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "
)
compare_sample_types(original_space, original_sample, unflattened_sample)
def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete):
assert isinstance(unflattened_sample, int), (
"Expected unflattened_sample to be an int. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
)
elif isinstance(original_space, Tuple):
for index in range(len(original_space)):
compare_sample_types(
original_space.spaces[index],
original_sample[index],
unflattened_sample[index],
)
elif isinstance(original_space, Dict):
for key, space in original_space.spaces.items():
compare_sample_types(space, original_sample[key], unflattened_sample[key])
else:
assert unflattened_sample.dtype == original_sample.dtype, (
"Expected unflattened_sample's dtype to equal "
"original_sample's dtype. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
)
homogeneous_samples = [
2,
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float32),
np.array([[1.0, 3.0], [5.0, 8.0]], dtype=np.float16),
(3, 7),
(2, np.array([0.5, 3.5], dtype=np.float32)),
(3, 0, 1),
np.array([0, 1, 7], dtype=np.int64),
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
OrderedDict(
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
),
3,
-2,
]
expected_flattened_hom_samples = [
np.array([0, 0, 1], dtype=np.int64),
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float32),
np.array([1.0, 3.0, 5.0, 8.0], dtype=np.float16),
np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
np.array([0, 0, 1, 0, 0, 0.5, 3.5], dtype=np.float64),
np.array([0, 0, 0, 1, 0, 1, 0, 0, 1], dtype=np.int64),
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
np.array([0, 1, 0], dtype=np.int64),
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
]
non_homogenous_samples = [
GraphInstance(
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float32),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([0, 1], dtype=int),
np.array([[[1, 2], [3, 4]]], dtype=np.float32),
np.array([[0, 1]], dtype=int),
),
GraphInstance(np.array([0, 1], dtype=int), None, np.array([[0, 1]], dtype=int)),
(0, 1, 2),
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
(
(np.array([0, 1], dtype=np.float32), np.array([2, 3], dtype=np.float32)),
(np.array([4, 5], dtype=np.float32), np.array([6, 7], dtype=np.float32)),
),
OrderedDict(
[("a", (0, 1, 2)), ("b", np.array([[0, 1], [2, 3]], dtype=np.float32))]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[0, 1], [2, 3]], dtype=np.float32)),
]
),
((0, 1, 2), np.array([[0, 1], [2, 3]], dtype=np.float32)),
(
GraphInstance(
np.array([1, 2], dtype=np.int),
np.array(
[
0,
],
dtype=int,
),
np.array([[0, 1]], dtype=int),
),
np.array([[0, 1], [2, 3]], dtype=np.float32),
),
(
GraphInstance(
nodes=np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32),
edges=np.array([0], dtype=int),
edge_links=np.array([[0, 1]]),
),
GraphInstance(
nodes=np.array(
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]], dtype=np.float32
),
edges=np.array([1], dtype=int),
edge_links=np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([[0, 1], [2, 3]], dtype=np.float32),
np.array([[4, 5], [6, 7]], dtype=np.float32),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=np.float32,
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([[8, 9], [10, 11]], dtype=np.float32)),
]
),
),
(
"b",
(
np.array([12, 13], dtype=np.float32),
np.array([14, 15], dtype=np.float32),
),
),
]
),
]
expected_flattened_non_hom_samples = [
GraphInstance(
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32),
np.array([[1, 0, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
np.array([[1, 2, 3, 4]], dtype=np.float32),
np.array([[0, 1]], dtype=int),
),
GraphInstance(
np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], dtype=int),
None,
np.array([[0, 1]], dtype=int),
),
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
OrderedDict(
[
(
"a",
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
OrderedDict(
[
(
"a",
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=int),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([0, 1, 2, 3], dtype=np.float32)),
]
),
(
(
np.array([1, 0, 0, 0], dtype=int),
np.array([0, 1, 0, 0], dtype=int),
np.array([0, 0, 1, 0], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32),
np.array([[1, 0, 0, 0]], dtype=int),
np.array([[0, 1]], dtype=int),
),
np.array([0, 1, 2, 3], dtype=np.float32),
),
(
GraphInstance(
np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.float32),
np.array([[1, 0, 0, 0]]),
np.array([[0, 1]]),
),
GraphInstance(
np.array([[8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32),
np.array([[0, 1, 0, 0]]),
np.array([[0, 1]]),
),
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
(
np.array([0, 1, 2, 3], dtype=np.float32),
np.array([4, 5, 6, 7], dtype=np.float32),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
OrderedDict(
[
(
"a",
OrderedDict(
[
(
"a",
GraphInstance(
np.array(
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32
),
None,
np.array([[0, 1]], dtype=int),
),
),
("b", np.array([8, 9, 10, 11], dtype=np.float32)),
]
),
),
("b", (np.array([12, 13, 14, 15], dtype=np.float32))),
]
),
]
@pytest.mark.parametrize(
["space", "sample", "expected_flattened_sample"],
zip(
homogeneous_spaces + non_homogenous_spaces,
homogeneous_samples + non_homogenous_samples,
expected_flattened_hom_samples + expected_flattened_non_hom_samples,
),
)
def test_flatten(space, sample, expected_flattened_sample):
flattened_sample = utils.flatten(space, sample)
flat_space = utils.flatten_space(space)
assert sample in space
assert flattened_sample in flat_space
if space.is_np_flattenable:
assert isinstance(flattened_sample, np.ndarray)
assert flattened_sample.shape == expected_flattened_sample.shape
assert flattened_sample.dtype == expected_flattened_sample.dtype
assert np.all(flattened_sample == expected_flattened_sample)
else:
assert not isinstance(flattened_sample, np.ndarray)
assert compare_nested(flattened_sample, expected_flattened_sample)
@pytest.mark.parametrize(
["space", "flattened_sample", "expected_sample"],
zip(homogeneous_spaces, expected_flattened_hom_samples, homogeneous_samples),
)
def test_unflatten(space, flattened_sample, expected_sample):
sample = utils.unflatten(space, flattened_sample)
assert compare_nested(sample, expected_sample)
expected_flattened_spaces = [
Box(low=0, high=1, shape=(3,), dtype=np.int64),
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float32),
Box(low=0.0, high=np.inf, shape=(4,), dtype=np.float16),
Box(low=0, high=1, shape=(15,), dtype=np.int64),
Box(
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64,
),
Box(low=0, high=1, shape=(9,), dtype=np.int64),
Box(low=0, high=1, shape=(14,), dtype=np.int64),
Box(low=0, high=1, shape=(10,), dtype=np.int8),
Box(
low=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64),
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64,
),
Box(low=0, high=1, shape=(3,), dtype=np.int64),
Box(low=0, high=1, shape=(8,), dtype=np.int64),
]
@pytest.mark.parametrize(
["space", "expected_flattened_space"],
zip(homogeneous_spaces, expected_flattened_spaces),
)
def test_flatten_space(space, expected_flattened_space):
flattened_space = utils.flatten_space(space)
assert flattened_space == expected_flattened_space