mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 15:11:51 +00:00
Improve observation space of FlattenObservation wrapper (#1884)
* Add tests for gym.spaces.utils. * Add docstrings to gym.spaces.utils. * Remove some trailing whitespace. * Add gym.spaces.utils.flatten_space. The new function also is reexported as gym.spaces.flatten_space. It improves the determination of observation_space in gym.wrappers.FlattenObservation. * Produce OrderedDict instead of dict in gym.spaces.unflatten(). `gym.spaces.Dict` is very particular about producing its samples as `OrderedDict` in order preserve the order of its items. Hence, `unflatten()` should reproduce this behavior. * In test_utils.compare_nested, also verify order of OrderedDict items. * Add examples to flatten_space() docstring. * Document ``flatten(space, space.sample()) in flatten_space(space)``. Co-authored-by: Nico Madysa <nico.madysa@tu-dresden.de>
This commit is contained in:
@@ -7,7 +7,8 @@ from gym.spaces.tuple import Tuple
|
|||||||
from gym.spaces.dict import Dict
|
from gym.spaces.dict import Dict
|
||||||
|
|
||||||
from gym.spaces.utils import flatdim
|
from gym.spaces.utils import flatdim
|
||||||
|
from gym.spaces.utils import flatten_space
|
||||||
from gym.spaces.utils import flatten
|
from gym.spaces.utils import flatten
|
||||||
from gym.spaces.utils import unflatten
|
from gym.spaces.utils import unflatten
|
||||||
|
|
||||||
__all__ = ["Space", "Box", "Discrete", "MultiDiscrete", "MultiBinary", "Tuple", "Dict", "flatdim", "flatten", "unflatten"]
|
__all__ = ["Space", "Box", "Discrete", "MultiDiscrete", "MultiBinary", "Tuple", "Dict", "flatdim", "flatten_space", "flatten", "unflatten"]
|
||||||
|
120
gym/spaces/tests/test_utils.py
Normal file
120
gym/spaces/tests/test_utils.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gym.spaces import utils
|
||||||
|
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(["space", "flatdim"], [
|
||||||
|
(Discrete(3), 3),
|
||||||
|
(Box(low=0., high=np.inf, shape=(2, 2)), 4),
|
||||||
|
(Tuple([Discrete(5), Discrete(10)]), 15),
|
||||||
|
(Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]), 7),
|
||||||
|
(Tuple((Discrete(5), Discrete(2), Discrete(2))), 9),
|
||||||
|
(MultiDiscrete([2, 2, 100]), 3),
|
||||||
|
(MultiBinary(10), 10),
|
||||||
|
(Dict({"position": Discrete(5),
|
||||||
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}), 7),
|
||||||
|
])
|
||||||
|
def test_flatdim(space, flatdim):
|
||||||
|
dim = utils.flatdim(space)
|
||||||
|
assert dim == flatdim, "Expected {} to equal {}".format(dim, flatdim)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", [
|
||||||
|
Discrete(3),
|
||||||
|
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
MultiDiscrete([2, 2, 100]),
|
||||||
|
MultiBinary(10),
|
||||||
|
Dict({"position": Discrete(5),
|
||||||
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||||
|
])
|
||||||
|
def test_flatten_space_boxes(space):
|
||||||
|
flat_space = utils.flatten_space(space)
|
||||||
|
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box)
|
||||||
|
flatdim = utils.flatdim(space)
|
||||||
|
(single_dim, ) = flat_space.shape
|
||||||
|
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", [
|
||||||
|
Discrete(3),
|
||||||
|
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
MultiDiscrete([2, 2, 100]),
|
||||||
|
MultiBinary(10),
|
||||||
|
Dict({"position": Discrete(5),
|
||||||
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||||
|
])
|
||||||
|
def test_flat_space_contains_flat_points(space):
|
||||||
|
some_samples = [space.sample() for _ in range(10)]
|
||||||
|
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||||
|
flat_space = utils.flatten_space(space)
|
||||||
|
for i, flat_sample in enumerate(flattened_samples):
|
||||||
|
assert flat_sample in flat_space,\
|
||||||
|
'Expected sample #{} {} to be in {}'.format(i, flat_sample, flat_space)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", [
|
||||||
|
Discrete(3),
|
||||||
|
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
MultiDiscrete([2, 2, 100]),
|
||||||
|
MultiBinary(10),
|
||||||
|
Dict({"position": Discrete(5),
|
||||||
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||||
|
])
|
||||||
|
def test_flatten_dim(space):
|
||||||
|
sample = utils.flatten(space, space.sample())
|
||||||
|
(single_dim, ) = sample.shape
|
||||||
|
flatdim = utils.flatdim(space)
|
||||||
|
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("space", [
|
||||||
|
Discrete(3),
|
||||||
|
Box(low=0., high=np.inf, shape=(2, 2)),
|
||||||
|
Tuple([Discrete(5), Discrete(10)]),
|
||||||
|
Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]),
|
||||||
|
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||||
|
MultiDiscrete([2, 2, 100]),
|
||||||
|
MultiBinary(10),
|
||||||
|
Dict({"position": Discrete(5),
|
||||||
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)}),
|
||||||
|
])
|
||||||
|
def test_flatten_roundtripping(space):
|
||||||
|
some_samples = [space.sample() for _ in range(10)]
|
||||||
|
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||||
|
roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples]
|
||||||
|
for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)):
|
||||||
|
assert compare_nested(original, roundtripped), \
|
||||||
|
'Expected sample #{} {} to equal {}'.format(i, original, roundtripped)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_nested(left, right):
|
||||||
|
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
|
||||||
|
return np.allclose(left, right)
|
||||||
|
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
||||||
|
res = len(left) == len(right)
|
||||||
|
for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()):
|
||||||
|
if not res:
|
||||||
|
return False
|
||||||
|
res = left_key == right_key and compare_nested(left_value, right_value)
|
||||||
|
return res
|
||||||
|
elif isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
|
||||||
|
res = len(left) == len(right)
|
||||||
|
for (x, y) in zip(left, right):
|
||||||
|
if not res:
|
||||||
|
return False
|
||||||
|
res = compare_nested(x, y)
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return left == right
|
@@ -1,3 +1,4 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
@@ -9,6 +10,12 @@ from gym.spaces import Dict
|
|||||||
|
|
||||||
|
|
||||||
def flatdim(space):
|
def flatdim(space):
|
||||||
|
"""Return the number of dimensions a flattened equivalent of this space
|
||||||
|
would have.
|
||||||
|
|
||||||
|
Accepts a space and returns an integer. Raises ``NotImplementedError`` if
|
||||||
|
the space is not defined in ``gym.spaces``.
|
||||||
|
"""
|
||||||
if isinstance(space, Box):
|
if isinstance(space, Box):
|
||||||
return int(np.prod(space.shape))
|
return int(np.prod(space.shape))
|
||||||
elif isinstance(space, Discrete):
|
elif isinstance(space, Discrete):
|
||||||
@@ -26,6 +33,15 @@ def flatdim(space):
|
|||||||
|
|
||||||
|
|
||||||
def flatten(space, x):
|
def flatten(space, x):
|
||||||
|
"""Flatten a data point from a space.
|
||||||
|
|
||||||
|
This is useful when e.g. points from spaces must be passed to a neural
|
||||||
|
network, which only understands flat arrays of floats.
|
||||||
|
|
||||||
|
Accepts a space and a point from that space. Always returns a 1D array.
|
||||||
|
Raises ``NotImplementedError`` if the space is not defined in
|
||||||
|
``gym.spaces``.
|
||||||
|
"""
|
||||||
if isinstance(space, Box):
|
if isinstance(space, Box):
|
||||||
return np.asarray(x, dtype=np.float32).flatten()
|
return np.asarray(x, dtype=np.float32).flatten()
|
||||||
elif isinstance(space, Discrete):
|
elif isinstance(space, Discrete):
|
||||||
@@ -33,9 +49,11 @@ def flatten(space, x):
|
|||||||
onehot[x] = 1.0
|
onehot[x] = 1.0
|
||||||
return onehot
|
return onehot
|
||||||
elif isinstance(space, Tuple):
|
elif isinstance(space, Tuple):
|
||||||
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
return np.concatenate(
|
||||||
|
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
||||||
elif isinstance(space, Dict):
|
elif isinstance(space, Dict):
|
||||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
return np.concatenate(
|
||||||
|
[flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||||
elif isinstance(space, MultiBinary):
|
elif isinstance(space, MultiBinary):
|
||||||
return np.asarray(x).flatten()
|
return np.asarray(x).flatten()
|
||||||
elif isinstance(space, MultiDiscrete):
|
elif isinstance(space, MultiDiscrete):
|
||||||
@@ -45,6 +63,15 @@ def flatten(space, x):
|
|||||||
|
|
||||||
|
|
||||||
def unflatten(space, x):
|
def unflatten(space, x):
|
||||||
|
"""Unflatten a data point from a space.
|
||||||
|
|
||||||
|
This reverses the transformation applied by ``flatten()``. You must ensure
|
||||||
|
that the ``space`` argument is the same as for the ``flatten()`` call.
|
||||||
|
|
||||||
|
Accepts a space and a flattened point. Returns a point with a structure
|
||||||
|
that matches the space. Raises ``NotImplementedError`` if the space is not
|
||||||
|
defined in ``gym.spaces``.
|
||||||
|
"""
|
||||||
if isinstance(space, Box):
|
if isinstance(space, Box):
|
||||||
return np.asarray(x, dtype=np.float32).reshape(space.shape)
|
return np.asarray(x, dtype=np.float32).reshape(space.shape)
|
||||||
elif isinstance(space, Discrete):
|
elif isinstance(space, Discrete):
|
||||||
@@ -52,18 +79,87 @@ def unflatten(space, x):
|
|||||||
elif isinstance(space, Tuple):
|
elif isinstance(space, Tuple):
|
||||||
dims = [flatdim(s) for s in space.spaces]
|
dims = [flatdim(s) for s in space.spaces]
|
||||||
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
||||||
list_unflattened = [unflatten(s, flattened)
|
list_unflattened = [
|
||||||
for flattened, s in zip(list_flattened, space.spaces)]
|
unflatten(s, flattened)
|
||||||
|
for flattened, s in zip(list_flattened, space.spaces)
|
||||||
|
]
|
||||||
return tuple(list_unflattened)
|
return tuple(list_unflattened)
|
||||||
elif isinstance(space, Dict):
|
elif isinstance(space, Dict):
|
||||||
dims = [flatdim(s) for s in space.spaces.values()]
|
dims = [flatdim(s) for s in space.spaces.values()]
|
||||||
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
||||||
list_unflattened = [(key, unflatten(s, flattened))
|
list_unflattened = [
|
||||||
for flattened, (key, s) in zip(list_flattened, space.spaces.items())]
|
(key, unflatten(s, flattened))
|
||||||
return dict(list_unflattened)
|
for flattened, (key,
|
||||||
|
s) in zip(list_flattened, space.spaces.items())
|
||||||
|
]
|
||||||
|
return OrderedDict(list_unflattened)
|
||||||
elif isinstance(space, MultiBinary):
|
elif isinstance(space, MultiBinary):
|
||||||
return np.asarray(x).reshape(space.shape)
|
return np.asarray(x).reshape(space.shape)
|
||||||
elif isinstance(space, MultiDiscrete):
|
elif isinstance(space, MultiDiscrete):
|
||||||
return np.asarray(x).reshape(space.shape)
|
return np.asarray(x).reshape(space.shape)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_space(space):
|
||||||
|
"""Flatten a space into a single ``Box``.
|
||||||
|
|
||||||
|
This is equivalent to ``flatten()``, but operates on the space itself. The
|
||||||
|
result always is a `Box` with flat boundaries. The box has exactly
|
||||||
|
``flatdim(space)`` dimensions. Flattening a sample of the original space
|
||||||
|
has the same effect as taking a sample of the flattenend space.
|
||||||
|
|
||||||
|
Raises ``NotImplementedError`` if the space is not defined in
|
||||||
|
``gym.spaces``.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
||||||
|
>>> box
|
||||||
|
Box(3, 4, 5)
|
||||||
|
>>> flatten_space(box)
|
||||||
|
Box(60,)
|
||||||
|
>>> flatten(box, box.sample()) in flatten_space(box)
|
||||||
|
True
|
||||||
|
|
||||||
|
Example that flattens a discrete space::
|
||||||
|
|
||||||
|
>>> discrete = Discrete(5)
|
||||||
|
>>> flatten_space(discrete)
|
||||||
|
Box(5,)
|
||||||
|
>>> flatten(box, box.sample()) in flatten_space(box)
|
||||||
|
True
|
||||||
|
|
||||||
|
Example that recursively flattens a dict::
|
||||||
|
|
||||||
|
>>> space = Dict({"position": Discrete(2),
|
||||||
|
... "velocity": Box(0, 1, shape=(2, 2))})
|
||||||
|
>>> flatten_space(space)
|
||||||
|
Box(6,)
|
||||||
|
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
if isinstance(space, Box):
|
||||||
|
return Box(space.low.flatten(), space.high.flatten())
|
||||||
|
if isinstance(space, Discrete):
|
||||||
|
return Box(low=0, high=1, shape=(space.n, ))
|
||||||
|
if isinstance(space, Tuple):
|
||||||
|
space = [flatten_space(s) for s in space.spaces]
|
||||||
|
return Box(
|
||||||
|
low=np.concatenate([s.low for s in space]),
|
||||||
|
high=np.concatenate([s.high for s in space]),
|
||||||
|
)
|
||||||
|
if isinstance(space, Dict):
|
||||||
|
space = [flatten_space(s) for s in space.spaces.values()]
|
||||||
|
return Box(
|
||||||
|
low=np.concatenate([s.low for s in space]),
|
||||||
|
high=np.concatenate([s.high for s in space]),
|
||||||
|
)
|
||||||
|
if isinstance(space, MultiBinary):
|
||||||
|
return Box(low=0, high=1, shape=(space.n, ))
|
||||||
|
if isinstance(space, MultiDiscrete):
|
||||||
|
return Box(
|
||||||
|
low=np.zeros_like(space.nvec),
|
||||||
|
high=space.nvec,
|
||||||
|
)
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import numpy as np
|
|
||||||
import gym.spaces as spaces
|
import gym.spaces as spaces
|
||||||
from gym import ObservationWrapper
|
from gym import ObservationWrapper
|
||||||
|
|
||||||
@@ -7,9 +6,7 @@ class FlattenObservation(ObservationWrapper):
|
|||||||
r"""Observation wrapper that flattens the observation."""
|
r"""Observation wrapper that flattens the observation."""
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super(FlattenObservation, self).__init__(env)
|
super(FlattenObservation, self).__init__(env)
|
||||||
|
self.observation_space = spaces.flatten_space(env.observation_space)
|
||||||
flatdim = spaces.flatdim(env.observation_space)
|
|
||||||
self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)
|
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
return spaces.flatten(self.env.observation_space, observation)
|
return spaces.flatten(self.env.observation_space, observation)
|
||||||
|
Reference in New Issue
Block a user