Added Action masking for Space.sample() (#2906)

* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9117.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Add support for action masking in Space.sample(mask=...)

* Fix action mask

* Fix action_mask

* Fix action_mask

* Added docstrings, fixed bugs and added taxi examples

* Fixed bugs

* Add tests for sample

* Add docstrings and test space sample mask Discrete and MultiBinary

* Add MultiDiscrete sampling and tests

* Remove sample mask from graph

* Update gym/spaces/multi_discrete.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Updates based on Marcus28 and jjshoots for Graph.py

* Updates based on Marcus28 and jjshoots for Graph.py

* jjshoot review

* jjshoot review

* Update assert check

* Update type hints

Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
Mark Towers
2022-06-26 23:23:15 +01:00
committed by GitHub
parent d750eb8df0
commit 024b0f5160
11 changed files with 562 additions and 71 deletions

View File

@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Sequence, SupportsFloat, Tuple, Type, U
import numpy as np
import gym.error
from gym import logger
from gym.spaces.space import Space
from gym.utils import seeding
@@ -146,7 +147,7 @@ class Box(Space[np.ndarray]):
else:
raise ValueError("manner is not in {'below', 'above', 'both'}")
def sample(self) -> np.ndarray:
def sample(self, mask: None = None) -> np.ndarray:
r"""Generates a single random sample inside the Box.
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
@@ -157,9 +158,17 @@ class Box(Space[np.ndarray]):
* :math:`(-\infty, b]` : shifted negative exponential distribution
* :math:`(-\infty, \infty)` : normal distribution
Args:
mask: A mask for sampling values from the Box space, currently unsupported.
Returns:
A sampled value from the Box
"""
if mask is not None:
raise gym.error.Error(
f"Box.sample cannot be provided a mask, actual value: {mask}"
)
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape)