2022-01-24 23:22:11 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2022-03-31 12:50:38 -07:00
|
|
|
import operator as op
|
2020-05-08 23:19:55 +02:00
|
|
|
from collections import OrderedDict
|
2022-03-31 12:50:38 -07:00
|
|
|
from functools import reduce, singledispatch
|
2022-01-24 23:22:11 +01:00
|
|
|
from typing import TypeVar, Union
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2019-03-24 19:29:43 +01:00
|
|
|
import numpy as np
|
|
|
|
|
2022-03-31 12:50:38 -07:00
|
|
|
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
2019-03-24 19:29:43 +01:00
|
|
|
|
|
|
|
|
2021-08-22 00:04:09 +02:00
|
|
|
@singledispatch
|
2022-01-24 23:22:11 +01:00
|
|
|
def flatdim(space: Space) -> int:
|
2020-05-08 23:19:55 +02:00
|
|
|
"""Return the number of dimensions a flattened equivalent of this space
|
|
|
|
would have.
|
|
|
|
|
|
|
|
Accepts a space and returns an integer. Raises ``NotImplementedError`` if
|
|
|
|
the space is not defined in ``gym.spaces``.
|
|
|
|
"""
|
2021-08-22 00:04:09 +02:00
|
|
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
2019-03-24 19:29:43 +01:00
|
|
|
|
|
|
|
|
2021-08-22 00:04:09 +02:00
|
|
|
@flatdim.register(Box)
|
|
|
|
@flatdim.register(MultiBinary)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
|
2021-08-22 00:04:09 +02:00
|
|
|
return reduce(op.mul, space.shape, 1)
|
|
|
|
|
|
|
|
|
|
|
|
@flatdim.register(Discrete)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatdim_discrete(space: Discrete) -> int:
|
2021-08-22 00:04:09 +02:00
|
|
|
return int(space.n)
|
|
|
|
|
|
|
|
|
|
|
|
@flatdim.register(MultiDiscrete)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
|
2021-08-22 00:04:09 +02:00
|
|
|
return int(np.sum(space.nvec))
|
|
|
|
|
|
|
|
|
|
|
|
@flatdim.register(Tuple)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatdim_tuple(space: Tuple) -> int:
|
2021-11-14 14:50:23 +01:00
|
|
|
return sum(flatdim(s) for s in space.spaces)
|
2021-08-22 00:04:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
@flatdim.register(Dict)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatdim_dict(space: Dict) -> int:
|
2021-11-14 14:50:23 +01:00
|
|
|
return sum(flatdim(s) for s in space.spaces.values())
|
2021-08-22 00:04:09 +02:00
|
|
|
|
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
|
2021-08-22 00:04:09 +02:00
|
|
|
@singledispatch
|
2022-01-24 23:22:11 +01:00
|
|
|
def flatten(space: Space[T], x: T) -> np.ndarray:
|
2020-05-08 23:19:55 +02:00
|
|
|
"""Flatten a data point from a space.
|
|
|
|
|
|
|
|
This is useful when e.g. points from spaces must be passed to a neural
|
|
|
|
network, which only understands flat arrays of floats.
|
|
|
|
|
|
|
|
Accepts a space and a point from that space. Always returns a 1D array.
|
|
|
|
Raises ``NotImplementedError`` if the space is not defined in
|
|
|
|
``gym.spaces``.
|
|
|
|
"""
|
2021-08-22 00:04:09 +02:00
|
|
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
|
|
|
|
|
|
|
|
|
|
|
@flatten.register(Box)
|
|
|
|
@flatten.register(MultiBinary)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_box_multibinary(space, x) -> np.ndarray:
|
2021-08-22 00:04:09 +02:00
|
|
|
return np.asarray(x, dtype=space.dtype).flatten()
|
|
|
|
|
|
|
|
|
|
|
|
@flatten.register(Discrete)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_discrete(space, x) -> np.ndarray:
|
2021-08-22 00:04:09 +02:00
|
|
|
onehot = np.zeros(space.n, dtype=space.dtype)
|
2022-03-04 15:17:16 -05:00
|
|
|
onehot[x - space.start] = 1
|
2021-08-22 00:04:09 +02:00
|
|
|
return onehot
|
|
|
|
|
|
|
|
|
|
|
|
@flatten.register(MultiDiscrete)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_multidiscrete(space, x) -> np.ndarray:
|
2021-08-22 00:04:09 +02:00
|
|
|
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
|
|
|
offsets[1:] = np.cumsum(space.nvec.flatten())
|
|
|
|
|
|
|
|
onehot = np.zeros((offsets[-1],), dtype=space.dtype)
|
|
|
|
onehot[offsets[:-1] + x.flatten()] = 1
|
|
|
|
return onehot
|
|
|
|
|
|
|
|
|
|
|
|
@flatten.register(Tuple)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_tuple(space, x) -> np.ndarray:
|
2021-08-22 00:04:09 +02:00
|
|
|
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
|
|
|
|
|
|
|
|
|
|
|
@flatten.register(Dict)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_dict(space, x) -> np.ndarray:
|
2021-08-22 00:04:09 +02:00
|
|
|
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
2019-03-24 19:29:43 +01:00
|
|
|
|
|
|
|
|
2021-08-22 00:04:09 +02:00
|
|
|
@singledispatch
|
2022-01-24 23:22:11 +01:00
|
|
|
def unflatten(space: Space[T], x: np.ndarray) -> T:
|
2020-05-08 23:19:55 +02:00
|
|
|
"""Unflatten a data point from a space.
|
|
|
|
|
|
|
|
This reverses the transformation applied by ``flatten()``. You must ensure
|
|
|
|
that the ``space`` argument is the same as for the ``flatten()`` call.
|
|
|
|
|
|
|
|
Accepts a space and a flattened point. Returns a point with a structure
|
|
|
|
that matches the space. Raises ``NotImplementedError`` if the space is not
|
|
|
|
defined in ``gym.spaces``.
|
|
|
|
"""
|
2021-08-22 00:04:09 +02:00
|
|
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
|
|
|
|
|
|
|
|
|
|
|
@unflatten.register(Box)
|
|
|
|
@unflatten.register(MultiBinary)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.ndarray:
|
2021-08-22 00:04:09 +02:00
|
|
|
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
|
|
|
|
|
|
|
|
|
|
|
@unflatten.register(Discrete)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
2022-03-04 15:17:16 -05:00
|
|
|
return int(space.start + np.nonzero(x)[0][0])
|
2021-08-22 00:04:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
@unflatten.register(MultiDiscrete)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
2021-08-22 00:04:09 +02:00
|
|
|
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
|
|
|
offsets[1:] = np.cumsum(space.nvec.flatten())
|
|
|
|
|
|
|
|
(indices,) = np.nonzero(x)
|
|
|
|
return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape)
|
|
|
|
|
|
|
|
|
|
|
|
@unflatten.register(Tuple)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _unflatten_tuple(space: Tuple, x: np.ndarray) -> tuple:
|
2021-08-22 00:04:09 +02:00
|
|
|
dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
|
|
|
|
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
|
|
|
return tuple(
|
|
|
|
unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@unflatten.register(Dict)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _unflatten_dict(space: Dict, x: np.ndarray) -> dict:
|
2021-08-22 00:04:09 +02:00
|
|
|
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
|
|
|
|
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
|
|
|
return OrderedDict(
|
|
|
|
[
|
2021-07-29 15:39:42 -04:00
|
|
|
(key, unflatten(s, flattened))
|
|
|
|
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
|
2020-05-08 23:19:55 +02:00
|
|
|
]
|
2021-08-22 00:04:09 +02:00
|
|
|
)
|
2020-05-08 23:19:55 +02:00
|
|
|
|
|
|
|
|
2021-08-22 00:04:09 +02:00
|
|
|
@singledispatch
|
2022-01-24 23:22:11 +01:00
|
|
|
def flatten_space(space: Space) -> Box:
|
2020-05-08 23:19:55 +02:00
|
|
|
"""Flatten a space into a single ``Box``.
|
|
|
|
|
|
|
|
This is equivalent to ``flatten()``, but operates on the space itself. The
|
|
|
|
result always is a `Box` with flat boundaries. The box has exactly
|
|
|
|
``flatdim(space)`` dimensions. Flattening a sample of the original space
|
|
|
|
has the same effect as taking a sample of the flattenend space.
|
|
|
|
|
|
|
|
Raises ``NotImplementedError`` if the space is not defined in
|
|
|
|
``gym.spaces``.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
|
|
|
>>> box
|
|
|
|
Box(3, 4, 5)
|
|
|
|
>>> flatten_space(box)
|
|
|
|
Box(60,)
|
|
|
|
>>> flatten(box, box.sample()) in flatten_space(box)
|
|
|
|
True
|
|
|
|
|
|
|
|
Example that flattens a discrete space::
|
|
|
|
|
|
|
|
>>> discrete = Discrete(5)
|
|
|
|
>>> flatten_space(discrete)
|
|
|
|
Box(5,)
|
|
|
|
>>> flatten(box, box.sample()) in flatten_space(box)
|
|
|
|
True
|
|
|
|
|
|
|
|
Example that recursively flattens a dict::
|
|
|
|
|
|
|
|
>>> space = Dict({"position": Discrete(2),
|
|
|
|
... "velocity": Box(0, 1, shape=(2, 2))})
|
|
|
|
>>> flatten_space(space)
|
|
|
|
Box(6,)
|
|
|
|
>>> flatten(space, space.sample()) in flatten_space(space)
|
|
|
|
True
|
|
|
|
"""
|
2021-08-22 00:04:09 +02:00
|
|
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
|
|
|
|
|
|
|
|
|
|
|
@flatten_space.register(Box)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_space_box(space: Box) -> Box:
|
2021-08-22 00:04:09 +02:00
|
|
|
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@flatten_space.register(Discrete)
|
|
|
|
@flatten_space.register(MultiBinary)
|
|
|
|
@flatten_space.register(MultiDiscrete)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
|
2021-08-22 00:04:09 +02:00
|
|
|
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@flatten_space.register(Tuple)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_space_tuple(space: Tuple) -> Box:
|
|
|
|
space_list = [flatten_space(s) for s in space.spaces]
|
2021-08-22 00:04:09 +02:00
|
|
|
return Box(
|
2022-01-24 23:22:11 +01:00
|
|
|
low=np.concatenate([s.low for s in space_list]),
|
|
|
|
high=np.concatenate([s.high for s in space_list]),
|
|
|
|
dtype=np.result_type(*[s.dtype for s in space_list]),
|
2021-08-22 00:04:09 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@flatten_space.register(Dict)
|
2022-01-24 23:22:11 +01:00
|
|
|
def _flatten_space_dict(space: Dict) -> Box:
|
|
|
|
space_list = [flatten_space(s) for s in space.spaces.values()]
|
2021-08-22 00:04:09 +02:00
|
|
|
return Box(
|
2022-01-24 23:22:11 +01:00
|
|
|
low=np.concatenate([s.low for s in space_list]),
|
|
|
|
high=np.concatenate([s.high for s in space_list]),
|
|
|
|
dtype=np.result_type(*[s.dtype for s in space_list]),
|
2021-08-22 00:04:09 +02:00
|
|
|
)
|