Added Action masking for Space.sample() (#2906)

* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9117.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Add support for action masking in Space.sample(mask=...)

* Fix action mask

* Fix action_mask

* Fix action_mask

* Added docstrings, fixed bugs and added taxi examples

* Fixed bugs

* Add tests for sample

* Add docstrings and test space sample mask Discrete and MultiBinary

* Add MultiDiscrete sampling and tests

* Remove sample mask from graph

* Update gym/spaces/multi_discrete.py

Co-authored-by: Markus Krimmel <montcyril@gmail.com>

* Updates based on Marcus28 and jjshoots for Graph.py

* Updates based on Marcus28 and jjshoots for Graph.py

* jjshoot review

* jjshoot review

* Update assert check

* Update type hints

Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
Mark Towers
2022-06-26 23:23:15 +01:00
committed by GitHub
parent d750eb8df0
commit 024b0f5160
11 changed files with 562 additions and 71 deletions

View File

@@ -1,6 +1,7 @@
"""Implementation of the `Space` metaclass."""
from typing import (
Any,
Generic,
Iterable,
List,
@@ -81,8 +82,17 @@ class Space(Generic[T_cov]):
"""Return the shape of the space as an immutable property."""
return self._shape
def sample(self) -> T_cov:
"""Randomly sample an element of this space. Can be uniform or non-uniform sampling based on boundedness of space."""
def sample(self, mask: Optional[Any] = None) -> T_cov:
"""Randomly sample an element of this space.
Can be uniform or non-uniform sampling based on boundedness of space.
Args:
mask: A mask used for sampling, expected ``dtype=np.int8`` and see sample implementation for expected shape.
Returns:
A sampled actions from the space
"""
raise NotImplementedError
def seed(self, seed: Optional[int] = None) -> list: