Files
Gymnasium/gym/spaces/multi_binary.py
Markus Krimmel 745e7059e7 Pydocstyle spaces docstring (#2798)
* Added docstrings for spaces, WIP

* Formatting changes

* Use raw docstring for Box.sample

* Formatting fix

* Formatting fix

* Use :class:, :meth:, formatting fixes, resolve TODO, use Optional
2022-05-10 11:18:06 -04:00

86 lines
3.1 KiB
Python

"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
from __future__ import annotations
from typing import Optional, Sequence, Union
import numpy as np
from gym.spaces.space import Space
from gym.utils import seeding
class MultiBinary(Space[np.ndarray]):
"""An n-shape binary space.
Elements of this space are binary arrays of a shape that is fixed during construction.
Example Usage::
>>> self.observation_space = spaces.MultiBinary(5)
>>> self.observation_space.sample()
array([0, 1, 0, 1, 0], dtype=int8)
>>> self.observation_space = spaces.MultiBinary([3, 2])
>>> self.observation_space.sample()
array([[0, 0],
[0, 1],
[1, 1]], dtype=int8)
"""
def __init__(
self,
n: Union[np.ndarray, Sequence[int], int],
seed: Optional[int | seeding.RandomNumberGenerator] = None,
):
"""Constructor of :class:`MultiBinary` space.
Args:
n: This will fix the shape of elements of the space. It can either be an integer (if the space is flat)
or some sort of sequence (tuple, list or np.ndarray) if there are multiple axes.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
if isinstance(n, (Sequence, np.ndarray)):
self.n = input_n = tuple(int(i) for i in n)
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
else:
self.n = n = int(n)
input_n = (n,)
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
super().__init__(input_n, np.int8, seed)
@property
def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None."""
return self._shape # type: ignore
def sample(self) -> np.ndarray:
"""Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
"""
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape:
return False
return ((x == 0) | (x == 1)).all()
def to_jsonable(self, sample_n) -> list:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n) -> list:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample) for sample in sample_n]
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return f"MultiBinary({self.n})"
def __eq__(self, other) -> bool:
"""Check whether `other` is equivalent to this instance."""
return isinstance(other, MultiBinary) and self.n == other.n