mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Update the spaces for complete type hinting (and updates numpy to 1.21) (#37)
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
"""Implementation of a space that represents closed boxes in euclidean space."""
|
||||
from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Mapping, Sequence, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import logger
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
|
||||
def _short_repr(arr: np.ndarray) -> str:
|
||||
def _short_repr(arr: NDArray[Any]) -> str:
|
||||
"""Create a shortened string representation of a numpy array.
|
||||
|
||||
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
||||
@@ -25,12 +27,12 @@ def _short_repr(arr: np.ndarray) -> str:
|
||||
return str(arr)
|
||||
|
||||
|
||||
def is_float_integer(var) -> bool:
|
||||
def is_float_integer(var: Any) -> bool:
|
||||
"""Checks if a variable is an integer or float."""
|
||||
return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
|
||||
|
||||
|
||||
class Box(Space[np.ndarray]):
|
||||
class Box(Space[NDArray[Any]]):
|
||||
r"""A (possibly unbounded) box in :math:`\mathbb{R}^n`.
|
||||
|
||||
Specifically, a Box represents the Cartesian product of n closed intervals.
|
||||
@@ -52,11 +54,11 @@ class Box(Space[np.ndarray]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
low: Union[SupportsFloat, np.ndarray],
|
||||
high: Union[SupportsFloat, np.ndarray],
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: Type = np.float32,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
low: SupportsFloat | NDArray[Any],
|
||||
high: SupportsFloat | NDArray[Any],
|
||||
shape: Sequence[int] | None = None,
|
||||
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
r"""Constructor of :class:`Box`.
|
||||
|
||||
@@ -102,12 +104,12 @@ class Box(Space[np.ndarray]):
|
||||
|
||||
# Capture the boundedness information before replacing np.inf with get_inf
|
||||
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
|
||||
self.bounded_below = -np.inf < _low
|
||||
self.bounded_below: bool = -np.inf < _low
|
||||
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
||||
self.bounded_above = np.inf > _high
|
||||
self.bounded_above: bool = np.inf > _high
|
||||
|
||||
low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
|
||||
high = _broadcast(high, dtype, shape, inf_sign="+") # type: ignore
|
||||
low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-")
|
||||
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+")
|
||||
|
||||
assert isinstance(low, np.ndarray)
|
||||
assert (
|
||||
@@ -118,13 +120,13 @@ class Box(Space[np.ndarray]):
|
||||
high.shape == shape
|
||||
), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}"
|
||||
|
||||
self._shape: Tuple[int, ...] = shape
|
||||
self._shape: tuple[int, ...] = shape
|
||||
|
||||
low_precision = get_precision(low.dtype)
|
||||
high_precision = get_precision(high.dtype)
|
||||
dtype_precision = get_precision(self.dtype)
|
||||
if min(low_precision, high_precision) > dtype_precision: # type: ignore
|
||||
logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
|
||||
if min(low_precision, high_precision) > dtype_precision:
|
||||
gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
|
||||
self.low = low.astype(self.dtype)
|
||||
self.high = high.astype(self.dtype)
|
||||
|
||||
@@ -134,8 +136,8 @@ class Box(Space[np.ndarray]):
|
||||
super().__init__(self.shape, self.dtype, seed)
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Has stricter type than gymnasium.Space - never None."""
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
@@ -168,7 +170,7 @@ class Box(Space[np.ndarray]):
|
||||
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
|
||||
)
|
||||
|
||||
def sample(self, mask: None = None) -> np.ndarray:
|
||||
def sample(self, mask: None = None) -> NDArray[Any]:
|
||||
r"""Generates a single random sample inside the Box.
|
||||
|
||||
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
|
||||
@@ -193,8 +195,7 @@ class Box(Space[np.ndarray]):
|
||||
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
||||
sample = np.empty(self.shape)
|
||||
|
||||
# Masking arrays which classify the coordinates according to interval
|
||||
# type
|
||||
# Masking arrays which classify the coordinates according to interval type
|
||||
unbounded = ~self.bounded_below & ~self.bounded_above
|
||||
upp_bounded = ~self.bounded_below & self.bounded_above
|
||||
low_bounded = self.bounded_below & ~self.bounded_above
|
||||
@@ -221,10 +222,10 @@ class Box(Space[np.ndarray]):
|
||||
|
||||
return sample.astype(self.dtype)
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if not isinstance(x, np.ndarray):
|
||||
logger.warn("Casting input x to numpy array.")
|
||||
gym.logger.warn("Casting input x to numpy array.")
|
||||
try:
|
||||
x = np.asarray(x, dtype=self.dtype)
|
||||
except (ValueError, TypeError):
|
||||
@@ -237,11 +238,11 @@ class Box(Space[np.ndarray]):
|
||||
and np.all(x <= self.high)
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[NDArray[Any]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(self, sample_n: Sequence[Union[float, int]]) -> List[np.ndarray]:
|
||||
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
|
||||
@@ -256,7 +257,7 @@ class Box(Space[np.ndarray]):
|
||||
"""
|
||||
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
|
||||
return (
|
||||
isinstance(other, Box)
|
||||
@@ -266,7 +267,7 @@ class Box(Space[np.ndarray]):
|
||||
and np.allclose(self.high, other.high)
|
||||
)
|
||||
|
||||
def __setstate__(self, state: Dict):
|
||||
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
|
||||
"""Sets the state of the box for unpickling a box with legacy support."""
|
||||
super().__setstate__(state)
|
||||
|
||||
@@ -278,7 +279,7 @@ class Box(Space[np.ndarray]):
|
||||
self.high_repr = _short_repr(self.high)
|
||||
|
||||
|
||||
def get_inf(dtype, sign: str) -> SupportsFloat:
|
||||
def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat:
|
||||
"""Returns an infinite that doesn't break things.
|
||||
|
||||
Args:
|
||||
@@ -310,7 +311,7 @@ def get_inf(dtype, sign: str) -> SupportsFloat:
|
||||
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
|
||||
|
||||
|
||||
def get_precision(dtype) -> SupportsFloat:
|
||||
def get_precision(dtype: np.dtype) -> SupportsFloat:
|
||||
"""Get precision of a data type."""
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
return np.finfo(dtype).precision
|
||||
@@ -319,14 +320,14 @@ def get_precision(dtype) -> SupportsFloat:
|
||||
|
||||
|
||||
def _broadcast(
|
||||
value: Union[SupportsFloat, np.ndarray],
|
||||
dtype,
|
||||
shape: Tuple[int, ...],
|
||||
value: SupportsFloat | NDArray[Any],
|
||||
dtype: np.dtype,
|
||||
shape: tuple[int, ...],
|
||||
inf_sign: str,
|
||||
) -> np.ndarray:
|
||||
) -> NDArray[Any]:
|
||||
"""Handle infinite bounds and broadcast at the same time if needed."""
|
||||
if is_float_integer(value):
|
||||
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
|
||||
value = get_inf(dtype, inf_sign) if np.isinf(value) else value
|
||||
value = np.full(shape, value, dtype=dtype)
|
||||
else:
|
||||
assert isinstance(value, np.ndarray)
|
||||
|
@@ -1,18 +1,17 @@
|
||||
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Dict as TypingDict
|
||||
from typing import List, Optional
|
||||
from typing import Sequence as TypingSequence
|
||||
from typing import Tuple, Union
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
|
||||
class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||
"""A dictionary of :class:`Space` instances.
|
||||
|
||||
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
||||
@@ -53,13 +52,8 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spaces: Optional[
|
||||
Union[
|
||||
TypingDict[str, Space],
|
||||
TypingSequence[Tuple[str, Space]],
|
||||
]
|
||||
] = None,
|
||||
seed: Optional[Union[dict, int, np.random.Generator]] = None,
|
||||
spaces: None | dict[str, Space] | Sequence[tuple[str, Space]] = None,
|
||||
seed: dict | int | np.random.Generator | None = None,
|
||||
**spaces_kwargs: Space,
|
||||
):
|
||||
"""Constructor of :class:`Dict` space.
|
||||
@@ -82,7 +76,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
|
||||
"""
|
||||
# Convert the spaces into an OrderedDict
|
||||
if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict):
|
||||
if isinstance(spaces, collections.abc.Mapping) and not isinstance(
|
||||
spaces, OrderedDict
|
||||
):
|
||||
try:
|
||||
spaces = OrderedDict(sorted(spaces.items()))
|
||||
except TypeError:
|
||||
@@ -107,22 +103,21 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
f"Dict space keyword '{key}' already exists in the spaces dictionary."
|
||||
)
|
||||
|
||||
self.spaces = spaces
|
||||
self.spaces: dict[str, Space[Any]] = spaces
|
||||
for key, space in self.spaces.items():
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
||||
|
||||
super().__init__(
|
||||
None, None, seed # type: ignore
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
# None for shape and dtype, since it'll require special handling
|
||||
super().__init__(None, None, seed)
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces.values())
|
||||
|
||||
def seed(self, seed: Optional[Union[dict, int]] = None) -> list:
|
||||
def seed(self, seed: dict[str, Any] | int | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and all subspaces.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
@@ -133,7 +128,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
"""
|
||||
seeds = []
|
||||
seeds: list[int] = []
|
||||
|
||||
if isinstance(seed, dict):
|
||||
assert (
|
||||
@@ -159,7 +154,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(self, mask: Optional[TypingDict[str, Any]] = None) -> dict:
|
||||
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
The sample is an ordered dictionary of independent samples from the constituent spaces.
|
||||
@@ -183,17 +178,17 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
|
||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, dict) and x.keys() == self.spaces.keys():
|
||||
return all(x[key] in self.spaces[key] for key in self.spaces.keys())
|
||||
return False
|
||||
|
||||
def __getitem__(self, key: str) -> Space:
|
||||
def __getitem__(self, key: str) -> Space[Any]:
|
||||
"""Get the space that is associated to `key`."""
|
||||
return self.spaces[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Space):
|
||||
def __setitem__(self, key: str, value: Space[Any]):
|
||||
"""Set the space that is associated to `key`."""
|
||||
assert isinstance(
|
||||
value, Space
|
||||
@@ -214,7 +209,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
"Dict(" + ", ".join([f"{k!r}: {s}" for k, s in self.spaces.items()]) + ")"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether `other` is equivalent to this instance."""
|
||||
return (
|
||||
isinstance(other, Dict)
|
||||
@@ -222,7 +217,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
and self.spaces == other.spaces # OrderedDict.__eq__
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n: list) -> dict:
|
||||
def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# serialize as dict-repr of vectors
|
||||
return {
|
||||
@@ -230,9 +225,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||
for key, space in self.spaces.items()
|
||||
}
|
||||
|
||||
def from_jsonable(self, sample_n: TypingDict[str, list]) -> List[dict]:
|
||||
def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
dict_of_list: TypingDict[str, list] = {
|
||||
dict_of_list: dict[str, list[Any]] = {
|
||||
key: space.from_jsonable(sample_n[key])
|
||||
for key, space in self.spaces.items()
|
||||
}
|
||||
|
@@ -1,9 +1,11 @@
|
||||
"""Implementation of a space consisting of finitely many elements."""
|
||||
from typing import Optional, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.spaces.space import Space
|
||||
from gymnasium.spaces.space import MaskNDArray, Space
|
||||
|
||||
|
||||
class Discrete(Space[int]):
|
||||
@@ -20,7 +22,7 @@ class Discrete(Space[int]):
|
||||
def __init__(
|
||||
self,
|
||||
n: int,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
start: int = 0,
|
||||
):
|
||||
r"""Constructor of :class:`Discrete` space.
|
||||
@@ -44,7 +46,7 @@ class Discrete(Space[int]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: Optional[np.ndarray] = None) -> int:
|
||||
def sample(self, mask: MaskNDArray | None = None) -> int:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
A sample will be chosen uniformly at random with the mask if provided
|
||||
@@ -80,14 +82,14 @@ class Discrete(Space[int]):
|
||||
|
||||
return int(self.start + self.np_random.integers(self.n))
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, int):
|
||||
as_int = x
|
||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
||||
np.issubdtype(x.dtype, np.integer) and x.shape == ()
|
||||
):
|
||||
as_int = int(x) # type: ignore
|
||||
as_int = int(x)
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -99,7 +101,7 @@ class Discrete(Space[int]):
|
||||
return f"Discrete({self.n}, start={self.start})"
|
||||
return f"Discrete({self.n})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return (
|
||||
isinstance(other, Discrete)
|
||||
@@ -107,7 +109,7 @@ class Discrete(Space[int]):
|
||||
and self.start == other.start
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
|
||||
"""Used when loading a pickled space.
|
||||
|
||||
This method has to be implemented explicitly to allow for loading of legacy states.
|
||||
|
@@ -1,9 +1,12 @@
|
||||
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
|
||||
from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, NamedTuple, Sequence
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from gymnasium.logger import warn
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces.box import Box
|
||||
from gymnasium.spaces.discrete import Discrete
|
||||
from gymnasium.spaces.multi_discrete import MultiDiscrete
|
||||
@@ -18,25 +21,24 @@ class GraphInstance(NamedTuple):
|
||||
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.
|
||||
"""
|
||||
|
||||
nodes: np.ndarray
|
||||
edges: Optional[np.ndarray]
|
||||
edge_links: Optional[np.ndarray]
|
||||
nodes: NDArray[Any]
|
||||
edges: NDArray[Any] | None
|
||||
edge_links: NDArray[Any] | None
|
||||
|
||||
|
||||
class Graph(Space):
|
||||
class Graph(Space[GraphInstance]):
|
||||
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> from gymnasium.spaces import Box, Discrete
|
||||
>>> Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3))
|
||||
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_space: Union[Box, Discrete],
|
||||
edge_space: Union[None, Box, Discrete],
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
node_space: Box | Discrete,
|
||||
edge_space: None | Box | Discrete,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
r"""Constructor of :class:`Graph`.
|
||||
|
||||
@@ -70,8 +72,8 @@ class Graph(Space):
|
||||
return False
|
||||
|
||||
def _generate_sample_space(
|
||||
self, base_space: Union[None, Box, Discrete], num: int
|
||||
) -> Optional[Union[Box, MultiDiscrete]]:
|
||||
self, base_space: None | Box | Discrete, num: int
|
||||
) -> Box | MultiDiscrete | None:
|
||||
if num == 0 or base_space is None:
|
||||
return None
|
||||
|
||||
@@ -92,14 +94,15 @@ class Graph(Space):
|
||||
|
||||
def sample(
|
||||
self,
|
||||
mask: Optional[
|
||||
Tuple[
|
||||
Optional[Union[np.ndarray, tuple]],
|
||||
Optional[Union[np.ndarray, tuple]],
|
||||
mask: None
|
||||
| (
|
||||
tuple[
|
||||
NDArray[Any] | tuple[Any, ...] | None,
|
||||
NDArray[Any] | tuple[Any, ...] | None,
|
||||
]
|
||||
] = None,
|
||||
) = None,
|
||||
num_nodes: int = 10,
|
||||
num_edges: Optional[int] = None,
|
||||
num_edges: int | None = None,
|
||||
) -> GraphInstance:
|
||||
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
|
||||
|
||||
@@ -134,7 +137,7 @@ class Graph(Space):
|
||||
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
|
||||
else:
|
||||
if self.edge_space is None:
|
||||
warn(
|
||||
gym.logger.warn(
|
||||
f"The number of edges is set ({num_edges}) but the edge space is None."
|
||||
)
|
||||
assert (
|
||||
@@ -198,7 +201,7 @@ class Graph(Space):
|
||||
"""
|
||||
return f"Graph({self.node_space}, {self.edge_space})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether `other` is equivalent to this instance."""
|
||||
return (
|
||||
isinstance(other, Graph)
|
||||
@@ -206,22 +209,24 @@ class Graph(Space):
|
||||
and (self.edge_space == other.edge_space)
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n: NamedTuple) -> list:
|
||||
def to_jsonable(
|
||||
self, sample_n: Sequence[GraphInstance]
|
||||
) -> list[dict[str, list[int] | list[float]]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# serialize as list of dicts
|
||||
ret_n = []
|
||||
ret_n: list[dict[str, list[int | float]]] = []
|
||||
for sample in sample_n:
|
||||
ret = {}
|
||||
ret["nodes"] = sample.nodes.tolist()
|
||||
if sample.edges is not None:
|
||||
ret = {"nodes": sample.nodes.tolist()}
|
||||
if sample.edges is not None and sample.edge_links is not None:
|
||||
ret["edges"] = sample.edges.tolist()
|
||||
ret["edge_links"] = sample.edge_links.tolist()
|
||||
ret_n.append(ret)
|
||||
return ret_n
|
||||
|
||||
def from_jsonable(self, sample_n: Sequence[dict]) -> list:
|
||||
def from_jsonable(
|
||||
self, sample_n: Sequence[dict[str, list[list[int] | list[float]]]]
|
||||
) -> list[GraphInstance]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
ret = []
|
||||
ret: list[GraphInstance] = []
|
||||
for sample in sample_n:
|
||||
if "edges" in sample:
|
||||
ret_n = GraphInstance(
|
||||
|
@@ -1,12 +1,15 @@
|
||||
"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from gymnasium.spaces.space import Space
|
||||
from gymnasium.spaces.space import MaskNDArray, Space
|
||||
|
||||
|
||||
class MultiBinary(Space[np.ndarray]):
|
||||
class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
"""An n-shape binary space.
|
||||
|
||||
Elements of this space are binary arrays of a shape that is fixed during construction.
|
||||
@@ -25,8 +28,8 @@ class MultiBinary(Space[np.ndarray]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: Union[np.ndarray, Sequence[int], int],
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
"""Constructor of :class:`MultiBinary` space.
|
||||
|
||||
@@ -46,8 +49,8 @@ class MultiBinary(Space[np.ndarray]):
|
||||
super().__init__(input_n, np.int8, seed)
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Has stricter type than gymnasium.Space - never None."""
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
"""Has stricter type than gym.Space - never None."""
|
||||
return self._shape # type: ignore
|
||||
|
||||
@property
|
||||
@@ -55,7 +58,7 @@ class MultiBinary(Space[np.ndarray]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
|
||||
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||
@@ -90,7 +93,7 @@ class MultiBinary(Space[np.ndarray]):
|
||||
|
||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, Sequence):
|
||||
x = np.array(x) # Promote list to array for contains check
|
||||
@@ -98,14 +101,18 @@ class MultiBinary(Space[np.ndarray]):
|
||||
return bool(
|
||||
isinstance(x, np.ndarray)
|
||||
and self.shape == x.shape
|
||||
and np.all((x == 0) | (x == 1))
|
||||
and np.all(np.logical_or(x == 0, x == 1))
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n) -> list:
|
||||
def to_jsonable(
|
||||
self, sample_n: Sequence[npt.NDArray[np.int8]]
|
||||
) -> list[Sequence[int]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return np.array(sample_n).tolist()
|
||||
|
||||
def from_jsonable(self, sample_n) -> list:
|
||||
def from_jsonable(
|
||||
self, sample_n: list[Sequence[int]]
|
||||
) -> list[npt.NDArray[np.int8]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
||||
|
||||
@@ -113,6 +120,6 @@ class MultiBinary(Space[np.ndarray]):
|
||||
"""Gives a string representation of this space."""
|
||||
return f"MultiBinary({self.n})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether `other` is equivalent to this instance."""
|
||||
return isinstance(other, MultiBinary) and self.n == other.n
|
||||
|
@@ -1,14 +1,17 @@
|
||||
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from gymnasium import logger
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces.discrete import Discrete
|
||||
from gymnasium.spaces.space import Space
|
||||
from gymnasium.spaces.space import MaskNDArray, Space
|
||||
|
||||
|
||||
class MultiDiscrete(Space[np.ndarray]):
|
||||
class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
||||
|
||||
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
|
||||
@@ -29,17 +32,17 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
|
||||
Example::
|
||||
|
||||
>>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
|
||||
>>> d.sample()
|
||||
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
|
||||
>> d.sample()
|
||||
array([[0, 0],
|
||||
[2, 3]])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nvec: Union[np.ndarray, list],
|
||||
dtype=np.int64,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
nvec: npt.NDArray[np.integer[Any]] | list[int],
|
||||
dtype: str | type[np.integer[Any]] = np.int64,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
"""Constructor of :class:`MultiDiscrete` space.
|
||||
|
||||
@@ -57,8 +60,8 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
super().__init__(self.nvec.shape, dtype, seed)
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Has stricter type than :class:`gymnasium.Space` - never None."""
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
"""Has stricter type than :class:`gym.Space` - never None."""
|
||||
return self._shape # type: ignore
|
||||
|
||||
@property
|
||||
@@ -66,7 +69,9 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: Optional[tuple] = None) -> np.ndarray:
|
||||
def sample(
|
||||
self, mask: tuple[MaskNDArray, ...] | None = None
|
||||
) -> npt.NDArray[np.integer[Any]]:
|
||||
"""Generates a single random sample this space.
|
||||
|
||||
Args:
|
||||
@@ -80,9 +85,9 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
if mask is not None:
|
||||
|
||||
def _apply_mask(
|
||||
sub_mask: Union[np.ndarray, tuple],
|
||||
sub_nvec: Union[np.ndarray, np.integer],
|
||||
) -> Union[int, List[int]]:
|
||||
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
|
||||
sub_nvec: MaskNDArray | np.integer[Any],
|
||||
) -> int | Sequence[int]:
|
||||
if isinstance(sub_nvec, np.ndarray):
|
||||
assert isinstance(
|
||||
sub_mask, tuple
|
||||
@@ -122,7 +127,7 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
|
||||
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, Sequence):
|
||||
x = np.array(x) # Promote list to array for contains check
|
||||
@@ -137,11 +142,15 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
and np.all(x < self.nvec)
|
||||
)
|
||||
|
||||
def to_jsonable(self, sample_n: Iterable[np.ndarray]):
|
||||
def to_jsonable(
|
||||
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]]
|
||||
) -> list[Sequence[int]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return [sample.tolist() for sample in sample_n]
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
def from_jsonable(
|
||||
self, sample_n: list[Sequence[int]]
|
||||
) -> list[npt.NDArray[np.integer[Any]]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return np.array(sample_n)
|
||||
|
||||
@@ -149,7 +158,7 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
"""Gives a string representation of this space."""
|
||||
return f"MultiDiscrete({self.nvec})"
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: int):
|
||||
"""Extract a subspace from this ``MultiDiscrete`` space."""
|
||||
nvec = self.nvec[index]
|
||||
if nvec.ndim == 0:
|
||||
@@ -165,11 +174,13 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
def __len__(self):
|
||||
"""Gives the ``len`` of samples from this space."""
|
||||
if self.nvec.ndim >= 2:
|
||||
logger.warn(
|
||||
gym.logger.warn(
|
||||
"Getting the length of a multi-dimensional MultiDiscrete space."
|
||||
)
|
||||
return len(self.nvec)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||
return bool(
|
||||
isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||
)
|
||||
|
@@ -1,20 +1,24 @@
|
||||
"""Implementation of a space that represents finite-length sequences."""
|
||||
from collections.abc import Sequence as CollectionSequence
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
|
||||
class Sequence(Space[Tuple]):
|
||||
class Sequence(Space[typing.Tuple[Any, ...]]):
|
||||
r"""This space represent sets of finite-length sequences.
|
||||
|
||||
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
|
||||
to some space that is specified during initialization and the integer :math:`n` is not fixed
|
||||
|
||||
Example::
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> space = Sequence(Box(0, 1))
|
||||
>>> space.sample()
|
||||
(array([0.0259352], dtype=float32),)
|
||||
@@ -24,8 +28,8 @@ class Sequence(Space[Tuple]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space: Space,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
space: Space[Any],
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
"""Constructor of the :class:`Sequence` space.
|
||||
|
||||
@@ -34,14 +38,14 @@ class Sequence(Space[Tuple]):
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||
"""
|
||||
assert isinstance(
|
||||
space, gym.Space
|
||||
), f"Expects the feature space to be instance of a gymnasium Space, actual type: {type(space)}"
|
||||
space, Space
|
||||
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
|
||||
self.feature_space = space
|
||||
super().__init__(
|
||||
None, None, seed # type: ignore
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> list:
|
||||
# None for shape and dtype, since it'll require special handling
|
||||
super().__init__(None, None, seed)
|
||||
|
||||
def seed(self, seed: int | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and the feature space."""
|
||||
seeds = super().seed(seed)
|
||||
seeds += self.feature_space.seed(seed)
|
||||
@@ -54,8 +58,14 @@ class Sequence(Space[Tuple]):
|
||||
|
||||
def sample(
|
||||
self,
|
||||
mask: Optional[Tuple[Optional[Union[np.ndarray, int]], Optional[Any]]] = None,
|
||||
) -> Tuple[Any]:
|
||||
mask: None
|
||||
| (
|
||||
tuple[
|
||||
None | np.integer | npt.NDArray[np.integer],
|
||||
Any,
|
||||
]
|
||||
) = None,
|
||||
) -> tuple[Any]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
Args:
|
||||
@@ -89,6 +99,9 @@ class Sequence(Space[Tuple]):
|
||||
assert np.all(
|
||||
0 <= length_mask
|
||||
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
|
||||
assert np.issubdtype(
|
||||
length_mask.dtype, np.integer
|
||||
), f"Expects the length mask array to have dtype to be an numpy integer, actual type: {length_mask.dtype}"
|
||||
length = self.np_random.choice(length_mask)
|
||||
else:
|
||||
raise TypeError(
|
||||
@@ -102,9 +115,9 @@ class Sequence(Space[Tuple]):
|
||||
self.feature_space.sample(mask=feature_mask) for _ in range(length)
|
||||
)
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
return isinstance(x, CollectionSequence) and all(
|
||||
return isinstance(x, collections.abc.Sequence) and all(
|
||||
self.feature_space.contains(item) for item in x
|
||||
)
|
||||
|
||||
@@ -112,15 +125,17 @@ class Sequence(Space[Tuple]):
|
||||
"""Gives a string representation of this space."""
|
||||
return f"Sequence({self.feature_space})"
|
||||
|
||||
def to_jsonable(self, sample_n: list) -> list:
|
||||
def to_jsonable(
|
||||
self, sample_n: typing.Sequence[tuple[Any, ...]]
|
||||
) -> list[list[Any]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# serialize as dict-repr of vectors
|
||||
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
|
||||
|
||||
def from_jsonable(self, sample_n: List[List[Any]]) -> list:
|
||||
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return isinstance(other, Sequence) and self.feature_space == other.feature_space
|
||||
|
@@ -1,26 +1,19 @@
|
||||
"""Implementation of the `Space` metaclass."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
T_cov = TypeVar("T_cov", covariant=True)
|
||||
|
||||
|
||||
MaskNDArray = npt.NDArray[np.int8]
|
||||
|
||||
|
||||
class Space(Generic[T_cov]):
|
||||
"""Superclass that is used to define observation and action spaces.
|
||||
|
||||
@@ -41,17 +34,17 @@ class Space(Generic[T_cov]):
|
||||
class. However, most use-cases should be covered by the existing space
|
||||
classes (e.g. :class:`Box`, :class:`Discrete`, etc...), and container classes (:class`Tuple` &
|
||||
:class:`Dict`). Note that parametrized probability distributions (through the
|
||||
:meth:`Space.sample()` method), and batching functions (in :class:`gymnasium.vector.VectorEnv`), are
|
||||
only well-defined for instances of spaces provided in gymnasium by default.
|
||||
:meth:`Space.sample()` method), and batching functions (in :class:`gym.vector.VectorEnv`), are
|
||||
only well-defined for instances of spaces provided in gym by default.
|
||||
Moreover, some implementations of Reinforcement Learning algorithms might
|
||||
not handle custom spaces properly. Use custom spaces with care.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: Optional[Union[Type, str, np.dtype]] = None,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
shape: Sequence[int] | None = None,
|
||||
dtype: npt.DTypeLike | None = None,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
"""Constructor of :class:`Space`.
|
||||
|
||||
@@ -71,23 +64,31 @@ class Space(Generic[T_cov]):
|
||||
|
||||
@property
|
||||
def np_random(self) -> np.random.Generator:
|
||||
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space."""
|
||||
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space.
|
||||
|
||||
As :meth:`seed` is not guaranteed to set the `_np_random` for particular seeds. We add a
|
||||
check after :meth:`seed` to set a new random number generator.
|
||||
"""
|
||||
if self._np_random is None:
|
||||
self.seed()
|
||||
|
||||
return self._np_random # type: ignore ## self.seed() call guarantees right type.
|
||||
# As `seed` is not guaranteed (in particular for composite spaces) to set the `_np_random` then we set it randomly.
|
||||
if self._np_random is None:
|
||||
self._np_random, _ = seeding.np_random()
|
||||
|
||||
return self._np_random
|
||||
|
||||
@property
|
||||
def shape(self) -> Optional[Tuple[int, ...]]:
|
||||
def shape(self) -> tuple[int, ...] | None:
|
||||
"""Return the shape of the space as an immutable property."""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
def is_np_flattenable(self) -> bool:
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, mask: Optional[Any] = None) -> T_cov:
|
||||
def sample(self, mask: Any | None = None) -> T_cov:
|
||||
"""Randomly sample an element of this space.
|
||||
|
||||
Can be uniform or non-uniform sampling based on boundedness of space.
|
||||
@@ -100,20 +101,20 @@ class Space(Generic[T_cov]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> list:
|
||||
def seed(self, seed: int | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __contains__(self, x) -> bool:
|
||||
def __contains__(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
return self.contains(x)
|
||||
|
||||
def __setstate__(self, state: Union[Iterable, Mapping]):
|
||||
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
|
||||
"""Used when loading a pickled space.
|
||||
|
||||
This method was implemented explicitly to allow for loading of legacy states.
|
||||
@@ -130,7 +131,7 @@ class Space(Generic[T_cov]):
|
||||
# https://github.com/openai/gym/pull/1913 -- np_random
|
||||
#
|
||||
if "shape" in state:
|
||||
state["_shape"] = state["shape"]
|
||||
state["_shape"] = state.get("shape")
|
||||
del state["shape"]
|
||||
if "np_random" in state:
|
||||
state["_np_random"] = state["np_random"]
|
||||
@@ -139,12 +140,12 @@ class Space(Generic[T_cov]):
|
||||
# Update our state
|
||||
self.__dict__.update(state)
|
||||
|
||||
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list:
|
||||
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list[Any]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# By default, assume identity is JSONable
|
||||
return list(sample_n)
|
||||
|
||||
def from_jsonable(self, sample_n: list) -> List[T_cov]:
|
||||
def from_jsonable(self, sample_n: list[Any]) -> list[T_cov]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
# By default, assume identity is JSONable
|
||||
return sample_n
|
||||
|
@@ -1,11 +1,14 @@
|
||||
"""Implementation of a space that represents textual strings."""
|
||||
from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
alphanumeric: FrozenSet[str] = frozenset(
|
||||
alphanumeric: frozenset[str] = frozenset(
|
||||
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
)
|
||||
|
||||
@@ -14,7 +17,6 @@ class Text(Space[str]):
|
||||
r"""A space representing a string comprised of characters from a given charset.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # {"", "B5", "hello", ...}
|
||||
>>> Text(5)
|
||||
>>> # {"0", "42", "0123456789", ...}
|
||||
@@ -29,8 +31,8 @@ class Text(Space[str]):
|
||||
max_length: int,
|
||||
*,
|
||||
min_length: int = 1,
|
||||
charset: Union[Set[str], str] = alphanumeric,
|
||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
||||
charset: set[str] | str = alphanumeric,
|
||||
seed: int | np.random.Generator | None = None,
|
||||
):
|
||||
r"""Constructor of :class:`Text` space.
|
||||
|
||||
@@ -58,9 +60,9 @@ class Text(Space[str]):
|
||||
self.min_length: int = int(min_length)
|
||||
self.max_length: int = int(max_length)
|
||||
|
||||
self._char_set: FrozenSet[str] = frozenset(charset)
|
||||
self._char_list: Tuple[str, ...] = tuple(charset)
|
||||
self._char_index: Dict[str, np.int32] = {
|
||||
self._char_set: frozenset[str] = frozenset(charset)
|
||||
self._char_list: tuple[str, ...] = tuple(charset)
|
||||
self._char_index: dict[str, np.int32] = {
|
||||
val: np.int32(i) for i, val in enumerate(tuple(charset))
|
||||
}
|
||||
self._char_str: str = "".join(sorted(tuple(charset)))
|
||||
@@ -69,7 +71,8 @@ class Text(Space[str]):
|
||||
super().__init__(dtype=str, seed=seed)
|
||||
|
||||
def sample(
|
||||
self, mask: Optional[Tuple[Optional[int], Optional[np.ndarray]]] = None
|
||||
self,
|
||||
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None,
|
||||
) -> str:
|
||||
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
||||
|
||||
@@ -152,7 +155,7 @@ class Text(Space[str]):
|
||||
f"Text({self.min_length}, {self.max_length}, characters={self.characters})"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return (
|
||||
isinstance(other, Text)
|
||||
@@ -162,12 +165,12 @@ class Text(Space[str]):
|
||||
)
|
||||
|
||||
@property
|
||||
def character_set(self) -> FrozenSet[str]:
|
||||
def character_set(self) -> frozenset[str]:
|
||||
"""Returns the character set for the space."""
|
||||
return self._char_set
|
||||
|
||||
@property
|
||||
def character_list(self) -> Tuple[str, ...]:
|
||||
def character_list(self) -> tuple[str, ...]:
|
||||
"""Returns a tuple of characters in the space."""
|
||||
return self._char_list
|
||||
|
||||
|
@@ -1,16 +1,16 @@
|
||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||
from collections.abc import Sequence as CollectionSequence
|
||||
from typing import Iterable, Optional
|
||||
from typing import Sequence as TypingSequence
|
||||
from typing import Tuple as TypingTuple
|
||||
from typing import Union
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
|
||||
class Tuple(Space[tuple], CollectionSequence):
|
||||
class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
||||
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
|
||||
|
||||
Elements of this space are tuples of elements of the constituent spaces.
|
||||
@@ -25,8 +25,8 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spaces: Iterable[Space],
|
||||
seed: Optional[Union[int, TypingSequence[int], np.random.Generator]] = None,
|
||||
spaces: Iterable[Space[Any]],
|
||||
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
|
||||
):
|
||||
r"""Constructor of :class:`Tuple` space.
|
||||
|
||||
@@ -40,7 +40,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
for space in self.spaces:
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), "Elements of the tuple must be instances of gymnasium.Space"
|
||||
), "Elements of the tuple must be instances of gym.Space"
|
||||
super().__init__(None, None, seed) # type: ignore
|
||||
|
||||
@property
|
||||
@@ -48,9 +48,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces)
|
||||
|
||||
def seed(
|
||||
self, seed: Optional[Union[int, TypingSequence[int]]] = None
|
||||
) -> TypingSequence[int]:
|
||||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and all subspaces.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
@@ -61,9 +59,9 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
"""
|
||||
seeds = []
|
||||
seeds: list[int] = []
|
||||
|
||||
if isinstance(seed, CollectionSequence):
|
||||
if isinstance(seed, collections.abc.Sequence):
|
||||
assert len(seed) == len(
|
||||
self.spaces
|
||||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
|
||||
@@ -86,9 +84,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(
|
||||
self, mask: Optional[TypingTuple[Optional[np.ndarray], ...]] = None
|
||||
) -> tuple:
|
||||
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
|
||||
"""Generates a single random sample inside this space.
|
||||
|
||||
This method draws independent samples from the subspaces.
|
||||
@@ -115,10 +111,11 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
|
||||
return tuple(space.sample() for space in self.spaces)
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
if isinstance(x, (list, np.ndarray)):
|
||||
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
||||
|
||||
return (
|
||||
isinstance(x, tuple)
|
||||
and len(x) == len(self.spaces)
|
||||
@@ -129,7 +126,9 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
"""Gives a string representation of this space."""
|
||||
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
||||
|
||||
def to_jsonable(self, sample_n: CollectionSequence) -> list:
|
||||
def to_jsonable(
|
||||
self, sample_n: typing.Sequence[tuple[Any, ...]]
|
||||
) -> list[list[Any]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
# serialize as list-repr of tuple of vectors
|
||||
return [
|
||||
@@ -137,7 +136,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
for i, space in enumerate(self.spaces)
|
||||
]
|
||||
|
||||
def from_jsonable(self, sample_n) -> list:
|
||||
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [
|
||||
sample
|
||||
@@ -149,7 +148,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
)
|
||||
]
|
||||
|
||||
def __getitem__(self, index: int) -> Space:
|
||||
def __getitem__(self, index: int) -> Space[Any]:
|
||||
"""Get the subspace at specific `index`."""
|
||||
return self.spaces[index]
|
||||
|
||||
@@ -157,6 +156,6 @@ class Tuple(Space[tuple], CollectionSequence):
|
||||
"""Get the number of subspaces that are involved in the cartesian product."""
|
||||
return len(self.spaces)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return isinstance(other, Tuple) and self.spaces == other.spaces
|
||||
|
@@ -3,13 +3,16 @@
|
||||
These functions mostly take care of flattening and unflattening elements of spaces
|
||||
to facilitate their usage in learning code.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import operator as op
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from functools import reduce, singledispatch
|
||||
from typing import Dict as TypingDict
|
||||
from typing import TypeVar, Union, cast
|
||||
from typing import Any, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from gymnasium.spaces import (
|
||||
Box,
|
||||
@@ -27,12 +30,12 @@ from gymnasium.spaces import (
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatdim(space: Space) -> int:
|
||||
def flatdim(space: Space[Any]) -> int:
|
||||
"""Return the number of dimensions a flattened equivalent of this space would have.
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> from gymnasium.spaces import Discrete
|
||||
>>> from gymnasium.spaces import Discrete, Dict
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||
>>> flatdim(space)
|
||||
5
|
||||
@@ -47,7 +50,7 @@ def flatdim(space: Space) -> int:
|
||||
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||
ValueError: if the space cannot be flattened into a :class:`Box`
|
||||
"""
|
||||
if not space.is_np_flattenable:
|
||||
if space.is_np_flattenable is False:
|
||||
raise ValueError(
|
||||
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||
)
|
||||
@@ -57,7 +60,7 @@ def flatdim(space: Space) -> int:
|
||||
|
||||
@flatdim.register(Box)
|
||||
@flatdim.register(MultiBinary)
|
||||
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
|
||||
def _flatdim_box_multibinary(space: Box | MultiBinary) -> int:
|
||||
return reduce(op.mul, space.shape, 1)
|
||||
|
||||
|
||||
@@ -102,7 +105,9 @@ def _flatdim_text(space: Text) -> int:
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
|
||||
FlatType = Union[
|
||||
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
|
||||
]
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -112,6 +117,19 @@ def flatten(space: Space[T], x: T) -> FlatType:
|
||||
This is useful when e.g. points from spaces must be passed to a neural
|
||||
network, which only understands flat arrays of floats.
|
||||
|
||||
Example usage::
|
||||
>>> from gymnasium.spaces import Box, Discrete, Tuple
|
||||
>>> space = Box(0, 1, shape=(3, 5))
|
||||
>>> flatten(space, space.sample()).shape
|
||||
(15,)
|
||||
>>> space = Discrete(4)
|
||||
>>> flatten(space, 2)
|
||||
array([0, 0, 1, 0])
|
||||
>>> space = Tuple((Box(0, 1, shape=(2,)), Box(0, 1, shape=(3,)), Discrete(3)))
|
||||
>>> example = ((.5, .25), (1., 0., .2), 1)
|
||||
>>> flatten(space, example)
|
||||
array([0.5 , 0.25, 1. , 0. , 0.2 , 0. , 1. , 0. ])
|
||||
|
||||
Args:
|
||||
space: The space that ``x`` is flattened by
|
||||
x: The value to flatten
|
||||
@@ -137,19 +155,21 @@ def flatten(space: Space[T], x: T) -> FlatType:
|
||||
|
||||
@flatten.register(Box)
|
||||
@flatten.register(MultiBinary)
|
||||
def _flatten_box_multibinary(space, x) -> np.ndarray:
|
||||
def _flatten_box_multibinary(space: Box | MultiBinary, x: NDArray[Any]) -> NDArray[Any]:
|
||||
return np.asarray(x, dtype=space.dtype).flatten()
|
||||
|
||||
|
||||
@flatten.register(Discrete)
|
||||
def _flatten_discrete(space, x) -> np.ndarray:
|
||||
def _flatten_discrete(space: Discrete, x: int) -> NDArray[np.int64]:
|
||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||
onehot[x - space.start] = 1
|
||||
return onehot
|
||||
|
||||
|
||||
@flatten.register(MultiDiscrete)
|
||||
def _flatten_multidiscrete(space, x) -> np.ndarray:
|
||||
def _flatten_multidiscrete(
|
||||
space: MultiDiscrete, x: NDArray[np.int64]
|
||||
) -> NDArray[np.int64]:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=np.int32)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
@@ -159,7 +179,7 @@ def _flatten_multidiscrete(space, x) -> np.ndarray:
|
||||
|
||||
|
||||
@flatten.register(Tuple)
|
||||
def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
|
||||
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
|
||||
if space.is_np_flattenable:
|
||||
return np.concatenate(
|
||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
||||
@@ -168,22 +188,26 @@ def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
|
||||
|
||||
|
||||
@flatten.register(Dict)
|
||||
def _flatten_dict(space, x) -> Union[dict, np.ndarray]:
|
||||
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
|
||||
if space.is_np_flattenable:
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
||||
|
||||
|
||||
@flatten.register(Graph)
|
||||
def _flatten_graph(space, x) -> GraphInstance:
|
||||
def _flatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
||||
|
||||
def _graph_unflatten(unflatten_space, unflatten_x):
|
||||
def _graph_unflatten(
|
||||
unflatten_space: Discrete | Box | None,
|
||||
unflatten_x: NDArray[Any] | None,
|
||||
) -> NDArray[Any] | None:
|
||||
ret = None
|
||||
if unflatten_space is not None and unflatten_x is not None:
|
||||
if isinstance(unflatten_space, Box):
|
||||
ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
|
||||
elif isinstance(unflatten_space, Discrete):
|
||||
else:
|
||||
assert isinstance(unflatten_space, Discrete)
|
||||
ret = np.zeros(
|
||||
(unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
|
||||
dtype=unflatten_space.dtype,
|
||||
@@ -194,13 +218,14 @@ def _flatten_graph(space, x) -> GraphInstance:
|
||||
return ret
|
||||
|
||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||
assert nodes is not None
|
||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||
|
||||
return GraphInstance(nodes, edges, x.edge_links)
|
||||
|
||||
|
||||
@flatten.register(Text)
|
||||
def _flatten_text(space: Text, x: str) -> np.ndarray:
|
||||
def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
|
||||
arr = np.full(
|
||||
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
|
||||
)
|
||||
@@ -210,7 +235,7 @@ def _flatten_text(space: Text, x: str) -> np.ndarray:
|
||||
|
||||
|
||||
@flatten.register(Sequence)
|
||||
def _flatten_sequence(space, x) -> tuple:
|
||||
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
return tuple(flatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@@ -237,18 +262,20 @@ def unflatten(space: Space[T], x: FlatType) -> T:
|
||||
@unflatten.register(Box)
|
||||
@unflatten.register(MultiBinary)
|
||||
def _unflatten_box_multibinary(
|
||||
space: Union[Box, MultiBinary], x: np.ndarray
|
||||
) -> np.ndarray:
|
||||
space: Box | MultiBinary, x: NDArray[Any]
|
||||
) -> NDArray[Any]:
|
||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||
|
||||
|
||||
@unflatten.register(Discrete)
|
||||
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
||||
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> int:
|
||||
return int(space.start + np.nonzero(x)[0][0])
|
||||
|
||||
|
||||
@unflatten.register(MultiDiscrete)
|
||||
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
||||
def _unflatten_multidiscrete(
|
||||
space: MultiDiscrete, x: NDArray[np.int32]
|
||||
) -> NDArray[np.int32]:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
@@ -257,7 +284,9 @@ def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
@unflatten.register(Tuple)
|
||||
def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
|
||||
def _unflatten_tuple(
|
||||
space: Tuple, x: NDArray[Any] | tuple[Any, ...]
|
||||
) -> tuple[Any, ...]:
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(
|
||||
x, np.ndarray
|
||||
@@ -275,7 +304,7 @@ def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
|
||||
|
||||
|
||||
@unflatten.register(Dict)
|
||||
def _unflatten_dict(space: Dict, x: Union[np.ndarray, TypingDict]) -> dict:
|
||||
def _unflatten_dict(space: Dict, x: NDArray[Any] | dict[str, Any]) -> dict[str, Any]:
|
||||
if space.is_np_flattenable:
|
||||
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
|
||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||
@@ -299,14 +328,14 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||
nodes and edges in the graph.
|
||||
"""
|
||||
|
||||
def _graph_unflatten(space, x):
|
||||
ret = None
|
||||
if space is not None and x is not None:
|
||||
if isinstance(space, Box):
|
||||
ret = x.reshape(-1, *space.shape)
|
||||
elif isinstance(space, Discrete):
|
||||
ret = np.asarray(np.nonzero(x))[-1, :]
|
||||
return ret
|
||||
def _graph_unflatten(unflatten_space, unflatten_x):
|
||||
result = None
|
||||
if unflatten_space is not None and unflatten_x is not None:
|
||||
if isinstance(unflatten_space, Box):
|
||||
result = unflatten_x.reshape(-1, *unflatten_space.shape)
|
||||
elif isinstance(unflatten_space, Discrete):
|
||||
result = np.asarray(np.nonzero(unflatten_x))[-1, :]
|
||||
return result
|
||||
|
||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||
@@ -315,19 +344,19 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||
|
||||
|
||||
@unflatten.register(Text)
|
||||
def _unflatten_text(space: Text, x: np.ndarray) -> str:
|
||||
def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
|
||||
return "".join(
|
||||
[space.character_list[val] for val in x if val < len(space.character_set)]
|
||||
)
|
||||
|
||||
|
||||
@unflatten.register(Sequence)
|
||||
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
|
||||
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
||||
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
"""Flatten a space into a space that is as flat as possible.
|
||||
|
||||
This function will attempt to flatten `space` into a single :class:`Box` space.
|
||||
@@ -342,7 +371,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
||||
space.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
||||
>>> box
|
||||
Box(3, 4, 5)
|
||||
@@ -352,7 +381,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
||||
True
|
||||
|
||||
Example that flattens a discrete space::
|
||||
|
||||
>>> from gymnasium.spaces import Discrete
|
||||
>>> discrete = Discrete(5)
|
||||
>>> flatten_space(discrete)
|
||||
Box(5,)
|
||||
@@ -360,7 +389,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
||||
True
|
||||
|
||||
Example that recursively flattens a dict::
|
||||
|
||||
>>> from gymnasium.spaces import Dict, Discrete, Box
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
|
||||
>>> flatten_space(space)
|
||||
Box(6,)
|
||||
@@ -383,7 +412,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
||||
A flattened Box
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the space is not defined in ``gymnasium.spaces``.
|
||||
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||
"""
|
||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||
|
||||
@@ -396,12 +425,12 @@ def _flatten_space_box(space: Box) -> Box:
|
||||
@flatten_space.register(Discrete)
|
||||
@flatten_space.register(MultiBinary)
|
||||
@flatten_space.register(MultiDiscrete)
|
||||
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
|
||||
def _flatten_space_binary(space: Discrete | MultiBinary | MultiDiscrete) -> Box:
|
||||
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
|
||||
|
||||
|
||||
@flatten_space.register(Tuple)
|
||||
def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
|
||||
def _flatten_space_tuple(space: Tuple) -> Box | Tuple:
|
||||
if space.is_np_flattenable:
|
||||
space_list = [flatten_space(s) for s in space.spaces]
|
||||
return Box(
|
||||
@@ -413,7 +442,7 @@ def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
|
||||
|
||||
|
||||
@flatten_space.register(Dict)
|
||||
def _flatten_space_dict(space: Dict) -> Union[Box, Dict]:
|
||||
def _flatten_space_dict(space: Dict) -> Box | Dict:
|
||||
if space.is_np_flattenable:
|
||||
space_list = [flatten_space(s) for s in space.spaces.values()]
|
||||
return Box(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
numpy>=1.18.0
|
||||
numpy>=1.21.0
|
||||
cloudpickle>=1.2.0
|
||||
importlib_metadata>=4.8.0; python_version < '3.10'
|
||||
gymnasium_notices>=0.0.1
|
||||
|
2
setup.py
2
setup.py
@@ -85,7 +85,7 @@ setup(
|
||||
},
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
"numpy >= 1.18.0",
|
||||
"numpy >= 1.21.0",
|
||||
"cloudpickle >= 1.2.0",
|
||||
"importlib_metadata >= 4.8.0; python_version < '3.10'",
|
||||
"gymnasium_notices >= 0.0.1",
|
||||
|
@@ -1,3 +1,5 @@
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
@@ -25,18 +27,18 @@ def test_dict_init():
|
||||
):
|
||||
Dict(a=Discrete(2), b="Box")
|
||||
|
||||
with pytest.warns(None) as warnings:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
|
||||
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
|
||||
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
|
||||
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
|
||||
|
||||
assert a == b == c == d
|
||||
assert len(warnings) == 0
|
||||
assert len(caught_warnings) == 0
|
||||
|
||||
with pytest.warns(None) as warnings:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
Dict({1: Discrete(2), "a": Discrete(3)})
|
||||
assert len(warnings) == 0
|
||||
assert len(caught_warnings) == 0
|
||||
|
||||
|
||||
DICT_SPACE = Dict(
|
||||
@@ -109,7 +111,12 @@ def test_none_seeding():
|
||||
|
||||
|
||||
def test_bad_seed():
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=re.escape(
|
||||
"Expected seed type: dict, int or None, actual type: <class 'str'>"
|
||||
),
|
||||
):
|
||||
DICT_SPACE.seed("a")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user