mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 01:50:19 +00:00
Add probability masking to space.sample
(#1310)
Co-authored-by: Mario Jerez <jerezmario1@gmail.com>
This commit is contained in:
@@ -103,7 +103,13 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
self,
|
||||
mask: None | (
|
||||
tuple[
|
||||
None | np.integer | NDArray[np.integer],
|
||||
None | int | NDArray[np.integer],
|
||||
Any,
|
||||
]
|
||||
) = None,
|
||||
probability: None | (
|
||||
tuple[
|
||||
None | int | NDArray[np.integer],
|
||||
Any,
|
||||
]
|
||||
) = None,
|
||||
@@ -114,50 +120,37 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
|
||||
If you specify ``mask``, it is expected to be a tuple of the form ``(length_mask, sample_mask)`` where ``length_mask`` is
|
||||
|
||||
* ``None`` The length will be randomly drawn from a geometric distribution
|
||||
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
|
||||
* ``int`` for a fixed length sample
|
||||
* ``None`` - The length will be randomly drawn from a geometric distribution
|
||||
* ``int`` - Fixed length
|
||||
* ``np.ndarray`` of integers - Length of the sampled sequence is randomly drawn from this array.
|
||||
|
||||
The second element of the mask tuple ``sample`` mask specifies a mask that is applied when
|
||||
sampling elements from the base space. The mask is applied for each feature space sample.
|
||||
The second element of the tuple ``sample_mask`` specifies how the feature space will be sampled.
|
||||
Depending on if mask or probability is used will affect what argument is used.
|
||||
probability: See mask description above, the only difference is on the ``sample_mask`` for the feature space being probability rather than mask.
|
||||
|
||||
Returns:
|
||||
A tuple of random length with random samples of elements from the :attr:`feature_space`.
|
||||
"""
|
||||
if mask is not None:
|
||||
length_mask, feature_mask = mask
|
||||
if mask is not None and probability is not None:
|
||||
raise ValueError(
|
||||
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
|
||||
)
|
||||
elif mask is not None:
|
||||
sample_length = self.generate_sample_length(mask[0], "mask")
|
||||
sampled_values = tuple(
|
||||
self.feature_space.sample(mask=mask[1]) for _ in range(sample_length)
|
||||
)
|
||||
elif probability is not None:
|
||||
sample_length = self.generate_sample_length(probability[0], "probability")
|
||||
sampled_values = tuple(
|
||||
self.feature_space.sample(probability=probability[1])
|
||||
for _ in range(sample_length)
|
||||
)
|
||||
else:
|
||||
length_mask, feature_mask = None, None
|
||||
|
||||
if length_mask is not None:
|
||||
if np.issubdtype(type(length_mask), np.integer):
|
||||
assert (
|
||||
0 <= length_mask
|
||||
), f"Expects the length mask to be greater than or equal to zero, actual value: {length_mask}"
|
||||
length = length_mask
|
||||
elif isinstance(length_mask, np.ndarray):
|
||||
assert (
|
||||
len(length_mask.shape) == 1
|
||||
), f"Expects the shape of the length mask to be 1-dimensional, actual shape: {length_mask.shape}"
|
||||
assert np.all(
|
||||
0 <= length_mask
|
||||
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
|
||||
assert np.issubdtype(
|
||||
length_mask.dtype, np.integer
|
||||
), f"Expects the length mask array to have dtype to be an numpy integer, actual type: {length_mask.dtype}"
|
||||
length = self.np_random.choice(length_mask)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expects the type of length_mask to an integer or a np.ndarray, actual type: {type(length_mask)}"
|
||||
)
|
||||
else:
|
||||
# The choice of 0.25 is arbitrary
|
||||
length = self.np_random.geometric(0.25)
|
||||
|
||||
# Generate sample values from feature_space.
|
||||
sampled_values = tuple(
|
||||
self.feature_space.sample(mask=feature_mask) for _ in range(length)
|
||||
)
|
||||
sample_length = self.np_random.geometric(0.25)
|
||||
sampled_values = tuple(
|
||||
self.feature_space.sample() for _ in range(sample_length)
|
||||
)
|
||||
|
||||
if self.stack:
|
||||
# Concatenate values if stacked.
|
||||
@@ -168,6 +161,39 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
|
||||
return sampled_values
|
||||
|
||||
def generate_sample_length(
|
||||
self,
|
||||
length_mask: None | np.integer | NDArray[np.integer],
|
||||
mask_type: None | str,
|
||||
) -> int:
|
||||
"""Generate the sample length for a given length mask and mask type."""
|
||||
if length_mask is not None:
|
||||
if np.issubdtype(type(length_mask), np.integer):
|
||||
assert (
|
||||
0 <= length_mask
|
||||
), f"Expects the length mask of `{mask_type}` to be greater than or equal to zero, actual value: {length_mask}"
|
||||
|
||||
return length_mask
|
||||
elif isinstance(length_mask, np.ndarray):
|
||||
assert (
|
||||
len(length_mask.shape) == 1
|
||||
), f"Expects the shape of the length mask of `{mask_type}` to be 1-dimensional, actual shape: {length_mask.shape}"
|
||||
assert np.all(
|
||||
0 <= length_mask
|
||||
), f"Expects all values in the length_mask of `{mask_type}` to be greater than or equal to zero, actual values: {length_mask}"
|
||||
assert np.issubdtype(
|
||||
length_mask.dtype, np.integer
|
||||
), f"Expects the length mask array of `{mask_type}` to have dtype of np.integer, actual type: {length_mask.dtype}"
|
||||
|
||||
return self.np_random.choice(length_mask)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expects the type of length_mask of `{mask_type}` to be an integer or a np.ndarray, actual type: {type(length_mask)}"
|
||||
)
|
||||
else:
|
||||
# The choice of 0.25 is arbitrary
|
||||
return self.np_random.geometric(0.25)
|
||||
|
||||
def contains(self, x: Any) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
# by definition, any sequence is an iterable
|
||||
|
Reference in New Issue
Block a user