Fix type hints errors in gymnasium/spaces (#327)

This commit is contained in:
Valentin
2023-02-13 18:18:40 +01:00
committed by GitHub
parent d101d389dc
commit f6d41e85f9
10 changed files with 41 additions and 40 deletions

View File

@@ -69,8 +69,8 @@ class Box(Space[NDArray[Any]]):
this value across all dimensions. this value across all dimensions.
Args: Args:
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals. low (SupportsFloat | np.ndarray): Lower bounds of the intervals.
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals. high (SupportsFloat | np.ndarray]): Upper bounds of the intervals.
shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with
`low` and `high` scalars defaulting to a shape of (1,) `low` and `high` scalars defaulting to a shape of (1,)
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space. dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
@@ -104,12 +104,13 @@ class Box(Space[NDArray[Any]]):
# Capture the boundedness information before replacing np.inf with get_inf # Capture the boundedness information before replacing np.inf with get_inf
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low _low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
self.bounded_below: bool = -np.inf < _low self.bounded_below: NDArray[np.bool_] = -np.inf < _low
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
self.bounded_above: bool = np.inf > _high
low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-") _high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+") self.bounded_above: NDArray[np.bool_] = np.inf > _high
low = _broadcast(low, self.dtype, shape, inf_sign="-")
high = _broadcast(high, self.dtype, shape, inf_sign="+")
assert isinstance(low, np.ndarray) assert isinstance(low, np.ndarray)
assert ( assert (
@@ -280,7 +281,7 @@ class Box(Space[NDArray[Any]]):
self.high_repr = _short_repr(self.high) self.high_repr = _short_repr(self.high)
def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat: def get_inf(dtype: np.dtype, sign: str) -> int | float:
"""Returns an infinite that doesn't break things. """Returns an infinite that doesn't break things.
Args: Args:

View File

@@ -110,7 +110,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
), f"Dict space element is not an instance of Space: key='{key}', space={space}" ), f"Dict space element is not an instance of Space: key='{key}', space={space}"
# None for shape and dtype, since it'll require special handling # None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed) super().__init__(None, None, seed) # type: ignore
@property @property
def is_np_flattenable(self): def is_np_flattenable(self):
@@ -226,7 +226,9 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
for key, space in self.spaces.items() for key, space in self.spaces.items()
} }
def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]: def from_jsonable(
self, sample_n: dict[str, list[Any]]
) -> list[OrderedDict[str, Any]]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
dict_of_list: dict[str, list[Any]] = { dict_of_list: dict[str, list[Any]] = {
key: space.from_jsonable(sample_n[key]) key: space.from_jsonable(sample_n[key])

View File

@@ -55,7 +55,7 @@ class Discrete(Space[np.int64]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" """Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True return True
def sample(self, mask: MaskNDArray | None = None) -> int: def sample(self, mask: MaskNDArray | None = None) -> np.int64:
"""Generates a single random sample from this space. """Generates a single random sample from this space.
A sample will be chosen uniformly at random with the mask if provided A sample will be chosen uniformly at random with the mask if provided

View File

@@ -230,9 +230,9 @@ class Graph(Space[GraphInstance]):
def to_jsonable( def to_jsonable(
self, sample_n: Sequence[GraphInstance] self, sample_n: Sequence[GraphInstance]
) -> list[dict[str, list[int] | list[float]]]: ) -> list[dict[str, list[int | float]]]:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
ret_n: list[dict[str, list[int | float]]] = [] ret_n = []
for sample in sample_n: for sample in sample_n:
ret = {"nodes": sample.nodes.tolist()} ret = {"nodes": sample.nodes.tolist()}
if sample.edges is not None and sample.edge_links is not None: if sample.edges is not None and sample.edge_links is not None:

View File

@@ -4,12 +4,12 @@ from __future__ import annotations
from typing import Any, Sequence from typing import Any, Sequence
import numpy as np import numpy as np
import numpy.typing as npt from numpy.typing import NDArray
from gymnasium.spaces.space import MaskNDArray, Space from gymnasium.spaces.space import MaskNDArray, Space
class MultiBinary(Space[npt.NDArray[np.int8]]): class MultiBinary(Space[NDArray[np.int8]]):
"""An n-shape binary space. """An n-shape binary space.
Elements of this space are binary arrays of a shape that is fixed during construction. Elements of this space are binary arrays of a shape that is fixed during construction.
@@ -28,7 +28,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
def __init__( def __init__(
self, self,
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int, n: NDArray[np.integer[Any]] | Sequence[int] | int,
seed: int | np.random.Generator | None = None, seed: int | np.random.Generator | None = None,
): ):
"""Constructor of :class:`MultiBinary` space. """Constructor of :class:`MultiBinary` space.
@@ -58,7 +58,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" """Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True return True
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]: def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
"""Generates a single random sample from this space. """Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space). A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
@@ -104,15 +104,11 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
and np.all(np.logical_or(x == 0, x == 1)) and np.all(np.logical_or(x == 0, x == 1))
) )
def to_jsonable( def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
self, sample_n: Sequence[npt.NDArray[np.int8]]
) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist() return np.array(sample_n).tolist()
def from_jsonable( def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.int8]]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample, self.dtype) for sample in sample_n] return [np.asarray(sample, self.dtype) for sample in sample_n]

View File

