mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-17 23:10:05 +00:00
Added Action masking for Space.sample() (#2906)
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset
* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"
This reverts commit 519dfd9117
.
* Remove bad pushed commits
* Fixed spelling in core.py
* Pins pytest to the last py 3.6 version
* Add support for action masking in Space.sample(mask=...)
* Fix action mask
* Fix action_mask
* Fix action_mask
* Added docstrings, fixed bugs and added taxi examples
* Fixed bugs
* Add tests for sample
* Add docstrings and test space sample mask Discrete and MultiBinary
* Add MultiDiscrete sampling and tests
* Remove sample mask from graph
* Update gym/spaces/multi_discrete.py
Co-authored-by: Markus Krimmel <montcyril@gmail.com>
* Updates based on Marcus28 and jjshoots for Graph.py
* Updates based on Marcus28 and jjshoots for Graph.py
* jjshoot review
* jjshoot review
* Update assert check
* Update type hints
Co-authored-by: Markus Krimmel <montcyril@gmail.com>
This commit is contained in:
@@ -8,6 +8,8 @@ from gym.spaces.discrete import Discrete
|
||||
from gym.spaces.space import Space
|
||||
from gym.utils import seeding
|
||||
|
||||
SAMPLE_MASK_TYPE = Tuple[Union["SAMPLE_MASK_TYPE", np.ndarray], ...]
|
||||
|
||||
|
||||
class MultiDiscrete(Space[np.ndarray]):
|
||||
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.
|
||||
@@ -23,8 +25,17 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
2. Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
|
||||
3. Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
|
||||
|
||||
It can be initialized as ``MultiDiscrete([ 5, 2, 2 ])``
|
||||
It can be initialized as ``MultiDiscrete([ 5, 2, 2 ])`` such that a sample might be ``array([3, 1, 0])``.
|
||||
|
||||
Although this feature is rarely used, :class:`MultiDiscrete` spaces may also have several axes
|
||||
if ``nvec`` has several axes:
|
||||
|
||||
Example::
|
||||
|
||||
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
|
||||
>> d.sample()
|
||||
array([[0, 0],
|
||||
[2, 3]])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -37,16 +48,6 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
|
||||
The argument ``nvec`` will determine the number of values each categorical variable can take.
|
||||
|
||||
Although this feature is rarely used, :class:`MultiDiscrete` spaces may also have several axes
|
||||
if ``nvec`` has several axes:
|
||||
|
||||
Example::
|
||||
|
||||
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
|
||||
>> d.sample()
|
||||
array([[0, 0],
|
||||
[2, 3]])
|
||||
|
||||
Args:
|
||||
nvec: vector of counts of each categorical variable. This will usually be a list of integers. However,
|
||||
you may also pass a more complicated numpy array if you'd like the space to have several axes.
|
||||
@@ -63,8 +64,56 @@ class MultiDiscrete(Space[np.ndarray]):
|
||||
"""Has stricter type than :class:`gym.Space` - never None."""
|
||||
return self._shape # type: ignore
|
||||
|
||||
def sample(self) -> np.ndarray:
|
||||
"""Generates a single random sample this space."""
|
||||
def sample(self, mask: Optional[SAMPLE_MASK_TYPE] = None) -> np.ndarray:
|
||||
"""Generates a single random sample this space.
|
||||
|
||||
Args:
|
||||
mask: An optional mask for multi-discrete, expects tuples with a `np.ndarray` mask in the position of each
|
||||
action with shape `(n,)` where `n` is the number of actions and `dtype=np.int8`.
|
||||
Only mask values == 1 are possible to sample unless all mask values for an action are 0 then the default action 0 is sampled.
|
||||
|
||||
Returns:
|
||||
An `np.ndarray` of shape `space.shape`
|
||||
"""
|
||||
if mask is not None:
|
||||
|
||||
def _apply_mask(
|
||||
sub_mask: SAMPLE_MASK_TYPE, sub_nvec: np.ndarray
|
||||
) -> Union[int, List[int]]:
|
||||
if isinstance(sub_mask, np.ndarray):
|
||||
assert np.issubdtype(
|
||||
type(sub_nvec), np.integer
|
||||
), f"Expects the mask to be for an action, actual for {sub_nvec}"
|
||||
assert (
|
||||
len(sub_mask) == sub_nvec
|
||||
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {sub_nvec}"
|
||||
assert (
|
||||
sub_mask.dtype == np.int8
|
||||
), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}"
|
||||
|
||||
valid_action_mask = sub_mask == 1
|
||||
assert np.all(
|
||||
np.logical_or(sub_mask == 0, valid_action_mask)
|
||||
), f"Expects all masks values to 0 or 1, actual values: {sub_mask}"
|
||||
|
||||
if np.any(valid_action_mask):
|
||||
return self.np_random.choice(np.where(valid_action_mask)[0])
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
assert isinstance(
|
||||
sub_mask, tuple
|
||||
), f"Expects the mask to be a tuple or np.ndarray, actual type: {type(sub_mask)}"
|
||||
assert len(sub_mask) == len(
|
||||
sub_nvec
|
||||
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
|
||||
return [
|
||||
_apply_mask(new_mask, new_nvec)
|
||||
for new_mask, new_nvec in zip(sub_mask, sub_nvec)
|
||||
]
|
||||
|
||||
return np.array(_apply_mask(mask, self.nvec), dtype=self.dtype)
|
||||
|
||||
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||
|
||||
def contains(self, x) -> bool:
|
||||
|
Reference in New Issue
Block a user