Update the spaces for complete type hinting (and updates numpy to 1.21) (#37)

This commit is contained in:
Mark Towers
2022-11-15 14:09:22 +00:00
committed by GitHub
parent 3f611a8c2e
commit 37b4c0b0a8
14 changed files with 342 additions and 267 deletions

View File

@@ -1,14 +1,16 @@
"""Implementation of a space that represents closed boxes in euclidean space."""
from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
from __future__ import annotations
from typing import Any, Iterable, Mapping, Sequence, SupportsFloat
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium import logger
from gymnasium.spaces.space import Space
def _short_repr(arr: np.ndarray) -> str:
def _short_repr(arr: NDArray[Any]) -> str:
"""Create a shortened string representation of a numpy array.
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
@@ -25,12 +27,12 @@ def _short_repr(arr: np.ndarray) -> str:
return str(arr)
def is_float_integer(var) -> bool:
def is_float_integer(var: Any) -> bool:
"""Checks if a variable is an integer or float."""
return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
class Box(Space[np.ndarray]):
class Box(Space[NDArray[Any]]):
r"""A (possibly unbounded) box in :math:`\mathbb{R}^n`.
Specifically, a Box represents the Cartesian product of n closed intervals.
@@ -52,11 +54,11 @@ class Box(Space[np.ndarray]):
def __init__(
self,
low: Union[SupportsFloat, np.ndarray],
high: Union[SupportsFloat, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: Type = np.float32,
seed: Optional[Union[int, np.random.Generator]] = None,
low: SupportsFloat | NDArray[Any],
high: SupportsFloat | NDArray[Any],
shape: Sequence[int] | None = None,
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
seed: int | np.random.Generator | None = None,
):
r"""Constructor of :class:`Box`.
@@ -102,12 +104,12 @@ class Box(Space[np.ndarray]):
# Capture the boundedness information before replacing np.inf with get_inf
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
self.bounded_below = -np.inf < _low
self.bounded_below: bool = -np.inf < _low
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
self.bounded_above = np.inf > _high
self.bounded_above: bool = np.inf > _high
low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
high = _broadcast(high, dtype, shape, inf_sign="+") # type: ignore
low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-")
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+")
assert isinstance(low, np.ndarray)
assert (
@@ -118,13 +120,13 @@ class Box(Space[np.ndarray]):
high.shape == shape
), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}"
self._shape: Tuple[int, ...] = shape
self._shape: tuple[int, ...] = shape
low_precision = get_precision(low.dtype)
high_precision = get_precision(high.dtype)
dtype_precision = get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision: # type: ignore
logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
if min(low_precision, high_precision) > dtype_precision:
gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
self.low = low.astype(self.dtype)
self.high = high.astype(self.dtype)
@@ -134,8 +136,8 @@ class Box(Space[np.ndarray]):
super().__init__(self.shape, self.dtype, seed)
@property
def shape(self) -> Tuple[int, ...]:
"""Has stricter type than gymnasium.Space - never None."""
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape
@property
@@ -168,7 +170,7 @@ class Box(Space[np.ndarray]):
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
)
def sample(self, mask: None = None) -> np.ndarray:
def sample(self, mask: None = None) -> NDArray[Any]:
r"""Generates a single random sample inside the Box.
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
@@ -193,8 +195,7 @@ class Box(Space[np.ndarray]):
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape)
# Masking arrays which classify the coordinates according to interval
# type
# Masking arrays which classify the coordinates according to interval type
unbounded = ~self.bounded_below & ~self.bounded_above
upp_bounded = ~self.bounded_below & self.bounded_above
low_bounded = self.bounded_below & ~self.bounded_above
@@ -221,10 +222,10 @@ class Box(Space[np.ndarray]):
return sample.astype(self.dtype)
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if not isinstance(x, np.ndarray):
logger.warn("Casting input x to numpy array.")
gym.logger.warn("Casting input x to numpy array.")
try:
x = np.asarray(x, dtype=self.dtype)
except (ValueError, TypeError):
@@ -237,11 +238,11 @@ class Box(Space[np.ndarray]):
and np.all(x <= self.high)
)
def to_jsonable(self, sample_n):
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[NDArray[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n: Sequence[Union[float, int]]) -> List[np.ndarray]:
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n]
@@ -256,7 +257,7 @@ class Box(Space[np.ndarray]):
"""
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
return (
isinstance(other, Box)
@@ -266,7 +267,7 @@ class Box(Space[np.ndarray]):
and np.allclose(self.high, other.high)
)
def __setstate__(self, state: Dict):
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
"""Sets the state of the box for unpickling a box with legacy support."""
super().__setstate__(state)
@@ -278,7 +279,7 @@ class Box(Space[np.ndarray]):
self.high_repr = _short_repr(self.high)
def get_inf(dtype, sign: str) -> SupportsFloat:
def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat:
"""Returns an infinite that doesn't break things.
Args:
@@ -310,7 +311,7 @@ def get_inf(dtype, sign: str) -> SupportsFloat:
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
def get_precision(dtype) -> SupportsFloat:
def get_precision(dtype: np.dtype) -> SupportsFloat:
"""Get precision of a data type."""
if np.issubdtype(dtype, np.floating):
return np.finfo(dtype).precision
@@ -319,14 +320,14 @@ def get_precision(dtype) -> SupportsFloat:
def _broadcast(
value: Union[SupportsFloat, np.ndarray],
dtype,
shape: Tuple[int, ...],
value: SupportsFloat | NDArray[Any],
dtype: np.dtype,
shape: tuple[int, ...],
inf_sign: str,
) -> np.ndarray:
) -> NDArray[Any]:
"""Handle infinite bounds and broadcast at the same time if needed."""
if is_float_integer(value):
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
value = get_inf(dtype, inf_sign) if np.isinf(value) else value
value = np.full(shape, value, dtype=dtype)
else:
assert isinstance(value, np.ndarray)

