mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Fix type hints errors in gymnasium/spaces (#327)
This commit is contained in:
@@ -4,12 +4,12 @@ from __future__ import annotations
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from gymnasium.spaces.space import MaskNDArray, Space
|
||||
|
||||
|
||||
class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
class MultiBinary(Space[NDArray[np.int8]]):
|
||||
"""An n-shape binary space.
|
||||
|
||||
Elements of this space are binary arrays of a shape that is fixed during construction.
|
||||
@@ -28,7 +28,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
|
||||
n: NDArray[np.integer[Any]] | Sequence[int] | int,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
"""Constructor of :class:`MultiBinary` space.
|
||||
@@ -58,7 +58,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
|
||||
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||
@@ -104,15 +104,11 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
and np.all(np.logical_or(x == 0, x == 1))
|
||||
)
|
||||
|
||||
def to_jsonable(
|
||||
self, sample_n: Sequence[npt.NDArray[np.int8]]
|
||||
) -> list[Sequence[int]]:
|
||||
def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(
|
||||
self, sample_n: list[Sequence[int]]
|
||||
) -> list[npt.NDArray[np.int8]]:
|
||||
def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
||||
|
||||
|
Reference in New Issue
Block a user