mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-19 07:26:35 +00:00
Added Action masking for Space.sample() (#2906)
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset
* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"
This reverts commit 519dfd9117
.
* Remove bad pushed commits
* Fixed spelling in core.py
* Pins pytest to the last py 3.6 version
* Add support for action masking in Space.sample(mask=...)
* Fix action mask
* Fix action_mask
* Fix action_mask
* Added docstrings, fixed bugs and added taxi examples
* Fixed bugs
* Add tests for sample
* Add docstrings and test space sample mask Discrete and MultiBinary
* Add MultiDiscrete sampling and tests
* Remove sample mask from graph
* Update gym/spaces/multi_discrete.py
Co-authored-by: Markus Krimmel <montcyril@gmail.com>
* Updates based on Marcus28 and jjshoots for Graph.py
* Updates based on Marcus28 and jjshoots for Graph.py
* jjshoot review
* jjshoot review
* Update assert check
* Update type hints
Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Implementation of the `Space` metaclass."""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
@@ -81,8 +82,17 @@ class Space(Generic[T_cov]):
|
||||
"""Return the shape of the space as an immutable property."""
|
||||
return self._shape
|
||||
|
||||
def sample(self) -> T_cov:
|
||||
"""Randomly sample an element of this space. Can be uniform or non-uniform sampling based on boundedness of space."""
|
||||
def sample(self, mask: Optional[Any] = None) -> T_cov:
|
||||
"""Randomly sample an element of this space.
|
||||
|
||||
Can be uniform or non-uniform sampling based on boundedness of space.
|
||||
|
||||
Args:
|
||||
mask: A mask used for sampling, expected ``dtype=np.int8`` and see sample implementation for expected shape.
|
||||
|
||||
Returns:
|
||||
A sampled actions from the space
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> list:
|
||||
|
Reference in New Issue
Block a user