Files
Gymnasium/gym/spaces/tuple.py
Alexis DUBURCQ 108f32c743 gym.spaces.Tuple inherits from collections.abc.Sequence (#2637)
* gym.spaces.Tuple inherits from collections.abc.Sequence

Following the PR I did a few months back (https://github.com/openai/gym/pull/2446), the tuple wrapper of gym should inherits from the abstract interface of Python. It is important for type check via `isinstance` and enable using such objects transparently with libraries such as [dmtree](https://github.com/deepmind/tree). 

It will bring a way helper methods along the way but it cannot be avoided to interoperability: : `__iter__`, `__reversed__`, `index`, and `count`
Personally I don't think it is an issue since it is new features and it is not conflicting.

As the previous PR, this patch is NOT removing any existing feature and should not break backward compatibility.

* Add unit test

* Fix unit tests

* Final fix.

* Remove irrelevant comment.

* Fix black formatter.
2022-03-04 10:25:19 -05:00

98 lines
3.1 KiB
Python

from collections.abc import Sequence
from typing import Iterable, List, Optional, Union
import numpy as np
from .space import Space
class Tuple(Space[tuple], Sequence):
"""
A tuple (i.e., product) of simpler spaces
Example usage:
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
"""
def __init__(
self, spaces: Iterable[Space], seed: Optional[Union[int, List[int]]] = None
):
spaces = tuple(spaces)
self.spaces = spaces
for space in spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed) # type: ignore
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
seeds = []
if isinstance(seed, list):
for i, space in enumerate(self.spaces):
seeds += space.seed(seed[i])
elif isinstance(seed, int):
seeds = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds.append(subspace.seed(int(subseed))[0])
elif seed is None:
for space in self.spaces:
seeds += space.seed(seed)
else:
raise TypeError("Passed seed not of an expected type: list or int or None")
return seeds
def sample(self) -> tuple:
return tuple(space.sample() for space in self.spaces)
def contains(self, x) -> bool:
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
return (
isinstance(x, tuple)
and len(x) == len(self.spaces)
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
)
def __repr__(self) -> str:
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n) -> list:
# serialize as list-repr of tuple of vectors
return [
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n) -> list:
return [
sample
for sample in zip(
*[
space.from_jsonable(sample_n[i])
for i, space in enumerate(self.spaces)
]
)
]
def __getitem__(self, index: int) -> Space:
return self.spaces[index]
def __len__(self) -> int:
return len(self.spaces)
def __eq__(self, other) -> bool:
return isinstance(other, Tuple) and self.spaces == other.spaces