mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Update the spaces for complete type hinting (and updates numpy to 1.21) (#37)
This commit is contained in:
@@ -1,14 +1,16 @@
|
|||||||
"""Implementation of a space that represents closed boxes in euclidean space."""
|
"""Implementation of a space that represents closed boxes in euclidean space."""
|
||||||
from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Iterable, Mapping, Sequence, SupportsFloat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import logger
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import Space
|
||||||
|
|
||||||
|
|
||||||
def _short_repr(arr: np.ndarray) -> str:
|
def _short_repr(arr: NDArray[Any]) -> str:
|
||||||
"""Create a shortened string representation of a numpy array.
|
"""Create a shortened string representation of a numpy array.
|
||||||
|
|
||||||
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
||||||
@@ -25,12 +27,12 @@ def _short_repr(arr: np.ndarray) -> str:
|
|||||||
return str(arr)
|
return str(arr)
|
||||||
|
|
||||||
|
|
||||||
def is_float_integer(var) -> bool:
|
def is_float_integer(var: Any) -> bool:
|
||||||
"""Checks if a variable is an integer or float."""
|
"""Checks if a variable is an integer or float."""
|
||||||
return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
|
return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
|
||||||
|
|
||||||
|
|
||||||
class Box(Space[np.ndarray]):
|
class Box(Space[NDArray[Any]]):
|
||||||
r"""A (possibly unbounded) box in :math:`\mathbb{R}^n`.
|
r"""A (possibly unbounded) box in :math:`\mathbb{R}^n`.
|
||||||
|
|
||||||
Specifically, a Box represents the Cartesian product of n closed intervals.
|
Specifically, a Box represents the Cartesian product of n closed intervals.
|
||||||
@@ -52,11 +54,11 @@ class Box(Space[np.ndarray]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
low: Union[SupportsFloat, np.ndarray],
|
low: SupportsFloat | NDArray[Any],
|
||||||
high: Union[SupportsFloat, np.ndarray],
|
high: SupportsFloat | NDArray[Any],
|
||||||
shape: Optional[Sequence[int]] = None,
|
shape: Sequence[int] | None = None,
|
||||||
dtype: Type = np.float32,
|
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Box`.
|
r"""Constructor of :class:`Box`.
|
||||||
|
|
||||||
@@ -102,12 +104,12 @@ class Box(Space[np.ndarray]):
|
|||||||
|
|
||||||
# Capture the boundedness information before replacing np.inf with get_inf
|
# Capture the boundedness information before replacing np.inf with get_inf
|
||||||
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
|
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
|
||||||
self.bounded_below = -np.inf < _low
|
self.bounded_below: bool = -np.inf < _low
|
||||||
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
|
||||||
self.bounded_above = np.inf > _high
|
self.bounded_above: bool = np.inf > _high
|
||||||
|
|
||||||
low = _broadcast(low, dtype, shape, inf_sign="-") # type: ignore
|
low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-")
|
||||||
high = _broadcast(high, dtype, shape, inf_sign="+") # type: ignore
|
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+")
|
||||||
|
|
||||||
assert isinstance(low, np.ndarray)
|
assert isinstance(low, np.ndarray)
|
||||||
assert (
|
assert (
|
||||||
@@ -118,13 +120,13 @@ class Box(Space[np.ndarray]):
|
|||||||
high.shape == shape
|
high.shape == shape
|
||||||
), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}"
|
), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}"
|
||||||
|
|
||||||
self._shape: Tuple[int, ...] = shape
|
self._shape: tuple[int, ...] = shape
|
||||||
|
|
||||||
low_precision = get_precision(low.dtype)
|
low_precision = get_precision(low.dtype)
|
||||||
high_precision = get_precision(high.dtype)
|
high_precision = get_precision(high.dtype)
|
||||||
dtype_precision = get_precision(self.dtype)
|
dtype_precision = get_precision(self.dtype)
|
||||||
if min(low_precision, high_precision) > dtype_precision: # type: ignore
|
if min(low_precision, high_precision) > dtype_precision:
|
||||||
logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
|
gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
|
||||||
self.low = low.astype(self.dtype)
|
self.low = low.astype(self.dtype)
|
||||||
self.high = high.astype(self.dtype)
|
self.high = high.astype(self.dtype)
|
||||||
|
|
||||||
@@ -134,8 +136,8 @@ class Box(Space[np.ndarray]):
|
|||||||
super().__init__(self.shape, self.dtype, seed)
|
super().__init__(self.shape, self.dtype, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
"""Has stricter type than gymnasium.Space - never None."""
|
"""Has stricter type than gym.Space - never None."""
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -168,7 +170,7 @@ class Box(Space[np.ndarray]):
|
|||||||
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
|
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample(self, mask: None = None) -> np.ndarray:
|
def sample(self, mask: None = None) -> NDArray[Any]:
|
||||||
r"""Generates a single random sample inside the Box.
|
r"""Generates a single random sample inside the Box.
|
||||||
|
|
||||||
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
|
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
|
||||||
@@ -193,8 +195,7 @@ class Box(Space[np.ndarray]):
|
|||||||
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
||||||
sample = np.empty(self.shape)
|
sample = np.empty(self.shape)
|
||||||
|
|
||||||
# Masking arrays which classify the coordinates according to interval
|
# Masking arrays which classify the coordinates according to interval type
|
||||||
# type
|
|
||||||
unbounded = ~self.bounded_below & ~self.bounded_above
|
unbounded = ~self.bounded_below & ~self.bounded_above
|
||||||
upp_bounded = ~self.bounded_below & self.bounded_above
|
upp_bounded = ~self.bounded_below & self.bounded_above
|
||||||
low_bounded = self.bounded_below & ~self.bounded_above
|
low_bounded = self.bounded_below & ~self.bounded_above
|
||||||
@@ -221,10 +222,10 @@ class Box(Space[np.ndarray]):
|
|||||||
|
|
||||||
return sample.astype(self.dtype)
|
return sample.astype(self.dtype)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if not isinstance(x, np.ndarray):
|
if not isinstance(x, np.ndarray):
|
||||||
logger.warn("Casting input x to numpy array.")
|
gym.logger.warn("Casting input x to numpy array.")
|
||||||
try:
|
try:
|
||||||
x = np.asarray(x, dtype=self.dtype)
|
x = np.asarray(x, dtype=self.dtype)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
@@ -237,11 +238,11 @@ class Box(Space[np.ndarray]):
|
|||||||
and np.all(x <= self.high)
|
and np.all(x <= self.high)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n):
|
def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[NDArray[Any]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return np.array(sample_n).tolist()
|
return np.array(sample_n).tolist()
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: Sequence[Union[float, int]]) -> List[np.ndarray]:
|
def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [np.asarray(sample) for sample in sample_n]
|
return [np.asarray(sample) for sample in sample_n]
|
||||||
|
|
||||||
@@ -256,7 +257,7 @@ class Box(Space[np.ndarray]):
|
|||||||
"""
|
"""
|
||||||
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
|
"""Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
|
||||||
return (
|
return (
|
||||||
isinstance(other, Box)
|
isinstance(other, Box)
|
||||||
@@ -266,7 +267,7 @@ class Box(Space[np.ndarray]):
|
|||||||
and np.allclose(self.high, other.high)
|
and np.allclose(self.high, other.high)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setstate__(self, state: Dict):
|
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
|
||||||
"""Sets the state of the box for unpickling a box with legacy support."""
|
"""Sets the state of the box for unpickling a box with legacy support."""
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
|
|
||||||
@@ -278,7 +279,7 @@ class Box(Space[np.ndarray]):
|
|||||||
self.high_repr = _short_repr(self.high)
|
self.high_repr = _short_repr(self.high)
|
||||||
|
|
||||||
|
|
||||||
def get_inf(dtype, sign: str) -> SupportsFloat:
|
def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat:
|
||||||
"""Returns an infinite that doesn't break things.
|
"""Returns an infinite that doesn't break things.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -310,7 +311,7 @@ def get_inf(dtype, sign: str) -> SupportsFloat:
|
|||||||
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
|
raise ValueError(f"Unknown dtype {dtype} for infinite bounds")
|
||||||
|
|
||||||
|
|
||||||
def get_precision(dtype) -> SupportsFloat:
|
def get_precision(dtype: np.dtype) -> SupportsFloat:
|
||||||
"""Get precision of a data type."""
|
"""Get precision of a data type."""
|
||||||
if np.issubdtype(dtype, np.floating):
|
if np.issubdtype(dtype, np.floating):
|
||||||
return np.finfo(dtype).precision
|
return np.finfo(dtype).precision
|
||||||
@@ -319,14 +320,14 @@ def get_precision(dtype) -> SupportsFloat:
|
|||||||
|
|
||||||
|
|
||||||
def _broadcast(
|
def _broadcast(
|
||||||
value: Union[SupportsFloat, np.ndarray],
|
value: SupportsFloat | NDArray[Any],
|
||||||
dtype,
|
dtype: np.dtype,
|
||||||
shape: Tuple[int, ...],
|
shape: tuple[int, ...],
|
||||||
inf_sign: str,
|
inf_sign: str,
|
||||||
) -> np.ndarray:
|
) -> NDArray[Any]:
|
||||||
"""Handle infinite bounds and broadcast at the same time if needed."""
|
"""Handle infinite bounds and broadcast at the same time if needed."""
|
||||||
if is_float_integer(value):
|
if is_float_integer(value):
|
||||||
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
|
value = get_inf(dtype, inf_sign) if np.isinf(value) else value
|
||||||
value = np.full(shape, value, dtype=dtype)
|
value = np.full(shape, value, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
assert isinstance(value, np.ndarray)
|
assert isinstance(value, np.ndarray)
|
||||||
|
@@ -1,18 +1,17 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
|
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
import typing
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Mapping, Sequence
|
from typing import Any, Sequence
|
||||||
from typing import Any
|
|
||||||
from typing import Dict as TypingDict
|
|
||||||
from typing import List, Optional
|
|
||||||
from typing import Sequence as TypingSequence
|
|
||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import Space
|
||||||
|
|
||||||
|
|
||||||
class Dict(Space[TypingDict[str, Space]], Mapping):
|
class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||||
"""A dictionary of :class:`Space` instances.
|
"""A dictionary of :class:`Space` instances.
|
||||||
|
|
||||||
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
||||||
@@ -53,13 +52,8 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spaces: Optional[
|
spaces: None | dict[str, Space] | Sequence[tuple[str, Space]] = None,
|
||||||
Union[
|
seed: dict | int | np.random.Generator | None = None,
|
||||||
TypingDict[str, Space],
|
|
||||||
TypingSequence[Tuple[str, Space]],
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
seed: Optional[Union[dict, int, np.random.Generator]] = None,
|
|
||||||
**spaces_kwargs: Space,
|
**spaces_kwargs: Space,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`Dict` space.
|
"""Constructor of :class:`Dict` space.
|
||||||
@@ -82,7 +76,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
|
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
|
||||||
"""
|
"""
|
||||||
# Convert the spaces into an OrderedDict
|
# Convert the spaces into an OrderedDict
|
||||||
if isinstance(spaces, Mapping) and not isinstance(spaces, OrderedDict):
|
if isinstance(spaces, collections.abc.Mapping) and not isinstance(
|
||||||
|
spaces, OrderedDict
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
spaces = OrderedDict(sorted(spaces.items()))
|
spaces = OrderedDict(sorted(spaces.items()))
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@@ -107,22 +103,21 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
f"Dict space keyword '{key}' already exists in the spaces dictionary."
|
f"Dict space keyword '{key}' already exists in the spaces dictionary."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.spaces = spaces
|
self.spaces: dict[str, Space[Any]] = spaces
|
||||||
for key, space in self.spaces.items():
|
for key, space in self.spaces.items():
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
space, Space
|
space, Space
|
||||||
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
||||||
|
|
||||||
super().__init__(
|
# None for shape and dtype, since it'll require special handling
|
||||||
None, None, seed # type: ignore
|
super().__init__(None, None, seed)
|
||||||
) # None for shape and dtype, since it'll require special handling
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_np_flattenable(self):
|
def is_np_flattenable(self):
|
||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return all(space.is_np_flattenable for space in self.spaces.values())
|
return all(space.is_np_flattenable for space in self.spaces.values())
|
||||||
|
|
||||||
def seed(self, seed: Optional[Union[dict, int]] = None) -> list:
|
def seed(self, seed: dict[str, Any] | int | None = None) -> list[int]:
|
||||||
"""Seed the PRNG of this space and all subspaces.
|
"""Seed the PRNG of this space and all subspaces.
|
||||||
|
|
||||||
Depending on the type of seed, the subspaces will be seeded differently
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
@@ -133,7 +128,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
Args:
|
Args:
|
||||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||||
"""
|
"""
|
||||||
seeds = []
|
seeds: list[int] = []
|
||||||
|
|
||||||
if isinstance(seed, dict):
|
if isinstance(seed, dict):
|
||||||
assert (
|
assert (
|
||||||
@@ -159,7 +154,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
return seeds
|
return seeds
|
||||||
|
|
||||||
def sample(self, mask: Optional[TypingDict[str, Any]] = None) -> dict:
|
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
The sample is an ordered dictionary of independent samples from the constituent spaces.
|
The sample is an ordered dictionary of independent samples from the constituent spaces.
|
||||||
@@ -183,17 +178,17 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, dict) and x.keys() == self.spaces.keys():
|
if isinstance(x, dict) and x.keys() == self.spaces.keys():
|
||||||
return all(x[key] in self.spaces[key] for key in self.spaces.keys())
|
return all(x[key] in self.spaces[key] for key in self.spaces.keys())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> Space:
|
def __getitem__(self, key: str) -> Space[Any]:
|
||||||
"""Get the space that is associated to `key`."""
|
"""Get the space that is associated to `key`."""
|
||||||
return self.spaces[key]
|
return self.spaces[key]
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: Space):
|
def __setitem__(self, key: str, value: Space[Any]):
|
||||||
"""Set the space that is associated to `key`."""
|
"""Set the space that is associated to `key`."""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
value, Space
|
value, Space
|
||||||
@@ -214,7 +209,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
"Dict(" + ", ".join([f"{k!r}: {s}" for k, s in self.spaces.items()]) + ")"
|
"Dict(" + ", ".join([f"{k!r}: {s}" for k, s in self.spaces.items()]) + ")"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether `other` is equivalent to this instance."""
|
"""Check whether `other` is equivalent to this instance."""
|
||||||
return (
|
return (
|
||||||
isinstance(other, Dict)
|
isinstance(other, Dict)
|
||||||
@@ -222,7 +217,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
and self.spaces == other.spaces # OrderedDict.__eq__
|
and self.spaces == other.spaces # OrderedDict.__eq__
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: list) -> dict:
|
def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# serialize as dict-repr of vectors
|
# serialize as dict-repr of vectors
|
||||||
return {
|
return {
|
||||||
@@ -230,9 +225,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
for key, space in self.spaces.items()
|
for key, space in self.spaces.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: TypingDict[str, list]) -> List[dict]:
|
def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
dict_of_list: TypingDict[str, list] = {
|
dict_of_list: dict[str, list[Any]] = {
|
||||||
key: space.from_jsonable(sample_n[key])
|
key: space.from_jsonable(sample_n[key])
|
||||||
for key, space in self.spaces.items()
|
for key, space in self.spaces.items()
|
||||||
}
|
}
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
"""Implementation of a space consisting of finitely many elements."""
|
"""Implementation of a space consisting of finitely many elements."""
|
||||||
from typing import Optional, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Iterable, Mapping
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import MaskNDArray, Space
|
||||||
|
|
||||||
|
|
||||||
class Discrete(Space[int]):
|
class Discrete(Space[int]):
|
||||||
@@ -20,7 +22,7 @@ class Discrete(Space[int]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: int,
|
n: int,
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
start: int = 0,
|
start: int = 0,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Discrete` space.
|
r"""Constructor of :class:`Discrete` space.
|
||||||
@@ -44,7 +46,7 @@ class Discrete(Space[int]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def sample(self, mask: Optional[np.ndarray] = None) -> int:
|
def sample(self, mask: MaskNDArray | None = None) -> int:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
A sample will be chosen uniformly at random with the mask if provided
|
A sample will be chosen uniformly at random with the mask if provided
|
||||||
@@ -80,14 +82,14 @@ class Discrete(Space[int]):
|
|||||||
|
|
||||||
return int(self.start + self.np_random.integers(self.n))
|
return int(self.start + self.np_random.integers(self.n))
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, int):
|
if isinstance(x, int):
|
||||||
as_int = x
|
as_int = x
|
||||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
elif isinstance(x, (np.generic, np.ndarray)) and (
|
||||||
np.issubdtype(x.dtype, np.integer) and x.shape == ()
|
np.issubdtype(x.dtype, np.integer) and x.shape == ()
|
||||||
):
|
):
|
||||||
as_int = int(x) # type: ignore
|
as_int = int(x)
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -99,7 +101,7 @@ class Discrete(Space[int]):
|
|||||||
return f"Discrete({self.n}, start={self.start})"
|
return f"Discrete({self.n}, start={self.start})"
|
||||||
return f"Discrete({self.n})"
|
return f"Discrete({self.n})"
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether ``other`` is equivalent to this instance."""
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return (
|
return (
|
||||||
isinstance(other, Discrete)
|
isinstance(other, Discrete)
|
||||||
@@ -107,7 +109,7 @@ class Discrete(Space[int]):
|
|||||||
and self.start == other.start
|
and self.start == other.start
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
|
||||||
"""Used when loading a pickled space.
|
"""Used when loading a pickled space.
|
||||||
|
|
||||||
This method has to be implemented explicitly to allow for loading of legacy states.
|
This method has to be implemented explicitly to allow for loading of legacy states.
|
||||||
|
@@ -1,9 +1,12 @@
|
|||||||
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
|
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
|
||||||
from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, NamedTuple, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from gymnasium.logger import warn
|
import gymnasium as gym
|
||||||
from gymnasium.spaces.box import Box
|
from gymnasium.spaces.box import Box
|
||||||
from gymnasium.spaces.discrete import Discrete
|
from gymnasium.spaces.discrete import Discrete
|
||||||
from gymnasium.spaces.multi_discrete import MultiDiscrete
|
from gymnasium.spaces.multi_discrete import MultiDiscrete
|
||||||
@@ -18,25 +21,24 @@ class GraphInstance(NamedTuple):
|
|||||||
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.
|
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nodes: np.ndarray
|
nodes: NDArray[Any]
|
||||||
edges: Optional[np.ndarray]
|
edges: NDArray[Any] | None
|
||||||
edge_links: Optional[np.ndarray]
|
edge_links: NDArray[Any] | None
|
||||||
|
|
||||||
|
|
||||||
class Graph(Space):
|
class Graph(Space[GraphInstance]):
|
||||||
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
>>> from gymnasium.spaces import Box, Discrete
|
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3))
|
||||||
>>> Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3))
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
node_space: Union[Box, Discrete],
|
node_space: Box | Discrete,
|
||||||
edge_space: Union[None, Box, Discrete],
|
edge_space: None | Box | Discrete,
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Graph`.
|
r"""Constructor of :class:`Graph`.
|
||||||
|
|
||||||
@@ -70,8 +72,8 @@ class Graph(Space):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _generate_sample_space(
|
def _generate_sample_space(
|
||||||
self, base_space: Union[None, Box, Discrete], num: int
|
self, base_space: None | Box | Discrete, num: int
|
||||||
) -> Optional[Union[Box, MultiDiscrete]]:
|
) -> Box | MultiDiscrete | None:
|
||||||
if num == 0 or base_space is None:
|
if num == 0 or base_space is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -92,14 +94,15 @@ class Graph(Space):
|
|||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
mask: Optional[
|
mask: None
|
||||||
Tuple[
|
| (
|
||||||
Optional[Union[np.ndarray, tuple]],
|
tuple[
|
||||||
Optional[Union[np.ndarray, tuple]],
|
NDArray[Any] | tuple[Any, ...] | None,
|
||||||
|
NDArray[Any] | tuple[Any, ...] | None,
|
||||||
]
|
]
|
||||||
] = None,
|
) = None,
|
||||||
num_nodes: int = 10,
|
num_nodes: int = 10,
|
||||||
num_edges: Optional[int] = None,
|
num_edges: int | None = None,
|
||||||
) -> GraphInstance:
|
) -> GraphInstance:
|
||||||
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
|
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
|
||||||
|
|
||||||
@@ -134,7 +137,7 @@ class Graph(Space):
|
|||||||
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
|
edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
|
||||||
else:
|
else:
|
||||||
if self.edge_space is None:
|
if self.edge_space is None:
|
||||||
warn(
|
gym.logger.warn(
|
||||||
f"The number of edges is set ({num_edges}) but the edge space is None."
|
f"The number of edges is set ({num_edges}) but the edge space is None."
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
@@ -198,7 +201,7 @@ class Graph(Space):
|
|||||||
"""
|
"""
|
||||||
return f"Graph({self.node_space}, {self.edge_space})"
|
return f"Graph({self.node_space}, {self.edge_space})"
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether `other` is equivalent to this instance."""
|
"""Check whether `other` is equivalent to this instance."""
|
||||||
return (
|
return (
|
||||||
isinstance(other, Graph)
|
isinstance(other, Graph)
|
||||||
@@ -206,22 +209,24 @@ class Graph(Space):
|
|||||||
and (self.edge_space == other.edge_space)
|
and (self.edge_space == other.edge_space)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: NamedTuple) -> list:
|
def to_jsonable(
|
||||||
|
self, sample_n: Sequence[GraphInstance]
|
||||||
|
) -> list[dict[str, list[int] | list[float]]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# serialize as list of dicts
|
ret_n: list[dict[str, list[int | float]]] = []
|
||||||
ret_n = []
|
|
||||||
for sample in sample_n:
|
for sample in sample_n:
|
||||||
ret = {}
|
ret = {"nodes": sample.nodes.tolist()}
|
||||||
ret["nodes"] = sample.nodes.tolist()
|
if sample.edges is not None and sample.edge_links is not None:
|
||||||
if sample.edges is not None:
|
|
||||||
ret["edges"] = sample.edges.tolist()
|
ret["edges"] = sample.edges.tolist()
|
||||||
ret["edge_links"] = sample.edge_links.tolist()
|
ret["edge_links"] = sample.edge_links.tolist()
|
||||||
ret_n.append(ret)
|
ret_n.append(ret)
|
||||||
return ret_n
|
return ret_n
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: Sequence[dict]) -> list:
|
def from_jsonable(
|
||||||
|
self, sample_n: Sequence[dict[str, list[list[int] | list[float]]]]
|
||||||
|
) -> list[GraphInstance]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
ret = []
|
ret: list[GraphInstance] = []
|
||||||
for sample in sample_n:
|
for sample in sample_n:
|
||||||
if "edges" in sample:
|
if "edges" in sample:
|
||||||
ret_n = GraphInstance(
|
ret_n = GraphInstance(
|
||||||
|
@@ -1,12 +1,15 @@
|
|||||||
"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
|
"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
|
||||||
from typing import Optional, Sequence, Tuple, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import MaskNDArray, Space
|
||||||
|
|
||||||
|
|
||||||
class MultiBinary(Space[np.ndarray]):
|
class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||||
"""An n-shape binary space.
|
"""An n-shape binary space.
|
||||||
|
|
||||||
Elements of this space are binary arrays of a shape that is fixed during construction.
|
Elements of this space are binary arrays of a shape that is fixed during construction.
|
||||||
@@ -25,8 +28,8 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: Union[np.ndarray, Sequence[int], int],
|
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`MultiBinary` space.
|
"""Constructor of :class:`MultiBinary` space.
|
||||||
|
|
||||||
@@ -46,8 +49,8 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
super().__init__(input_n, np.int8, seed)
|
super().__init__(input_n, np.int8, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
"""Has stricter type than gymnasium.Space - never None."""
|
"""Has stricter type than gym.Space - never None."""
|
||||||
return self._shape # type: ignore
|
return self._shape # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -55,7 +58,7 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
|
def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||||
@@ -90,7 +93,7 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
|
|
||||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, Sequence):
|
if isinstance(x, Sequence):
|
||||||
x = np.array(x) # Promote list to array for contains check
|
x = np.array(x) # Promote list to array for contains check
|
||||||
@@ -98,14 +101,18 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
return bool(
|
return bool(
|
||||||
isinstance(x, np.ndarray)
|
isinstance(x, np.ndarray)
|
||||||
and self.shape == x.shape
|
and self.shape == x.shape
|
||||||
and np.all((x == 0) | (x == 1))
|
and np.all(np.logical_or(x == 0, x == 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n) -> list:
|
def to_jsonable(
|
||||||
|
self, sample_n: Sequence[npt.NDArray[np.int8]]
|
||||||
|
) -> list[Sequence[int]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return np.array(sample_n).tolist()
|
return np.array(sample_n).tolist()
|
||||||
|
|
||||||
def from_jsonable(self, sample_n) -> list:
|
def from_jsonable(
|
||||||
|
self, sample_n: list[Sequence[int]]
|
||||||
|
) -> list[npt.NDArray[np.int8]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
return [np.asarray(sample, self.dtype) for sample in sample_n]
|
||||||
|
|
||||||
@@ -113,6 +120,6 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
"""Gives a string representation of this space."""
|
"""Gives a string representation of this space."""
|
||||||
return f"MultiBinary({self.n})"
|
return f"MultiBinary({self.n})"
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether `other` is equivalent to this instance."""
|
"""Check whether `other` is equivalent to this instance."""
|
||||||
return isinstance(other, MultiBinary) and self.n == other.n
|
return isinstance(other, MultiBinary) and self.n == other.n
|
||||||
|
@@ -1,14 +1,17 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
|
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
|
||||||
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
from gymnasium import logger
|
import gymnasium as gym
|
||||||
from gymnasium.spaces.discrete import Discrete
|
from gymnasium.spaces.discrete import Discrete
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import MaskNDArray, Space
|
||||||
|
|
||||||
|
|
||||||
class MultiDiscrete(Space[np.ndarray]):
|
class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||||
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
||||||
|
|
||||||
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
|
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
|
||||||
@@ -29,17 +32,17 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
|
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
|
||||||
>>> d.sample()
|
>> d.sample()
|
||||||
array([[0, 0],
|
array([[0, 0],
|
||||||
[2, 3]])
|
[2, 3]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
nvec: Union[np.ndarray, list],
|
nvec: npt.NDArray[np.integer[Any]] | list[int],
|
||||||
dtype=np.int64,
|
dtype: str | type[np.integer[Any]] = np.int64,
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`MultiDiscrete` space.
|
"""Constructor of :class:`MultiDiscrete` space.
|
||||||
|
|
||||||
@@ -57,8 +60,8 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
super().__init__(self.nvec.shape, dtype, seed)
|
super().__init__(self.nvec.shape, dtype, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
"""Has stricter type than :class:`gymnasium.Space` - never None."""
|
"""Has stricter type than :class:`gym.Space` - never None."""
|
||||||
return self._shape # type: ignore
|
return self._shape # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -66,7 +69,9 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def sample(self, mask: Optional[tuple] = None) -> np.ndarray:
|
def sample(
|
||||||
|
self, mask: tuple[MaskNDArray, ...] | None = None
|
||||||
|
) -> npt.NDArray[np.integer[Any]]:
|
||||||
"""Generates a single random sample this space.
|
"""Generates a single random sample this space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -80,9 +85,9 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
|
||||||
def _apply_mask(
|
def _apply_mask(
|
||||||
sub_mask: Union[np.ndarray, tuple],
|
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
|
||||||
sub_nvec: Union[np.ndarray, np.integer],
|
sub_nvec: MaskNDArray | np.integer[Any],
|
||||||
) -> Union[int, List[int]]:
|
) -> int | Sequence[int]:
|
||||||
if isinstance(sub_nvec, np.ndarray):
|
if isinstance(sub_nvec, np.ndarray):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
sub_mask, tuple
|
sub_mask, tuple
|
||||||
@@ -122,7 +127,7 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
|
|
||||||
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, Sequence):
|
if isinstance(x, Sequence):
|
||||||
x = np.array(x) # Promote list to array for contains check
|
x = np.array(x) # Promote list to array for contains check
|
||||||
@@ -137,11 +142,15 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
and np.all(x < self.nvec)
|
and np.all(x < self.nvec)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: Iterable[np.ndarray]):
|
def to_jsonable(
|
||||||
|
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]]
|
||||||
|
) -> list[Sequence[int]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return [sample.tolist() for sample in sample_n]
|
return [sample.tolist() for sample in sample_n]
|
||||||
|
|
||||||
def from_jsonable(self, sample_n):
|
def from_jsonable(
|
||||||
|
self, sample_n: list[Sequence[int]]
|
||||||
|
) -> list[npt.NDArray[np.integer[Any]]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return np.array(sample_n)
|
return np.array(sample_n)
|
||||||
|
|
||||||
@@ -149,7 +158,7 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
"""Gives a string representation of this space."""
|
"""Gives a string representation of this space."""
|
||||||
return f"MultiDiscrete({self.nvec})"
|
return f"MultiDiscrete({self.nvec})"
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index: int):
|
||||||
"""Extract a subspace from this ``MultiDiscrete`` space."""
|
"""Extract a subspace from this ``MultiDiscrete`` space."""
|
||||||
nvec = self.nvec[index]
|
nvec = self.nvec[index]
|
||||||
if nvec.ndim == 0:
|
if nvec.ndim == 0:
|
||||||
@@ -165,11 +174,13 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Gives the ``len`` of samples from this space."""
|
"""Gives the ``len`` of samples from this space."""
|
||||||
if self.nvec.ndim >= 2:
|
if self.nvec.ndim >= 2:
|
||||||
logger.warn(
|
gym.logger.warn(
|
||||||
"Getting the length of a multi-dimensional MultiDiscrete space."
|
"Getting the length of a multi-dimensional MultiDiscrete space."
|
||||||
)
|
)
|
||||||
return len(self.nvec)
|
return len(self.nvec)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether ``other`` is equivalent to this instance."""
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
return bool(
|
||||||
|
isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||||
|
)
|
||||||
|
@@ -1,20 +1,24 @@
|
|||||||
"""Implementation of a space that represents finite-length sequences."""
|
"""Implementation of a space that represents finite-length sequences."""
|
||||||
from collections.abc import Sequence as CollectionSequence
|
from __future__ import annotations
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
|
||||||
|
import collections.abc
|
||||||
|
import typing
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import Space
|
||||||
|
|
||||||
|
|
||||||
class Sequence(Space[Tuple]):
|
class Sequence(Space[typing.Tuple[Any, ...]]):
|
||||||
r"""This space represent sets of finite-length sequences.
|
r"""This space represent sets of finite-length sequences.
|
||||||
|
|
||||||
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
|
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
|
||||||
to some space that is specified during initialization and the integer :math:`n` is not fixed
|
to some space that is specified during initialization and the integer :math:`n` is not fixed
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
>>> from gymnasium.spaces import Box
|
||||||
>>> space = Sequence(Box(0, 1))
|
>>> space = Sequence(Box(0, 1))
|
||||||
>>> space.sample()
|
>>> space.sample()
|
||||||
(array([0.0259352], dtype=float32),)
|
(array([0.0259352], dtype=float32),)
|
||||||
@@ -24,8 +28,8 @@ class Sequence(Space[Tuple]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
space: Space,
|
space: Space[Any],
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
"""Constructor of the :class:`Sequence` space.
|
"""Constructor of the :class:`Sequence` space.
|
||||||
|
|
||||||
@@ -34,14 +38,14 @@ class Sequence(Space[Tuple]):
|
|||||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
space, gym.Space
|
space, Space
|
||||||
), f"Expects the feature space to be instance of a gymnasium Space, actual type: {type(space)}"
|
), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
|
||||||
self.feature_space = space
|
self.feature_space = space
|
||||||
super().__init__(
|
|
||||||
None, None, seed # type: ignore
|
|
||||||
) # None for shape and dtype, since it'll require special handling
|
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> list:
|
# None for shape and dtype, since it'll require special handling
|
||||||
|
super().__init__(None, None, seed)
|
||||||
|
|
||||||
|
def seed(self, seed: int | None = None) -> list[int]:
|
||||||
"""Seed the PRNG of this space and the feature space."""
|
"""Seed the PRNG of this space and the feature space."""
|
||||||
seeds = super().seed(seed)
|
seeds = super().seed(seed)
|
||||||
seeds += self.feature_space.seed(seed)
|
seeds += self.feature_space.seed(seed)
|
||||||
@@ -54,8 +58,14 @@ class Sequence(Space[Tuple]):
|
|||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
mask: Optional[Tuple[Optional[Union[np.ndarray, int]], Optional[Any]]] = None,
|
mask: None
|
||||||
) -> Tuple[Any]:
|
| (
|
||||||
|
tuple[
|
||||||
|
None | np.integer | npt.NDArray[np.integer],
|
||||||
|
Any,
|
||||||
|
]
|
||||||
|
) = None,
|
||||||
|
) -> tuple[Any]:
|
||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -89,6 +99,9 @@ class Sequence(Space[Tuple]):
|
|||||||
assert np.all(
|
assert np.all(
|
||||||
0 <= length_mask
|
0 <= length_mask
|
||||||
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
|
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
|
||||||
|
assert np.issubdtype(
|
||||||
|
length_mask.dtype, np.integer
|
||||||
|
), f"Expects the length mask array to have dtype to be an numpy integer, actual type: {length_mask.dtype}"
|
||||||
length = self.np_random.choice(length_mask)
|
length = self.np_random.choice(length_mask)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@@ -102,9 +115,9 @@ class Sequence(Space[Tuple]):
|
|||||||
self.feature_space.sample(mask=feature_mask) for _ in range(length)
|
self.feature_space.sample(mask=feature_mask) for _ in range(length)
|
||||||
)
|
)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
return isinstance(x, CollectionSequence) and all(
|
return isinstance(x, collections.abc.Sequence) and all(
|
||||||
self.feature_space.contains(item) for item in x
|
self.feature_space.contains(item) for item in x
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -112,15 +125,17 @@ class Sequence(Space[Tuple]):
|
|||||||
"""Gives a string representation of this space."""
|
"""Gives a string representation of this space."""
|
||||||
return f"Sequence({self.feature_space})"
|
return f"Sequence({self.feature_space})"
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: list) -> list:
|
def to_jsonable(
|
||||||
|
self, sample_n: typing.Sequence[tuple[Any, ...]]
|
||||||
|
) -> list[list[Any]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# serialize as dict-repr of vectors
|
# serialize as dict-repr of vectors
|
||||||
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
|
return [self.feature_space.to_jsonable(list(sample)) for sample in sample_n]
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: List[List[Any]]) -> list:
|
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
|
return [tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n]
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether ``other`` is equivalent to this instance."""
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return isinstance(other, Sequence) and self.feature_space == other.feature_space
|
return isinstance(other, Sequence) and self.feature_space == other.feature_space
|
||||||
|
@@ -1,26 +1,19 @@
|
|||||||
"""Implementation of the `Space` metaclass."""
|
"""Implementation of the `Space` metaclass."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import (
|
from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar
|
||||||
Any,
|
|
||||||
Generic,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
from gymnasium.utils import seeding
|
from gymnasium.utils import seeding
|
||||||
|
|
||||||
T_cov = TypeVar("T_cov", covariant=True)
|
T_cov = TypeVar("T_cov", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
MaskNDArray = npt.NDArray[np.int8]
|
||||||
|
|
||||||
|
|
||||||
class Space(Generic[T_cov]):
|
class Space(Generic[T_cov]):
|
||||||
"""Superclass that is used to define observation and action spaces.
|
"""Superclass that is used to define observation and action spaces.
|
||||||
|
|
||||||
@@ -41,17 +34,17 @@ class Space(Generic[T_cov]):
|
|||||||
class. However, most use-cases should be covered by the existing space
|
class. However, most use-cases should be covered by the existing space
|
||||||
classes (e.g. :class:`Box`, :class:`Discrete`, etc...), and container classes (:class`Tuple` &
|
classes (e.g. :class:`Box`, :class:`Discrete`, etc...), and container classes (:class`Tuple` &
|
||||||
:class:`Dict`). Note that parametrized probability distributions (through the
|
:class:`Dict`). Note that parametrized probability distributions (through the
|
||||||
:meth:`Space.sample()` method), and batching functions (in :class:`gymnasium.vector.VectorEnv`), are
|
:meth:`Space.sample()` method), and batching functions (in :class:`gym.vector.VectorEnv`), are
|
||||||
only well-defined for instances of spaces provided in gymnasium by default.
|
only well-defined for instances of spaces provided in gym by default.
|
||||||
Moreover, some implementations of Reinforcement Learning algorithms might
|
Moreover, some implementations of Reinforcement Learning algorithms might
|
||||||
not handle custom spaces properly. Use custom spaces with care.
|
not handle custom spaces properly. Use custom spaces with care.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shape: Optional[Sequence[int]] = None,
|
shape: Sequence[int] | None = None,
|
||||||
dtype: Optional[Union[Type, str, np.dtype]] = None,
|
dtype: npt.DTypeLike | None = None,
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`Space`.
|
"""Constructor of :class:`Space`.
|
||||||
|
|
||||||
@@ -71,23 +64,31 @@ class Space(Generic[T_cov]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def np_random(self) -> np.random.Generator:
|
def np_random(self) -> np.random.Generator:
|
||||||
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space."""
|
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space.
|
||||||
|
|
||||||
|
As :meth:`seed` is not guaranteed to set the `_np_random` for particular seeds. We add a
|
||||||
|
check after :meth:`seed` to set a new random number generator.
|
||||||
|
"""
|
||||||
if self._np_random is None:
|
if self._np_random is None:
|
||||||
self.seed()
|
self.seed()
|
||||||
|
|
||||||
return self._np_random # type: ignore ## self.seed() call guarantees right type.
|
# As `seed` is not guaranteed (in particular for composite spaces) to set the `_np_random` then we set it randomly.
|
||||||
|
if self._np_random is None:
|
||||||
|
self._np_random, _ = seeding.np_random()
|
||||||
|
|
||||||
|
return self._np_random
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Optional[Tuple[int, ...]]:
|
def shape(self) -> tuple[int, ...] | None:
|
||||||
"""Return the shape of the space as an immutable property."""
|
"""Return the shape of the space as an immutable property."""
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_np_flattenable(self):
|
def is_np_flattenable(self) -> bool:
|
||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def sample(self, mask: Optional[Any] = None) -> T_cov:
|
def sample(self, mask: Any | None = None) -> T_cov:
|
||||||
"""Randomly sample an element of this space.
|
"""Randomly sample an element of this space.
|
||||||
|
|
||||||
Can be uniform or non-uniform sampling based on boundedness of space.
|
Can be uniform or non-uniform sampling based on boundedness of space.
|
||||||
@@ -100,20 +101,20 @@ class Space(Generic[T_cov]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> list:
|
def seed(self, seed: int | None = None) -> list[int]:
|
||||||
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
||||||
self._np_random, seed = seeding.np_random(seed)
|
self._np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __contains__(self, x) -> bool:
|
def __contains__(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
return self.contains(x)
|
return self.contains(x)
|
||||||
|
|
||||||
def __setstate__(self, state: Union[Iterable, Mapping]):
|
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
|
||||||
"""Used when loading a pickled space.
|
"""Used when loading a pickled space.
|
||||||
|
|
||||||
This method was implemented explicitly to allow for loading of legacy states.
|
This method was implemented explicitly to allow for loading of legacy states.
|
||||||
@@ -130,7 +131,7 @@ class Space(Generic[T_cov]):
|
|||||||
# https://github.com/openai/gym/pull/1913 -- np_random
|
# https://github.com/openai/gym/pull/1913 -- np_random
|
||||||
#
|
#
|
||||||
if "shape" in state:
|
if "shape" in state:
|
||||||
state["_shape"] = state["shape"]
|
state["_shape"] = state.get("shape")
|
||||||
del state["shape"]
|
del state["shape"]
|
||||||
if "np_random" in state:
|
if "np_random" in state:
|
||||||
state["_np_random"] = state["np_random"]
|
state["_np_random"] = state["np_random"]
|
||||||
@@ -139,12 +140,12 @@ class Space(Generic[T_cov]):
|
|||||||
# Update our state
|
# Update our state
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list:
|
def to_jsonable(self, sample_n: Sequence[T_cov]) -> list[Any]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# By default, assume identity is JSONable
|
# By default, assume identity is JSONable
|
||||||
return list(sample_n)
|
return list(sample_n)
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: list) -> List[T_cov]:
|
def from_jsonable(self, sample_n: list[Any]) -> list[T_cov]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
# By default, assume identity is JSONable
|
# By default, assume identity is JSONable
|
||||||
return sample_n
|
return sample_n
|
||||||
|
@@ -1,11 +1,14 @@
|
|||||||
"""Implementation of a space that represents textual strings."""
|
"""Implementation of a space that represents textual strings."""
|
||||||
from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import Space
|
||||||
|
|
||||||
alphanumeric: FrozenSet[str] = frozenset(
|
alphanumeric: frozenset[str] = frozenset(
|
||||||
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,7 +17,6 @@ class Text(Space[str]):
|
|||||||
r"""A space representing a string comprised of characters from a given charset.
|
r"""A space representing a string comprised of characters from a given charset.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> # {"", "B5", "hello", ...}
|
>>> # {"", "B5", "hello", ...}
|
||||||
>>> Text(5)
|
>>> Text(5)
|
||||||
>>> # {"0", "42", "0123456789", ...}
|
>>> # {"0", "42", "0123456789", ...}
|
||||||
@@ -29,8 +31,8 @@ class Text(Space[str]):
|
|||||||
max_length: int,
|
max_length: int,
|
||||||
*,
|
*,
|
||||||
min_length: int = 1,
|
min_length: int = 1,
|
||||||
charset: Union[Set[str], str] = alphanumeric,
|
charset: set[str] | str = alphanumeric,
|
||||||
seed: Optional[Union[int, np.random.Generator]] = None,
|
seed: int | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Text` space.
|
r"""Constructor of :class:`Text` space.
|
||||||
|
|
||||||
@@ -58,9 +60,9 @@ class Text(Space[str]):
|
|||||||
self.min_length: int = int(min_length)
|
self.min_length: int = int(min_length)
|
||||||
self.max_length: int = int(max_length)
|
self.max_length: int = int(max_length)
|
||||||
|
|
||||||
self._char_set: FrozenSet[str] = frozenset(charset)
|
self._char_set: frozenset[str] = frozenset(charset)
|
||||||
self._char_list: Tuple[str, ...] = tuple(charset)
|
self._char_list: tuple[str, ...] = tuple(charset)
|
||||||
self._char_index: Dict[str, np.int32] = {
|
self._char_index: dict[str, np.int32] = {
|
||||||
val: np.int32(i) for i, val in enumerate(tuple(charset))
|
val: np.int32(i) for i, val in enumerate(tuple(charset))
|
||||||
}
|
}
|
||||||
self._char_str: str = "".join(sorted(tuple(charset)))
|
self._char_str: str = "".join(sorted(tuple(charset)))
|
||||||
@@ -69,7 +71,8 @@ class Text(Space[str]):
|
|||||||
super().__init__(dtype=str, seed=seed)
|
super().__init__(dtype=str, seed=seed)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self, mask: Optional[Tuple[Optional[int], Optional[np.ndarray]]] = None
|
self,
|
||||||
|
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
||||||
|
|
||||||
@@ -152,7 +155,7 @@ class Text(Space[str]):
|
|||||||
f"Text({self.min_length}, {self.max_length}, characters={self.characters})"
|
f"Text({self.min_length}, {self.max_length}, characters={self.characters})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether ``other`` is equivalent to this instance."""
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return (
|
return (
|
||||||
isinstance(other, Text)
|
isinstance(other, Text)
|
||||||
@@ -162,12 +165,12 @@ class Text(Space[str]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def character_set(self) -> FrozenSet[str]:
|
def character_set(self) -> frozenset[str]:
|
||||||
"""Returns the character set for the space."""
|
"""Returns the character set for the space."""
|
||||||
return self._char_set
|
return self._char_set
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def character_list(self) -> Tuple[str, ...]:
|
def character_list(self) -> tuple[str, ...]:
|
||||||
"""Returns a tuple of characters in the space."""
|
"""Returns a tuple of characters in the space."""
|
||||||
return self._char_list
|
return self._char_list
|
||||||
|
|
||||||
|
@@ -1,16 +1,16 @@
|
|||||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||||
from collections.abc import Sequence as CollectionSequence
|
from __future__ import annotations
|
||||||
from typing import Iterable, Optional
|
|
||||||
from typing import Sequence as TypingSequence
|
import collections.abc
|
||||||
from typing import Tuple as TypingTuple
|
import typing
|
||||||
from typing import Union
|
from typing import Any, Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium.spaces.space import Space
|
from gymnasium.spaces.space import Space
|
||||||
|
|
||||||
|
|
||||||
class Tuple(Space[tuple], CollectionSequence):
|
class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
||||||
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
|
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
|
||||||
|
|
||||||
Elements of this space are tuples of elements of the constituent spaces.
|
Elements of this space are tuples of elements of the constituent spaces.
|
||||||
@@ -25,8 +25,8 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spaces: Iterable[Space],
|
spaces: Iterable[Space[Any]],
|
||||||
seed: Optional[Union[int, TypingSequence[int], np.random.Generator]] = None,
|
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Tuple` space.
|
r"""Constructor of :class:`Tuple` space.
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
for space in self.spaces:
|
for space in self.spaces:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
space, Space
|
space, Space
|
||||||
), "Elements of the tuple must be instances of gymnasium.Space"
|
), "Elements of the tuple must be instances of gym.Space"
|
||||||
super().__init__(None, None, seed) # type: ignore
|
super().__init__(None, None, seed) # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -48,9 +48,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||||
return all(space.is_np_flattenable for space in self.spaces)
|
return all(space.is_np_flattenable for space in self.spaces)
|
||||||
|
|
||||||
def seed(
|
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
|
||||||
self, seed: Optional[Union[int, TypingSequence[int]]] = None
|
|
||||||
) -> TypingSequence[int]:
|
|
||||||
"""Seed the PRNG of this space and all subspaces.
|
"""Seed the PRNG of this space and all subspaces.
|
||||||
|
|
||||||
Depending on the type of seed, the subspaces will be seeded differently
|
Depending on the type of seed, the subspaces will be seeded differently
|
||||||
@@ -61,9 +59,9 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
Args:
|
Args:
|
||||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||||
"""
|
"""
|
||||||
seeds = []
|
seeds: list[int] = []
|
||||||
|
|
||||||
if isinstance(seed, CollectionSequence):
|
if isinstance(seed, collections.abc.Sequence):
|
||||||
assert len(seed) == len(
|
assert len(seed) == len(
|
||||||
self.spaces
|
self.spaces
|
||||||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
|
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seeds)}, length of subspaces: {len(self.spaces)}"
|
||||||
@@ -86,9 +84,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
|
|
||||||
return seeds
|
return seeds
|
||||||
|
|
||||||
def sample(
|
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
|
||||||
self, mask: Optional[TypingTuple[Optional[np.ndarray], ...]] = None
|
|
||||||
) -> tuple:
|
|
||||||
"""Generates a single random sample inside this space.
|
"""Generates a single random sample inside this space.
|
||||||
|
|
||||||
This method draws independent samples from the subspaces.
|
This method draws independent samples from the subspaces.
|
||||||
@@ -115,10 +111,11 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
|
|
||||||
return tuple(space.sample() for space in self.spaces)
|
return tuple(space.sample() for space in self.spaces)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x: Any) -> bool:
|
||||||
"""Return boolean specifying if x is a valid member of this space."""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, (list, np.ndarray)):
|
if isinstance(x, (list, np.ndarray)):
|
||||||
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
||||||
|
|
||||||
return (
|
return (
|
||||||
isinstance(x, tuple)
|
isinstance(x, tuple)
|
||||||
and len(x) == len(self.spaces)
|
and len(x) == len(self.spaces)
|
||||||
@@ -129,7 +126,9 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
"""Gives a string representation of this space."""
|
"""Gives a string representation of this space."""
|
||||||
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: CollectionSequence) -> list:
|
def to_jsonable(
|
||||||
|
self, sample_n: typing.Sequence[tuple[Any, ...]]
|
||||||
|
) -> list[list[Any]]:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# serialize as list-repr of tuple of vectors
|
# serialize as list-repr of tuple of vectors
|
||||||
return [
|
return [
|
||||||
@@ -137,7 +136,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
for i, space in enumerate(self.spaces)
|
for i, space in enumerate(self.spaces)
|
||||||
]
|
]
|
||||||
|
|
||||||
def from_jsonable(self, sample_n) -> list:
|
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
||||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [
|
return [
|
||||||
sample
|
sample
|
||||||
@@ -149,7 +148,7 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Space:
|
def __getitem__(self, index: int) -> Space[Any]:
|
||||||
"""Get the subspace at specific `index`."""
|
"""Get the subspace at specific `index`."""
|
||||||
return self.spaces[index]
|
return self.spaces[index]
|
||||||
|
|
||||||
@@ -157,6 +156,6 @@ class Tuple(Space[tuple], CollectionSequence):
|
|||||||
"""Get the number of subspaces that are involved in the cartesian product."""
|
"""Get the number of subspaces that are involved in the cartesian product."""
|
||||||
return len(self.spaces)
|
return len(self.spaces)
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check whether ``other`` is equivalent to this instance."""
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return isinstance(other, Tuple) and self.spaces == other.spaces
|
return isinstance(other, Tuple) and self.spaces == other.spaces
|
||||||
|
@@ -3,13 +3,16 @@
|
|||||||
These functions mostly take care of flattening and unflattening elements of spaces
|
These functions mostly take care of flattening and unflattening elements of spaces
|
||||||
to facilitate their usage in learning code.
|
to facilitate their usage in learning code.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import operator as op
|
import operator as op
|
||||||
|
import typing
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import reduce, singledispatch
|
from functools import reduce, singledispatch
|
||||||
from typing import Dict as TypingDict
|
from typing import Any, TypeVar, Union, cast
|
||||||
from typing import TypeVar, Union, cast
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from gymnasium.spaces import (
|
from gymnasium.spaces import (
|
||||||
Box,
|
Box,
|
||||||
@@ -27,12 +30,12 @@ from gymnasium.spaces import (
|
|||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def flatdim(space: Space) -> int:
|
def flatdim(space: Space[Any]) -> int:
|
||||||
"""Return the number of dimensions a flattened equivalent of this space would have.
|
"""Return the number of dimensions a flattened equivalent of this space would have.
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
>>> from gymnasium.spaces import Discrete
|
>>> from gymnasium.spaces import Discrete, Dict
|
||||||
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||||
>>> flatdim(space)
|
>>> flatdim(space)
|
||||||
5
|
5
|
||||||
@@ -47,7 +50,7 @@ def flatdim(space: Space) -> int:
|
|||||||
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||||
ValueError: if the space cannot be flattened into a :class:`Box`
|
ValueError: if the space cannot be flattened into a :class:`Box`
|
||||||
"""
|
"""
|
||||||
if not space.is_np_flattenable:
|
if space.is_np_flattenable is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
|
||||||
)
|
)
|
||||||
@@ -57,7 +60,7 @@ def flatdim(space: Space) -> int:
|
|||||||
|
|
||||||
@flatdim.register(Box)
|
@flatdim.register(Box)
|
||||||
@flatdim.register(MultiBinary)
|
@flatdim.register(MultiBinary)
|
||||||
def _flatdim_box_multibinary(space: Union[Box, MultiBinary]) -> int:
|
def _flatdim_box_multibinary(space: Box | MultiBinary) -> int:
|
||||||
return reduce(op.mul, space.shape, 1)
|
return reduce(op.mul, space.shape, 1)
|
||||||
|
|
||||||
|
|
||||||
@@ -102,7 +105,9 @@ def _flatdim_text(space: Text) -> int:
|
|||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
FlatType = Union[np.ndarray, TypingDict, tuple, GraphInstance]
|
FlatType = Union[
|
||||||
|
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
@@ -112,6 +117,19 @@ def flatten(space: Space[T], x: T) -> FlatType:
|
|||||||
This is useful when e.g. points from spaces must be passed to a neural
|
This is useful when e.g. points from spaces must be passed to a neural
|
||||||
network, which only understands flat arrays of floats.
|
network, which only understands flat arrays of floats.
|
||||||
|
|
||||||
|
Example usage::
|
||||||
|
>>> from gymnasium.spaces import Box, Discrete, Tuple
|
||||||
|
>>> space = Box(0, 1, shape=(3, 5))
|
||||||
|
>>> flatten(space, space.sample()).shape
|
||||||
|
(15,)
|
||||||
|
>>> space = Discrete(4)
|
||||||
|
>>> flatten(space, 2)
|
||||||
|
array([0, 0, 1, 0])
|
||||||
|
>>> space = Tuple((Box(0, 1, shape=(2,)), Box(0, 1, shape=(3,)), Discrete(3)))
|
||||||
|
>>> example = ((.5, .25), (1., 0., .2), 1)
|
||||||
|
>>> flatten(space, example)
|
||||||
|
array([0.5 , 0.25, 1. , 0. , 0.2 , 0. , 1. , 0. ])
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
space: The space that ``x`` is flattened by
|
space: The space that ``x`` is flattened by
|
||||||
x: The value to flatten
|
x: The value to flatten
|
||||||
@@ -137,19 +155,21 @@ def flatten(space: Space[T], x: T) -> FlatType:
|
|||||||
|
|
||||||
@flatten.register(Box)
|
@flatten.register(Box)
|
||||||
@flatten.register(MultiBinary)
|
@flatten.register(MultiBinary)
|
||||||
def _flatten_box_multibinary(space, x) -> np.ndarray:
|
def _flatten_box_multibinary(space: Box | MultiBinary, x: NDArray[Any]) -> NDArray[Any]:
|
||||||
return np.asarray(x, dtype=space.dtype).flatten()
|
return np.asarray(x, dtype=space.dtype).flatten()
|
||||||
|
|
||||||
|
|
||||||
@flatten.register(Discrete)
|
@flatten.register(Discrete)
|
||||||
def _flatten_discrete(space, x) -> np.ndarray:
|
def _flatten_discrete(space: Discrete, x: int) -> NDArray[np.int64]:
|
||||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||||
onehot[x - space.start] = 1
|
onehot[x - space.start] = 1
|
||||||
return onehot
|
return onehot
|
||||||
|
|
||||||
|
|
||||||
@flatten.register(MultiDiscrete)
|
@flatten.register(MultiDiscrete)
|
||||||
def _flatten_multidiscrete(space, x) -> np.ndarray:
|
def _flatten_multidiscrete(
|
||||||
|
space: MultiDiscrete, x: NDArray[np.int64]
|
||||||
|
) -> NDArray[np.int64]:
|
||||||
offsets = np.zeros((space.nvec.size + 1,), dtype=np.int32)
|
offsets = np.zeros((space.nvec.size + 1,), dtype=np.int32)
|
||||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||||
|
|
||||||
@@ -159,7 +179,7 @@ def _flatten_multidiscrete(space, x) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
@flatten.register(Tuple)
|
@flatten.register(Tuple)
|
||||||
def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
|
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
return np.concatenate(
|
return np.concatenate(
|
||||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
||||||
@@ -168,22 +188,26 @@ def _flatten_tuple(space, x) -> Union[tuple, np.ndarray]:
|
|||||||
|
|
||||||
|
|
||||||
@flatten.register(Dict)
|
@flatten.register(Dict)
|
||||||
def _flatten_dict(space, x) -> Union[dict, np.ndarray]:
|
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||||
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())
|
||||||
|
|
||||||
|
|
||||||
@flatten.register(Graph)
|
@flatten.register(Graph)
|
||||||
def _flatten_graph(space, x) -> GraphInstance:
|
def _flatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
||||||
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
"""We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
|
||||||
|
|
||||||
def _graph_unflatten(unflatten_space, unflatten_x):
|
def _graph_unflatten(
|
||||||
|
unflatten_space: Discrete | Box | None,
|
||||||
|
unflatten_x: NDArray[Any] | None,
|
||||||
|
) -> NDArray[Any] | None:
|
||||||
ret = None
|
ret = None
|
||||||
if unflatten_space is not None and unflatten_x is not None:
|
if unflatten_space is not None and unflatten_x is not None:
|
||||||
if isinstance(unflatten_space, Box):
|
if isinstance(unflatten_space, Box):
|
||||||
ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
|
ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
|
||||||
elif isinstance(unflatten_space, Discrete):
|
else:
|
||||||
|
assert isinstance(unflatten_space, Discrete)
|
||||||
ret = np.zeros(
|
ret = np.zeros(
|
||||||
(unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
|
(unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
|
||||||
dtype=unflatten_space.dtype,
|
dtype=unflatten_space.dtype,
|
||||||
@@ -194,13 +218,14 @@ def _flatten_graph(space, x) -> GraphInstance:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||||
|
assert nodes is not None
|
||||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||||
|
|
||||||
return GraphInstance(nodes, edges, x.edge_links)
|
return GraphInstance(nodes, edges, x.edge_links)
|
||||||
|
|
||||||
|
|
||||||
@flatten.register(Text)
|
@flatten.register(Text)
|
||||||
def _flatten_text(space: Text, x: str) -> np.ndarray:
|
def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
|
||||||
arr = np.full(
|
arr = np.full(
|
||||||
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
|
shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
|
||||||
)
|
)
|
||||||
@@ -210,7 +235,7 @@ def _flatten_text(space: Text, x: str) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
@flatten.register(Sequence)
|
@flatten.register(Sequence)
|
||||||
def _flatten_sequence(space, x) -> tuple:
|
def _flatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||||
return tuple(flatten(space.feature_space, item) for item in x)
|
return tuple(flatten(space.feature_space, item) for item in x)
|
||||||
|
|
||||||
|
|
||||||
@@ -237,18 +262,20 @@ def unflatten(space: Space[T], x: FlatType) -> T:
|
|||||||
@unflatten.register(Box)
|
@unflatten.register(Box)
|
||||||
@unflatten.register(MultiBinary)
|
@unflatten.register(MultiBinary)
|
||||||
def _unflatten_box_multibinary(
|
def _unflatten_box_multibinary(
|
||||||
space: Union[Box, MultiBinary], x: np.ndarray
|
space: Box | MultiBinary, x: NDArray[Any]
|
||||||
) -> np.ndarray:
|
) -> NDArray[Any]:
|
||||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||||
|
|
||||||
|
|
||||||
@unflatten.register(Discrete)
|
@unflatten.register(Discrete)
|
||||||
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> int:
|
||||||
return int(space.start + np.nonzero(x)[0][0])
|
return int(space.start + np.nonzero(x)[0][0])
|
||||||
|
|
||||||
|
|
||||||
@unflatten.register(MultiDiscrete)
|
@unflatten.register(MultiDiscrete)
|
||||||
def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
def _unflatten_multidiscrete(
|
||||||
|
space: MultiDiscrete, x: NDArray[np.int32]
|
||||||
|
) -> NDArray[np.int32]:
|
||||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||||
|
|
||||||
@@ -257,7 +284,9 @@ def _unflatten_multidiscrete(space: MultiDiscrete, x: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
@unflatten.register(Tuple)
|
@unflatten.register(Tuple)
|
||||||
def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
|
def _unflatten_tuple(
|
||||||
|
space: Tuple, x: NDArray[Any] | tuple[Any, ...]
|
||||||
|
) -> tuple[Any, ...]:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
x, np.ndarray
|
x, np.ndarray
|
||||||
@@ -275,7 +304,7 @@ def _unflatten_tuple(space: Tuple, x: Union[np.ndarray, tuple]) -> tuple:
|
|||||||
|
|
||||||
|
|
||||||
@unflatten.register(Dict)
|
@unflatten.register(Dict)
|
||||||
def _unflatten_dict(space: Dict, x: Union[np.ndarray, TypingDict]) -> dict:
|
def _unflatten_dict(space: Dict, x: NDArray[Any] | dict[str, Any]) -> dict[str, Any]:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
|
dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
|
||||||
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
list_flattened = np.split(x, np.cumsum(dims[:-1]))
|
||||||
@@ -299,14 +328,14 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
|||||||
nodes and edges in the graph.
|
nodes and edges in the graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _graph_unflatten(space, x):
|
def _graph_unflatten(unflatten_space, unflatten_x):
|
||||||
ret = None
|
result = None
|
||||||
if space is not None and x is not None:
|
if unflatten_space is not None and unflatten_x is not None:
|
||||||
if isinstance(space, Box):
|
if isinstance(unflatten_space, Box):
|
||||||
ret = x.reshape(-1, *space.shape)
|
result = unflatten_x.reshape(-1, *unflatten_space.shape)
|
||||||
elif isinstance(space, Discrete):
|
elif isinstance(unflatten_space, Discrete):
|
||||||
ret = np.asarray(np.nonzero(x))[-1, :]
|
result = np.asarray(np.nonzero(unflatten_x))[-1, :]
|
||||||
return ret
|
return result
|
||||||
|
|
||||||
nodes = _graph_unflatten(space.node_space, x.nodes)
|
nodes = _graph_unflatten(space.node_space, x.nodes)
|
||||||
edges = _graph_unflatten(space.edge_space, x.edges)
|
edges = _graph_unflatten(space.edge_space, x.edges)
|
||||||
@@ -315,19 +344,19 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
|
|||||||
|
|
||||||
|
|
||||||
@unflatten.register(Text)
|
@unflatten.register(Text)
|
||||||
def _unflatten_text(space: Text, x: np.ndarray) -> str:
|
def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
|
||||||
return "".join(
|
return "".join(
|
||||||
[space.character_list[val] for val in x if val < len(space.character_set)]
|
[space.character_list[val] for val in x if val < len(space.character_set)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@unflatten.register(Sequence)
|
@unflatten.register(Sequence)
|
||||||
def _unflatten_sequence(space: Sequence, x: tuple) -> tuple:
|
def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||||
"""Flatten a space into a space that is as flat as possible.
|
"""Flatten a space into a space that is as flat as possible.
|
||||||
|
|
||||||
This function will attempt to flatten `space` into a single :class:`Box` space.
|
This function will attempt to flatten `space` into a single :class:`Box` space.
|
||||||
@@ -342,7 +371,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
|||||||
space.
|
space.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
>>> from gymnasium.spaces import Box
|
||||||
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
||||||
>>> box
|
>>> box
|
||||||
Box(3, 4, 5)
|
Box(3, 4, 5)
|
||||||
@@ -352,7 +381,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
|||||||
True
|
True
|
||||||
|
|
||||||
Example that flattens a discrete space::
|
Example that flattens a discrete space::
|
||||||
|
>>> from gymnasium.spaces import Discrete
|
||||||
>>> discrete = Discrete(5)
|
>>> discrete = Discrete(5)
|
||||||
>>> flatten_space(discrete)
|
>>> flatten_space(discrete)
|
||||||
Box(5,)
|
Box(5,)
|
||||||
@@ -360,7 +389,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
|||||||
True
|
True
|
||||||
|
|
||||||
Example that recursively flattens a dict::
|
Example that recursively flattens a dict::
|
||||||
|
>>> from gymnasium.spaces import Dict, Discrete, Box
|
||||||
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
|
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
|
||||||
>>> flatten_space(space)
|
>>> flatten_space(space)
|
||||||
Box(6,)
|
Box(6,)
|
||||||
@@ -383,7 +412,7 @@ def flatten_space(space: Space) -> Union[Dict, Sequence, Tuple, Graph]:
|
|||||||
A flattened Box
|
A flattened Box
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: if the space is not defined in ``gymnasium.spaces``.
|
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||||
|
|
||||||
@@ -396,12 +425,12 @@ def _flatten_space_box(space: Box) -> Box:
|
|||||||
@flatten_space.register(Discrete)
|
@flatten_space.register(Discrete)
|
||||||
@flatten_space.register(MultiBinary)
|
@flatten_space.register(MultiBinary)
|
||||||
@flatten_space.register(MultiDiscrete)
|
@flatten_space.register(MultiDiscrete)
|
||||||
def _flatten_space_binary(space: Union[Discrete, MultiBinary, MultiDiscrete]) -> Box:
|
def _flatten_space_binary(space: Discrete | MultiBinary | MultiDiscrete) -> Box:
|
||||||
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
|
return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
|
||||||
|
|
||||||
|
|
||||||
@flatten_space.register(Tuple)
|
@flatten_space.register(Tuple)
|
||||||
def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
|
def _flatten_space_tuple(space: Tuple) -> Box | Tuple:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
space_list = [flatten_space(s) for s in space.spaces]
|
space_list = [flatten_space(s) for s in space.spaces]
|
||||||
return Box(
|
return Box(
|
||||||
@@ -413,7 +442,7 @@ def _flatten_space_tuple(space: Tuple) -> Union[Box, Tuple]:
|
|||||||
|
|
||||||
|
|
||||||
@flatten_space.register(Dict)
|
@flatten_space.register(Dict)
|
||||||
def _flatten_space_dict(space: Dict) -> Union[Box, Dict]:
|
def _flatten_space_dict(space: Dict) -> Box | Dict:
|
||||||
if space.is_np_flattenable:
|
if space.is_np_flattenable:
|
||||||
space_list = [flatten_space(s) for s in space.spaces.values()]
|
space_list = [flatten_space(s) for s in space.spaces.values()]
|
||||||
return Box(
|
return Box(
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
numpy>=1.18.0
|
numpy>=1.21.0
|
||||||
cloudpickle>=1.2.0
|
cloudpickle>=1.2.0
|
||||||
importlib_metadata>=4.8.0; python_version < '3.10'
|
importlib_metadata>=4.8.0; python_version < '3.10'
|
||||||
gymnasium_notices>=0.0.1
|
gymnasium_notices>=0.0.1
|
||||||
|
2
setup.py
2
setup.py
@@ -85,7 +85,7 @@ setup(
|
|||||||
},
|
},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"numpy >= 1.18.0",
|
"numpy >= 1.21.0",
|
||||||
"cloudpickle >= 1.2.0",
|
"cloudpickle >= 1.2.0",
|
||||||
"importlib_metadata >= 4.8.0; python_version < '3.10'",
|
"importlib_metadata >= 4.8.0; python_version < '3.10'",
|
||||||
"gymnasium_notices >= 0.0.1",
|
"gymnasium_notices >= 0.0.1",
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
import re
|
||||||
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -25,18 +27,18 @@ def test_dict_init():
|
|||||||
):
|
):
|
||||||
Dict(a=Discrete(2), b="Box")
|
Dict(a=Discrete(2), b="Box")
|
||||||
|
|
||||||
with pytest.warns(None) as warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
|
a = Dict({"a": Discrete(2), "b": Box(low=0.0, high=1.0)})
|
||||||
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
|
b = Dict(OrderedDict(a=Discrete(2), b=Box(low=0.0, high=1.0)))
|
||||||
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
|
c = Dict((("a", Discrete(2)), ("b", Box(low=0.0, high=1.0))))
|
||||||
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
|
d = Dict(a=Discrete(2), b=Box(low=0.0, high=1.0))
|
||||||
|
|
||||||
assert a == b == c == d
|
assert a == b == c == d
|
||||||
assert len(warnings) == 0
|
assert len(caught_warnings) == 0
|
||||||
|
|
||||||
with pytest.warns(None) as warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
Dict({1: Discrete(2), "a": Discrete(3)})
|
Dict({1: Discrete(2), "a": Discrete(3)})
|
||||||
assert len(warnings) == 0
|
assert len(caught_warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
DICT_SPACE = Dict(
|
DICT_SPACE = Dict(
|
||||||
@@ -109,7 +111,12 @@ def test_none_seeding():
|
|||||||
|
|
||||||
|
|
||||||
def test_bad_seed():
|
def test_bad_seed():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(
|
||||||
|
TypeError,
|
||||||
|
match=re.escape(
|
||||||
|
"Expected seed type: dict, int or None, actual type: <class 'str'>"
|
||||||
|
),
|
||||||
|
):
|
||||||
DICT_SPACE.seed("a")
|
DICT_SPACE.seed("a")
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user