2020-05-08 23:19:55 +02:00
|
|
|
from collections import OrderedDict
|
2019-03-24 19:29:43 +01:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from gym.spaces import Box
|
|
|
|
from gym.spaces import Discrete
|
|
|
|
from gym.spaces import MultiDiscrete
|
|
|
|
from gym.spaces import MultiBinary
|
|
|
|
from gym.spaces import Tuple
|
|
|
|
from gym.spaces import Dict
|
|
|
|
|
|
|
|
|
|
|
|
def flatdim(space):
|
2020-05-08 23:19:55 +02:00
|
|
|
"""Return the number of dimensions a flattened equivalent of this space
|
|
|
|
would have.
|
|
|
|
|
|
|
|
Accepts a space and returns an integer. Raises ``NotImplementedError`` if
|
|
|
|
the space is not defined in ``gym.spaces``.
|
|
|
|
"""
|
2019-03-24 19:29:43 +01:00
|
|
|
if isinstance(space, Box):
|
|
|
|
return int(np.prod(space.shape))
|
|
|
|
elif isinstance(space, Discrete):
|
|
|
|
return int(space.n)
|
|
|
|
elif isinstance(space, Tuple):
|
|
|
|
return int(sum([flatdim(s) for s in space.spaces]))
|
|
|
|
elif isinstance(space, Dict):
|
|
|
|
return int(sum([flatdim(s) for s in space.spaces.values()]))
|
|
|
|
elif isinstance(space, MultiBinary):
|
|
|
|
return int(space.n)
|
|
|
|
elif isinstance(space, MultiDiscrete):
|
|
|
|
return int(np.prod(space.shape))
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def flatten(space, x):
|
2020-05-08 23:19:55 +02:00
|
|
|
"""Flatten a data point from a space.
|
|
|
|
|
|
|
|
This is useful when e.g. points from spaces must be passed to a neural
|
|
|
|
network, which only understands flat arrays of floats.
|
|
|
|
|
|
|
|
Accepts a space and a point from that space. Always returns a 1D array.
|
|
|
|
Raises ``NotImplementedError`` if the space is not defined in
|
|
|
|
``gym.spaces``.
|
|
|
|
"""
|
2019-03-24 19:29:43 +01:00
|
|
|
if isinstance(space, Box):
|
2020-11-06 15:06:29 -05:00
|
|
|
return np.asarray(x, dtype=space.dtype).flatten()
|
2019-03-24 19:29:43 +01:00
|
|
|
elif isinstance(space, Discrete):
|
2020-11-06 15:06:29 -05:00
|
|
|
onehot = np.zeros(space.n, dtype=space.dtype)
|
|
|
|
onehot[x] = 1
|
2019-03-24 19:29:43 +01:00
|
|
|
return onehot
|
|
|
|
elif isinstance(space, Tuple):
|
2021-07-29 12:42:48 -04:00
|
|
|
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
2019-03-24 19:29:43 +01:00
|
|
|
elif isinstance(space, Dict):
|
2021-07-29 02:26:34 +02:00
|
|
|
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
2019-03-24 19:29:43 +01:00
|
|
|
elif isinstance(space, MultiBinary):
|
2020-11-06 15:06:29 -05:00
|
|
|
return np.asarray(x, dtype=space.dtype).flatten()
|
2019-03-24 19:29:43 +01:00
|
|
|
elif isinstance(space, MultiDiscrete):
|
2020-11-06 15:06:29 -05:00
|
|
|
return np.asarray(x, dtype=space.dtype).flatten()
|
2019-03-24 19:29:43 +01:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def unflatten(space, x):
|
2020-05-08 23:19:55 +02:00
|
|
|
"""Unflatten a data point from a space.
|
|
|
|
|
|
|
|
This reverses the transformation applied by ``flatten()``. You must ensure
|
|
|
|
that the ``space`` argument is the same as for the ``flatten()`` call.
|
|
|
|
|
|
|
|
Accepts a space and a flattened point. Returns a point with a structure
|
|
|
|
that matches the space. Raises ``NotImplementedError`` if the space is not
|
|
|
|
defined in ``gym.spaces``.
|
|
|
|
"""
|
2019-03-24 19:29:43 +01:00
|
|
|
if isinstance(space, Box):
|
2020-11-06 15:06:29 -05:00
|
|
|
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
2019-03-24 19:29:43 +01:00
|
|
|
elif isinstance(space, Discrete):
|
|
|
|
return int(np.nonzero(x)[0][0])
|
|
|
|
elif isinstance(space, Tuple):
|
|
|
|
dims = [flatdim(s) for s in space.spaces]
|
|
|
|
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
2021-07-29 12:42:48 -04:00
|
|
|
list_unflattened = [unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)]
|
2019-03-24 19:29:43 +01:00
|
|
|
return tuple(list_unflattened)
|
|
|
|
elif isinstance(space, Dict):
|
|
|
|
dims = [flatdim(s) for s in space.spaces.values()]
|
|
|
|
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
2020-05-08 23:19:55 +02:00
|
|
|
list_unflattened = [
|
2021-07-29 12:42:48 -04:00
|
|
|
(key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items())
|
2020-05-08 23:19:55 +02:00
|
|
|
]
|
|
|
|
return OrderedDict(list_unflattened)
|
2019-03-24 19:29:43 +01:00
|
|
|
elif isinstance(space, MultiBinary):
|
2020-11-06 15:06:29 -05:00
|
|
|
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
2019-03-24 19:29:43 +01:00
|
|
|
elif isinstance(space, MultiDiscrete):
|
2020-11-06 15:06:29 -05:00
|
|
|
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
2019-03-24 19:29:43 +01:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2020-05-08 23:19:55 +02:00
|
|
|
|
|
|
|
|
|
|
|
def flatten_space(space):
|
|
|
|
"""Flatten a space into a single ``Box``.
|
|
|
|
|
|
|
|
This is equivalent to ``flatten()``, but operates on the space itself. The
|
|
|
|
result always is a `Box` with flat boundaries. The box has exactly
|
|
|
|
``flatdim(space)`` dimensions. Flattening a sample of the original space
|
|
|
|
has the same effect as taking a sample of the flattenend space.
|
|
|
|
|
|
|
|
Raises ``NotImplementedError`` if the space is not defined in
|
|
|
|
``gym.spaces``.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
|
|
|
>>> box
|
|
|
|
Box(3, 4, 5)
|
|
|
|
>>> flatten_space(box)
|
|
|
|
Box(60,)
|
|
|
|
>>> flatten(box, box.sample()) in flatten_space(box)
|
|
|
|
True
|
|
|
|
|
|
|
|
Example that flattens a discrete space::
|
|
|
|
|
|
|
|
>>> discrete = Discrete(5)
|
|
|
|
>>> flatten_space(discrete)
|
|
|
|
Box(5,)
|
|
|
|
>>> flatten(box, box.sample()) in flatten_space(box)
|
|
|
|
True
|
|
|
|
|
|
|
|
Example that recursively flattens a dict::
|
|
|
|
|
|
|
|
>>> space = Dict({"position": Discrete(2),
|
|
|
|
... "velocity": Box(0, 1, shape=(2, 2))})
|
|
|
|
>>> flatten_space(space)
|
|
|
|
Box(6,)
|
|
|
|
>>> flatten(space, space.sample()) in flatten_space(space)
|
|
|
|
True
|
|
|
|
"""
|
|
|
|
if isinstance(space, Box):
|
2020-11-06 15:06:29 -05:00
|
|
|
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
|
2020-05-08 23:19:55 +02:00
|
|
|
if isinstance(space, Discrete):
|
2021-07-29 02:26:34 +02:00
|
|
|
return Box(low=0, high=1, shape=(space.n,), dtype=space.dtype)
|
2020-05-08 23:19:55 +02:00
|
|
|
if isinstance(space, Tuple):
|
|
|
|
space = [flatten_space(s) for s in space.spaces]
|
|
|
|
return Box(
|
|
|
|
low=np.concatenate([s.low for s in space]),
|
|
|
|
high=np.concatenate([s.high for s in space]),
|
2021-07-29 02:26:34 +02:00
|
|
|
dtype=np.result_type(*[s.dtype for s in space]),
|
2020-05-08 23:19:55 +02:00
|
|
|
)
|
|
|
|
if isinstance(space, Dict):
|
|
|
|
space = [flatten_space(s) for s in space.spaces.values()]
|
|
|
|
return Box(
|
|
|
|
low=np.concatenate([s.low for s in space]),
|
|
|
|
high=np.concatenate([s.high for s in space]),
|
2021-07-29 02:26:34 +02:00
|
|
|
dtype=np.result_type(*[s.dtype for s in space]),
|
2020-05-08 23:19:55 +02:00
|
|
|
)
|
|
|
|
if isinstance(space, MultiBinary):
|
2021-07-29 02:26:34 +02:00
|
|
|
return Box(low=0, high=1, shape=(space.n,), dtype=space.dtype)
|
2020-05-08 23:19:55 +02:00
|
|
|
if isinstance(space, MultiDiscrete):
|
2021-07-29 02:26:34 +02:00
|
|
|
return Box(low=np.zeros_like(space.nvec), high=space.nvec, dtype=space.dtype)
|
2020-05-08 23:19:55 +02:00
|
|
|
raise NotImplementedError
|