2022-05-10 17:18:06 +02:00
|
|
|
"""Implementation of a space that represents the cartesian product of other spaces."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import typing
|
2025-06-07 17:57:58 +01:00
|
|
|
from collections.abc import Iterable
|
|
|
|
from typing import Any
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2019-03-25 00:47:16 +01:00
|
|
|
import numpy as np
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium.spaces.space import Space
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2019-01-30 22:39:55 +01:00
|
|
|
|
2025-06-07 17:57:58 +01:00
|
|
|
class Tuple(Space[tuple[Any, ...]], typing.Sequence[Any]):
|
2022-05-10 17:18:06 +02:00
|
|
|
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
|
|
|
|
|
|
|
|
Elements of this space are tuples of elements of the constituent spaces.
|
2016-06-11 23:10:58 -07:00
|
|
|
|
2023-01-23 11:30:00 +01:00
|
|
|
Example:
|
|
|
|
>>> from gymnasium.spaces import Tuple, Box, Discrete
|
2023-01-20 14:28:09 +01:00
|
|
|
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
|
2022-05-20 14:49:30 +01:00
|
|
|
>>> observation_space.sample()
|
2024-09-03 12:30:58 +01:00
|
|
|
(np.int64(0), array([-0.3991573 , 0.21649833], dtype=float32))
|
2016-04-27 08:00:58 -07:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __init__(
|
2022-04-24 17:14:33 +01:00
|
|
|
self,
|
2022-11-15 14:09:22 +00:00
|
|
|
spaces: Iterable[Space[Any]],
|
|
|
|
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
|
2022-01-24 23:22:11 +01:00
|
|
|
):
|
2022-05-24 23:09:05 +01:00
|
|
|
r"""Constructor of :class:`Tuple` space.
|
2022-05-10 17:18:06 +02:00
|
|
|
|
|
|
|
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
|
|
|
|
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
|
|
|
|
"""
|
2022-06-06 16:21:45 +01:00
|
|
|
self.spaces = tuple(spaces)
|
|
|
|
for space in self.spaces:
|
2021-07-29 15:39:42 -04:00
|
|
|
assert isinstance(
|
|
|
|
space, Space
|
2023-02-20 16:02:12 +00:00
|
|
|
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
|
2022-01-24 23:22:11 +01:00
|
|
|
super().__init__(None, None, seed) # type: ignore
|
2019-01-30 22:39:55 +01:00
|
|
|
|
2022-08-15 17:11:32 +02:00
|
|
|
@property
|
|
|
|
def is_np_flattenable(self):
|
|
|
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
|
|
|
return all(space.is_np_flattenable for space in self.spaces)
|
|
|
|
|
2024-08-29 16:52:43 +01:00
|
|
|
def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]:
|
2022-08-16 16:59:10 +01:00
|
|
|
"""Seed the PRNG of this space and all subspaces.
|
|
|
|
|
|
|
|
Depending on the type of seed, the subspaces will be seeded differently
|
2022-12-03 12:02:09 +01:00
|
|
|
|
|
|
|
* ``None`` - All the subspaces will use a random initial seed
|
2023-11-07 13:27:25 +00:00
|
|
|
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
2024-08-29 16:52:43 +01:00
|
|
|
* ``List`` / ``Tuple`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
2022-08-16 16:59:10 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
seed: An optional list of ints or int to seed the (sub-)spaces.
|
2021-09-13 20:08:01 +02:00
|
|
|
|
2024-04-28 16:10:35 +01:00
|
|
|
Returns:
|
|
|
|
A tuple of the seed values for all subspaces
|
|
|
|
"""
|
|
|
|
if seed is None:
|
|
|
|
return tuple(space.seed(None) for space in self.spaces)
|
2021-09-13 20:08:01 +02:00
|
|
|
elif isinstance(seed, int):
|
2024-04-28 16:10:35 +01:00
|
|
|
super().seed(seed)
|
2022-08-16 16:59:10 +01:00
|
|
|
subseeds = self.np_random.integers(
|
|
|
|
np.iinfo(np.int32).max, size=len(self.spaces)
|
|
|
|
)
|
2024-04-28 16:10:35 +01:00
|
|
|
return tuple(
|
|
|
|
subspace.seed(int(subseed))
|
|
|
|
for subspace, subseed in zip(self.spaces, subseeds)
|
|
|
|
)
|
|
|
|
elif isinstance(seed, (tuple, list)):
|
|
|
|
if len(seed) != len(self.spaces):
|
|
|
|
raise ValueError(
|
|
|
|
f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
|
|
|
)
|
|
|
|
|
|
|
|
return tuple(
|
|
|
|
space.seed(subseed) for subseed, space in zip(seed, self.spaces)
|
|
|
|
)
|
2021-09-13 20:08:01 +02:00
|
|
|
else:
|
2022-09-03 23:39:23 +01:00
|
|
|
raise TypeError(
|
|
|
|
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
|
|
|
)
|
2021-09-13 20:08:01 +02:00
|
|
|
|
2025-02-21 13:39:23 +00:00
|
|
|
def sample(
|
|
|
|
self,
|
|
|
|
mask: tuple[Any | None, ...] | None = None,
|
|
|
|
probability: tuple[Any | None, ...] | None = None,
|
|
|
|
) -> tuple[Any, ...]:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Generates a single random sample inside this space.
|
|
|
|
|
|
|
|
This method draws independent samples from the subspaces.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
2022-06-26 23:23:15 +01:00
|
|
|
Args:
|
|
|
|
mask: An optional tuple of optional masks for each of the subspace's samples,
|
|
|
|
expects the same number of masks as spaces
|
2025-02-21 13:39:23 +00:00
|
|
|
probability: An optional tuple of optional probability masks for each of the subspace's samples,
|
|
|
|
expects the same number of probability masks as spaces
|
2022-06-26 23:23:15 +01:00
|
|
|
|
2022-05-25 14:46:41 +01:00
|
|
|
Returns:
|
|
|
|
Tuple of the subspace's samples
|
2022-05-10 17:18:06 +02:00
|
|
|
"""
|
2025-02-21 13:39:23 +00:00
|
|
|
if mask is not None and probability is not None:
|
|
|
|
raise ValueError(
|
|
|
|
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
|
|
|
|
)
|
|
|
|
elif mask is not None:
|
2022-06-26 23:23:15 +01:00
|
|
|
assert isinstance(
|
|
|
|
mask, tuple
|
2025-02-21 13:39:23 +00:00
|
|
|
), f"Expected type of `mask` to be tuple, actual type: {type(mask)}"
|
2022-06-26 23:23:15 +01:00
|
|
|
assert len(mask) == len(
|
|
|
|
self.spaces
|
2025-02-21 13:39:23 +00:00
|
|
|
), f"Expected length of `mask` to be {len(self.spaces)}, actual length: {len(mask)}"
|
2022-06-26 23:23:15 +01:00
|
|
|
|
|
|
|
return tuple(
|
2025-02-21 13:39:23 +00:00
|
|
|
space.sample(mask=space_mask)
|
|
|
|
for space, space_mask in zip(self.spaces, mask)
|
2022-06-26 23:23:15 +01:00
|
|
|
)
|
|
|
|
|
2025-02-21 13:39:23 +00:00
|
|
|
elif probability is not None:
|
|
|
|
assert isinstance(
|
|
|
|
probability, tuple
|
|
|
|
), f"Expected type of `probability` to be tuple, actual type: {type(probability)}"
|
|
|
|
assert len(probability) == len(
|
|
|
|
self.spaces
|
|
|
|
), f"Expected length of `probability` to be {len(self.spaces)}, actual length: {len(probability)}"
|
|
|
|
|
|
|
|
return tuple(
|
|
|
|
space.sample(probability=space_probability)
|
|
|
|
for space, space_probability in zip(self.spaces, probability)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return tuple(space.sample() for space in self.spaces)
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def contains(self, x: Any) -> bool:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Return boolean specifying if x is a valid member of this space."""
|
2021-12-16 13:45:37 +08:00
|
|
|
if isinstance(x, (list, np.ndarray)):
|
|
|
|
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
2022-11-15 14:09:22 +00:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
return (
|
|
|
|
isinstance(x, tuple)
|
|
|
|
and len(x) == len(self.spaces)
|
|
|
|
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
|
|
|
|
)
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __repr__(self) -> str:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Gives a string representation of this space."""
|
2021-07-29 02:26:34 +02:00
|
|
|
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def to_jsonable(
|
|
|
|
self, sample_n: typing.Sequence[tuple[Any, ...]]
|
|
|
|
) -> list[list[Any]]:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
2016-04-27 08:00:58 -07:00
|
|
|
# serialize as list-repr of tuple of vectors
|
2021-07-29 15:39:42 -04:00
|
|
|
return [
|
|
|
|
space.to_jsonable([sample[i] for sample in sample_n])
|
|
|
|
for i, space in enumerate(self.spaces)
|
|
|
|
]
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
2021-07-29 15:39:42 -04:00
|
|
|
return [
|
|
|
|
sample
|
|
|
|
for sample in zip(
|
|
|
|
*[
|
|
|
|
space.from_jsonable(sample_n[i])
|
|
|
|
for i, space in enumerate(self.spaces)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
]
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def __getitem__(self, index: int) -> Space[Any]:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Get the subspace at specific `index`."""
|
2019-03-01 18:22:58 -05:00
|
|
|
return self.spaces[index]
|
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __len__(self) -> int:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Get the number of subspaces that are involved in the cartesian product."""
|
2019-03-01 18:22:58 -05:00
|
|
|
return len(self.spaces)
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def __eq__(self, other: Any) -> bool:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Check whether ``other`` is equivalent to this instance."""
|
2019-03-23 23:18:19 -07:00
|
|
|
return isinstance(other, Tuple) and self.spaces == other.spaces
|