mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 23:12:46 +00:00
Add probability masking to space.sample
(#1310)
Co-authored-by: Mario Jerez <jerezmario1@gmail.com>
This commit is contained in:
@@ -59,19 +59,29 @@ class MultiBinary(Space[NDArray[np.int8]]):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return True
|
||||
|
||||
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
|
||||
def sample(
|
||||
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
|
||||
) -> NDArray[np.int8]:
|
||||
"""Generates a single random sample from this space.
|
||||
|
||||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||
|
||||
Args:
|
||||
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
|
||||
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated.
|
||||
mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``.
|
||||
For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``.
|
||||
For random samples, using a mask value of ``2``.
|
||||
The expected mask shape is the space shape and mask dtype is ``np.int8``.
|
||||
probability: An optional ``np.ndarray`` to mask samples with expected shape of space.shape where each element
|
||||
represents the probability of the corresponding sample element being a 1.
|
||||
The expected mask shape is the space shape and mask dtype is ``np.float64``.
|
||||
|
||||
Returns:
|
||||
Sampled values from space
|
||||
"""
|
||||
if mask is not None and probability is not None:
|
||||
raise ValueError(
|
||||
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
|
||||
)
|
||||
if mask is not None:
|
||||
assert isinstance(
|
||||
mask, np.ndarray
|
||||
@@ -91,8 +101,25 @@ class MultiBinary(Space[NDArray[np.int8]]):
|
||||
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
|
||||
mask.astype(self.dtype),
|
||||
)
|
||||
elif probability is not None:
|
||||
assert isinstance(
|
||||
probability, np.ndarray
|
||||
), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}"
|
||||
assert (
|
||||
probability.dtype == np.float64
|
||||
), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}"
|
||||
assert (
|
||||
probability.shape == self.shape
|
||||
), f"The expected shape of the probability is {self.shape}, actual shape: {probability}"
|
||||
assert np.all(
|
||||
np.logical_and(probability >= 0, probability <= 1)
|
||||
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
|
||||
|
||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
return (self.np_random.random(size=self.shape) <= probability).astype(
|
||||
self.dtype
|
||||
)
|
||||
else:
|
||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
|
Reference in New Issue
Block a user