typing in gym.spaces (#2541)

* typing in spaces.Box and spaces.Discrete

* adds typing to dict and tuple spaces

* Typecheck all spaces

* Explicit regex to include all files under space folder

* Style: use native types and __future__ annotations

* Allow only specific strings for Box.is_bounded args

* Add typing to changes from #2517

* Remove Literal as it's not supported by py3.7

* Use more recent version of pyright

* Avoid name clash for type checker

* Revert "Avoid name clash for type checker"

This reverts commit 1aaf3e0e0328171623a17a997b65fe734bc0afb1.

* Ignore the error. It's reported as probable bug at https://github.com/microsoft/pyright/issues/2852

* rebase and add typing for `_short_repr`
This commit is contained in:
Ilya Kamen
2022-01-24 23:22:11 +01:00
committed by GitHub
parent fcbff7de12
commit ad79b0ad0f
11 changed files with 199 additions and 146 deletions

View File

@@ -1,8 +1,9 @@
from typing import Iterable, List, Optional, Union
import numpy as np
from .space import Space
class Tuple(Space):
class Tuple(Space[tuple]):
"""
A tuple (i.e., product) of simpler spaces
@@ -10,16 +11,18 @@ class Tuple(Space):
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
"""
def __init__(self, spaces, seed=None):
def __init__(
self, spaces: Iterable[Space], seed: Optional[Union[int, List[int]]] = None
):
spaces = tuple(spaces)
self.spaces = spaces
for space in spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed)
super().__init__(None, None, seed) # type: ignore
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> list:
seeds = []
if isinstance(seed, list):
@@ -50,10 +53,10 @@ class Tuple(Space):
return seeds
def sample(self):
def sample(self) -> tuple:
return tuple(space.sample() for space in self.spaces)
def contains(self, x):
def contains(self, x) -> bool:
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
return (
@@ -62,17 +65,17 @@ class Tuple(Space):
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
)
def __repr__(self):
def __repr__(self) -> str:
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n):
def to_jsonable(self, sample_n) -> list:
# serialize as list-repr of tuple of vectors
return [
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n):
def from_jsonable(self, sample_n) -> list:
return [
sample
for sample in zip(
@@ -83,11 +86,11 @@ class Tuple(Space):
)
]
def __getitem__(self, index):
def __getitem__(self, index: int) -> Space:
return self.spaces[index]
def __len__(self):
def __len__(self) -> int:
return len(self.spaces)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return isinstance(other, Tuple) and self.spaces == other.spaces