mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-27 00:37:19 +00:00
Pydocstyle spaces docstring (#2798)
* Added docstrings for spaces, WIP * Formatting changes * Use raw docstring for Box.sample * Formatting fix * Formatting fix * Use :class:, :meth:, formatting fixes, resolve TODO, use Optional
This commit is contained in:
@@ -1,3 +1,13 @@
|
|||||||
|
"""This module implements various spaces.
|
||||||
|
|
||||||
|
Spaces describe mathematical sets and are used in Gym to specify valid actions and observations.
|
||||||
|
Every Gym environment must have the attributes ``action_space`` and ``observation_space``.
|
||||||
|
If, for instance, three possible actions (0,1,2) can be performed in your environment and observations
|
||||||
|
are vectors in the two-dimensional unit cube, the environment code may contain the following two lines::
|
||||||
|
|
||||||
|
self.action_space = spaces.Discrete(3)
|
||||||
|
self.observation_space = spaces.Box(0, 1, shape=(2,))
|
||||||
|
"""
|
||||||
from gym.spaces.box import Box
|
from gym.spaces.box import Box
|
||||||
from gym.spaces.dict import Dict
|
from gym.spaces.dict import Dict
|
||||||
from gym.spaces.discrete import Discrete
|
from gym.spaces.discrete import Discrete
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Implementation of a space that represents closed boxes in euclidean space."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
from typing import Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
||||||
@@ -21,10 +22,11 @@ def _short_repr(arr: np.ndarray) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class Box(Space[np.ndarray]):
|
class Box(Space[np.ndarray]):
|
||||||
"""
|
r"""A (possibly unbounded) box in :math:`\mathbb{R}^n`.
|
||||||
A (possibly unbounded) box in R^n. Specifically, a Box represents the
|
|
||||||
Cartesian product of n closed intervals. Each interval has the form of one
|
Specifically, a Box represents the Cartesian product of n closed intervals.
|
||||||
of [a, b], (-oo, b], [a, oo), or (-oo, oo).
|
Each interval has the form of one of :math:`[a, b]`, :math:`(-\infty, b]`,
|
||||||
|
:math:`[a, \infty)`, or :math:`(-\infty, \infty)`.
|
||||||
|
|
||||||
There are two common use cases:
|
There are two common use cases:
|
||||||
|
|
||||||
@@ -37,7 +39,6 @@ class Box(Space[np.ndarray]):
|
|||||||
|
|
||||||
>>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
|
>>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
|
||||||
Box(2,)
|
Box(2,)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -48,6 +49,23 @@ class Box(Space[np.ndarray]):
|
|||||||
dtype: Type = np.float32,
|
dtype: Type = np.float32,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
||||||
):
|
):
|
||||||
|
r"""Constructor of :class:`Box`.
|
||||||
|
|
||||||
|
The argument ``low`` specifies the lower bound of each dimension and ``high`` specifies the upper bounds.
|
||||||
|
I.e., the space that is constructed will be the product of the intervals :math:`[\text{low}[i], \text{high}[i]]`.
|
||||||
|
|
||||||
|
If ``low`` (or ``high``) is a scalar, the lower bound (or upper bound, respectively) will be assumed to be
|
||||||
|
this value across all dimensions.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
|
||||||
|
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
|
||||||
|
shape (Optional[Sequence[int]]): This only needs to be specified if both ``low`` and ``high`` are scalars and determines the shape of the space.
|
||||||
|
Otherwise, the shape is inferred from the shape of ``low`` or ``high``.
|
||||||
|
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
|
||||||
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||||
|
"""
|
||||||
assert dtype is not None, "dtype must be explicitly provided. "
|
assert dtype is not None, "dtype must be explicitly provided. "
|
||||||
self.dtype = np.dtype(dtype)
|
self.dtype = np.dtype(dtype)
|
||||||
|
|
||||||
@@ -99,6 +117,14 @@ class Box(Space[np.ndarray]):
|
|||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
def is_bounded(self, manner: str = "both") -> bool:
|
def is_bounded(self, manner: str = "both") -> bool:
|
||||||
|
"""Checks whether the box is bounded in some sense.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manner (str): One of ``"both"``, ``"below"``, ``"above"``.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `manner` is neither `"both"` nor `"below"`or `"above"`
|
||||||
|
"""
|
||||||
below = bool(np.all(self.bounded_below))
|
below = bool(np.all(self.bounded_below))
|
||||||
above = bool(np.all(self.bounded_above))
|
above = bool(np.all(self.bounded_above))
|
||||||
if manner == "both":
|
if manner == "both":
|
||||||
@@ -111,16 +137,15 @@ class Box(Space[np.ndarray]):
|
|||||||
raise ValueError("manner is not in {'below', 'above', 'both'}")
|
raise ValueError("manner is not in {'below', 'above', 'both'}")
|
||||||
|
|
||||||
def sample(self) -> np.ndarray:
|
def sample(self) -> np.ndarray:
|
||||||
"""
|
r"""Generates a single random sample inside the Box.
|
||||||
Generates a single random sample inside of the Box.
|
|
||||||
|
|
||||||
In creating a sample of the box, each coordinate is sampled according to
|
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
|
||||||
the form of the interval:
|
that is chosen according to the form of the interval:
|
||||||
|
|
||||||
* [a, b] : uniform distribution
|
* :math:`[a, b]` : uniform distribution
|
||||||
* [a, oo) : shifted exponential distribution
|
* :math:`[a, \infty)` : shifted exponential distribution
|
||||||
* (-oo, b] : shifted negative exponential distribution
|
* :math:`(-\infty, b]` : shifted negative exponential distribution
|
||||||
* (-oo, oo) : normal distribution
|
* :math:`(-\infty, \infty)` : normal distribution
|
||||||
"""
|
"""
|
||||||
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
||||||
sample = np.empty(self.shape)
|
sample = np.empty(self.shape)
|
||||||
@@ -154,6 +179,7 @@ class Box(Space[np.ndarray]):
|
|||||||
return sample.astype(self.dtype)
|
return sample.astype(self.dtype)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if not isinstance(x, np.ndarray):
|
if not isinstance(x, np.ndarray):
|
||||||
logger.warn("Casting input x to numpy array.")
|
logger.warn("Casting input x to numpy array.")
|
||||||
x = np.asarray(x, dtype=self.dtype)
|
x = np.asarray(x, dtype=self.dtype)
|
||||||
@@ -166,15 +192,23 @@ class Box(Space[np.ndarray]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n):
|
def to_jsonable(self, sample_n):
|
||||||
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return np.array(sample_n).tolist()
|
return np.array(sample_n).tolist()
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> list[np.ndarray]:
|
def from_jsonable(self, sample_n: Sequence[SupportsFloat]) -> list[np.ndarray]:
|
||||||
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [np.asarray(sample) for sample in sample_n]
|
return [np.asarray(sample) for sample in sample_n]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
"""A string representation of this space.
|
||||||
|
|
||||||
|
The representation will include bounds, shape and dtype.
|
||||||
|
If a bound is uniform, only the corresponding scalar will be given to avoid redundant and ugly strings.
|
||||||
|
"""
|
||||||
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
|
"""Check whether `other` is equivalent to this instance."""
|
||||||
return (
|
return (
|
||||||
isinstance(other, Box)
|
isinstance(other, Box)
|
||||||
and (self.shape == other.shape)
|
and (self.shape == other.shape)
|
||||||
@@ -185,8 +219,10 @@ class Box(Space[np.ndarray]):
|
|||||||
|
|
||||||
def get_inf(dtype, sign: str) -> SupportsFloat:
|
def get_inf(dtype, sign: str) -> SupportsFloat:
|
||||||
"""Returns an infinite that doesn't break things.
|
"""Returns an infinite that doesn't break things.
|
||||||
`dtype` must be an `np.dtype`
|
|
||||||
`bound` must be either `min` or `max`
|
Args:
|
||||||
|
dtype: An `np.dtype`
|
||||||
|
sign (str): must be either `"+"` or `"-"`
|
||||||
"""
|
"""
|
||||||
if np.dtype(dtype).kind == "f":
|
if np.dtype(dtype).kind == "f":
|
||||||
if sign == "+":
|
if sign == "+":
|
||||||
@@ -207,6 +243,7 @@ def get_inf(dtype, sign: str) -> SupportsFloat:
|
|||||||
|
|
||||||
|
|
||||||
def get_precision(dtype) -> SupportsFloat:
|
def get_precision(dtype) -> SupportsFloat:
|
||||||
|
"""Get precision of a data type."""
|
||||||
if np.issubdtype(dtype, np.floating):
|
if np.issubdtype(dtype, np.floating):
|
||||||
return np.finfo(dtype).precision
|
return np.finfo(dtype).precision
|
||||||
else:
|
else:
|
||||||
@@ -219,7 +256,7 @@ def _broadcast(
|
|||||||
shape: tuple[int, ...],
|
shape: tuple[int, ...],
|
||||||
inf_sign: str,
|
inf_sign: str,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""handle infinite bounds and broadcast at the same time if needed"""
|
"""Handle infinite bounds and broadcast at the same time if needed."""
|
||||||
if np.isscalar(value):
|
if np.isscalar(value):
|
||||||
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
|
value = get_inf(dtype, inf_sign) if np.isinf(value) else value # type: ignore
|
||||||
value = np.full(shape, value, dtype=dtype)
|
value = np.full(shape, value, dtype=dtype)
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@@ -12,43 +13,66 @@ from gym.utils import seeding
|
|||||||
|
|
||||||
|
|
||||||
class Dict(Space[TypingDict[str, Space]], Mapping):
|
class Dict(Space[TypingDict[str, Space]], Mapping):
|
||||||
"""
|
"""A dictionary of :class:`Space` instances.
|
||||||
A dictionary of simpler spaces.
|
|
||||||
|
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
self.observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
|
>>> observation_space = spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
|
||||||
|
>>> observation_space.sample()
|
||||||
|
OrderedDict([('position', 1), ('velocity', 2)])
|
||||||
|
|
||||||
Example usage [nested]::
|
Example usage [nested]::
|
||||||
|
|
||||||
self.nested_observation_space = spaces.Dict({
|
>>> spaces.Dict(
|
||||||
'sensors': spaces.Dict({
|
... {
|
||||||
'position': spaces.Box(low=-100, high=100, shape=(3,)),
|
... "ext_controller": spaces.MultiDiscrete((5, 2, 2)),
|
||||||
'velocity': spaces.Box(low=-1, high=1, shape=(3,)),
|
... "inner_state": spaces.Dict(
|
||||||
'front_cam': spaces.Tuple((
|
... {
|
||||||
spaces.Box(low=0, high=1, shape=(10, 10, 3)),
|
... "charge": spaces.Discrete(100),
|
||||||
spaces.Box(low=0, high=1, shape=(10, 10, 3))
|
... "system_checks": spaces.MultiBinary(10),
|
||||||
)),
|
... "job_status": spaces.Dict(
|
||||||
'rear_cam': spaces.Box(low=0, high=1, shape=(10, 10, 3)),
|
... {
|
||||||
}),
|
... "task": spaces.Discrete(5),
|
||||||
'ext_controller': spaces.MultiDiscrete((5, 2, 2)),
|
... "progress": spaces.Box(low=0, high=100, shape=()),
|
||||||
'inner_state':spaces.Dict({
|
... }
|
||||||
'charge': spaces.Discrete(100),
|
... ),
|
||||||
'system_checks': spaces.MultiBinary(10),
|
... }
|
||||||
'job_status': spaces.Dict({
|
... ),
|
||||||
'task': spaces.Discrete(5),
|
... }
|
||||||
'progress': spaces.Box(low=0, high=100, shape=()),
|
... )
|
||||||
})
|
|
||||||
})
|
It can be convenient to use :class:`Dict` spaces if you want to make complex observations or actions more human-readable.
|
||||||
})
|
Usually, it will be not be possible to use elements of this space directly in learning code. However, you can easily
|
||||||
|
convert `Dict` observations to flat arrays by using a :class:`gym.wrappers.FlattenObservation` wrapper. Similar wrappers can be
|
||||||
|
implemented to deal with :class:`Dict` actions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spaces: dict[str, Space] | None = None,
|
spaces: Optional[dict[str, Space]] = None,
|
||||||
seed: Optional[dict | int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[dict | int | seeding.RandomNumberGenerator] = None,
|
||||||
**spaces_kwargs: Space,
|
**spaces_kwargs: Space,
|
||||||
):
|
):
|
||||||
|
"""Constructor of :class:`Dict` space.
|
||||||
|
|
||||||
|
This space can be instantiated in one of two ways: Either you pass a dictionary
|
||||||
|
of spaces to :meth:`__init__` via the ``spaces`` argument, or you pass the spaces as separate
|
||||||
|
keyword arguments (where you will need to avoid the keys ``spaces`` and ``seed``)
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> spaces.Dict({"position": spaces.Box(-1, 1, shape=(2,)), "color": spaces.Discrete(3)})
|
||||||
|
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
||||||
|
>>> spaces.Dict(position=spaces.Box(-1, 1, shape=(2,)), color=spaces.Discrete(3))
|
||||||
|
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spaces: A dictionary of spaces. This specifies the structure of the :class:`Dict` space
|
||||||
|
seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
|
||||||
|
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
|
||||||
|
"""
|
||||||
assert (spaces is None) or (
|
assert (spaces is None) or (
|
||||||
not spaces_kwargs
|
not spaces_kwargs
|
||||||
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||||
@@ -75,6 +99,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
) # None for shape and dtype, since it'll require special handling
|
) # None for shape and dtype, since it'll require special handling
|
||||||
|
|
||||||
def seed(self, seed: Optional[dict | int] = None) -> list:
|
def seed(self, seed: Optional[dict | int] = None) -> list:
|
||||||
|
"""Seed the PRNG of this space and all subspaces."""
|
||||||
seeds = []
|
seeds = []
|
||||||
if isinstance(seed, dict):
|
if isinstance(seed, dict):
|
||||||
for key, seed_key in zip(self.spaces, seed):
|
for key, seed_key in zip(self.spaces, seed):
|
||||||
@@ -112,9 +137,14 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
return seeds
|
return seeds
|
||||||
|
|
||||||
def sample(self) -> dict:
|
def sample(self) -> dict:
|
||||||
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
|
The sample is an ordered dictionary of independent samples from the constituent spaces.
|
||||||
|
"""
|
||||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if not isinstance(x, dict) or len(x) != len(self.spaces):
|
if not isinstance(x, dict) or len(x) != len(self.spaces):
|
||||||
return False
|
return False
|
||||||
for k, space in self.spaces.items():
|
for k, space in self.spaces.items():
|
||||||
@@ -125,18 +155,23 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
"""Get the space that is associated to `key`."""
|
||||||
return self.spaces[key]
|
return self.spaces[key]
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
"""Set the space that is associated to `key`."""
|
||||||
self.spaces[key] = value
|
self.spaces[key] = value
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
"""Iterator through the keys of the subspaces."""
|
||||||
yield from self.spaces
|
yield from self.spaces
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
"""Gives the number of simpler spaces that make up the `Dict` space."""
|
||||||
return len(self.spaces)
|
return len(self.spaces)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
"""Gives a string representation of this space."""
|
||||||
return (
|
return (
|
||||||
"Dict("
|
"Dict("
|
||||||
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
|
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
|
||||||
@@ -144,6 +179,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: list) -> dict:
|
def to_jsonable(self, sample_n: list) -> dict:
|
||||||
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# serialize as dict-repr of vectors
|
# serialize as dict-repr of vectors
|
||||||
return {
|
return {
|
||||||
key: space.to_jsonable([sample[key] for sample in sample_n])
|
key: space.to_jsonable([sample[key] for sample in sample_n])
|
||||||
@@ -151,6 +187,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def from_jsonable(self, sample_n: dict[str, list]) -> list:
|
def from_jsonable(self, sample_n: dict[str, list]) -> list:
|
||||||
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
dict_of_list: dict[str, list] = {}
|
dict_of_list: dict[str, list] = {}
|
||||||
for key, space in self.spaces.items():
|
for key, space in self.spaces.items():
|
||||||
dict_of_list[key] = space.from_jsonable(sample_n[key])
|
dict_of_list[key] = space.from_jsonable(sample_n[key])
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Implementation of a space consisting of finitely many elements."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -9,16 +10,14 @@ from gym.utils import seeding
|
|||||||
|
|
||||||
|
|
||||||
class Discrete(Space[int]):
|
class Discrete(Space[int]):
|
||||||
r"""A discrete space in :math:`\{ 0, 1, \dots, n-1 \}`.
|
r"""A space consisting of finitely many elements.
|
||||||
|
|
||||||
A start value can be optionally specified to shift the range
|
This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`.
|
||||||
to :math:`\{ a, a+1, \dots, a+n-1 \}`.
|
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> Discrete(2) # {0, 1}
|
>>> Discrete(2) # {0, 1}
|
||||||
>>> Discrete(3, start=-1) # {-1, 0, 1}
|
>>> Discrete(3, start=-1) # {-1, 0, 1}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -27,6 +26,15 @@ class Discrete(Space[int]):
|
|||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
||||||
start: int = 0,
|
start: int = 0,
|
||||||
):
|
):
|
||||||
|
r"""Constructor of :class:`Discrete` space.
|
||||||
|
|
||||||
|
This will construct the space :math:`\{\text{start}, ..., \text{start} + n - 1\}`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of elements of this space.
|
||||||
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
|
||||||
|
start (int): The smallest element of this space.
|
||||||
|
"""
|
||||||
assert n > 0, "n (counts) have to be positive"
|
assert n > 0, "n (counts) have to be positive"
|
||||||
assert isinstance(start, (int, np.integer))
|
assert isinstance(start, (int, np.integer))
|
||||||
self.n = int(n)
|
self.n = int(n)
|
||||||
@@ -34,9 +42,14 @@ class Discrete(Space[int]):
|
|||||||
super().__init__((), np.int64, seed)
|
super().__init__((), np.int64, seed)
|
||||||
|
|
||||||
def sample(self) -> int:
|
def sample(self) -> int:
|
||||||
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
|
A sample will be chosen uniformly at random.
|
||||||
|
"""
|
||||||
return int(self.start + self.np_random.integers(self.n))
|
return int(self.start + self.np_random.integers(self.n))
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, int):
|
if isinstance(x, int):
|
||||||
as_int = x
|
as_int = x
|
||||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
elif isinstance(x, (np.generic, np.ndarray)) and (
|
||||||
@@ -48,11 +61,13 @@ class Discrete(Space[int]):
|
|||||||
return self.start <= as_int < self.start + self.n
|
return self.start <= as_int < self.start + self.n
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
"""Gives a string representation of this space."""
|
||||||
if self.start != 0:
|
if self.start != 0:
|
||||||
return "Discrete(%d, start=%d)" % (self.n, self.start)
|
return "Discrete(%d, start=%d)" % (self.n, self.start)
|
||||||
return "Discrete(%d)" % self.n
|
return "Discrete(%d)" % self.n
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return (
|
return (
|
||||||
isinstance(other, Discrete)
|
isinstance(other, Discrete)
|
||||||
and self.n == other.n
|
and self.n == other.n
|
||||||
@@ -60,6 +75,10 @@ class Discrete(Space[int]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
"""Used when loading a pickled space.
|
||||||
|
|
||||||
|
This method has to be implemented explicitly to allow for loading of legacy states.
|
||||||
|
"""
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
|
|
||||||
# Don't mutate the original state
|
# Don't mutate the original state
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Optional, Sequence, Union
|
||||||
@@ -9,10 +10,9 @@ from gym.utils import seeding
|
|||||||
|
|
||||||
|
|
||||||
class MultiBinary(Space[np.ndarray]):
|
class MultiBinary(Space[np.ndarray]):
|
||||||
"""
|
"""An n-shape binary space.
|
||||||
An n-shape binary space.
|
|
||||||
|
|
||||||
The argument to MultiBinary defines n, which could be a number or a ``list`` of numbers.
|
Elements of this space are binary arrays of a shape that is fixed during construction.
|
||||||
|
|
||||||
Example Usage::
|
Example Usage::
|
||||||
|
|
||||||
@@ -24,7 +24,6 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
array([[0, 0],
|
array([[0, 0],
|
||||||
[0, 1],
|
[0, 1],
|
||||||
[1, 1]], dtype=int8)
|
[1, 1]], dtype=int8)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -32,6 +31,13 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
n: Union[np.ndarray, Sequence[int], int],
|
n: Union[np.ndarray, Sequence[int], int],
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
||||||
):
|
):
|
||||||
|
"""Constructor of :class:`MultiBinary` space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: This will fix the shape of elements of the space. It can either be an integer (if the space is flat)
|
||||||
|
or some sort of sequence (tuple, list or np.ndarray) if there are multiple axes.
|
||||||
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||||
|
"""
|
||||||
if isinstance(n, (Sequence, np.ndarray)):
|
if isinstance(n, (Sequence, np.ndarray)):
|
||||||
self.n = input_n = tuple(int(i) for i in n)
|
self.n = input_n = tuple(int(i) for i in n)
|
||||||
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
|
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
|
||||||
@@ -48,9 +54,14 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
return self._shape # type: ignore
|
return self._shape # type: ignore
|
||||||
|
|
||||||
def sample(self) -> np.ndarray:
|
def sample(self) -> np.ndarray:
|
||||||
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
|
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||||
|
"""
|
||||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, Sequence):
|
if isinstance(x, Sequence):
|
||||||
x = np.array(x) # Promote list to array for contains check
|
x = np.array(x) # Promote list to array for contains check
|
||||||
if self.shape != x.shape:
|
if self.shape != x.shape:
|
||||||
@@ -58,13 +69,17 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
return ((x == 0) | (x == 1)).all()
|
return ((x == 0) | (x == 1)).all()
|
||||||
|
|
||||||
def to_jsonable(self, sample_n) -> list:
|
def to_jsonable(self, sample_n) -> list:
|
||||||
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return np.array(sample_n).tolist()
|
return np.array(sample_n).tolist()
|
||||||
|
|
||||||
def from_jsonable(self, sample_n) -> list:
|
def from_jsonable(self, sample_n) -> list:
|
||||||
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [np.asarray(sample) for sample in sample_n]
|
return [np.asarray(sample) for sample in sample_n]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
"""Gives a string representation of this space."""
|
||||||
return f"MultiBinary({self.n})"
|
return f"MultiBinary({self.n})"
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
|
"""Check whether `other` is equivalent to this instance."""
|
||||||
return isinstance(other, MultiBinary) and self.n == other.n
|
return isinstance(other, MultiBinary) and self.n == other.n
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
|
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable, Optional, Sequence
|
from typing import Iterable, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -11,11 +12,11 @@ from gym.utils import seeding
|
|||||||
|
|
||||||
|
|
||||||
class MultiDiscrete(Space[np.ndarray]):
|
class MultiDiscrete(Space[np.ndarray]):
|
||||||
"""
|
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
||||||
The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each. It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space. It is parametrized by passing an array of positive integers specifying number of actions for each discrete action space.
|
|
||||||
|
It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
||||||
Some environment wrappers assume a value of 0 always represents the NOOP action.
|
Some environment wrappers assume a value of 0 always represents the NOOP action.
|
||||||
|
|
||||||
e.g. Nintendo Game Controller - Can be conceptualized as 3 discrete action spaces:
|
e.g. Nintendo Game Controller - Can be conceptualized as 3 discrete action spaces:
|
||||||
@@ -30,12 +31,29 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
nvec: list[int],
|
nvec: Union[np.ndarray, list[int]],
|
||||||
dtype=np.int64,
|
dtype=np.int64,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""Constructor of :class:`MultiDiscrete` space.
|
||||||
nvec: vector of counts of each categorical variable
|
|
||||||
|
The argument ``nvec`` will determine the number of values each categorical variable can take.
|
||||||
|
|
||||||
|
Although this feature is rarely used, :class:`MultiDiscrete` spaces may also have several axes
|
||||||
|
if ``nvec`` has several axes:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
|
||||||
|
>> d.sample()
|
||||||
|
array([[0, 0],
|
||||||
|
[2, 3]])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nvec: vector of counts of each categorical variable. This will usually be a list of integers. However,
|
||||||
|
you may also pass a more complicated numpy array if you'd like the space to have several axes.
|
||||||
|
dtype: This should be some kind of integer type.
|
||||||
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||||
"""
|
"""
|
||||||
self.nvec = np.array(nvec, dtype=dtype, copy=True)
|
self.nvec = np.array(nvec, dtype=dtype, copy=True)
|
||||||
assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
|
assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
|
||||||
@@ -44,13 +62,15 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
"""Has stricter type than gym.Space - never None."""
|
"""Has stricter type than :class:`gym.Space` - never None."""
|
||||||
return self._shape # type: ignore
|
return self._shape # type: ignore
|
||||||
|
|
||||||
def sample(self) -> np.ndarray:
|
def sample(self) -> np.ndarray:
|
||||||
|
"""Generates a single random sample this space."""
|
||||||
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, Sequence):
|
if isinstance(x, Sequence):
|
||||||
x = np.array(x) # Promote list to array for contains check
|
x = np.array(x) # Promote list to array for contains check
|
||||||
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
|
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
|
||||||
@@ -58,15 +78,19 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all())
|
return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all())
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: Iterable[np.ndarray]):
|
def to_jsonable(self, sample_n: Iterable[np.ndarray]):
|
||||||
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
return [sample.tolist() for sample in sample_n]
|
return [sample.tolist() for sample in sample_n]
|
||||||
|
|
||||||
def from_jsonable(self, sample_n):
|
def from_jsonable(self, sample_n):
|
||||||
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return np.array(sample_n)
|
return np.array(sample_n)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
"""Gives a string representation of this space."""
|
||||||
return f"MultiDiscrete({self.nvec})"
|
return f"MultiDiscrete({self.nvec})"
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
"""Extract a subspace from this ``MultiDiscrete`` space."""
|
||||||
nvec = self.nvec[index]
|
nvec = self.nvec[index]
|
||||||
if nvec.ndim == 0:
|
if nvec.ndim == 0:
|
||||||
subspace = Discrete(nvec)
|
subspace = Discrete(nvec)
|
||||||
@@ -76,9 +100,11 @@ class MultiDiscrete(Space[np.ndarray]):
|
|||||||
return subspace
|
return subspace
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
"""Gives the ``len`` of samples from this space."""
|
||||||
if self.nvec.ndim >= 2:
|
if self.nvec.ndim >= 2:
|
||||||
logger.warn("Get length of a multi-dimensional MultiDiscrete space.")
|
logger.warn("Get length of a multi-dimensional MultiDiscrete space.")
|
||||||
return len(self.nvec)
|
return len(self.nvec)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Implementation of the `Space` metaclass."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Generic, Iterable, Mapping, Optional, Sequence, Type, TypeVar
|
from typing import Generic, Iterable, Mapping, Optional, Sequence, Type, TypeVar
|
||||||
@@ -10,15 +11,24 @@ T_cov = TypeVar("T_cov", covariant=True)
|
|||||||
|
|
||||||
|
|
||||||
class Space(Generic[T_cov]):
|
class Space(Generic[T_cov]):
|
||||||
"""Defines the observation and action spaces, so you can write generic
|
"""Superclass that is used to define observation and action spaces.
|
||||||
code that applies to any Env. For example, you can choose a random
|
|
||||||
action.
|
|
||||||
|
|
||||||
WARNING - Custom observation & action spaces can inherit from the `Space`
|
Spaces are crucially used in Gym to define the format of valid actions and observations.
|
||||||
|
They serve various purposes:
|
||||||
|
|
||||||
|
* They clearly define how to interact with environments, i.e. they specify what actions need to look like and what observations will look like
|
||||||
|
* They allow us to work with highly structured data (e.g. in the form of elements of :class:`Dict` spaces) and painlessly transform them into flat arrays that can be used in learning code
|
||||||
|
* They provide a method to sample random elements. This is especially useful for exploration and debugging.
|
||||||
|
|
||||||
|
Different spaces can be combined hierarchically via container spaces (:class:`Tuple` and :class:`Dict`) to build a
|
||||||
|
more expressive space
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
Custom observation & action spaces can inherit from the ``Space``
|
||||||
class. However, most use-cases should be covered by the existing space
|
class. However, most use-cases should be covered by the existing space
|
||||||
classes (e.g. `Box`, `Discrete`, etc...), and container classes (`Tuple` &
|
classes (e.g. :class:`Box`, :class:`Discrete`, etc...), and container classes (:class`Tuple` &
|
||||||
`Dict`). Note that parametrized probability distributions (through the
|
:class:`Dict`). Note that parametrized probability distributions (through the
|
||||||
`sample()` method), and batching functions (in `gym.vector.VectorEnv`), are
|
:meth:`Space.sample()` method), and batching functions (in :class:`gym.vector.VectorEnv`), are
|
||||||
only well-defined for instances of spaces provided in gym by default.
|
only well-defined for instances of spaces provided in gym by default.
|
||||||
Moreover, some implementations of Reinforcement Learning algorithms might
|
Moreover, some implementations of Reinforcement Learning algorithms might
|
||||||
not handle custom spaces properly. Use custom spaces with care.
|
not handle custom spaces properly. Use custom spaces with care.
|
||||||
@@ -30,6 +40,13 @@ class Space(Generic[T_cov]):
|
|||||||
dtype: Optional[Type | str] = None,
|
dtype: Optional[Type | str] = None,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
||||||
):
|
):
|
||||||
|
"""Constructor of :class:`Space`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape (Optional[Sequence[int]]): If elements of the space are numpy arrays, this should specify their shape.
|
||||||
|
dtype (Optional[Type | str]): If elements of the space are numpy arrays, this should specify their dtype.
|
||||||
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space
|
||||||
|
"""
|
||||||
self._shape = None if shape is None else tuple(shape)
|
self._shape = None if shape is None else tuple(shape)
|
||||||
self.dtype = None if dtype is None else np.dtype(dtype)
|
self.dtype = None if dtype is None else np.dtype(dtype)
|
||||||
self._np_random = None
|
self._np_random = None
|
||||||
@@ -41,9 +58,7 @@ class Space(Generic[T_cov]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def np_random(self) -> seeding.RandomNumberGenerator:
|
def np_random(self) -> seeding.RandomNumberGenerator:
|
||||||
"""Lazily seed the rng since this is expensive and only needed if
|
"""Lazily seed the PRNG since this is expensive and only needed if sampling from this space."""
|
||||||
sampling from this space.
|
|
||||||
"""
|
|
||||||
if self._np_random is None:
|
if self._np_random is None:
|
||||||
self.seed()
|
self.seed()
|
||||||
|
|
||||||
@@ -51,30 +66,31 @@ class Space(Generic[T_cov]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Optional[tuple[int, ...]]:
|
def shape(self) -> Optional[tuple[int, ...]]:
|
||||||
"""Return the shape of the space as an immutable property"""
|
"""Return the shape of the space as an immutable property."""
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
def sample(self) -> T_cov:
|
def sample(self) -> T_cov:
|
||||||
"""Randomly sample an element of this space. Can be
|
"""Randomly sample an element of this space. Can be uniform or non-uniform sampling based on boundedness of space."""
|
||||||
uniform or non-uniform sampling based on boundedness of space."""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> list:
|
def seed(self, seed: Optional[int] = None) -> list:
|
||||||
"""Seed the PRNG of this space."""
|
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
|
||||||
self._np_random, seed = seeding.np_random(seed)
|
self._np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
"""
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
Return boolean specifying if x is a valid
|
|
||||||
member of this space
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __contains__(self, x) -> bool:
|
def __contains__(self, x) -> bool:
|
||||||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
return self.contains(x)
|
return self.contains(x)
|
||||||
|
|
||||||
def __setstate__(self, state: Iterable | Mapping):
|
def __setstate__(self, state: Iterable | Mapping):
|
||||||
|
"""Used when loading a pickled space.
|
||||||
|
|
||||||
|
This method was implemented explicitly to allow for loading of legacy states.
|
||||||
|
"""
|
||||||
# Don't mutate the original state
|
# Don't mutate the original state
|
||||||
state = dict(state)
|
state = dict(state)
|
||||||
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable, Optional, Sequence
|
from typing import Iterable, Optional, Sequence
|
||||||
@@ -9,12 +10,15 @@ from gym.utils import seeding
|
|||||||
|
|
||||||
|
|
||||||
class Tuple(Space[tuple], Sequence):
|
class Tuple(Space[tuple], Sequence):
|
||||||
"""
|
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
|
||||||
A tuple (i.e., product) of simpler spaces
|
|
||||||
|
Elements of this space are tuples of elements of the constituent spaces.
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
>> observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Box(-1, 1, shape=(2,))))
|
||||||
|
>> observation_space.sample()
|
||||||
|
(0, array([0.03633198, 0.42370757], dtype=float32))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -22,6 +26,14 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
spaces: Iterable[Space],
|
spaces: Iterable[Space],
|
||||||
seed: Optional[int | list[int] | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | list[int] | seeding.RandomNumberGenerator] = None,
|
||||||
):
|
):
|
||||||
|
r"""Constructor of :class:`Tuple`` space.
|
||||||
|
|
||||||
|
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
|
||||||
|
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
|
||||||
|
"""
|
||||||
spaces = tuple(spaces)
|
spaces = tuple(spaces)
|
||||||
self.spaces = spaces
|
self.spaces = spaces
|
||||||
for space in spaces:
|
for space in spaces:
|
||||||
@@ -31,6 +43,7 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
super().__init__(None, None, seed) # type: ignore
|
super().__init__(None, None, seed) # type: ignore
|
||||||
|
|
||||||
def seed(self, seed: Optional[int | list[int]] = None) -> list:
|
def seed(self, seed: Optional[int | list[int]] = None) -> list:
|
||||||
|
"""Seed the PRNG of this space and all subspaces."""
|
||||||
seeds = []
|
seeds = []
|
||||||
|
|
||||||
if isinstance(seed, list):
|
if isinstance(seed, list):
|
||||||
@@ -62,9 +75,14 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
return seeds
|
return seeds
|
||||||
|
|
||||||
def sample(self) -> tuple:
|
def sample(self) -> tuple:
|
||||||
|
"""Generates a single random sample inside this space.
|
||||||
|
|
||||||
|
This method draws independent samples from the subspaces.
|
||||||
|
"""
|
||||||
return tuple(space.sample() for space in self.spaces)
|
return tuple(space.sample() for space in self.spaces)
|
||||||
|
|
||||||
def contains(self, x) -> bool:
|
def contains(self, x) -> bool:
|
||||||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||||||
if isinstance(x, (list, np.ndarray)):
|
if isinstance(x, (list, np.ndarray)):
|
||||||
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
x = tuple(x) # Promote list and ndarray to tuple for contains check
|
||||||
return (
|
return (
|
||||||
@@ -74,9 +92,11 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
"""Gives a string representation of this space."""
|
||||||
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: Sequence) -> list:
|
def to_jsonable(self, sample_n: Sequence) -> list:
|
||||||
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
# serialize as list-repr of tuple of vectors
|
# serialize as list-repr of tuple of vectors
|
||||||
return [
|
return [
|
||||||
space.to_jsonable([sample[i] for sample in sample_n])
|
space.to_jsonable([sample[i] for sample in sample_n])
|
||||||
@@ -84,6 +104,7 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def from_jsonable(self, sample_n) -> list:
|
def from_jsonable(self, sample_n) -> list:
|
||||||
|
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||||
return [
|
return [
|
||||||
sample
|
sample
|
||||||
for sample in zip(
|
for sample in zip(
|
||||||
@@ -95,10 +116,13 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Space:
|
def __getitem__(self, index: int) -> Space:
|
||||||
|
"""Get the subspace at specific `index`."""
|
||||||
return self.spaces[index]
|
return self.spaces[index]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
"""Get the number of subspaces that are involved in the cartesian product."""
|
||||||
return len(self.spaces)
|
return len(self.spaces)
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
|
"""Check whether ``other`` is equivalent to this instance."""
|
||||||
return isinstance(other, Tuple) and self.spaces == other.spaces
|
return isinstance(other, Tuple) and self.spaces == other.spaces
|
||||||
|
@@ -1,3 +1,7 @@
|
|||||||
|
"""Implementation of utility functions that can be applied to spaces.
|
||||||
|
|
||||||
|
These functions mostly take care of flattening and unflattening elements of spaces to facilitate their usage in learning code.
|
||||||
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import operator as op
|
import operator as op
|
||||||
@@ -12,11 +16,12 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, T
|
|||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def flatdim(space: Space) -> int:
|
def flatdim(space: Space) -> int:
|
||||||
"""Return the number of dimensions a flattened equivalent of this space
|
"""Return the number of dimensions a flattened equivalent of this space would have.
|
||||||
would have.
|
|
||||||
|
|
||||||
Accepts a space and returns an integer. Raises ``NotImplementedError`` if
|
Accepts a space and returns an integer.
|
||||||
the space is not defined in ``gym.spaces``.
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user