Files
Gymnasium/gym/spaces/tuple.py
Mark Towers 273e3f22ce Updated docstrings using darglint (#2827)
* Updated docstrings using darglint, ignoring 402 and 202 plus shortened lines into multiple where they were overflowing

* Remove abstract method decorators, for a future PR

* Add __future__ import annotation for python 3.7+ notion

* Added missing bracket

* Fix minor docstring tables
2022-05-25 09:46:41 -04:00

133 lines
4.7 KiB
Python

"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations
from typing import Iterable, Optional, Sequence
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class Tuple(Space[tuple], Sequence):
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
Elements of this space are tuples of elements of the constituent spaces.
Example usage::
>>> from gym.spaces import Box, Discrete
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))))
>>> observation_space.sample()
(0, array([0.03633198, 0.42370757], dtype=float32))
"""
def __init__(
self,
spaces: Iterable[Space],
seed: Optional[int | list[int] | seeding.RandomNumberGenerator] = None,
):
r"""Constructor of :class:`Tuple` space.
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
Args:
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
spaces = tuple(spaces)
self.spaces = spaces
for space in spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed) # type: ignore
def seed(self, seed: Optional[int | list[int]] = None) -> list:
"""Seed the PRNG of this space and all subspaces."""
seeds = []
if isinstance(seed, list):
for i, space in enumerate(self.spaces):
seeds += space.seed(seed[i])
elif isinstance(seed, int):
seeds = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds.append(subspace.seed(int(subseed))[0])
elif seed is None:
for space in self.spaces:
seeds += space.seed(seed)
else:
raise TypeError("Passed seed not of an expected type: list or int or None")
return seeds
def sample(self) -> tuple:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
Returns:
Tuple of the subspace's samples
"""
return tuple(space.sample() for space in self.spaces)
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
return (
isinstance(x, tuple)
and len(x) == len(self.spaces)
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n: Sequence) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list-repr of tuple of vectors
return [
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
sample
for sample in zip(
*[
space.from_jsonable(sample_n[i])
for i, space in enumerate(self.spaces)
]
)
]
def __getitem__(self, index: int) -> Space:
"""Get the subspace at specific `index`."""
return self.spaces[index]
def __len__(self) -> int:
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)
def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, Tuple) and self.spaces == other.spaces