fix for issue 1256 (Box(low=0, high=255, dtype='uint8').sample() returned zeros) (#1307)

This commit is contained in:
pzhokhov
2019-02-05 17:49:29 -08:00
committed by GitHub
parent 4ceff7dc09
commit 3067a0b890
2 changed files with 47 additions and 29 deletions

View File

@@ -1,4 +1,4 @@
import json # note: ujson fails this test due to float equality
import json # note: ujson fails this test due to float equality
from copy import copy
import numpy as np
@@ -8,14 +8,14 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
@pytest.mark.parametrize("space", [
Discrete(3),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 100]),
Dict({"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
])
Discrete(3),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 100]),
Dict({"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
])
def test_roundtripping(space):
sample_1 = space.sample()
sample_2 = space.sample()
@@ -37,16 +37,16 @@ def test_roundtripping(space):
@pytest.mark.parametrize("space", [
Discrete(3),
Box(low=np.array([-10, 0]),high=np.array([10, 10])),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 100]),
MultiBinary(6),
Dict({"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
])
Discrete(3),
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([2, 2, 100]),
MultiBinary(6),
Dict({"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
])
def test_equality(space):
space1 = space
space2 = copy(space)
@@ -54,15 +54,32 @@ def test_equality(space):
@pytest.mark.parametrize("spaces", [
(Discrete(3), Discrete(4)),
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
(MultiBinary(8), MultiBinary(7)),
(Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
Box(low=np.array([-10, 0]), high=np.array([10, 9]), dtype=np.float32)),
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
])
(Discrete(3), Discrete(4)),
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
(MultiBinary(8), MultiBinary(7)),
(Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
Box(low=np.array([-10, 0]), high=np.array([10, 9]), dtype=np.float32)),
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
])
def test_inequality(spaces):
space1, space2 = spaces
assert space1 != space2, "Expected {} != {}".format(space1, space2)
@pytest.mark.parametrize("space", [
Discrete(5),
Box(low=0, high=255, shape=(2,), dtype='uint8'),
])
def test_sample(space):
space.seed(0)
n_trials = 100
samples = np.array([space.sample() for _ in range(n_trials)])
if isinstance(space, Box):
expected_mean = (space.high + space.low) / 2
elif isinstance(space, Discrete):
expected_mean = space.n / 2
else:
raise NotImplementedError
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())