mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Fix type hints errors in gymnasium/spaces (#327)
This commit is contained in:
@@ -69,8 +69,8 @@ class Box(Space[NDArray[Any]]):
|
|||||||
this value across all dimensions.
|
this value across all dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
|
low (SupportsFloat | np.ndarray): Lower bounds of the intervals.
|
||||||
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
|
high (SupportsFloat | np.ndarray]): Upper bounds of the intervals.
|
||||||
shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with
|
shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with
|
||||||
`low` and `high` scalars defaulting to a shape of (1,)
|
`low` and `high` scalars defaulting to a shape of (1,)
|
||||||
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
|
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
|
||||||
@@ -104,12 +104,13 @@ class Box(Space[NDArray[Any]]):
|
|||||||
|
|
||||||
# Capture the boundedness information before replacing np.inf with get_inf
|
# Capture the boundedness information before replacing np.inf with get_inf
|
||||||
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
|
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
|
||||||
self.bounded_below: bool = -np.inf < _low
|
self.bounded_below: NDArray[np.bool_] = -np.inf < _low
|
||||||
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
|
||||||
self.bounded_above: bool = np.inf > _high
|
|
||||||
|
|
||||||
low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-")
|
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
||||||
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+")
|
self.bounded_above: NDArray[np.bool_] = np.inf > _high
|
||||||
|
|
||||||
|
low = _broadcast(low, self.dtype, shape, inf_sign="-")
|
||||||
|
high = _broadcast(high, self.dtype, shape, inf_sign="+")
|
||||||
|
|
||||||
assert isinstance(low, np.ndarray)
|
assert isinstance(low, np.ndarray)
|
||||||
assert (
|
assert (
|
||||||
@@ -280,7 +281,7 @@ class Box(Space[NDArray[Any]]):
|
|||||||
self.high_repr = _short_repr(self.high)
|
self.high_repr = _short_repr(self.high)
|
||||||
|
|
||||||
|
|
||||||
def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat:
|
def get_inf(dtype: np.dtype, sign: str) -> int | float:
|
||||||
"""Returns an infinite that doesn't break things.
|
"""Returns an infinite that doesn't break things.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -110,7 +110,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
|||||||
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
||||||
|
|
||||||
# None for shape and dtype, since it'll require special handling
|
# None for shape and dtype, since it'll require special handling
|
||||||
super().__init__(None, None, seed)
|
super().__init__(None, None, seed) # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_np_flattenable(self):
|
def is_np_flattenable(self):
|
||||||
@@ -226,7 +226,9 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
|||||||
for key, space in self.spaces.items()
|
for key, space in self.spaces.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
|
def from_jsonable(
|
||||||
|
self, sample_n: dict[str, list[Any]]
|
||||||
|
) -> list[OrderedDict[str, Any]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
dict_of_list: dict[str, list[Any]] = {
|
dict_of_list: dict[str, list[Any]] = {
|
||||||
key: space.from_jsonable(sample_n[key])
|
key: space.from_jsonable(sample_n[key])
|
||||||
|
@@ -55,7 +55,7 @@ class Discrete(Space[np.int64]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def sample(self, mask: MaskNDArray | None = None) -> int:
|
def sample(self, mask: MaskNDArray | None = None) -> np.int64:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
A sample will be chosen uniformly at random with the mask if provided
|
A sample will be chosen uniformly at random with the mask if provided
|
||||||
|
@@ -230,9 +230,9 @@ class Graph(Space[GraphInstance]):
|
|||||||
|
|
||||||
def to_jsonable(
|
def to_jsonable(
|
||||||
self, sample_n: Sequence[GraphInstance]
|
self, sample_n: Sequence[GraphInstance]
|
||||||
) -> list[dict[str, list[int] | list[float]]]:
|
) -> list[dict[str, list[int | float]]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
ret_n: list[dict[str, list[int | float]]] = []
|
ret_n = []
|
||||||
for sample in sample_n:
|
for sample in sample_n:
|
||||||
ret = {"nodes": sample.nodes.tolist()}
|
ret = {"nodes": sample.nodes.tolist()}
|
||||||
if sample.edges is not None and sample.edge_links is not None:
|
if sample.edges is not None and sample.edge_links is not None:
|
||||||
|
@@ -4,12 +4,12 @@ from __future__ import annotations
|
|||||||
from typing import Any, Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from gymnasium.spaces.space import MaskNDArray, Space
|
from gymnasium.spaces.space import MaskNDArray, Space
|
||||||
|
|
||||||
|
|
||||||
class MultiBinary(Space[npt.NDArray[np.int8]]):
|
class MultiBinary(Space[NDArray[np.int8]]):
|
||||||
"""An n-shape binary space.
|
"""An n-shape binary space.
|
||||||
|
|
||||||
Elements of this space are binary arrays of a shape that is fixed during construction.
|
Elements of this space are binary arrays of a shape that is fixed during construction.
|
||||||
@@ -28,7 +28,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
|
n: NDArray[np.integer[Any]] | Sequence[int] | int,
|
||||||
seed: int | np.random.Generator | None = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`MultiBinary` space.
|
"""Constructor of :class:`MultiBinary` space.
|
||||||
@@ -58,7 +58,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
|
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||||
@@ -104,15 +104,11 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
|||||||
and np.all(np.logical_or(x == 0, x == 1))
|
and np.all(np.logical_or(x == 0, x == 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(
|
def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
|
||||||
self, sample_n: Sequence[npt.NDArray[np.int8]]
|
|
||||||
) -> list[Sequence[int]]:
|
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return np.array(sample_n).tolist()
|
return np.array(sample_n).tolist()
|
||||||
|
|
||||||
def from_jsonable(
|
def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
|
||||||
self, sample_n: list[Sequence[int]]
|
|
||||||
) -> list[npt.NDArray[np.int8]]:
|
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
||||||
|
|
||||||
|
@@ -4,14 +4,14 @@ from __future__ import annotations
|
|||||||
from typing import Any, Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.spaces.discrete import Discrete
|
from gymnasium.spaces.discrete import Discrete
|
||||||
from gymnasium.spaces.space import MaskNDArray, Space
|
from gymnasium.spaces.space import MaskNDArray, Space
|
||||||
|
|
||||||
|
|
||||||
class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
class MultiDiscrete(Space[NDArray[np.integer]]):
|
||||||
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
||||||
|
|
||||||
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
|
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
|
||||||
@@ -41,7 +41,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
nvec: npt.NDArray[np.integer[Any]] | list[int],
|
nvec: NDArray[np.integer[Any]] | list[int],
|
||||||
dtype: str | type[np.integer[Any]] = np.int64,
|
dtype: str | type[np.integer[Any]] = np.int64,
|
||||||
seed: int | np.random.Generator | None = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
@@ -72,7 +72,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
|||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self, mask: tuple[MaskNDArray, ...] | None = None
|
self, mask: tuple[MaskNDArray, ...] | None = None
|
||||||
) -> npt.NDArray[np.integer[Any]]:
|
) -> NDArray[np.integer[Any]]:
|
||||||
"""Generates a single random sample this space.
|
"""Generates a single random sample this space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -88,7 +88,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
|||||||
def _apply_mask(
|
def _apply_mask(
|
||||||
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
|
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
|
||||||
sub_nvec: MaskNDArray | np.integer[Any],
|
sub_nvec: MaskNDArray | np.integer[Any],
|
||||||
) -> int | Sequence[int]:
|
) -> int | list[Any]:
|
||||||
if isinstance(sub_nvec, np.ndarray):
|
if isinstance(sub_nvec, np.ndarray):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
sub_mask, tuple
|
sub_mask, tuple
|
||||||
@@ -144,14 +144,14 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(
|
def to_jsonable(
|
||||||
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]]
|
self, sample_n: Sequence[NDArray[np.integer[Any]]]
|
||||||
) -> list[Sequence[int]]:
|
) -> list[Sequence[int]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return [sample.tolist() for sample in sample_n]
|
return [sample.tolist() for sample in sample_n]
|
||||||
|
|
||||||
def from_jsonable(
|
def from_jsonable(
|
||||||
self, sample_n: list[Sequence[int]]
|
self, sample_n: list[Sequence[int]]
|
||||||
) -> list[npt.NDArray[np.integer[Any]]]:
|
) -> list[NDArray[np.integer[Any]]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [np.array(sample) for sample in sample_n]
|
return [np.array(sample) for sample in sample_n]
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ import typing
|
|||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import Space
|
||||||
@@ -69,11 +69,11 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
|||||||
mask: None
|
mask: None
|
||||||
| (
|
| (
|
||||||
tuple[
|
tuple[
|
||||||
None | np.integer | npt.NDArray[np.integer],
|
None | np.integer | NDArray[np.integer],
|
||||||
Any,
|
Any,
|
||||||
]
|
]
|
||||||
) = None,
|
) = None,
|
||||||
) -> tuple[Any]:
|
) -> tuple[Any] | Any:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -104,8 +104,8 @@ class Space(Generic[T_cov]):
|
|||||||
|
|
||||||
def seed(self, seed: int | None = None) -> list[int]:
|
def seed(self, seed: int | None = None) -> list[int]:
|
||||||
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
||||||
self._np_random, seed = seeding.np_random(seed)
|
self._np_random, np_random_seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [np_random_seed]
|
||||||
|
|
||||||
def contains(self, x: Any) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import Space
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ class Text(Space[str]):
|
|||||||
max_length: int,
|
max_length: int,
|
||||||
*,
|
*,
|
||||||
min_length: int = 1,
|
min_length: int = 1,
|
||||||
charset: set[str] | str = alphanumeric,
|
charset: frozenset[str] | str = alphanumeric,
|
||||||
seed: int | np.random.Generator | None = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Text` space.
|
r"""Constructor of :class:`Text` space.
|
||||||
@@ -76,7 +76,7 @@ class Text(Space[str]):
|
|||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None,
|
mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
||||||
|
|
||||||
|
@@ -184,7 +184,7 @@ def _flatten_multidiscrete(
|
|||||||
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
|
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
return np.concatenate(
|
return np.concatenate(
|
||||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
[np.array(flatten(s, x_part)) for x_part, s in zip(x, space.spaces)]
|
||||||
)
|
)
|
||||||
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
|
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
|
||||||
|
|
||||||
@@ -192,7 +192,9 @@ def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArra
|
|||||||
@flatten.register(Dict)
|
@flatten.register(Dict)
|
||||||
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
|
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
return np.concatenate(
|
||||||
|
[np.array(flatten(s, x[key])) for key, s in space.spaces.items()]
|
||||||
|
)
|
||||||
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user