mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
typing in gym.spaces (#2541)
* typing in spaces.Box and spaces.Discrete * adds typing to dict and tuple spaces * Typecheck all spaces * Explicit regex to include all files under space folder * Style: use native types and __future__ annotations * Allow only specific strings for Box.is_bounded args * Add typing to changes from #2517 * Remove Literal as it's not supported by py3.7 * Use more recent version of pyright * Avoid name clash for type checker * Revert "Avoid name clash for type checker" This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1. * Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852 * rebase and add typing for `_short_repr`
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from typing import Iterable, List, Optional, Union
|
||||
import numpy as np
|
||||
from .space import Space
|
||||
|
||||
|
||||
class Tuple(Space):
|
||||
class Tuple(Space[tuple]):
|
||||
"""
|
||||
A tuple (i.e., product) of simpler spaces
|
||||
|
||||
@@ -10,16 +11,18 @@ class Tuple(Space):
|
||||
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
||||
"""
|
||||
|
||||
def __init__(self, spaces, seed=None):
|
||||
def __init__(
|
||||
self, spaces: Iterable[Space], seed: Optional[Union[int, List[int]]] = None
|
||||
):
|
||||
spaces = tuple(spaces)
|
||||
self.spaces = spaces
|
||||
for space in spaces:
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), "Elements of the tuple must be instances of gym.Space"
|
||||
super().__init__(None, None, seed)
|
||||
super().__init__(None, None, seed) # type: ignore
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
|
||||
seeds = []
|
||||
|
||||
if isinstance(seed, list):
|
||||
@@ -50,10 +53,10 @@ class Tuple(Space):
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(self):
|
||||
def sample(self) -> tuple:
|
||||
return tuple(space.sample() for space in self.spaces)
|
||||
|
||||
def contains(self, x):
|
||||
def contains(self, x) -> bool:
|
||||
if isinstance(x, (list, np.ndarray)):
|
||||
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
||||
return (
|
||||
@@ -62,17 +65,17 @@ class Tuple(Space):
|
||||
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
def to_jsonable(self, sample_n) -> list:
|
||||
# serialize as list-repr of tuple of vectors
|
||||
return [
|
||||
space.to_jsonable([sample[i] for sample in sample_n])
|
||||
for i, space in enumerate(self.spaces)
|
||||
]
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
def from_jsonable(self, sample_n) -> list:
|
||||
return [
|
||||
sample
|
||||
for sample in zip(
|
||||
@@ -83,11 +86,11 @@ class Tuple(Space):
|
||||
)
|
||||
]
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: int) -> Space:
|
||||
return self.spaces[index]
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.spaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return isinstance(other, Tuple) and self.spaces == other.spaces
|
||||
|
Reference in New Issue
Block a user