@@ -4,14 +4,14 @@ from __future__ import annotations
from typing import Any, Sequence from typing import Any, Sequence
import numpy as np import numpy as np
import numpy.typing as npt from numpy.typing import NDArray
import gymnasium as gym import gymnasium as gym
from gymnasium.spaces.discrete import Discrete from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.space import MaskNDArray, Space from gymnasium.spaces.space import MaskNDArray, Space
class MultiDiscrete(Space[npt.NDArray[np.integer]]): class MultiDiscrete(Space[NDArray[np.integer]]):
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces. """This represents the cartesian product of arbitrary :class:`Discrete` spaces.
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space. It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
@@ -41,7 +41,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
def __init__( def __init__(
self, self,
nvec: npt.NDArray[np.integer[Any]] | list[int], nvec: NDArray[np.integer[Any]] | list[int],
dtype: str | type[np.integer[Any]] = np.int64, dtype: str | type[np.integer[Any]] = np.int64,
seed: int | np.random.Generator | None = None, seed: int | np.random.Generator | None = None,
): ):
@@ -72,7 +72,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
def sample( def sample(
self, mask: tuple[MaskNDArray, ...] | None = None self, mask: tuple[MaskNDArray, ...] | None = None
) -> npt.NDArray[np.integer[Any]]: ) -> NDArray[np.integer[Any]]:
"""Generates a single random sample this space. """Generates a single random sample this space.
Args: Args:
@@ -88,7 +88,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
def _apply_mask( def _apply_mask(
sub_mask: MaskNDArray | tuple[MaskNDArray, ...], sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
sub_nvec: MaskNDArray | np.integer[Any], sub_nvec: MaskNDArray | np.integer[Any],
) -> int | Sequence[int]: ) -> int | list[Any]:
if isinstance(sub_nvec, np.ndarray): if isinstance(sub_nvec, np.ndarray):
assert isinstance( assert isinstance(
sub_mask, tuple sub_mask, tuple
@@ -144,14 +144,14 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
) )
def to_jsonable( def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]] self, sample_n: Sequence[NDArray[np.integer[Any]]]
) -> list[Sequence[int]]: ) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
return [sample.tolist() for sample in sample_n] return [sample.tolist() for sample in sample_n]
def from_jsonable( def from_jsonable(
self, sample_n: list[Sequence[int]] self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.integer[Any]]]: ) -> list[NDArray[np.integer[Any]]]:
"""Convert a JSONable data type to a batch of samples from this space.""" """Convert a JSONable data type to a batch of samples from this space."""
return [np.array(sample) for sample in sample_n] return [np.array(sample) for sample in sample_n]

View File

@@ -5,7 +5,7 @@ import typing
from typing import Any, Union from typing import Any, Union
import numpy as np import numpy as np
import numpy.typing as npt from numpy.typing import NDArray
import gymnasium as gym import gymnasium as gym
from gymnasium.spaces.space import Space from gymnasium.spaces.space import Space
@@ -69,11 +69,11 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
mask: None mask: None
| ( | (
tuple[ tuple[
None | np.integer | npt.NDArray[np.integer], None | np.integer | NDArray[np.integer],
Any, Any,
] ]
) = None, ) = None,
) -> tuple[Any]: ) -> tuple[Any] | Any:
"""Generates a single random sample from this space. """Generates a single random sample from this space.
Args: Args:

View File

@@ -104,8 +104,8 @@ class Space(Generic[T_cov]):
def seed(self, seed: int | None = None) -> list[int]: def seed(self, seed: int | None = None) -> list[int]:
"""Seed the PRNG of this space and possibly the PRNGs of subspaces.""" """Seed the PRNG of this space and possibly the PRNGs of subspaces."""
self._np_random, seed = seeding.np_random(seed) self._np_random, np_random_seed = seeding.np_random(seed)
return [seed] return [np_random_seed]
def contains(self, x: Any) -> bool: def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space.""" """Return boolean specifying if x is a valid member of this space."""

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any from typing import Any
import numpy as np import numpy as np
import numpy.typing as npt from numpy.typing import NDArray
from gymnasium.spaces.space import Space from gymnasium.spaces.space import Space
@@ -35,7 +35,7 @@ class Text(Space[str]):
max_length: int, max_length: int,
*, *,
min_length: int = 1, min_length: int = 1,
charset: set[str] | str = alphanumeric, charset: frozenset[str] | str = alphanumeric,
seed: int | np.random.Generator | None = None, seed: int | np.random.Generator | None = None,
): ):
r"""Constructor of :class:`Text` space. r"""Constructor of :class:`Text` space.
@@ -76,7 +76,7 @@ class Text(Space[str]):
def sample( def sample(
self, self,
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None, mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
) -> str: ) -> str:
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`. """Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.

View File

@@ -184,7 +184,7 @@ def _flatten_multidiscrete(
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]: def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
if space.is_np_flattenable: if space.is_np_flattenable:
return np.concatenate( return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)] [np.array(flatten(s, x_part)) for x_part, s in zip(x, space.spaces)]
) )
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces)) return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
@@ -192,7 +192,9 @@ def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArra
@flatten.register(Dict) @flatten.register(Dict)
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]: def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
if space.is_np_flattenable: if space.is_np_flattenable:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) return np.concatenate(
[np.array(flatten(s, x[key])) for key, s in space.spaces.items()]
)
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items()) return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())