import re import numpy as np import pytest from gym.spaces import Discrete, Graph, GraphInstance def test_node_space_sample(): space = Graph(node_space=Discrete(3), edge_space=None) sample = space.sample( mask=(tuple(np.array([0, 1, 0], dtype=np.int8) for _ in range(5)), None), num_nodes=5, ) assert sample in space assert np.all(sample.nodes == 1) sample = space.sample( ( (np.array([1, 0, 0], dtype=np.int8), np.array([0, 1, 0], dtype=np.int8)), None, ), num_nodes=2, ) assert sample in space assert np.all(sample.nodes == np.array([0, 1])) with pytest.warns( UserWarning, match=re.escape("The number of edges is set (5) but the edge space is None."), ): sample = space.sample(num_edges=5) assert sample in space # Change the node_space or edge_space to a non-Box or discrete space. # This should not happen, test is primarily to increase coverage. with pytest.raises( TypeError, match=re.escape( "Expects base space to be Box and Discrete, actual space: " ), ): space.node_space = "abc" space.sample() def test_edge_space_sample(): space = Graph(node_space=Discrete(3), edge_space=Discrete(3)) # When num_nodes>1 then num_edges is set to 0 assert space.sample(num_nodes=1).edges is None assert 0 <= len(space.sample(num_edges=3).edges) < 6 sample = space.sample(mask=(None, np.array([0, 1, 0], dtype=np.int8))) assert np.all(sample.edges == 1) sample = space.sample( mask=( None, ( np.array([1, 0, 0], dtype=np.int8), np.array([0, 1, 0], dtype=np.int8), np.array([0, 0, 1], dtype=np.int8), ), ), num_edges=3, ) assert np.all(sample.edges == np.array([0, 1, 2])) with pytest.raises( AssertionError, match="Expects the number of edges to be greater than 0, actual value: -1", ): space.sample(num_edges=-1) space = Graph(node_space=Discrete(3), edge_space=None) with pytest.warns( UserWarning, match=re.escape( "\x1b[33mWARN: The number of edges is set (5) but the edge space is None.\x1b[0m" ), ): sample = space.sample(num_edges=5) assert sample.edges is None @pytest.mark.parametrize( "sample", [ "abc", GraphInstance( nodes=None, edges=np.array([0, 1]), edge_links=np.array([[0, 1], [1, 0]]) ), GraphInstance( nodes=np.array([10, 1, 0]), edges=np.array([0, 1]), edge_links=np.array([[0, 1], [1, 0]]), ), GraphInstance( nodes=np.array([0, 1]), edges=None, edge_links=np.array([[0, 1], [1, 0]]) ), GraphInstance(nodes=np.array([0, 1]), edges=np.array([0, 1]), edge_links=None), GraphInstance( nodes=np.array([1, 2]), edges=np.array([10, 1]), edge_links=np.array([[0, 1], [1, 0]]), ), GraphInstance( nodes=np.array([1, 2]), edges=np.array([0, 1]), edge_links=np.array([[0.5, 1.0], [2.0, 1.0]]), ), GraphInstance( nodes=np.array([1, 2]), edges=np.array([10, 1]), edge_links=np.array([0, 1]) ), GraphInstance( nodes=np.array([1, 2]), edges=np.array([0, 1]), edge_links=np.array([[[0], [1]], [[0], [0]]]), ), GraphInstance( nodes=np.array([1, 2]), edges=np.array([0, 1]), edge_links=np.array([[10, 1], [0, 0]]), ), GraphInstance( nodes=np.array([1, 2]), edges=np.array([0, 1]), edge_links=np.array([[-10, 1], [0, 0]]), ), ], ) def test_not_contains(sample): space = Graph(node_space=Discrete(2), edge_space=Discrete(2)) assert sample not in space