mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 15:11:51 +00:00
Fix type hints errors in gymnasium/spaces (#327)
This commit is contained in:
@@ -69,8 +69,8 @@ class Box(Space[NDArray[Any]]):
|
||||
this value across all dimensions.
|
||||
|
||||
Args:
|
||||
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
|
||||
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
|
||||
low (SupportsFloat | np.ndarray): Lower bounds of the intervals.
|
||||
high (SupportsFloat | np.ndarray]): Upper bounds of the intervals.
|
||||
shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with
|
||||
`low` and `high` scalars defaulting to a shape of (1,)
|
||||
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
|
||||
@@ -104,12 +104,13 @@ class Box(Space[NDArray[Any]]):
|
||||
|
||||
# Capture the boundedness information before replacing np.inf with get_inf
|
||||
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
|
||||
self.bounded_below: bool = -np.inf < _low
|
||||
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
||||
self.bounded_above: bool = np.inf > _high
|
||||
self.bounded_below: NDArray[np.bool_] = -np.inf < _low
|
||||
|
||||
low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-")
|
||||
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+")
|
||||
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
||||
self.bounded_above: NDArray[np.bool_] = np.inf > _high
|
||||
|
||||
low = _broadcast(low, self.dtype, shape, inf_sign="-")
|
||||
high = _broadcast(high, self.dtype, shape, inf_sign="+")
|
||||
|
||||
assert isinstance(low, np.ndarray)
|
||||
assert (
|
||||
@@ -280,7 +281,7 @@ class Box(Space[NDArray[Any]]):
|
||||
self.high_repr = _short_repr(self.high)
|
||||
|
||||
|
||||
def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat:
|
||||
def get_inf(dtype: np.dtype, sign: str) -> int | float:
|
||||
"""Returns an infinite that doesn't break things.
|
||||
|
||||
Args:
|
||||
|
@@ -110,7 +110,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
||||
|
||||
# None for shape and dtype, since it'll require special handling
|
||||
super().__init__(None, None, seed)
|
||||
super().__init__(None, None, seed) # type: ignore
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
@@ -226,7 +226,9 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||
for key, space in self.spaces.items()
|
||||
}
|
||||
|
||||
def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
|
||||
def from_jsonable(
|
||||
self, sample_n: dict[str, list[Any]]
|
||||
) -> list[OrderedDict[str, Any]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
dict_of_list: dict[str, list[Any]] = {
|
||||
key: space.from_jsonable(sample_n[key])
|
||||
|
@@ -55,7 +55,7 @@ class Discrete(Space[np.int64]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: MaskNDArray | None = None) -> int:
|
||||
def sample(self, mask: MaskNDArray | None = None) -> np.int64:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
A sample will be chosen uniformly at random with the mask if provided
|
||||
|
@@ -230,9 +230,9 @@ class Graph(Space[GraphInstance]):
|
||||
|
||||
def to_jsonable(
|
||||
self, sample_n: Sequence[GraphInstance]
|
||||
) -> list[dict[str, list[int] | list[float]]]:
|
||||
) -> list[dict[str, list[int | float]]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
ret_n: list[dict[str, list[int | float]]] = []
|
||||
ret_n = []
|
||||
for sample in sample_n:
|
||||
ret = {"nodes": sample.nodes.tolist()}
|
||||
if sample.edges is not None and sample.edge_links is not None:
|
||||
|
@@ -4,12 +4,12 @@ from __future__ import annotations
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from gymnasium.spaces.space import MaskNDArray, Space
|
||||
|
||||
|
||||
class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
class MultiBinary(Space[NDArray[np.int8]]):
|
||||
"""An n-shape binary space.
|
||||
|
||||
Elements of this space are binary arrays of a shape that is fixed during construction.
|
||||
@@ -28,7 +28,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
|
||||
n: NDArray[np.integer[Any]] | Sequence[int] | int,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
"""Constructor of :class:`MultiBinary` space.
|
||||
@@ -58,7 +58,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
|
||||
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||
@@ -104,15 +104,11 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
and np.all(np.logical_or(x == 0, x == 1))
|
||||
)
|
||||
|
||||
def to_jsonable(
|
||||
self, sample_n: Sequence[npt.NDArray[np.int8]]
|
||||
) -> list[Sequence[int]]:
|
||||
def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(
|
||||
self, sample_n: list[Sequence[int]]
|
||||
) -> list[npt.NDArray[np.int8]]:
|
||||
def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
||||
|
||||
|
@@ -4,14 +4,14 @@ from __future__ import annotations
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces.discrete import Discrete
|
||||
from gymnasium.spaces.space import MaskNDArray, Space
|
||||
|
||||
|
||||
class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
class MultiDiscrete(Space[NDArray[np.integer]]):
|
||||
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
||||
|
||||
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
|
||||
@@ -41,7 +41,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nvec: npt.NDArray[np.integer[Any]] | list[int],
|
||||
nvec: NDArray[np.integer[Any]] | list[int],
|
||||
dtype: str | type[np.integer[Any]] = np.int64,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
@@ -72,7 +72,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
|
||||
def sample(
|
||||
self, mask: tuple[MaskNDArray, ...] | None = None
|
||||
) -> npt.NDArray[np.integer[Any]]:
|
||||
) -> NDArray[np.integer[Any]]:
|
||||
"""Generates a single random sample this space.
|
||||
|
||||
Args:
|
||||
@@ -88,7 +88,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
def _apply_mask(
|
||||
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
|
||||
sub_nvec: MaskNDArray | np.integer[Any],
|
||||
) -> int | Sequence[int]:
|
||||
) -> int | list[Any]:
|
||||
if isinstance(sub_nvec, np.ndarray):
|
||||
assert isinstance(
|
||||
sub_mask, tuple
|
||||
@@ -144,14 +144,14 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
)
|
||||
|
||||
def to_jsonable(
|
||||
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]]
|
||||
self, sample_n: Sequence[NDArray[np.integer[Any]]]
|
||||
) -> list[Sequence[int]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return [sample.tolist() for sample in sample_n]
|
||||
|
||||
def from_jsonable(
|
||||
self, sample_n: list[Sequence[int]]
|
||||
) -> list[npt.NDArray[np.integer[Any]]]:
|
||||
) -> list[NDArray[np.integer[Any]]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.array(sample) for sample in sample_n]
|
||||
|
||||
|
@@ -5,7 +5,7 @@ import typing
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces.space import Space
|
||||
@@ -69,11 +69,11 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
mask: None
|
||||
| (
|
||||
tuple[
|
||||
None | np.integer | npt.NDArray[np.integer],
|
||||
None | np.integer | NDArray[np.integer],
|
||||
Any,
|
||||
]
|
||||
) = None,
|
||||
) -> tuple[Any]:
|
||||
) -> tuple[Any] | Any:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
Args:
|
||||
|
@@ -104,8 +104,8 @@ class Space(Generic[T_cov]):
|
||||
|
||||
def seed(self, seed: int | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
self._np_random, np_random_seed = seeding.np_random(seed)
|
||||
return [np_random_seed]
|
||||
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
@@ -35,7 +35,7 @@ class Text(Space[str]):
|
||||
max_length: int,
|
||||
*,
|
||||
min_length: int = 1,
|
||||
charset: set[str] | str = alphanumeric,
|
||||
charset: frozenset[str] | str = alphanumeric,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
r"""Constructor of :class:`Text` space.
|
||||
@@ -76,7 +76,7 @@ class Text(Space[str]):
|
||||
|
||||
def sample(
|
||||
self,
|
||||
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None,
|
||||
mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
|
||||
) -> str:
|
||||
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
||||
|
||||
|
@@ -184,7 +184,7 @@ def _flatten_multidiscrete(
|
||||
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
|
||||
if space.is_np_flattenable:
|
||||
return np.concatenate(
|
||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
||||
[np.array(flatten(s, x_part)) for x_part, s in zip(x, space.spaces)]
|
||||
)
|
||||
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
|
||||
|
||||
@@ -192,7 +192,9 @@ def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArra
|
||||
@flatten.register(Dict)
|
||||
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
|
||||
if space.is_np_flattenable:
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
return np.concatenate(
|
||||
[np.array(flatten(s, x[key])) for key, s in space.spaces.items()]
|
||||
)
|
||||
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user