Files
Gymnasium/gym/spaces/tuple.py
Markus Krimmel 63ea5f2517 Add Sequence space, update flatten functions (#2968)
* Added Sequence space, updated flatten functions to work with Sequence, Graph. WIP.

* Small fixes, added Sequence space to tests

* Replace Optional[Any] by Any

* Added tests for flattening of non-numpy-flattenable spaces

* Return all seeds
2022-08-15 11:11:32 -04:00

152 lines
5.6 KiB
Python

"""Implementation of a space that represents the cartesian product of other spaces."""
from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class Tuple(Space[tuple], Sequence):
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
Elements of this space are tuples of elements of the constituent spaces.
Example usage::
>>> from gym.spaces import Box, Discrete
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))))
>>> observation_space.sample()
(0, array([0.03633198, 0.42370757], dtype=float32))
"""
def __init__(
self,
spaces: Iterable[Space],
seed: Optional[Union[int, List[int], seeding.RandomNumberGenerator]] = None,
):
r"""Constructor of :class:`Tuple` space.
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
Args:
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
self.spaces = tuple(spaces)
for space in self.spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed) # type: ignore
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
"""Seed the PRNG of this space and all subspaces."""
seeds = []
if isinstance(seed, list):
for i, space in enumerate(self.spaces):
seeds += space.seed(seed[i])
elif isinstance(seed, int):
seeds = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds.append(subspace.seed(int(subseed))[0])
elif seed is None:
for space in self.spaces:
seeds += space.seed(seed)
else:
raise TypeError("Passed seed not of an expected type: list or int or None")
return seeds
def sample(self, mask: Optional[Tuple[Optional[np.ndarray]]] = None) -> tuple:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
Args:
mask: An optional tuple of optional masks for each of the subspace's samples,
expects the same number of masks as spaces
Returns:
Tuple of the subspace's samples
"""
if mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"
return tuple(
space.sample(mask=sub_mask)
for space, sub_mask in zip(self.spaces, mask)
)
return tuple(space.sample() for space in self.spaces)
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
return (
isinstance(x, tuple)
and len(x) == len(self.spaces)
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n: Sequence) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list-repr of tuple of vectors
return [
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
sample
for sample in zip(
*[
space.from_jsonable(sample_n[i])
for i, space in enumerate(self.spaces)
]
)
]
def __getitem__(self, index: int) -> Space:
"""Get the subspace at specific `index`."""
return self.spaces[index]
def __len__(self) -> int:
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)
def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, Tuple) and self.spaces == other.spaces