mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 23:12:46 +00:00
fix for issue 1256 (Box(low=0, high=255, dtype='uint8').sample() returned zeros) (#1307)
This commit is contained in:
@@ -41,7 +41,8 @@ class Box(Space):
|
|||||||
self.np_random.seed(seed)
|
self.np_random.seed(seed)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return self.np_random.uniform(low=self.low, high=self.high + (0 if self.dtype.kind == 'f' else 1), size=self.low.shape).astype(self.dtype)
|
high = self.high if self.dtype.kind == 'f' else self.high.astype('int64') + 1
|
||||||
|
return self.np_random.uniform(low=self.low, high=high, size=self.low.shape).astype(self.dtype)
|
||||||
|
|
||||||
def contains(self, x):
|
def contains(self, x):
|
||||||
return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all()
|
return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all()
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
import json # note: ujson fails this test due to float equality
|
import json # note: ujson fails this test due to float equality
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -8,14 +8,14 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("space", [
|
@pytest.mark.parametrize("space", [
|
||||||
Discrete(3),
|
Discrete(3),
|
||||||
Tuple([Discrete(5), Discrete(10)]),
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
MultiDiscrete([2, 2, 100]),
|
MultiDiscrete([2, 2, 100]),
|
||||||
Dict({"position": Discrete(5),
|
Dict({"position": Discrete(5),
|
||||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||||
])
|
])
|
||||||
def test_roundtripping(space):
|
def test_roundtripping(space):
|
||||||
sample_1 = space.sample()
|
sample_1 = space.sample()
|
||||||
sample_2 = space.sample()
|
sample_2 = space.sample()
|
||||||
@@ -37,16 +37,16 @@ def test_roundtripping(space):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("space", [
|
@pytest.mark.parametrize("space", [
|
||||||
Discrete(3),
|
Discrete(3),
|
||||||
Box(low=np.array([-10, 0]),high=np.array([10, 10])),
|
Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
||||||
Tuple([Discrete(5), Discrete(10)]),
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
MultiDiscrete([2, 2, 100]),
|
MultiDiscrete([2, 2, 100]),
|
||||||
MultiBinary(6),
|
MultiBinary(6),
|
||||||
Dict({"position": Discrete(5),
|
Dict({"position": Discrete(5),
|
||||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||||
])
|
])
|
||||||
def test_equality(space):
|
def test_equality(space):
|
||||||
space1 = space
|
space1 = space
|
||||||
space2 = copy(space)
|
space2 = copy(space)
|
||||||
@@ -54,15 +54,32 @@ def test_equality(space):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("spaces", [
|
@pytest.mark.parametrize("spaces", [
|
||||||
(Discrete(3), Discrete(4)),
|
(Discrete(3), Discrete(4)),
|
||||||
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
(MultiDiscrete([2, 2, 100]), MultiDiscrete([2, 2, 8])),
|
||||||
(MultiBinary(8), MultiBinary(7)),
|
(MultiBinary(8), MultiBinary(7)),
|
||||||
(Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
(Box(low=np.array([-10, 0]), high=np.array([10, 10]), dtype=np.float32),
|
||||||
Box(low=np.array([-10, 0]), high=np.array([10, 9]), dtype=np.float32)),
|
Box(low=np.array([-10, 0]), high=np.array([10, 9]), dtype=np.float32)),
|
||||||
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
(Tuple([Discrete(5), Discrete(10)]), Tuple([Discrete(1), Discrete(10)])),
|
||||||
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
(Dict({"position": Discrete(5)}), Dict({"position": Discrete(4)})),
|
||||||
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
(Dict({"position": Discrete(5)}), Dict({"speed": Discrete(5)})),
|
||||||
])
|
])
|
||||||
def test_inequality(spaces):
|
def test_inequality(spaces):
|
||||||
space1, space2 = spaces
|
space1, space2 = spaces
|
||||||
assert space1 != space2, "Expected {} != {}".format(space1, space2)
|
assert space1 != space2, "Expected {} != {}".format(space1, space2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", [
|
||||||
|
Discrete(5),
|
||||||
|
Box(low=0, high=255, shape=(2,), dtype='uint8'),
|
||||||
|
])
|
||||||
|
def test_sample(space):
|
||||||
|
space.seed(0)
|
||||||
|
n_trials = 100
|
||||||
|
samples = np.array([space.sample() for _ in range(n_trials)])
|
||||||
|
if isinstance(space, Box):
|
||||||
|
expected_mean = (space.high + space.low) / 2
|
||||||
|
elif isinstance(space, Discrete):
|
||||||
|
expected_mean = space.n / 2
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
np.testing.assert_allclose(expected_mean, samples.mean(), atol=3.0 * samples.std())
|
||||||
|
Reference in New Issue
Block a user