2022-05-10 17:18:06 +02:00
|
|
|
"""Implementation of a space consisting of finitely many elements."""
|
2022-11-15 14:09:22 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from typing import Any, Iterable, Mapping
|
2022-01-24 23:22:11 +01:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
import numpy as np
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
from gymnasium.spaces.space import MaskNDArray, Space
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2019-01-30 22:39:55 +01:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
class Discrete(Space[int]):
|
2022-05-10 17:18:06 +02:00
|
|
|
r"""A space consisting of finitely many elements.
|
2019-07-12 14:08:54 -07:00
|
|
|
|
2022-05-10 17:18:06 +02:00
|
|
|
This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`.
|
2021-10-30 21:42:01 +05:30
|
|
|
|
2019-03-25 00:42:53 +01:00
|
|
|
Example::
|
2019-07-12 14:08:54 -07:00
|
|
|
|
2022-03-02 11:14:59 -05:00
|
|
|
>>> Discrete(2) # {0, 1}
|
2021-10-30 21:42:01 +05:30
|
|
|
>>> Discrete(3, start=-1) # {-1, 0, 1}
|
2016-04-27 08:00:58 -07:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-04-24 17:14:33 +01:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
n: int,
|
2022-11-15 14:09:22 +00:00
|
|
|
seed: int | np.random.Generator | None = None,
|
2022-04-24 17:14:33 +01:00
|
|
|
start: int = 0,
|
|
|
|
):
|
2022-05-10 17:18:06 +02:00
|
|
|
r"""Constructor of :class:`Discrete` space.
|
|
|
|
|
|
|
|
This will construct the space :math:`\{\text{start}, ..., \text{start} + n - 1\}`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (int): The number of elements of this space.
|
|
|
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
|
|
|
|
start (int): The smallest element of this space.
|
|
|
|
"""
|
2022-06-06 16:21:45 +01:00
|
|
|
assert isinstance(n, (int, np.integer))
|
2021-12-16 13:45:37 +08:00
|
|
|
assert n > 0, "n (counts) have to be positive"
|
|
|
|
assert isinstance(start, (int, np.integer))
|
|
|
|
self.n = int(n)
|
2021-10-30 21:42:01 +05:30
|
|
|
self.start = int(start)
|
2021-11-14 14:50:23 +01:00
|
|
|
super().__init__((), np.int64, seed)
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-08-15 17:11:32 +02:00
|
|
|
@property
|
|
|
|
def is_np_flattenable(self):
|
|
|
|
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
|
|
|
return True
|
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def sample(self, mask: MaskNDArray | None = None) -> int:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Generates a single random sample from this space.
|
|
|
|
|
2022-06-26 23:23:15 +01:00
|
|
|
A sample will be chosen uniformly at random with the mask if provided
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mask: An optional mask for if an action can be selected.
|
|
|
|
Expected `np.ndarray` of shape `(n,)` and dtype `np.int8` where `1` represents valid actions and `0` invalid / infeasible actions.
|
|
|
|
If there are no possible actions (i.e. `np.all(mask == 0)`) then `space.start` will be returned.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A sampled integer from the space
|
2022-05-10 17:18:06 +02:00
|
|
|
"""
|
2022-06-26 23:23:15 +01:00
|
|
|
if mask is not None:
|
|
|
|
assert isinstance(
|
|
|
|
mask, np.ndarray
|
|
|
|
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
|
|
|
|
assert (
|
|
|
|
mask.dtype == np.int8
|
|
|
|
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
|
|
|
|
assert mask.shape == (
|
|
|
|
self.n,
|
|
|
|
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
|
|
|
|
valid_action_mask = mask == 1
|
|
|
|
assert np.all(
|
|
|
|
np.logical_or(mask == 0, valid_action_mask)
|
|
|
|
), f"All values of a mask should be 0 or 1, actual values: {mask}"
|
|
|
|
if np.any(valid_action_mask):
|
|
|
|
return int(
|
|
|
|
self.start + self.np_random.choice(np.where(valid_action_mask)[0])
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return self.start
|
|
|
|
|
2022-03-14 14:27:03 +00:00
|
|
|
return int(self.start + self.np_random.integers(self.n))
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def contains(self, x: Any) -> bool:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Return boolean specifying if x is a valid member of this space."""
|
2016-04-27 18:31:32 -07:00
|
|
|
if isinstance(x, int):
|
|
|
|
as_int = x
|
2021-07-29 15:39:42 -04:00
|
|
|
elif isinstance(x, (np.generic, np.ndarray)) and (
|
2022-09-03 22:56:29 +01:00
|
|
|
np.issubdtype(x.dtype, np.integer) and x.shape == ()
|
2021-07-29 15:39:42 -04:00
|
|
|
):
|
2022-11-15 14:09:22 +00:00
|
|
|
as_int = int(x)
|
2016-04-27 18:31:32 -07:00
|
|
|
else:
|
|
|
|
return False
|
2022-09-03 22:56:29 +01:00
|
|
|
|
2021-10-30 21:42:01 +05:30
|
|
|
return self.start <= as_int < self.start + self.n
|
2018-08-27 15:30:47 -07:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __repr__(self) -> str:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Gives a string representation of this space."""
|
2021-10-30 21:42:01 +05:30
|
|
|
if self.start != 0:
|
2022-09-03 22:56:29 +01:00
|
|
|
return f"Discrete({self.n}, start={self.start})"
|
|
|
|
return f"Discrete({self.n})"
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def __eq__(self, other: Any) -> bool:
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Check whether ``other`` is equivalent to this instance."""
|
2021-10-30 21:42:01 +05:30
|
|
|
return (
|
|
|
|
isinstance(other, Discrete)
|
|
|
|
and self.n == other.n
|
|
|
|
and self.start == other.start
|
|
|
|
)
|
2022-03-02 11:14:59 -05:00
|
|
|
|
2022-11-15 14:09:22 +00:00
|
|
|
def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
|
2022-05-10 17:18:06 +02:00
|
|
|
"""Used when loading a pickled space.
|
|
|
|
|
|
|
|
This method has to be implemented explicitly to allow for loading of legacy states.
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
state: The new state
|
2022-05-10 17:18:06 +02:00
|
|
|
"""
|
2022-03-02 11:14:59 -05:00
|
|
|
# Don't mutate the original state
|
|
|
|
state = dict(state)
|
|
|
|
|
|
|
|
# Allow for loading of legacy states.
|
|
|
|
# See https://github.com/openai/gym/pull/2470
|
|
|
|
if "start" not in state:
|
|
|
|
state["start"] = 0
|
|
|
|
|
2022-09-03 22:56:29 +01:00
|
|
|
super().__setstate__(state)
|