2022-05-20 14:49:30 +01:00
|
|
|
"""Numpy utility functions: concatenate space samples and create empty array."""
|
2022-03-31 12:50:38 -07:00
|
|
|
from collections import OrderedDict
|
|
|
|
from functools import singledispatch
|
2022-05-20 14:49:30 +01:00
|
|
|
from typing import Iterable, Union
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
import numpy as np
|
|
|
|
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
2022-01-21 11:28:34 -05:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
__all__ = ["concatenate", "create_empty_array"]
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@singledispatch
|
2022-05-20 14:49:30 +01:00
|
|
|
def concatenate(
|
|
|
|
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
|
|
|
|
) -> Union[tuple, dict, np.ndarray]:
|
2019-06-21 17:29:44 -04:00
|
|
|
"""Concatenate multiple samples from space into a single object.
|
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Example::
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-09-08 10:10:07 +01:00
|
|
|
>>> from gymnasium.spaces import Box
|
2022-05-20 14:49:30 +01:00
|
|
|
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
|
|
|
|
>>> out = np.zeros((2, 3), dtype=np.float32)
|
|
|
|
>>> items = [space.sample() for _ in range(2)]
|
|
|
|
>>> concatenate(space, items, out)
|
|
|
|
array([[0.6348213 , 0.28607962, 0.60760117],
|
|
|
|
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Args:
|
|
|
|
space: Observation space of a single environment in the vectorized environment.
|
|
|
|
items: Samples to be concatenated.
|
|
|
|
out: The output object. This object is a (possibly nested) numpy array.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Returns:
|
2019-06-21 17:29:44 -04:00
|
|
|
The output object. This object is a (possibly nested) numpy array.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: Space is not a valid :class:`gym.Space` instance
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
2022-01-21 11:28:34 -05:00
|
|
|
raise ValueError(
|
2022-09-08 10:10:07 +01:00
|
|
|
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
2022-01-21 11:28:34 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@concatenate.register(Box)
|
|
|
|
@concatenate.register(Discrete)
|
|
|
|
@concatenate.register(MultiDiscrete)
|
|
|
|
@concatenate.register(MultiBinary)
|
|
|
|
def _concatenate_base(space, items, out):
|
2019-06-21 17:29:44 -04:00
|
|
|
return np.stack(items, axis=0, out=out)
|
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@concatenate.register(Tuple)
|
|
|
|
def _concatenate_tuple(space, items, out):
|
2021-07-29 15:39:42 -04:00
|
|
|
return tuple(
|
2022-01-21 11:28:34 -05:00
|
|
|
concatenate(subspace, [item[i] for item in items], out[i])
|
2021-07-29 15:39:42 -04:00
|
|
|
for (i, subspace) in enumerate(space.spaces)
|
|
|
|
)
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@concatenate.register(Dict)
|
|
|
|
def _concatenate_dict(space, items, out):
|
2021-07-29 02:26:34 +02:00
|
|
|
return OrderedDict(
|
2021-07-29 15:39:42 -04:00
|
|
|
[
|
2022-01-21 11:28:34 -05:00
|
|
|
(key, concatenate(subspace, [item[key] for item in items], out[key]))
|
2021-07-29 15:39:42 -04:00
|
|
|
for (key, subspace) in space.spaces.items()
|
|
|
|
]
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@concatenate.register(Space)
|
|
|
|
def _concatenate_custom(space, items, out):
|
2020-09-21 22:38:51 +02:00
|
|
|
return tuple(items)
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@singledispatch
|
2022-05-20 14:49:30 +01:00
|
|
|
def create_empty_array(
|
|
|
|
space: Space, n: int = 1, fn: callable = np.zeros
|
|
|
|
) -> Union[tuple, dict, np.ndarray]:
|
2019-06-21 17:29:44 -04:00
|
|
|
"""Create an empty (possibly nested) numpy array.
|
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Example::
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-09-08 10:10:07 +01:00
|
|
|
>>> from gymnasium.spaces import Box, Dict
|
2022-05-20 14:49:30 +01:00
|
|
|
>>> space = Dict({
|
|
|
|
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
|
|
|
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
|
|
|
>>> create_empty_array(space, n=2, fn=np.zeros)
|
|
|
|
OrderedDict([('position', array([[0., 0., 0.],
|
|
|
|
[0., 0., 0.]], dtype=float32)),
|
|
|
|
('velocity', array([[0., 0.],
|
|
|
|
[0., 0.]], dtype=float32))])
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Args:
|
|
|
|
space: Observation space of a single environment in the vectorized environment.
|
|
|
|
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
|
|
|
|
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Returns:
|
2019-06-21 17:29:44 -04:00
|
|
|
The output object. This object is a (possibly nested) numpy array.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: Space is not a valid :class:`gym.Space` instance
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
2022-01-21 11:28:34 -05:00
|
|
|
raise ValueError(
|
2022-09-08 10:10:07 +01:00
|
|
|
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
2022-01-21 11:28:34 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@create_empty_array.register(Box)
|
|
|
|
@create_empty_array.register(Discrete)
|
|
|
|
@create_empty_array.register(MultiDiscrete)
|
|
|
|
@create_empty_array.register(MultiBinary)
|
|
|
|
def _create_empty_array_base(space, n=1, fn=np.zeros):
|
2019-06-21 17:29:44 -04:00
|
|
|
shape = space.shape if (n is None) else (n,) + space.shape
|
|
|
|
return fn(shape, dtype=space.dtype)
|
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@create_empty_array.register(Tuple)
|
|
|
|
def _create_empty_array_tuple(space, n=1, fn=np.zeros):
|
2021-07-29 02:26:34 +02:00
|
|
|
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@create_empty_array.register(Dict)
|
|
|
|
def _create_empty_array_dict(space, n=1, fn=np.zeros):
|
2021-07-29 15:39:42 -04:00
|
|
|
return OrderedDict(
|
|
|
|
[
|
|
|
|
(key, create_empty_array(subspace, n=n, fn=fn))
|
|
|
|
for (key, subspace) in space.spaces.items()
|
|
|
|
]
|
|
|
|
)
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2020-09-21 22:38:51 +02:00
|
|
|
|
2022-01-21 11:28:34 -05:00
|
|
|
@create_empty_array.register(Space)
|
|
|
|
def _create_empty_array_custom(space, n=1, fn=np.zeros):
|
2021-07-29 02:26:34 +02:00
|
|
|
return None
|