View File

@@ -1,18 +1,17 @@
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
from __future__ import annotations
import collections.abc
import typing
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Dict as TypingDict
from typing import List, Optional
from typing import Sequence as TypingSequence
from typing import Tuple, Union
from typing import Any, Sequence
import numpy as np
from gymnasium.spaces.space import Space
class Dict(Space[TypingDict[str, Space]], Mapping):
class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
"""A dictionary of :class:`Space` instances.
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
@@ -53,13 +52,8 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def __init__(
self,
spaces: Optional[
Union[
TypingDict[str, Space],
TypingSequence[Tuple[str, Space]],
]
] = None,
seed: Optional[Union[dict, int, np.random.Generator]] = None,
spaces: None | dict[str, Space] | Sequence[tuple[str, Space]] = None,
seed: dict | int | np.random.Generator | None = None,
**spaces_kwargs: Space,
):
"""Constructor of :class:`Dict` space.
@@ -82,7 +76,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
"""
# Convert the spaces into an OrderedDict
if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict):
if isinstance(spaces, collections.abc.Mapping) and not isinstance(
spaces, OrderedDict
):
try:
spaces = OrderedDict(sorted(spaces.items()))
except TypeError:
@@ -107,22 +103,21 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
f"Dict space keyword '{key}' already exists in the spaces dictionary."
)
self.spaces = spaces
self.spaces: dict[str, Space[Any]] = spaces
for key, space in self.spaces.items():
assert isinstance(
space, Space
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
super().__init__(
None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling
# None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed)
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces.values())
def seed(self, seed: Optional[Union[dict, int]] = None) -> list:
def seed(self, seed: dict[str, Any] | int | None = None) -> list[int]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
@@ -133,7 +128,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
seeds = []
seeds: list[int] = []
if isinstance(seed, dict):
assert (
@@ -159,7 +154,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
return seeds
def sample(self, mask: Optional[TypingDict[str, Any]] = None) -> dict:
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
"""Generates a single random sample from this space.
The sample is an ordered dictionary of independent samples from the constituent spaces.
@@ -183,17 +178,17 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, dict) and x.keys() == self.spaces.keys():
return all(x[key] in self.spaces[key] for key in self.spaces.keys())
return False
def __getitem__(self, key: str) -> Space:
def __getitem__(self, key: str) -> Space[Any]:
"""Get the space that is associated to `key`."""
return self.spaces[key]
def __setitem__(self, key: str, value: Space):
def __setitem__(self, key: str, value: Space[Any]):
"""Set the space that is associated to `key`."""
assert isinstance(
value, Space
@@ -214,7 +209,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
"Dict(" + ", ".join([f"{k!r}: {s}" for k, s in self.spaces.items()]) + ")"
)
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether `other` is equivalent to this instance."""
return (
isinstance(other, Dict)
@@ -222,7 +217,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
and self.spaces == other.spaces # OrderedDict.__eq__
)
def to_jsonable(self, sample_n: list) -> dict:
def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as dict-repr of vectors
return {
@@ -230,9 +225,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
for key, space in self.spaces.items()
}
def from_jsonable(self, sample_n: TypingDict[str, list]) -> List[dict]:
def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
"""Convert a JSONable data type to a batch of samples from this space."""
dict_of_list: TypingDict[str, list] = {
dict_of_list: dict[str, list[Any]] = {
key: space.from_jsonable(sample_n[key])
for key, space in self.spaces.items()
}

View File

@@ -1,9 +1,11 @@
"""Implementation of a space consisting of finitely many elements."""
from typing import Optional, Union
from __future__ import annotations
from typing import Any, Iterable, Mapping
import numpy as np
from gymnasium.spaces.space import Space
from gymnasium.spaces.space import MaskNDArray, Space
class Discrete(Space[int]):
@@ -20,7 +22,7 @@ class Discrete(Space[int]):
def __init__(
self,
n: int,
seed: Optional[Union[int, np.random.Generator]] = None,
seed: int | np.random.Generator | None = None,
start: int = 0,
):
r"""Constructor of :class:`Discrete` space.
@@ -44,7 +46,7 @@ class Discrete(Space[int]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: Optional[np.ndarray] = None) -> int:
def sample(self, mask: MaskNDArray | None = None) -> int:
"""Generates a single random sample from this space.
A sample will be chosen uniformly at random with the mask if provided
@@ -80,14 +82,14 @@ class Discrete(Space[int]):
return int(self.start + self.np_random.integers(self.n))
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, int):
as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and (
np.issubdtype(x.dtype, np.integer) and x.shape == ()
):
as_int = int(x) # type: ignore
as_int = int(x)
else:
return False
@@ -99,7 +101,7 @@ class Discrete(Space[int]):
return f"Discrete({self.n}, start={self.start})"
return f"Discrete({self.n})"
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return (
isinstance(other, Discrete)
@@ -107,7 +109,7 @@ class Discrete(Space[int]):
and self.start == other.start
)
def __setstate__(self, state):
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
"""Used when loading a pickled space.
This method has to be implemented explicitly to allow for loading of legacy states.

View File

