Add probability masking to space.sample (#1310)

Co-authored-by: Mario Jerez <jerezmario1@gmail.com>
This commit is contained in:
Mark Towers
2025-02-21 13:39:23 +00:00
committed by GitHub
parent 1dffcc6ed4
commit e4c1f901e9
21 changed files with 1053 additions and 182 deletions

View File

@@ -59,19 +59,29 @@ class MultiBinary(Space[NDArray[np.int8]]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
def sample(
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
) -> NDArray[np.int8]:
"""Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated.
mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``.
For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``.
For random samples, using a mask value of ``2``.
The expected mask shape is the space shape and mask dtype is ``np.int8``.
probability: An optional ``np.ndarray`` to mask samples with expected shape of space.shape where each element
represents the probability of the corresponding sample element being a 1.
The expected mask shape is the space shape and mask dtype is ``np.float64``.
Returns:
Sampled values from space
"""
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
if mask is not None:
assert isinstance(
mask, np.ndarray
@@ -91,8 +101,25 @@ class MultiBinary(Space[NDArray[np.int8]]):
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
mask.astype(self.dtype),
)
elif probability is not None:
assert isinstance(
probability, np.ndarray
), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}"
assert (
probability.dtype == np.float64
), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}"
assert (
probability.shape == self.shape
), f"The expected shape of the probability is {self.shape}, actual shape: {probability}"
assert np.all(
np.logical_and(probability >= 0, probability <= 1)
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
return (self.np_random.random(size=self.shape) <= probability).astype(
self.dtype
)
else:
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""