Files
Gymnasium/gym/spaces/discrete.py
Mark Towers 024b0f5160 Added Action masking for Space.sample() (#2906)
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9117.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Add support for action masking in Space.sample(mask=...)

* Fix action mask

* Fix action_mask

* Fix action_mask

* Added docstrings, fixed bugs and added taxi examples

* Fixed bugs

* Add tests for sample

* Add docstrings and test space sample mask Discrete and MultiBinary

* Add MultiDiscrete sampling and tests

* Remove sample mask from graph

* Update gym/spaces/multi_discrete.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Updates based on Marcus28 and jjshoots for Graph.py

* Updates based on Marcus28 and jjshoots for Graph.py

* jjshoot review

* jjshoot review

* Update assert check

* Update type hints

Co-authored-by: Markus Krimmel <montcyril@gmail.com>
2022-06-26 18:23:15 -04:00

125 lines
4.3 KiB
Python

"""Implementation of a space consisting of finitely many elements."""
from typing import Optional, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class Discrete(Space[int]):
r"""A space consisting of finitely many elements.
This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`.
Example::
>>> Discrete(2) # {0, 1}
>>> Discrete(3, start=-1) # {-1, 0, 1}
"""
def __init__(
self,
n: int,
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
start: int = 0,
):
r"""Constructor of :class:`Discrete` space.
This will construct the space :math:`\{\text{start}, ..., \text{start} + n - 1\}`.
Args:
n (int): The number of elements of this space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
start (int): The smallest element of this space.
"""
assert isinstance(n, (int, np.integer))
assert n > 0, "n (counts) have to be positive"
assert isinstance(start, (int, np.integer))
self.n = int(n)
self.start = int(start)
super().__init__((), np.int64, seed)
def sample(self, mask: Optional[np.ndarray] = None) -> int:
"""Generates a single random sample from this space.
A sample will be chosen uniformly at random with the mask if provided
Args:
mask: An optional mask for if an action can be selected.
Expected `np.ndarray` of shape `(n,)` and dtype `np.int8` where `1` represents valid actions and `0` invalid / infeasible actions.
If there are no possible actions (i.e. `np.all(mask == 0)`) then `space.start` will be returned.
Returns:
A sampled integer from the space
"""
if mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
assert mask.shape == (
self.n,
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
valid_action_mask = mask == 1
assert np.all(
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
if np.any(valid_action_mask):
return int(
self.start + self.np_random.choice(np.where(valid_action_mask)[0])
)
else:
return self.start
return int(self.start + self.np_random.integers(self.n))
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, int):
as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
):
as_int = int(x) # type: ignore
else:
return False
return self.start <= as_int < self.start + self.n
def __repr__(self) -> str:
"""Gives a string representation of this space."""
if self.start != 0:
return "Discrete(%d, start=%d)" % (self.n, self.start)
return "Discrete(%d)" % self.n
def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return (
isinstance(other, Discrete)
and self.n == other.n
and self.start == other.start
)
def __setstate__(self, state):
"""Used when loading a pickled space.
This method has to be implemented explicitly to allow for loading of legacy states.
Args:
state: The new state
"""
super().__setstate__(state)
# Don't mutate the original state
state = dict(state)
# Allow for loading of legacy states.
# See https://github.com/openai/gym/pull/2470
if "start" not in state:
state["start"] = 0
# Update our state
self.__dict__.update(state)