import json # note: ujson fails this test due to float equality from copy import copy import numpy as np import pytest from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict @pytest.mark.parametrize("space", [ Discrete(3), Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(5), Box(low=np.array([0, 0]),high=np.array([1, 5]))]), Tuple((Discrete(5), Discrete(2), Discrete(2))), MultiDiscrete([2, 2, 100]), Dict({"position": Discrete(5), "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]))}), ]) def test_roundtripping(space): sample_1 = space.sample() sample_2 = space.sample() assert space.contains(sample_1) assert space.contains(sample_2) json_rep = space.to_jsonable([sample_1, sample_2]) json_roundtripped = json.loads(json.dumps(json_rep)) samples_after_roundtrip = space.from_jsonable(json_roundtripped) sample_1_prime, sample_2_prime = samples_after_roundtrip s1 = space.to_jsonable([sample_1]) s1p = space.to_jsonable([sample_1_prime]) s2 = space.to_jsonable([sample_2]) s2p = space.to_jsonable([sample_2_prime]) assert s1 == s1p, "Expected {} to equal {}".format(s1, s1p) assert s2 == s2p, "Expected {} to equal {}".format(s2, s2p) @pytest.mark.parametrize("space", [ Discrete(3), Box(low=np.array([-10, 0]),high=np.array([10, 10])), Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(5), Box(low=np.array([0, 0]),high=np.array([1, 5]))]), Tuple((Discrete(5), Discrete(2), Discrete(2))), MultiDiscrete([2, 2, 100]), MultiBinary(6), Dict({"position": Discrete(5), "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]))}), ]) def test_equality(space): space1 = space space2 = copy(space) assert space1 == space2, "Expected {} to equal {}".format(space1, space2) @pytest.mark.parametrize("spaces", [ (Discrete(3), Discrete(4)), (MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])), (MultiBinary(8), MultiBinary(7)), (Box(low=np.array([-10, 0]),high=np.array([10, 10])), Box(low=np.array([-10, 0]),high=np.array([10, 9]))), (Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])), (Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})), (Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})), ]) def test_inequality(spaces): space1, space2 = spaces assert space1 != space2, "Expected {} != {}".format(space1, space2)