Fix type hints errors in gymnasium/spaces (#327)

This commit is contained in:
Valentin
2023-02-13 18:18:40 +01:00
committed by GitHub
parent d101d389dc
commit f6d41e85f9
10 changed files with 41 additions and 40 deletions

View File

@@ -4,12 +4,12 @@ from __future__ import annotations
from typing import Any, Sequence
import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray
from gymnasium.spaces.space import MaskNDArray, Space
class MultiBinary(Space[npt.NDArray[np.int8]]):
class MultiBinary(Space[NDArray[np.int8]]):
"""An n-shape binary space.
Elements of this space are binary arrays of a shape that is fixed during construction.
@@ -28,7 +28,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
def __init__(
self,
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
n: NDArray[np.integer[Any]] | Sequence[int] | int,
seed: int | np.random.Generator | None = None,
):
"""Constructor of :class:`MultiBinary` space.
@@ -58,7 +58,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
"""Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
@@ -104,15 +104,11 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
and np.all(np.logical_or(x == 0, x == 1))
)
def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.int8]]
) -> list[Sequence[int]]:
def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()
def from_jsonable(
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.int8]]:
def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample, self.dtype) for sample in sample_n]