Files
Gymnasium/gym/spaces/multi_binary.py
Mark Towers 024b0f5160 Added Action masking for Space.sample() (#2906)
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9117.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Add support for action masking in Space.sample(mask=...)

* Fix action mask

* Fix action_mask

* Fix action_mask

* Added docstrings, fixed bugs and added taxi examples

* Fixed bugs

* Add tests for sample

* Add docstrings and test space sample mask Discrete and MultiBinary

* Add MultiDiscrete sampling and tests

* Remove sample mask from graph

* Update gym/spaces/multi_discrete.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Updates based on Marcus28 and jjshoots for Graph.py

* Updates based on Marcus28 and jjshoots for Graph.py

* jjshoot review

* jjshoot review

* Update assert check

* Update type hints

Co-authored-by: Markus Krimmel <montcyril@gmail.com>
2022-06-26 18:23:15 -04:00

109 lines
4.1 KiB
Python

"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
from typing import Optional, Sequence, Tuple, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class MultiBinary(Space[np.ndarray]):
"""An n-shape binary space.
Elements of this space are binary arrays of a shape that is fixed during construction.
Example Usage::
>>> observation_space = MultiBinary(5)
>>> observation_space.sample()
array([0, 1, 0, 1, 0], dtype=int8)
>>> observation_space = MultiBinary([3, 2])
>>> observation_space.sample()
array([[0, 0],
[0, 1],
[1, 1]], dtype=int8)
"""
def __init__(
self,
n: Union[np.ndarray, Sequence[int], int],
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
):
"""Constructor of :class:`MultiBinary` space.
Args:
n: This will fix the shape of elements of the space. It can either be an integer (if the space is flat)
or some sort of sequence (tuple, list or np.ndarray) if there are multiple axes.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
if isinstance(n, (Sequence, np.ndarray)):
self.n = input_n = tuple(int(i) for i in n)
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
else:
self.n = n = int(n)
input_n = (n,)
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
super().__init__(input_n, np.int8, seed)
@property
def shape(self) -> Tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
def sample(self, mask: Optional[np.ndarray] = None) -> np.ndarray:
"""Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
Where mask == 0 then the samples will be 0.
Returns:
Sampled values from space
"""
if mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
assert (
mask.shape == self.shape
), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}"
assert np.all(
np.logical_or(mask == 0, mask == 1)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
return mask * self.np_random.integers(
low=0, high=2, size=self.n, dtype=self.dtype
)
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape:
return False
return ((x == 0) | (x == 1)).all()
def to_jsonable(self, sample_n) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n]
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return f"MultiBinary({self.n})"
def __eq__(self, other) -> bool:
"""Check whether `other` is equivalent to this instance."""
return isinstance(other, MultiBinary) and self.n == other.n