@@ -1,9 +1,12 @@
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from typing import NamedTuple, Optional, Sequence, Tuple, Union
from __future__ import annotations
from typing import Any, NamedTuple, Sequence
import numpy as np
from numpy.typing import NDArray
from gymnasium.logger import warn
import gymnasium as gym
from gymnasium.spaces.box import Box
from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.multi_discrete import MultiDiscrete
@@ -18,25 +21,24 @@ class GraphInstance(NamedTuple):
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.
"""
nodes: np.ndarray
edges: Optional[np.ndarray]
edge_links: Optional[np.ndarray]
nodes: NDArray[Any]
edges: NDArray[Any] | None
edge_links: NDArray[Any] | None
class Graph(Space):
class Graph(Space[GraphInstance]):
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
Example usage::
>>> from gymnasium.spaces import Box, Discrete
>>> Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3))
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3))
"""
def __init__(
self,
node_space: Union[Box, Discrete],
edge_space: Union[None, Box, Discrete],
seed: Optional[Union[int, np.random.Generator]] = None,
node_space: Box | Discrete,
edge_space: None | Box | Discrete,
seed: int | np.random.Generator | None = None,
):
r"""Constructor of :class:`Graph`.
@@ -70,8 +72,8 @@ class Graph(Space):
return False
def _generate_sample_space(
self, base_space: Union[None, Box, Discrete], num: int
) -> Optional[Union[Box, MultiDiscrete]]:
self, base_space: None | Box | Discrete, num: int
) -> Box | MultiDiscrete | None:
if num == 0 or base_space is None:
return None
@@ -92,14 +94,15 @@ class Graph(Space):
def sample(
self,
mask: Optional[
Tuple[
Optional[Union[np.ndarray, tuple]],
Optional[Union[np.ndarray, tuple]],
mask: None
| (
tuple[
NDArray[Any] | tuple[Any, ...] | None,
NDArray[Any] | tuple[Any, ...] | None,
]
] = None,
) = None,
num_nodes: int = 10,
num_edges: Optional[int] = None,
num_edges: int | None = None,
) -> GraphInstance:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
@@ -134,7 +137,7 @@ class Graph(Space):
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
else:
if self.edge_space is None:
warn(
gym.logger.warn(
f"The number of edges is set ({num_edges}) but the edge space is None."
)
assert (
@@ -198,7 +201,7 @@ class Graph(Space):
"""
return f"Graph({self.node_space}, {self.edge_space})"
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether `other` is equivalent to this instance."""
return (
isinstance(other, Graph)
@@ -206,22 +209,24 @@ class Graph(Space):
and (self.edge_space == other.edge_space)
)
def to_jsonable(self, sample_n: NamedTuple) -> list:
def to_jsonable(
self, sample_n: Sequence[GraphInstance]
) -> list[dict[str, list[int] | list[float]]]:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list of dicts
ret_n = []
ret_n: list[dict[str, list[int | float]]] = []
for sample in sample_n:
ret = {}
ret["nodes"] = sample.nodes.tolist()
if sample.edges is not None:
ret = {"nodes": sample.nodes.tolist()}
if sample.edges is not None and sample.edge_links is not None:
ret["edges"] = sample.edges.tolist()
ret["edge_links"] = sample.edge_links.tolist()
ret_n.append(ret)
return ret_n
def from_jsonable(self, sample_n: Sequence[dict]) -> list:
def from_jsonable(
self, sample_n: Sequence[dict[str, list[list[int] | list[float]]]]
) -> list[GraphInstance]:
"""Convert a JSONable data type to a batch of samples from this space."""
ret = []
ret: list[GraphInstance] = []
for sample in sample_n:
if "edges" in sample:
ret_n = GraphInstance(

View File

@@ -1,12 +1,15 @@
"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
from typing import Optional, Sequence, Tuple, Union
from __future__ import annotations
from typing import Any, Sequence
import numpy as np
import numpy.typing as npt
from gymnasium.spaces.space import Space
from gymnasium.spaces.space import MaskNDArray, Space
class MultiBinary(Space[np.ndarray]):
class MultiBinary(Space[npt.NDArray[np.int8]]):
"""An n-shape binary space.
Elements of this space are binary arrays of a shape that is fixed during construction.
@@ -25,8 +28,8 @@ class MultiBinary(Space[np.ndarray]):
def __init__(
self,
n: Union[np.ndarray, Sequence[int], int],
seed: Optional[Union[int, np.random.Generator]] = None,
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
seed: int | np.random.Generator | None = None,
):
"""Constructor of :class:`MultiBinary` space.
@@ -46,8 +49,8 @@ class MultiBinary(Space[np.ndarray]):
super().__init__(input_n, np.int8, seed)
@property
def shape(self) -> Tuple[int, ...]:
"""Has stricter type than gymnasium.Space - never None."""
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
@property
@@ -55,7 +58,7 @@ class MultiBinary(Space[np.ndarray]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
"""Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
@@ -90,7 +93,7 @@ class MultiBinary(Space[np.ndarray]):
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
@@ -98,14 +101,18 @@ class MultiBinary(Space[np.ndarray]):
return bool(
isinstance(x, np.ndarray)
and self.shape == x.shape
and np.all((x == 0) | (x == 1))
and np.all(np.logical_or(x == 0, x == 1))
)
def to_jsonable(self, sample_n) -> list:
def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.int8]]
) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n) -> list:
def from_jsonable(
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.int8]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample, self.dtype) for sample in sample_n]
@@ -113,6 +120,6 @@ class MultiBinary(Space[np.ndarray]):
"""Gives a string representation of this space."""
return f"MultiBinary({self.n})"
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether `other` is equivalent to this instance."""
return isinstance(other, MultiBinary) and self.n == other.n

View File

@@ -1,14 +1,17 @@
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
from typing import Iterable, List, Optional, Sequence, Tuple, Union
from __future__ import annotations
from typing import Any, Sequence
import numpy as np
import numpy.typing as npt
from gymnasium import logger
import gymnasium as gym
from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.space import Space
from gymnasium.spaces.space import MaskNDArray, Space
class MultiDiscrete(Space[np.ndarray]):
class MultiDiscrete(Space[npt.NDArray[np.integer]]):
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
@@ -29,17 +32,17 @@ class MultiDiscrete(Space[np.ndarray]):
Example::
>>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
>>> d.sample()
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
>> d.sample()
array([[0, 0],
[2, 3]])
"""
def __init__(
self,
nvec: Union[np.ndarray, list],
dtype=np.int64,
seed: Optional[Union[int, np.random.Generator]] = None,
nvec: npt.NDArray[np.integer[Any]] | list[int],
dtype: str | type[np.integer[Any]] = np.int64,
seed: int | np.random.Generator | None = None,
):
"""Constructor of :class:`MultiDiscrete` space.
@@ -57,8 +60,8 @@ class MultiDiscrete(Space[np.ndarray]):
super().__init__(self.nvec.shape, dtype, seed)
@property
def shape(self) -> Tuple[int, ...]:
"""Has stricter type than :class:`gymnasium.Space` - never None."""
def shape(self) -> tuple[int, ...]:
"""Has stricter type than :class:`gym.Space` - never None."""
return self._shape # type: ignore
@property
@@ -66,7 +69,9 @@ class MultiDiscrete(Space[np.ndarray]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: Optional[tuple] = None) -> np.ndarray:
def sample(
self, mask: tuple[MaskNDArray, ...] | None = None
) -> npt.NDArray[np.integer[Any]]:
"""Generates a single random sample this space.
Args:
@@ -80,9 +85,9 @@ class MultiDiscrete(Space[np.ndarray]):
if mask is not None:
def _apply_mask(
sub_mask: Union[np.ndarray, tuple],
sub_nvec: Union[np.ndarray, np.integer],
) -> Union[int, List[int]]:
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
sub_nvec: MaskNDArray | np.integer[Any],
) -> int | Sequence[int]:
if isinstance(sub_nvec, np.ndarray):
assert isinstance(
sub_mask, tuple
@@ -122,7 +127,7 @@ class MultiDiscrete(Space[np.ndarray]):
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
@@ -137,11 +142,15 @@ class MultiDiscrete(Space[np.ndarray]):
and np.all(x < self.nvec)
)
def to_jsonable(self, sample_n: Iterable[np.ndarray]):
def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]]
) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [sample.tolist() for sample in sample_n]
def from_jsonable(self, sample_n):
def from_jsonable(
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.integer[Any]]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return np.array(sample_n)
@@ -149,7 +158,7 @@ class MultiDiscrete(Space[np.ndarray]):
"""Gives a string representation of this space."""
return f"MultiDiscrete({self.nvec})"
def __getitem__(self, index):
def __getitem__(self, index: int):
"""Extract a subspace from this ``MultiDiscrete`` space."""
nvec = self.nvec[index]
if nvec.ndim == 0:
@@ -165,11 +174,13 @@ class MultiDiscrete(Space[np.ndarray]):
def __len__(self):
"""Gives the ``len`` of samples from this space."""
if self.nvec.ndim >= 2:
logger.warn(
gym.logger.warn(
"Getting the length of a multi-dimensional MultiDiscrete space."
)
return len(self.nvec)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
return bool(
isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
)

View File

@@ -1,20 +1,24 @@
"""Implementation of a space that represents finite-length sequences."""
from collections.abc import Sequence as CollectionSequence
from typing import Any, List, Optional, Tuple, Union
from __future__ import annotations
import collections.abc
import typing
from typing import Any
import numpy as np
import numpy.typing as npt
import gymnasium as gym
from gymnasium.spaces.space import Space
class Sequence(Space[Tuple]):
class Sequence(Space[typing.Tuple[Any, ...]]):
r"""This space represent sets of finite-length sequences.
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
to some space that is specified during initialization and the integer :math:`n` is not fixed
Example::
>>> from gymnasium.spaces import Box
>>> space = Sequence(Box(0, 1))
>>> space.sample()
(array([0.0259352], dtype=float32),)
@@ -24,8 +28,8 @@ class Sequence(Space[Tuple]):
def __init__(
self,
space: Space,
seed: Optional[Union[int, np.random.Generator]] = None,
space: Space[Any],
seed: int | np.random.Generator | None = None,
):
"""Constructor of the :class:`Sequence` space.
@@ -34,14 +38,14 @@ class Sequence(Space[Tuple]):
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
assert isinstance(
space, gym.Space
), f"Expects the feature space to be instance of a gymnasium Space, actual type: {type(space)}"
space, Space
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
self.feature_space = space
super().__init__(
None, None, seed # type: ignore
) # None for shape and dtype, since it'll require special handling
def seed(self, seed: Optional[int] = None) -> list:
# None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed)
def seed(self, seed: int | None = None) -> list[int]:
"""Seed the PRNG of this space and the feature space."""
seeds = super().seed(seed)
seeds += self.feature_space.seed(seed)
@@ -54,8 +58,14 @@ class Sequence(Space[Tuple]):
def sample(
self,
mask: Optional[Tuple[Optional[Union[np.ndarray, int]], Optional[Any]]] = None,
) -> Tuple[Any]:
mask: None
| (
tuple[
None | np.integer | npt.NDArray[np.integer],
Any,
]
) = None,
) -> tuple[Any]:
"""Generates a single random sample from this space.
Args:
@@ -89,6 +99,9 @@ class Sequence(Space[Tuple]):
assert np.all(
0 <= length_mask
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
assert np.issubdtype(
length_mask.dtype, np.integer
), f"Expects the length mask array to have dtype to be an numpy integer, actual type: {length_mask.dtype}"
length = self.np_random.choice(length_mask)
else:
raise TypeError(
@@ -102,9 +115,9 @@ class Sequence(Space[Tuple]):
self.feature_space.sample(mask=feature_mask) for _ in range(length)
)
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
return isinstance(x, CollectionSequence) and all(
return isinstance(x, collections.abc.Sequence) and all(
self.feature_space.contains(item) for item in x
)
@@ -112,15 +125,17 @@ class Sequence(Space[Tuple]):
"""Gives a string representation of this space."""
return f"Sequence({self.feature_space})"
def to_jsonable(self, sample_n: list) -> list:
def to_jsonable(
self, sample_n: typing.Sequence[tuple[Any, ...]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as dict-repr of vectors
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
def from_jsonable(self, sample_n: List[List[Any]]) -> list:
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, Sequence) and self.feature_space == other.feature_space

View File

@@ -1,26 +1,19 @@
"""Implementation of the `Space` metaclass."""
from __future__ import annotations
from typing import (
Any,
Generic,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar
import numpy as np
import numpy.typing as npt
from gymnasium.utils import seeding
T_cov = TypeVar("T_cov", covariant=True)
MaskNDArray = npt.NDArray[np.int8]
class Space(Generic[T_cov]):
"""Superclass that is used to define observation and action spaces.
@@ -41,17 +34,17 @@ class Space(Generic[T_cov]):
class. However, most use-cases should be covered by the existing space
classes (e.g. :class:`Box`, :class:`Discrete`, etc...), and container classes (:class`Tuple` &
:class:`Dict`). Note that parametrized probability distributions (through the
:meth:`Space.sample()` method), and batching functions (in :class:`gymnasium.vector.VectorEnv`), are
only well-defined for instances of spaces provided in gymnasium by default.
:meth:`Space.sample()` method), and batching functions (in :class:`gym.vector.VectorEnv`), are
only well-defined for instances of spaces provided in gym by default.
Moreover, some implementations of Reinforcement Learning algorithms might
not handle custom spaces properly. Use custom spaces with care.
"""
def __init__(
self,
shape: Optional[Sequence[int]] = None,
dtype: Optional[Union[Type, str, np.dtype]] = None,
seed: Optional[Union[int, np.random.Generator]] = None,
shape: Sequence[int] | None = None,
dtype: npt.DTypeLike | None = None,
seed: int | np.random.Generator | None = None,
):
"""Constructor of :class:`Space`.
@@ -71,23 +64,31 @@ class Space(Generic[T_cov]):
@property
def np_random(self) -> np.random.Generator:
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space."""
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space.
As :meth:`seed` is not guaranteed to set the `_np_random` for particular seeds. We add a
check after :meth:`seed` to set a new random number generator.
"""
if self._np_random is None:
self.seed()
return self._np_random # type: ignore ## self.seed() call guarantees right type.
# As `seed` is not guaranteed (in particular for composite spaces) to set the `_np_random` then we set it randomly.
if self._np_random is None:
self._np_random, _ = seeding.np_random()
return self._np_random
@property
def shape(self) -> Optional[Tuple[int, ...]]:
def shape(self) -> tuple[int, ...] | None:
"""Return the shape of the space as an immutable property."""
return self._shape
@property
def is_np_flattenable(self):
def is_np_flattenable(self) -> bool:
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
raise NotImplementedError
def sample(self, mask: Optional[Any] = None) -> T_cov:
def sample(self, mask: Any | None = None) -> T_cov:
"""Randomly sample an element of this space.
Can be uniform or non-uniform sampling based on boundedness of space.
@@ -100,20 +101,20 @@ class Space(Generic[T_cov]):
"""
raise NotImplementedError
def seed(self, seed: Optional[int] = None) -> list:
def seed(self, seed: int | None = None) -> list[int]:
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
self._np_random, seed = seeding.np_random(seed)
return [seed]
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
raise NotImplementedError
def __contains__(self, x) -> bool:
def __contains__(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
return self.contains(x)
def __setstate__(self, state: Union[Iterable, Mapping]):
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
"""Used when loading a pickled space.
This method was implemented explicitly to allow for loading of legacy states.
@@ -130,7 +131,7 @@ class Space(Generic[T_cov]):
# https://github.com/openai/gym/pull/1913 -- np_random
#
if "shape" in state:
state["_shape"] = state["shape"]
state["_shape"] = state.get("shape")
del state["shape"]
if "np_random" in state:
state["_np_random"] = state["np_random"]
@@ -139,12 +140,12 @@ class Space(Generic[T_cov]):
# Update our state
self.__dict__.update(state)
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list:
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list[Any]:
"""Convert a batch of samples from this space to a JSONable data type."""
# By default, assume identity is JSONable
return list(sample_n)
def from_jsonable(self, sample_n: list) -> List[T_cov]:
def from_jsonable(self, sample_n: list[Any]) -> list[T_cov]:
"""Convert a JSONable data type to a batch of samples from this space."""
# By default, assume identity is JSONable
return sample_n

View File

@@ -1,11 +1,14 @@
"""Implementation of a space that represents textual strings."""
from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
from __future__ import annotations
from typing import Any
import numpy as np
import numpy.typing as npt
from gymnasium.spaces.space import Space
alphanumeric: FrozenSet[str] = frozenset(
alphanumeric: frozenset[str] = frozenset(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
@@ -14,7 +17,6 @@ class Text(Space[str]):
r"""A space representing a string comprised of characters from a given charset.
Example::
>>> # {"", "B5", "hello", ...}
>>> Text(5)
>>> # {"0", "42", "0123456789", ...}
@@ -29,8 +31,8 @@ class Text(Space[str]):
max_length: int,
*,
min_length: int = 1,
charset: Union[Set[str], str] = alphanumeric,
seed: Optional[Union[int, np.random.Generator]] = None,
charset: set[str] | str = alphanumeric,
seed: int | np.random.Generator | None = None,
):
r"""Constructor of :class:`Text` space.
@@ -58,9 +60,9 @@ class Text(Space[str]):
self.min_length: int = int(min_length)
self.max_length: int = int(max_length)
self._char_set: FrozenSet[str] = frozenset(charset)
self._char_list: Tuple[str, ...] = tuple(charset)
self._char_index: Dict[str, np.int32] = {
self._char_set: frozenset[str] = frozenset(charset)
self._char_list: tuple[str, ...] = tuple(charset)
self._char_index: dict[str, np.int32] = {
val: np.int32(i) for i, val in enumerate(tuple(charset))
}
self._char_str: str = "".join(sorted(tuple(charset)))
@@ -69,7 +71,8 @@ class Text(Space[str]):
super().__init__(dtype=str, seed=seed)
def sample(
self, mask: Optional[Tuple[Optional[int], Optional[np.ndarray]]] = None
self,
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None,
) -> str:
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
@@ -152,7 +155,7 @@ class Text(Space[str]):
f"Text({self.min_length}, {self.max_length}, characters={self.characters})"
)
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return (
isinstance(other, Text)
@@ -162,12 +165,12 @@ class Text(Space[str]):
)
@property
def character_set(self) -> FrozenSet[str]:
def character_set(self) -> frozenset[str]:
"""Returns the character set for the space."""
return self._char_set
@property
def character_list(self) -> Tuple[str, ...]:
def character_list(self) -> tuple[str, ...]:
"""Returns a tuple of characters in the space."""
return self._char_list

View File

@@ -1,16 +1,16 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from collections.abc import Sequence as CollectionSequence
from typing import Iterable, Optional
from typing import Sequence as TypingSequence
from typing import Tuple as TypingTuple
from typing import Union
from __future__ import annotations
import collections.abc
import typing
from typing import Any, Iterable
import numpy as np
from gymnasium.spaces.space import Space
class Tuple(Space[tuple], CollectionSequence):
class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
Elements of this space are tuples of elements of the constituent spaces.
@@ -25,8 +25,8 @@ class Tuple(Space[tuple], CollectionSequence):
def __init__(
self,
spaces: Iterable[Space],
seed: Optional[Union[int, TypingSequence[int], np.random.Generator]] = None,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`Tuple` space.
@@ -40,7 +40,7 @@ class Tuple(Space[tuple], CollectionSequence):
for space in self.spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gymnasium.Space"
), "Elements of the tuple must be instances of gym.Space"
super().__init__(None, None, seed) # type: ignore
@property
@@ -48,9 +48,7 @@ class Tuple(Space[tuple], CollectionSequence):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(
self, seed: Optional[Union[int, TypingSequence[int]]] = None
) -> TypingSequence[int]:
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
@@ -61,9 +59,9 @@ class Tuple(Space[tuple], CollectionSequence):
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
seeds = []
seeds: list[int] = []
if isinstance(seed, CollectionSequence):
if isinstance(seed, collections.abc.Sequence):
assert len(seed) == len(
self.spaces
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
@@ -86,9 +84,7 @@ class Tuple(Space[tuple], CollectionSequence):
return seeds
def sample(
self, mask: Optional[TypingTuple[Optional[np.ndarray], ...]] = None
) -> tuple:
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
@@ -115,10 +111,11 @@ class Tuple(Space[tuple], CollectionSequence):
return tuple(space.sample() for space in self.spaces)
def contains(self, x) -> bool:
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list and ndarray to tuple for contains check
return (
isinstance(x, tuple)
and len(x) == len(self.spaces)
@@ -129,7 +126,9 @@ class Tuple(Space[tuple], CollectionSequence):
"""Gives a string representation of this space."""
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(self, sample_n: CollectionSequence) -> list:
def to_jsonable(
self, sample_n: typing.Sequence[tuple[Any, ...]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list-repr of tuple of vectors
return [
@@ -137,7 +136,7 @@ class Tuple(Space[tuple], CollectionSequence):
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n) -> list:
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
sample
@@ -149,7 +148,7 @@ class Tuple(Space[tuple], CollectionSequence):
)
]
def __getitem__(self, index: int) -> Space:
def __getitem__(self, index: int) -> Space[Any]:
"""Get the subspace at specific `index`."""
return self.spaces[index]
@@ -157,6 +156,6 @@ class Tuple(Space[tuple], CollectionSequence):
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, Tuple) and self.spaces == other.spaces

View File

@@ -3,13 +3,16 @@
These functions mostly take care of flattening and unflattening elements of spaces
to facilitate their usage in learning code.
"""
from __future__ import annotations
import operator as op
import typing
from collections import OrderedDict
from functools import reduce, singledispatch
from typing import Dict as TypingDict
from typing import TypeVar, Union, cast
from typing import Any, TypeVar, Union, cast
import numpy as np
from numpy.typing import NDArray
from gymnasium.spaces import (
Box,
@@ -27,12 +30,12 @@ from gymnasium.spaces import (
@singledispatch
def flatdim(space: Space) -> int:
def flatdim(space: Space[Any]) -> int:
"""Return the number of dimensions a flattened equivalent of this space would have.
Example usage::
>>> from gymnasium.spaces import Discrete
>>> from gymnasium.spaces import Discrete, Dict
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
>>> flatdim(space)
5
@@ -47,7 +50,7 @@ def flatdim(space: Space) -> int:
NotImplementedError: if the space is not defined in ``gym.spaces``.
ValueError: if the space cannot be flattened into a :class:`Box`
"""
if not space.is_np_flattenable:
if space.is_np_flattenable is False:
raise ValueError(
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
)
@@ -57,7 +60,7 @@ def flatdim(space: Space) -> int:
@flatdim.register(Box)
@flatdim.register(MultiBinary)
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
def _flatdim_box_multibinary(space: Box | MultiBinary) -> int:
return reduce(op.mul, space.shape, 1)
@@ -102,7 +105,9 @@ def _flatdim_text(space: Text) -> int:
T = TypeVar("T")
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
FlatType = Union[
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
]
@singledispatch
@@ -112,6 +117,19 @@ def flatten(space: Space[T], x: T) -> FlatType:
This is useful when e.g. points from spaces must be passed to a neural
network, which only understands flat arrays of floats.
Example usage::
>>> from gymnasium.spaces import Box, Discrete, Tuple
>>> space = Box(0, 1, shape=(3, 5))
>>> flatten(space, space.sample()).shape
(15,)
>>> space = Discrete(4)
>>> flatten(space, 2)
array([0, 0, 1, 0])
>>> space = Tuple((Box(0, 1, shape=(2,)), Box(0, 1, shape=(3,)), Discrete(3)))
>>> example = ((.5, .25), (1., 0., .2), 1)
>>> flatten(space, example)
array([0.5 , 0.25, 1. , 0. , 0.2 , 0. , 1. , 0. ])
Args:
space: The space that ``x`` is flattened by
x: The value to flatten
@@ -137,19 +155,21 @@ def flatten(space: Space[T], x: T) -> FlatType:
@flatten.register(Box)
@flatten.register(MultiBinary)
def _flatten_box_multibinary(space, x) -> np.ndarray:
def _flatten_box_multibinary(space: Box | MultiBinary, x: NDArray[Any]) -> NDArray[Any]:
return np.asarray(x, dtype=space.dtype).flatten()
@flatten.register(Discrete)
def _flatten_discrete(space, x) -> np.ndarray:
def _flatten_discrete(space: Discrete, x: int) -> NDArray[np.int64]:
onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x - space.start] = 1
return onehot
@flatten.register(MultiDiscrete)
def _flatten_multidiscrete(space, x) -> np.ndarray:
def _flatten_multidiscrete(
space: MultiDiscrete, x: NDArray[np.int64]
) -> NDArray[np.int64]:
offsets = np.zeros((space.nvec.size + 1,), dtype=np.int32)
offsets[1:] = np.cumsum(space.nvec.flatten())
@@ -159,7 +179,7 @@ def _flatten_multidiscrete(space, x) -> np.ndarray:
@flatten.register(Tuple)
def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
if space.is_np_flattenable:
return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
@@ -168,22 +188,26 @@ def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
@flatten.register(Dict)
def _flatten_dict(space, x) -> Union[dict, np.ndarray]:
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
if space.is_np_flattenable:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
@flatten.register(Graph)
def _flatten_graph(space, x) -> GraphInstance:
def _flatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
def _graph_unflatten(unflatten_space, unflatten_x):
def _graph_unflatten(
unflatten_space: Discrete | Box | None,
unflatten_x: NDArray[Any] | None,
) -> NDArray[Any] | None:
ret = None
if unflatten_space is not None and unflatten_x is not None:
if isinstance(unflatten_space, Box):
ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
elif isinstance(unflatten_space, Discrete):
else:
assert isinstance(unflatten_space, Discrete)
ret = np.zeros(
(unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
dtype=unflatten_space.dtype,
@@ -194,13 +218,14 @@ def _flatten_graph(space, x) -> GraphInstance:
return ret
nodes = _graph_unflatten(space.node_space, x.nodes)
assert nodes is not None
edges = _graph_unflatten(space.edge_space, x.edges)
return GraphInstance(nodes, edges, x.edge_links)
@flatten.register(Text)
def _flatten_text(space: Text, x: str) -> np.ndarray:
def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
arr = np.full(
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
)
@@ -210,7 +235,7 @@ def _flatten_text(space: Text, x: str) -> np.ndarray:
@flatten.register(Sequence)
def _flatten_sequence(space, x) -> tuple:
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(flatten(space.feature_space, item) for item in x)
@@ -237,18 +262,20 @@ def unflatten(space: Space[T], x: FlatType) -> T:
@unflatten.register(Box)
@unflatten.register(MultiBinary)
def _unflatten_box_multibinary(
space: Union[Box, MultiBinary], x: np.ndarray
) -> np.ndarray:
space: Box | MultiBinary, x: NDArray[Any]
) -> NDArray[Any]:
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
@unflatten.register(Discrete)
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> int:
return int(space.start + np.nonzero(x)[0][0])
@unflatten.register(MultiDiscrete)
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
def _unflatten_multidiscrete(
space: MultiDiscrete, x: NDArray[np.int32]
) -> NDArray[np.int32]:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten())
@@ -257,7 +284,9 @@ def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
@unflatten.register(Tuple)
def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
def _unflatten_tuple(
space: Tuple, x: NDArray[Any] | tuple[Any, ...]
) -> tuple[Any, ...]:
if space.is_np_flattenable:
assert isinstance(
x, np.ndarray
@@ -275,7 +304,7 @@ def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
@unflatten.register(Dict)
def _unflatten_dict(space: Dict, x: Union[np.ndarray, TypingDict]) -> dict:
def _unflatten_dict(space: Dict, x: NDArray[Any] | dict[str, Any]) -> dict[str, Any]:
if space.is_np_flattenable:
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
list_flattened = np.split(x, np.cumsum(dims[:-1]))
@@ -299,14 +328,14 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
nodes and edges in the graph.
"""
def _graph_unflatten(space, x):
ret = None
if space is not None and x is not None:
if isinstance(space, Box):
ret = x.reshape(-1, *space.shape)
elif isinstance(space, Discrete):
ret = np.asarray(np.nonzero(x))[-1, :]
return ret
def _graph_unflatten(unflatten_space, unflatten_x):
result = None
if unflatten_space is not None and unflatten_x is not None:
if isinstance(unflatten_space, Box):
result = unflatten_x.reshape(-1, *unflatten_space.shape)
elif isinstance(unflatten_space, Discrete):
result = np.asarray(np.nonzero(unflatten_x))[-1, :]
return result
nodes = _graph_unflatten(space.node_space, x.nodes)
edges = _graph_unflatten(space.edge_space, x.edges)
@@ -315,19 +344,19 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
@unflatten.register(Text)
def _unflatten_text(space: Text, x: np.ndarray) -> str:
def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
return "".join(
[space.character_list[val] for val in x if val < len(space.character_set)]
)
@unflatten.register(Sequence)
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(unflatten(space.feature_space, item) for item in x)
@singledispatch
def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
"""Flatten a space into a space that is as flat as possible.
This function will attempt to flatten `space` into a single :class:`Box` space.
@@ -342,7 +371,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
space.
Example::
>>> from gymnasium.spaces import Box
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
>>> box
Box(3, 4, 5)
@@ -352,7 +381,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
True
Example that flattens a discrete space::
>>> from gymnasium.spaces import Discrete
>>> discrete = Discrete(5)
>>> flatten_space(discrete)
Box(5,)
@@ -360,7 +389,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
True
Example that recursively flattens a dict::
>>> from gymnasium.spaces import Dict, Discrete, Box
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
>>> flatten_space(space)
Box(6,)
@@ -383,7 +412,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
A flattened Box
Raises:
NotImplementedError: if the space is not defined in ``gymnasium.spaces``.
NotImplementedError: if the space is not defined in ``gym.spaces``.
"""
raise NotImplementedError(f"Unknown space: `{space}`")
@@ -396,12 +425,12 @@ def _flatten_space_box(space: Box) -> Box:
@flatten_space.register(Discrete)
@flatten_space.register(MultiBinary)
@flatten_space.register(MultiDiscrete)
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
def _flatten_space_binary(space: Discrete | MultiBinary | MultiDiscrete) -> Box:
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
@flatten_space.register(Tuple)
def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
def _flatten_space_tuple(space: Tuple) -> Box | Tuple:
if space.is_np_flattenable:
space_list = [flatten_space(s) for s in space.spaces]
return Box(
@@ -413,7 +442,7 @@ def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
@flatten_space.register(Dict)
def _flatten_space_dict(space: Dict) -> Union[Box, Dict]:
def _flatten_space_dict(space: Dict) -> Box | Dict:
if space.is_np_flattenable:
space_list = [flatten_space(s) for s in space.spaces.values()]
return Box(

View File

@@ -1,4 +1,4 @@
numpy>=1.18.0
numpy>=1.21.0
cloudpickle>=1.2.0
importlib_metadata>=4.8.0; python_version < '3.10'
gymnasium_notices>=0.0.1

View File

@@ -85,7 +85,7 @@ setup(
},
include_package_data=True,
install_requires=[
"numpy >= 1.18.0",
"numpy >= 1.21.0",
"cloudpickle >= 1.2.0",
"importlib_metadata >= 4.8.0; python_version < '3.10'",
"gymnasium_notices >= 0.0.1",

View File

@@ -1,3 +1,5 @@
import re
import warnings
from collections import OrderedDict
import numpy as np
@@ -25,18 +27,18 @@ def test_dict_init():
):
Dict(a=Discrete(2), b="Box")
with pytest.warns(None) as warnings:
with warnings.catch_warnings(record=True) as caught_warnings:
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
assert a == b == c == d
assert len(warnings) == 0
assert len(caught_warnings) == 0
with pytest.warns(None) as warnings:
with warnings.catch_warnings(record=True) as caught_warnings:
Dict({1: Discrete(2), "a": Discrete(3)})
assert len(warnings) == 0
assert len(caught_warnings) == 0
DICT_SPACE = Dict(
@@ -109,7 +111,12 @@ def test_none_seeding():
def test_bad_seed():
with pytest.raises(TypeError):
with pytest.raises(
TypeError,
match=re.escape(
"Expected seed type: dict, int or None, actual type: <class 'str'>"
),
):
DICT_SPACE.seed("a")