mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Improve observation space of FlattenObservation wrapper (#1884)
* Add tests for gym.spaces.utils. * Add docstrings to gym.spaces.utils. * Remove some trailing whitespace. * Add gym.spaces.utils.flatten_space. The new function also is reexported as gym.spaces.flatten_space. It improves the determination of observation_space in gym.wrappers.FlattenObservation. * Produce OrderedDict instead of dict in gym.spaces.unflatten(). `gym.spaces.Dict` is very particular about producing its samples as `OrderedDict` in order preserve the order of its items. Hence, `unflatten()` should reproduce this behavior. * In test_utils.compare_nested, also verify order of OrderedDict items. * Add examples to flatten_space() docstring. * Document ``flatten(space, space.sample()) in flatten_space(space)``. Co-authored-by: Nico Madysa <nico.madysa@tu-dresden.de>
This commit is contained in:
120
gym/spaces/tests/test_utils.py
Normal file
120
gym/spaces/tests/test_utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import utils
|
||||
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], [
|
||||
(Discrete(3), 3),
|
||||
(Box(low=0., high=np.inf, shape=(2, 2)), 4),
|
||||
(Tuple([Discrete(5), Discrete(10)]), 15),
|
||||
(Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]), 7),
|
||||
(Tuple((Discrete(5), Discrete(2), Discrete(2))), 9),
|
||||
(MultiDiscrete([2, 2, 100]), 3),
|
||||
(MultiBinary(10), 10),
|
||||
(Dict({"position": Discrete(5),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}), 7),
|
||||
])
|
||||
def test_flatdim(space, flatdim):
|
||||
dim = utils.flatdim(space)
|
||||
assert dim == flatdim, "Expected {} to equal {}".format(dim, flatdim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [
|
||||
Discrete(3),
|
||||
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict({"position": Discrete(5),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||
])
|
||||
def test_flatten_space_boxes(space):
|
||||
flat_space = utils.flatten_space(space)
|
||||
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box)
|
||||
flatdim = utils.flatdim(space)
|
||||
(single_dim, ) = flat_space.shape
|
||||
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [
|
||||
Discrete(3),
|
||||
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict({"position": Discrete(5),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||
])
|
||||
def test_flat_space_contains_flat_points(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
flat_space = utils.flatten_space(space)
|
||||
for i, flat_sample in enumerate(flattened_samples):
|
||||
assert flat_sample in flat_space,\
|
||||
'Expected sample #{} {} to be in {}'.format(i, flat_sample, flat_space)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [
|
||||
Discrete(3),
|
||||
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict({"position": Discrete(5),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||
])
|
||||
def test_flatten_dim(space):
|
||||
sample = utils.flatten(space, space.sample())
|
||||
(single_dim, ) = sample.shape
|
||||
flatdim = utils.flatdim(space)
|
||||
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [
|
||||
Discrete(3),
|
||||
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([2, 2, 100]),
|
||||
MultiBinary(10),
|
||||
Dict({"position": Discrete(5),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||
])
|
||||
def test_flatten_roundtripping(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples]
|
||||
for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)):
|
||||
assert compare_nested(original, roundtripped), \
|
||||
'Expected sample #{} {} to equal {}'.format(i, original, roundtripped)
|
||||
|
||||
|
||||
def compare_nested(left, right):
|
||||
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
|
||||
return np.allclose(left, right)
|
||||
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
||||
res = len(left) == len(right)
|
||||
for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()):
|
||||
if not res:
|
||||
return False
|
||||
res = left_key == right_key and compare_nested(left_value, right_value)
|
||||
return res
|
||||
elif isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
|
||||
res = len(left) == len(right)
|
||||
for (x, y) in zip(left, right):
|
||||
if not res:
|
||||
return False
|
||||
res = compare_nested(x, y)
|
||||
return res
|
||||
else:
|
||||
return left == right
|
Reference in New Issue
Block a user