Files
Gymnasium/tests/spaces/test_graph.py
pseudo-rnd-thoughts 640c509109 Rename to gymnasium
2022-09-08 10:10:07 +01:00

138 lines
4.0 KiB
Python

import re
import numpy as np
import pytest
from gymnasium.spaces import Discrete, Graph, GraphInstance
def test_node_space_sample():
space = Graph(node_space=Discrete(3), edge_space=None)
space.seed(0)
sample = space.sample(
mask=(tuple(np.array([0, 1, 0], dtype=np.int8) for _ in range(5)), None),
num_nodes=5,
)
assert sample in space
assert np.all(sample.nodes == 1)
sample = space.sample(
(
(np.array([1, 0, 0], dtype=np.int8), np.array([0, 1, 0], dtype=np.int8)),
None,
),
num_nodes=2,
)
assert sample in space
assert np.all(sample.nodes == np.array([0, 1]))
with pytest.warns(
UserWarning,
match=re.escape("The number of edges is set (5) but the edge space is None."),
):
sample = space.sample(num_edges=5)
assert sample in space
# Change the node_space or edge_space to a non-Box or discrete space.
# This should not happen, test is primarily to increase coverage.
with pytest.raises(
TypeError,
match=re.escape(
"Expects base space to be Box and Discrete, actual space: <class 'str'>"
),
):
space.node_space = "abc"
space.sample()
def test_edge_space_sample():
space = Graph(node_space=Discrete(3), edge_space=Discrete(3))
space.seed(0)
# When num_nodes>1 then num_edges is set to 0
assert space.sample(num_nodes=1).edges is None
assert 0 <= len(space.sample(num_edges=3).edges) < 6
sample = space.sample(mask=(None, np.array([0, 1, 0], dtype=np.int8)))
assert np.all(sample.edges == 1) or sample.edges is None
sample = space.sample(
mask=(
None,
(
np.array([1, 0, 0], dtype=np.int8),
np.array([0, 1, 0], dtype=np.int8),
np.array([0, 0, 1], dtype=np.int8),
),
),
num_edges=3,
)
assert np.all(sample.edges == np.array([0, 1, 2]))
with pytest.raises(
AssertionError,
match="Expects the number of edges to be greater than 0, actual value: -1",
):
space.sample(num_edges=-1)
space = Graph(node_space=Discrete(3), edge_space=None)
with pytest.warns(
UserWarning,
match=re.escape(
"\x1b[33mWARN: The number of edges is set (5) but the edge space is None.\x1b[0m"
),
):
sample = space.sample(num_edges=5)
assert sample.edges is None
@pytest.mark.parametrize(
"sample",
[
"abc",
GraphInstance(
nodes=None, edges=np.array([0, 1]), edge_links=np.array([[0, 1], [1, 0]])
),
GraphInstance(
nodes=np.array([10, 1, 0]),
edges=np.array([0, 1]),
edge_links=np.array([[0, 1], [1, 0]]),
),
GraphInstance(
nodes=np.array([0, 1]), edges=None, edge_links=np.array([[0, 1], [1, 0]])
),
GraphInstance(nodes=np.array([0, 1]), edges=np.array([0, 1]), edge_links=None),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([10, 1]),
edge_links=np.array([[0, 1], [1, 0]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[0.5, 1.0], [2.0, 1.0]]),
),
GraphInstance(
nodes=np.array([1, 2]), edges=np.array([10, 1]), edge_links=np.array([0, 1])
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[[0], [1]], [[0], [0]]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[10, 1], [0, 0]]),
),
GraphInstance(
nodes=np.array([1, 2]),
edges=np.array([0, 1]),
edge_links=np.array([[-10, 1], [0, 0]]),
),
],
)
def test_not_contains(sample):
space = Graph(node_space=Discrete(2), edge_space=Discrete(2))
assert sample not in space