Add probability masking to space.sample (#1310)

Co-authored-by: Mario Jerez <jerezmario1@gmail.com>
This commit is contained in:
Mark Towers
2025-02-21 13:39:23 +00:00
committed by GitHub
parent 1dffcc6ed4
commit e4c1f901e9
21 changed files with 1053 additions and 182 deletions

View File

@@ -342,7 +342,7 @@ class Box(Space[NDArray[Any]]):
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
)
def sample(self, mask: None = None) -> NDArray[Any]:
def sample(self, mask: None = None, probability: None = None) -> NDArray[Any]:
r"""Generates a single random sample inside the Box.
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
@@ -355,6 +355,7 @@ class Box(Space[NDArray[Any]]):
Args:
mask: A mask for sampling values from the Box space, currently unsupported.
probability: A probability mask for sampling values from the Box space, currently unsupported.
Returns:
A sampled value from the Box
@@ -363,6 +364,10 @@ class Box(Space[NDArray[Any]]):
raise gym.error.Error(
f"Box.sample cannot be provided a mask, actual value: {mask}"
)
elif probability is not None:
raise gym.error.Error(
f"Box.sample cannot be provided a probability mask, actual value: {probability}"
)
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape)

View File

@@ -149,27 +149,49 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
)
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
def sample(
self,
mask: dict[str, Any] | None = None,
probability: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Generates a single random sample from this space.
The sample is an ordered dictionary of independent samples from the constituent spaces.
Args:
mask: An optional mask for each of the subspaces, expects the same keys as the space
probability: An optional probability mask for each of the subspaces, expects the same keys as the space
Returns:
A dictionary with the same key and sampled values from :attr:`self.spaces`
"""
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
assert isinstance(
mask, dict
), f"Expects mask to be a dict, actual type: {type(mask)}"
), f"Expected sample mask to be a dict, actual type: {type(mask)}"
assert (
mask.keys() == self.spaces.keys()
), f"Expect mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
return {k: space.sample(mask=mask[k]) for k, space in self.spaces.items()}
), f"Expected sample mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
return {k: space.sample() for k, space in self.spaces.items()}
return {k: space.sample(mask=mask[k]) for k, space in self.spaces.items()}
elif probability is not None:
assert isinstance(
probability, dict
), f"Expected sample probability mask to be a dict, actual type: {type(probability)}"
assert (
probability.keys() == self.spaces.keys()
), f"Expected sample probability mask keys to be same as space keys, mask keys: {probability.keys()}, space keys: {self.spaces.keys()}"
return {
k: space.sample(probability=probability[k])
for k, space in self.spaces.items()
}
else:
return {k: space.sample() for k, space in self.spaces.items()}
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""

View File

@@ -22,6 +22,12 @@ class Discrete(Space[np.int64]):
>>> observation_space = Discrete(3, start=-1, seed=42) # {-1, 0, 1}
>>> observation_space.sample()
np.int64(-1)
>>> observation_space.sample(mask=np.array([0,0,1], dtype=np.int8))
np.int64(1)
>>> observation_space.sample(probability=np.array([0,0,1], dtype=np.float64))
np.int64(1)
>>> observation_space.sample(probability=np.array([0,0.3,0.7], dtype=np.float64))
np.int64(1)
"""
def __init__(
@@ -56,41 +62,74 @@ class Discrete(Space[np.int64]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: MaskNDArray | None = None) -> np.int64:
def sample(
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
) -> np.int64:
"""Generates a single random sample from this space.
A sample will be chosen uniformly at random with the mask if provided
A sample will be chosen uniformly at random with the mask if provided, or it will be chosen according to a specified probability distribution if the probability mask is provided.
Args:
mask: An optional mask for if an action can be selected.
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions.
If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned.
probability: An optional probability mask describing the probability of each action being selected.
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.float64`` where each value is in the range ``[0, 1]`` and the sum of all values is 1.
If the values do not sum to 1, an exception will be thrown.
Returns:
A sampled integer from the space
"""
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
# binary mask sampling
elif mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
), f"The expected type of the sample mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
), f"The expected dtype of the sample mask is np.int8, actual dtype: {mask.dtype}"
assert mask.shape == (
self.n,
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
), f"The expected shape of the sample mask is {(int(self.n),)}, actual shape: {mask.shape}"
valid_action_mask = mask == 1
assert np.all(
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
), f"All values of the sample mask should be 0 or 1, actual values: {mask}"
if np.any(valid_action_mask):
return self.start + self.np_random.choice(
np.where(valid_action_mask)[0]
)
else:
return self.start
# probability mask sampling
elif probability is not None:
assert isinstance(
probability, np.ndarray
), f"The expected type of the sample probability is np.ndarray, actual type: {type(probability)}"
assert (
probability.dtype == np.float64
), f"The expected dtype of the sample probability is np.float64, actual dtype: {probability.dtype}"
assert probability.shape == (
self.n,
), f"The expected shape of the sample probability is {(int(self.n),)}, actual shape: {probability.shape}"
return self.start + self.np_random.integers(self.n)
assert np.all(
np.logical_and(probability >= 0, probability <= 1)
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
assert np.isclose(
np.sum(probability), 1
), f"The sum of the sample probability should be equal to 1, actual sum: {np.sum(probability)}"
return self.start + self.np_random.choice(np.arange(self.n), p=probability)
# uniform sampling
else:
return self.start + self.np_random.integers(self.n)
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""

View File

@@ -183,6 +183,12 @@ class Graph(Space[GraphInstance]):
NDArray[Any] | tuple[Any, ...] | None,
]
) = None,
probability: None | (
tuple[
NDArray[Any] | tuple[Any, ...] | None,
NDArray[Any] | tuple[Any, ...] | None,
]
) = None,
num_nodes: int = 10,
num_edges: int | None = None,
) -> GraphInstance:
@@ -192,6 +198,9 @@ class Graph(Space[GraphInstance]):
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
(Box spaces don't support sample masks).
If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
probability: An optional tuple of optional node and edge probability mask that is only possible with Discrete spaces
(Box spaces don't support sample probability masks).
If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
num_nodes: The number of nodes that will be sampled, the default is `10` nodes
num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2`
@@ -202,10 +211,18 @@ class Graph(Space[GraphInstance]):
num_nodes > 0
), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}"
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
node_space_mask, edge_space_mask = mask
mask_type = "mask"
elif probability is not None:
node_space_mask, edge_space_mask = probability
mask_type = "probability"
else:
node_space_mask, edge_space_mask = None, None
node_space_mask = edge_space_mask = mask_type = None
# we only have edges when we have at least 2 nodes
if num_edges is None:
@@ -228,15 +245,19 @@ class Graph(Space[GraphInstance]):
assert num_edges is not None
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
assert sampled_node_space is not None
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
assert sampled_node_space is not None
sampled_nodes = sampled_node_space.sample(node_space_mask)
sampled_edges = (
sampled_edge_space.sample(edge_space_mask)
if sampled_edge_space is not None
else None
)
if mask_type is not None:
node_sample_kwargs = {mask_type: node_space_mask}
edge_sample_kwargs = {mask_type: edge_space_mask}
else:
node_sample_kwargs = edge_sample_kwargs = {}
sampled_nodes = sampled_node_space.sample(**node_sample_kwargs)
sampled_edges = None
if sampled_edge_space is not None:
sampled_edges = sampled_edge_space.sample(**edge_sample_kwargs)
sampled_edge_links = None
if sampled_edges is not None and num_edges > 0:

View File

@@ -59,19 +59,29 @@ class MultiBinary(Space[NDArray[np.int8]]):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
def sample(
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
) -> NDArray[np.int8]:
"""Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated.
mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``.
For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``.
For random samples, using a mask value of ``2``.
The expected mask shape is the space shape and mask dtype is ``np.int8``.
probability: An optional ``np.ndarray`` to mask samples with expected shape of space.shape where each element
represents the probability of the corresponding sample element being a 1.
The expected mask shape is the space shape and mask dtype is ``np.float64``.
Returns:
Sampled values from space
"""
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
if mask is not None:
assert isinstance(
mask, np.ndarray
@@ -91,8 +101,25 @@ class MultiBinary(Space[NDArray[np.int8]]):
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
mask.astype(self.dtype),
)
elif probability is not None:
assert isinstance(
probability, np.ndarray
), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}"
assert (
probability.dtype == np.float64
), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}"
assert (
probability.shape == self.shape
), f"The expected shape of the probability is {self.shape}, actual shape: {probability}"
assert np.all(
np.logical_and(probability >= 0, probability <= 1)
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
return (self.np_random.random(size=self.shape) <= probability).astype(
self.dtype
)
else:
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""

View File

@@ -96,70 +96,107 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
return True
def sample(
self, mask: tuple[MaskNDArray, ...] | None = None
self,
mask: tuple[MaskNDArray, ...] | None = None,
probability: tuple[MaskNDArray, ...] | None = None,
) -> NDArray[np.integer[Any]]:
"""Generates a single random sample this space.
"""Generates a single random sample from this space.
Args:
mask: An optional mask for multi-discrete, expects tuples with a ``np.ndarray`` mask in the position of each
action with shape ``(n,)`` where ``n`` is the number of actions and ``dtype=np.int8``.
Only ``mask values == 1`` are possible to sample unless all mask values for an action are ``0`` then the default action ``self.start`` (the smallest element) is sampled.
probability: An optional probability mask for multi-discrete, expects tuples with a ``np.ndarray`` probability mask in the position of each
action with shape ``(n,)`` where ``n`` is the number of actions and ``dtype=np.float64``.
Only probability mask values within ``[0,1]`` are possible to sample as long as the sum of all values is ``1``.
Returns:
An ``np.ndarray`` of :meth:`Space.shape`
"""
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
return np.array(
self._apply_mask(mask, self.nvec, self.start, "mask"),
dtype=self.dtype,
)
elif probability is not None:
return np.array(
self._apply_mask(probability, self.nvec, self.start, "probability"),
dtype=self.dtype,
)
else:
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(
self.dtype
) + self.start
def _apply_mask(
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
sub_nvec: MaskNDArray | np.integer[Any],
sub_start: MaskNDArray | np.integer[Any],
) -> int | list[Any]:
if isinstance(sub_nvec, np.ndarray):
assert isinstance(
sub_mask, tuple
), f"Expects the mask to be a tuple for sub_nvec ({sub_nvec}), actual type: {type(sub_mask)}"
assert len(sub_mask) == len(
sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
return [
_apply_mask(new_mask, new_nvec, new_start)
for new_mask, new_nvec, new_start in zip(
sub_mask, sub_nvec, sub_start
)
]
else:
assert np.issubdtype(
type(sub_nvec), np.integer
), f"Expects the sub_nvec to be an action, actually: {sub_nvec}, {type(sub_nvec)}"
assert isinstance(
sub_mask, np.ndarray
), f"Expects the sub mask to be np.ndarray, actual type: {type(sub_mask)}"
assert (
len(sub_mask) == sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, action: {sub_nvec}"
assert (
sub_mask.dtype == np.int8
), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}"
def _apply_mask(
self,
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
sub_nvec: MaskNDArray | np.integer[Any],
sub_start: MaskNDArray | np.integer[Any],
mask_type: str,
) -> int | list[Any]:
"""Returns a sample using the provided mask or probability mask."""
if isinstance(sub_nvec, np.ndarray):
assert isinstance(
sub_mask, tuple
), f"Expects the mask to be a tuple for sub_nvec ({sub_nvec}), actual type: {type(sub_mask)}"
assert len(sub_mask) == len(
sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
return [
self._apply_mask(new_mask, new_nvec, new_start, mask_type)
for new_mask, new_nvec, new_start in zip(sub_mask, sub_nvec, sub_start)
]
valid_action_mask = sub_mask == 1
assert np.all(
np.logical_or(sub_mask == 0, valid_action_mask)
), f"Expects all masks values to 0 or 1, actual values: {sub_mask}"
assert np.issubdtype(
type(sub_nvec), np.integer
), f"Expects the sub_nvec to be an action, actually: {sub_nvec}, {type(sub_nvec)}"
assert isinstance(
sub_mask, np.ndarray
), f"Expects the sub mask to be np.ndarray, actual type: {type(sub_mask)}"
assert (
len(sub_mask) == sub_nvec
), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, action: {sub_nvec}"
if np.any(valid_action_mask):
return (
self.np_random.choice(np.where(valid_action_mask)[0])
+ sub_start
)
else:
return sub_start
if mask_type == "mask":
assert (
sub_mask.dtype == np.int8
), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}"
return np.array(_apply_mask(mask, self.nvec, self.start), dtype=self.dtype)
valid_action_mask = sub_mask == 1
assert np.all(
np.logical_or(sub_mask == 0, valid_action_mask)
), f"Expects all masks values to 0 or 1, actual values: {sub_mask}"
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(
self.dtype
) + self.start
if np.any(valid_action_mask):
return self.np_random.choice(np.where(valid_action_mask)[0]) + sub_start
else:
return sub_start
elif mask_type == "probability":
assert (
sub_mask.dtype == np.float64
), f"Expects the mask dtype to be np.float64, actual dtype: {sub_mask.dtype}"
valid_action_mask = np.logical_and(sub_mask > 0, sub_mask <= 1)
assert np.all(
np.logical_or(sub_mask == 0, valid_action_mask)
), f"Expects all masks values to be between 0 and 1, actual values: {sub_mask}"
assert np.isclose(
np.sum(sub_mask), 1
), f"Expects the sum of all mask values to be 1, actual sum: {np.sum(sub_mask)}"
normalized_sub_mask = sub_mask / np.sum(sub_mask)
return (
self.np_random.choice(
np.where(valid_action_mask)[0],
p=normalized_sub_mask[valid_action_mask],
)
+ sub_start
)
raise ValueError(f"Unsupported mask type: {mask_type}")
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""

View File

@@ -18,9 +18,9 @@ class OneOf(Space[Any]):
Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=123)
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
>>> observation_space.sample() # the first element is the space index (Discrete in this case) and the second element is the sample from Discrete
(np.int64(0), np.int64(0))
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
>>> observation_space.sample() # this time the Box space was sampled as index=1
(np.int64(1), array([-0.00711833, -0.7257502 ], dtype=float32))
>>> observation_space[0]
Discrete(2)
@@ -100,7 +100,11 @@ class OneOf(Space[Any]):
f"Expected None, int, or tuple of ints, actual type: {type(seed)}"
)
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
def sample(
self,
mask: tuple[Any | None, ...] | None = None,
probability: tuple[Any | None, ...] | None = None,
) -> tuple[int, Any]:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
@@ -108,23 +112,42 @@ class OneOf(Space[Any]):
Args:
mask: An optional tuple of optional masks for each of the subspace's samples,
expects the same number of masks as spaces
probability: An optional tuple of optional probability masks for each of the subspace's samples,
expects the same number of probability masks as spaces
Returns:
Tuple of the subspace's samples
"""
subspace_idx = self.np_random.integers(0, len(self.spaces), dtype=np.int64)
subspace = self.spaces[subspace_idx]
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
), f"Expected type of `mask` is tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"
), f"Expected length of `mask` is {len(self.spaces)}, actual length: {len(mask)}"
mask = mask[subspace_idx]
subspace_sample = subspace.sample(mask=mask[subspace_idx])
return subspace_idx, subspace.sample(mask=mask)
elif probability is not None:
assert isinstance(
probability, tuple
), f"Expected type of `probability` is tuple, actual type: {type(probability)}"
assert len(probability) == len(
self.spaces
), f"Expected length of `probability` is {len(self.spaces)}, actual length: {len(probability)}"
subspace_sample = subspace.sample(probability=probability[subspace_idx])
else:
subspace_sample = subspace.sample()
return subspace_idx, subspace_sample
def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""

View File

@@ -103,7 +103,13 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
self,
mask: None | (
tuple[
None | np.integer | NDArray[np.integer],
None | int | NDArray[np.integer],
Any,
]
) = None,
probability: None | (
tuple[
None | int | NDArray[np.integer],
Any,
]
) = None,
@@ -114,50 +120,37 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
If you specify ``mask``, it is expected to be a tuple of the form ``(length_mask, sample_mask)`` where ``length_mask`` is
* ``None`` The length will be randomly drawn from a geometric distribution
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
* ``int`` for a fixed length sample
* ``None`` - The length will be randomly drawn from a geometric distribution
* ``int`` - Fixed length
* ``np.ndarray`` of integers - Length of the sampled sequence is randomly drawn from this array.
The second element of the mask tuple ``sample`` mask specifies a mask that is applied when
sampling elements from the base space. The mask is applied for each feature space sample.
The second element of the tuple ``sample_mask`` specifies how the feature space will be sampled.
Depending on if mask or probability is used will affect what argument is used.
probability: See mask description above, the only difference is on the ``sample_mask`` for the feature space being probability rather than mask.
Returns:
A tuple of random length with random samples of elements from the :attr:`feature_space`.
"""
if mask is not None:
length_mask, feature_mask = mask
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
sample_length = self.generate_sample_length(mask[0], "mask")
sampled_values = tuple(
self.feature_space.sample(mask=mask[1]) for _ in range(sample_length)
)
elif probability is not None:
sample_length = self.generate_sample_length(probability[0], "probability")
sampled_values = tuple(
self.feature_space.sample(probability=probability[1])
for _ in range(sample_length)
)
else:
length_mask, feature_mask = None, None
if length_mask is not None:
if np.issubdtype(type(length_mask), np.integer):
assert (
0 <= length_mask
), f"Expects the length mask to be greater than or equal to zero, actual value: {length_mask}"
length = length_mask
elif isinstance(length_mask, np.ndarray):
assert (
len(length_mask.shape) == 1
), f"Expects the shape of the length mask to be 1-dimensional, actual shape: {length_mask.shape}"
assert np.all(
0 <= length_mask
), f"Expects all values in the length_mask to be greater than or equal to zero, actual values: {length_mask}"
assert np.issubdtype(
length_mask.dtype, np.integer
), f"Expects the length mask array to have dtype to be an numpy integer, actual type: {length_mask.dtype}"
length = self.np_random.choice(length_mask)
else:
raise TypeError(
f"Expects the type of length_mask to an integer or a np.ndarray, actual type: {type(length_mask)}"
)
else:
# The choice of 0.25 is arbitrary
length = self.np_random.geometric(0.25)
# Generate sample values from feature_space.
sampled_values = tuple(
self.feature_space.sample(mask=feature_mask) for _ in range(length)
)
sample_length = self.np_random.geometric(0.25)
sampled_values = tuple(
self.feature_space.sample() for _ in range(sample_length)
)
if self.stack:
# Concatenate values if stacked.
@@ -168,6 +161,39 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
return sampled_values
def generate_sample_length(
self,
length_mask: None | np.integer | NDArray[np.integer],
mask_type: None | str,
) -> int:
"""Generate the sample length for a given length mask and mask type."""
if length_mask is not None:
if np.issubdtype(type(length_mask), np.integer):
assert (
0 <= length_mask
), f"Expects the length mask of `{mask_type}` to be greater than or equal to zero, actual value: {length_mask}"
return length_mask
elif isinstance(length_mask, np.ndarray):
assert (
len(length_mask.shape) == 1
), f"Expects the shape of the length mask of `{mask_type}` to be 1-dimensional, actual shape: {length_mask.shape}"
assert np.all(
0 <= length_mask
), f"Expects all values in the length_mask of `{mask_type}` to be greater than or equal to zero, actual values: {length_mask}"
assert np.issubdtype(
length_mask.dtype, np.integer
), f"Expects the length mask array of `{mask_type}` to have dtype of np.integer, actual type: {length_mask.dtype}"
return self.np_random.choice(length_mask)
else:
raise TypeError(
f"Expects the type of length_mask of `{mask_type}` to be an integer or a np.ndarray, actual type: {type(length_mask)}"
)
else:
# The choice of 0.25 is arbitrary
return self.np_random.geometric(0.25)
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
# by definition, any sequence is an iterable

View File

@@ -90,13 +90,16 @@ class Space(Generic[T_cov]):
"""Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`."""
raise NotImplementedError
def sample(self, mask: Any | None = None) -> T_cov:
def sample(self, mask: Any | None = None, probability: Any | None = None) -> T_cov:
"""Randomly sample an element of this space.
Can be uniform or non-uniform sampling based on boundedness of space.
The binary mask and the probability mask can't be used at the same time.
Args:
mask: A mask used for sampling, expected ``dtype=np.int8`` and see sample implementation for expected shape.
mask: A mask used for random sampling, expected ``dtype=np.int8`` and see sample implementation for expected shape.
probability: A probability mask used for sampling according to the given probability distribution, expected ``dtype=np.float64`` and see sample implementation for expected shape.
Returns:
A sampled actions from the space

View File

@@ -78,75 +78,108 @@ class Text(Space[str]):
def sample(
self,
mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
probability: None | (tuple[int | None, NDArray[np.float64] | None]) = None,
) -> str:
"""Generates a single random sample from this space with by default a random length between ``min_length`` and ``max_length`` and sampled from the ``charset``.
Args:
mask: An optional tuples of length and mask for the text.
The length is expected to be between the ``min_length`` and ``max_length`` otherwise a random integer between ``min_length`` and ``max_length`` is selected.
The length is expected to be between the ``min_length`` and ``max_length``.
Otherwise, a random integer between ``min_length`` and ``max_length`` is selected.
For the mask, we expect a numpy array of length of the charset passed with ``dtype == np.int8``.
If the charlist mask is all zero then an empty string is returned no matter the ``min_length``
probability: An optional tuples of length and probability mask for the text.
The length is expected to be between the ``min_length`` and ``max_length``.
Otherwise, a random integer between ``min_length`` and ``max_length`` is selected.
For the probability mask, we expect a numpy array of length of the charset passed with ``dtype == np.float64``.
The sum of the probability mask should be 1, otherwise an exception is raised.
Returns:
A sampled string from the space
"""
if mask is not None:
assert isinstance(
mask, tuple
), f"Expects the mask type to be a tuple, actual type: {type(mask)}"
assert (
len(mask) == 2
), f"Expects the mask length to be two, actual length: {len(mask)}"
length, charlist_mask = mask
if length is not None:
assert np.issubdtype(
type(length), np.integer
), f"Expects the Text sample length to be an integer, actual type: {type(length)}"
assert (
self.min_length <= length <= self.max_length
), f"Expects the Text sample length be between {self.min_length} and {self.max_length}, actual length: {length}"
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
length, charlist_mask = self._validate_mask(mask, np.int8, "mask")
if charlist_mask is not None:
assert isinstance(
charlist_mask, np.ndarray
), f"Expects the Text sample mask to be an np.ndarray, actual type: {type(charlist_mask)}"
assert (
charlist_mask.dtype == np.int8
), f"Expects the Text sample mask to be an np.ndarray, actual dtype: {charlist_mask.dtype}"
assert charlist_mask.shape == (
len(self.character_set),
), f"expects the Text sample mask to be {(len(self.character_set),)}, actual shape: {charlist_mask.shape}"
assert np.all(
np.logical_or(charlist_mask == 0, charlist_mask == 1)
), f"Expects all masks values to 0 or 1, actual values: {charlist_mask}"
), f"Expects all mask values to 0 or 1, actual values: {charlist_mask}"
# normalise the mask to use as a probability
if np.sum(charlist_mask) > 0:
charlist_mask = charlist_mask / np.sum(charlist_mask)
elif probability is not None:
length, charlist_mask = self._validate_mask(
probability, np.float64, "probability"
)
if charlist_mask is not None:
assert np.all(
np.logical_and(charlist_mask >= 0, charlist_mask <= 1)
), f"Expects all probability mask values to be within 0 and 1, actual values: {charlist_mask}"
assert np.isclose(
np.sum(charlist_mask), 1
), f"Expects the sum of the probability mask to be 1, actual sum: {np.sum(charlist_mask)}"
else:
length, charlist_mask = None, None
length = charlist_mask = None
if length is None:
length = self.np_random.integers(self.min_length, self.max_length + 1)
if charlist_mask is None: # uniform sampling
charlist_mask = np.ones(len(self.character_set)) / len(self.character_set)
if charlist_mask is None:
string = self.np_random.choice(self.character_list, size=length)
else:
valid_mask = charlist_mask == 1
valid_indexes = np.where(valid_mask)[0]
if len(valid_indexes) == 0:
if self.min_length == 0:
string = ""
else:
# Otherwise the string will not be contained in the space
raise ValueError(
f"Trying to sample with a minimum length > 0 ({self.min_length}) but the character mask is all zero meaning that no character could be sampled."
)
if np.all(charlist_mask == 0):
if self.min_length == 0:
return ""
else:
string = "".join(
self.character_list[index]
for index in self.np_random.choice(valid_indexes, size=length)
# Otherwise the string will not be contained in the space
raise ValueError(
f"Trying to sample with a minimum length > 0 (actual minimum length={self.min_length}) but the character mask is all zero meaning that no character could be sampled."
)
string = self.np_random.choice(
self.character_list, size=length, p=charlist_mask
)
return "".join(string)
def _validate_mask(
self,
mask: tuple[int | None, NDArray[np.int8] | NDArray[np.float64] | None],
expected_dtype: np.dtype,
mask_type: str,
) -> tuple[int | None, NDArray[np.int8] | NDArray[np.float64] | None]:
assert isinstance(
mask, tuple
), f"Expects the `{mask_type}` type to be a tuple, actual type: {type(mask)}"
assert (
len(mask) == 2
), f"Expects the `{mask_type}` length to be two, actual length: {len(mask)}"
length, charlist_mask = mask
if length is not None:
assert np.issubdtype(
type(length), np.integer
), f"Expects the Text sample length to be an integer, actual type: {type(length)}"
assert (
self.min_length <= length <= self.max_length
), f"Expects the Text sample length be between {self.min_length} and {self.max_length}, actual length: {length}"
if charlist_mask is not None:
assert isinstance(
charlist_mask, np.ndarray
), f"Expects the Text sample `{mask_type}` to be an np.ndarray, actual type: {type(charlist_mask)}"
assert (
charlist_mask.dtype == expected_dtype
), f"Expects the Text sample `{mask_type}` to be type {expected_dtype}, actual dtype: {charlist_mask.dtype}"
assert charlist_mask.shape == (
len(self.character_set),
), f"expects the Text sample `{mask_type}` to be {(len(self.character_set),)}, actual shape: {charlist_mask.shape}"
return length, charlist_mask
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, str):

View File

@@ -87,7 +87,11 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[Any, ...]:
def sample(
self,
mask: tuple[Any | None, ...] | None = None,
probability: tuple[Any | None, ...] | None = None,
) -> tuple[Any, ...]:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
@@ -95,24 +99,43 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
Args:
mask: An optional tuple of optional masks for each of the subspace's samples,
expects the same number of masks as spaces
probability: An optional tuple of optional probability masks for each of the subspace's samples,
expects the same number of probability masks as spaces
Returns:
Tuple of the subspace's samples
"""
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
), f"Expected type of `mask` to be tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"
), f"Expected length of `mask` to be {len(self.spaces)}, actual length: {len(mask)}"
return tuple(
space.sample(mask=sub_mask)
for space, sub_mask in zip(self.spaces, mask)
space.sample(mask=space_mask)
for space, space_mask in zip(self.spaces, mask)
)
return tuple(space.sample() for space in self.spaces)
elif probability is not None:
assert isinstance(
probability, tuple
), f"Expected type of `probability` to be tuple, actual type: {type(probability)}"
assert len(probability) == len(
self.spaces
), f"Expected length of `probability` to be {len(self.spaces)}, actual length: {len(probability)}"
return tuple(
space.sample(probability=space_probability)
for space, space_probability in zip(self.spaces, probability)
)
else:
return tuple(space.sample() for space in self.spaces)
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""

View File

@@ -373,3 +373,15 @@ def test_sample_mask():
match=re.escape("Box.sample cannot be provided a mask, actual value: "),
):
space.sample(mask=np.array([0, 1, 0], dtype=np.int8))
def test_sample_probability_mask():
"""Box cannot have a probability mask applied."""
space = Box(0, 1)
with pytest.raises(
gym.error.Error,
match=re.escape(
"Box.sample cannot be provided a probability mask, actual value: "
),
):
space.sample(probability=np.array([0, 1, 0], dtype=np.float64))

View File

@@ -170,3 +170,146 @@ def test_keys_contains():
assert "a" in space.keys()
assert "c" not in space.keys()
def test_sample_with_mask():
"""Test the sample method with valid masks."""
space = Dict(
{
"a": Discrete(5),
"b": Box(low=0, high=1, shape=(2,)),
}
)
mask = {
"a": np.array(
[0, 1, 0, 0, 0], dtype=np.int8
), # Only allow sampling the value 1
"b": None, # No mask for Box space
}
for _ in range(10):
sample = space.sample(mask=mask)
assert sample["a"] == 1 # Discrete space should only return 1
assert space["b"].contains(sample["b"])
def test_sample_with_probability():
"""Test the sample method with valid probabilities."""
space = Dict(
{
"a": Discrete(3),
"b": Box(low=0, high=1, shape=(2,)),
}
)
probability = {
"a": np.array(
[0.1, 0.7, 0.2], dtype=np.float64
), # Sampling probabilities for Discrete space
"b": None, # No probability for Box space
}
samples = [space.sample(probability=probability)["a"] for _ in range(1000)]
# Check that the sampling roughly follows the probability distribution
counts = np.bincount(samples, minlength=3) / len(samples)
np.testing.assert_almost_equal(counts, probability["a"], decimal=1)
def test_sample_with_invalid_mask():
"""Test the sample method with an invalid mask."""
space = Dict(
{
"a": Discrete(5),
"b": Box(low=0, high=1, shape=(2,)),
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected shape of the sample mask is (5,), actual shape: (3,)"
),
):
space.sample(
mask={
"a": np.array([1, 0, 0], dtype=np.int8), # Length mismatch
"b": None,
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected dtype of the sample mask is np.int8, actual dtype: float32"
),
):
space.sample(
mask={
"a": np.array([1, 0, 0, 1, 1], dtype=np.float32), # dtype mismatch
"b": None,
}
)
def test_sample_with_invalid_probability():
"""Test the sample method with an invalid probability."""
space = Dict(
{
"a": Discrete(5),
"b": Box(low=0, high=1, shape=(2,)),
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected shape of the sample probability is (5,), actual shape: (2,)"
),
):
space.sample(
probability={
"a": np.array([0.5, 0.5], dtype=np.float64), # Length mismatch
"b": None,
}
)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected dtype of the sample probability is np.float64, actual dtype: int8"
),
):
space.sample(
probability={
"a": np.array([0.5, 0.5], dtype=np.int8), # dtype mismatch
"b": None,
}
)
def test_sample_with_mask_and_probability():
"""Ensure an error is raised when both mask and probability are provided."""
space = Dict(
{
"a": Discrete(3),
"b": Box(low=0, high=1, shape=(2,)),
}
)
mask = {
"a": np.array([1, 0, 1], dtype=np.int8),
"b": None,
}
probability = {
"a": np.array([0.5, 0.2, 0.3], dtype=np.float64),
"b": None,
}
with pytest.raises(
ValueError,
match=re.escape("Only one of `mask` or `probability` can be provided"),
):
space.sample(mask=mask, probability=probability)

View File

@@ -1,6 +1,8 @@
import re
from copy import deepcopy
import numpy as np
import pytest
from gymnasium.spaces import Discrete
@@ -27,8 +29,83 @@ def test_space_legacy_pickling():
def test_sample_mask():
"""Test that the mask parameter of the sample function works as expected."""
space = Discrete(4, start=2)
assert 2 <= space.sample() < 6
assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]
def test_probability_mask():
"""Test that the probability parameter of the sample function works as expected."""
space = Discrete(4, start=2)
assert space.sample(probability=np.array([0, 1, 0, 0], dtype=np.float64)) == 3
assert space.sample(probability=np.array([0, 0.5, 0, 0.5], dtype=np.float64)) in [
3,
5,
]
assert space.sample(
probability=np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float64)
) in [
2,
3,
4,
5,
]
def test_sample_with_mask_and_probability():
"""Ensure an error is raised when both mask and probability are provided."""
space = Discrete(4, start=2)
with pytest.raises(
ValueError,
match=re.escape("Only one of `mask` or `probability` can be provided"),
):
space.sample(
mask=np.array([0, 1, 0, 0], dtype=np.int8),
probability=np.array([0, 1, 0, 0], dtype=np.float64),
)
def test_invalid_probability_mask_dtype():
"""Test that invalid probability mask dtype raises the correct exception."""
space = Discrete(4, start=2)
with pytest.raises(
AssertionError,
match=re.escape(
"The expected dtype of the sample probability is np.float64, actual dtype: int8"
),
):
space.sample(probability=np.array([0, 1, 0, 0], dtype=np.int8))
def test_invalid_probability_mask_values():
"""Test that invalid probability mask values raises the correct exception."""
space = Discrete(4, start=2)
with pytest.raises(
AssertionError,
match=re.escape(
"All values of the sample probability should be between 0 and 1, actual values: [-0.5 1. 0.5 0. ]"
),
):
space.sample(probability=np.array([-0.5, 1, 0.5, 0], dtype=np.float64))
with pytest.raises(
AssertionError,
match=re.escape(
"The sum of the sample probability should be equal to 1, actual sum: 1.1"
),
):
space.sample(probability=np.array([0.2, 0.3, 0.4, 0.2], dtype=np.float64))
with pytest.raises(
AssertionError,
match=re.escape(
"The sum of the sample probability should be equal to 1, actual sum: 0.0"
),
):
space.sample(probability=np.array([0, 0, 0, 0], dtype=np.float64))

View File

@@ -135,3 +135,98 @@ def test_edge_space_sample():
def test_not_contains(sample):
space = Graph(node_space=Discrete(2), edge_space=Discrete(2))
assert sample not in space
def test_probability_node_sampling():
"""
Test the probability parameter for node sampling.
Ensures nodes are sampled according to the given probability distribution.
"""
space = Graph(node_space=Discrete(3), edge_space=None)
space.seed(42)
# Define a probability distribution for nodes
probability = np.array([0.7, 0.2, 0.1], dtype=np.float64)
num_samples = 1000
# Collect samples with the given probability
samples = [
space.sample(probability=((probability,), None), num_nodes=1).nodes[0]
for _ in range(num_samples)
]
# Check the empirical distribution of the samples
counts = np.bincount(samples, minlength=3)
empirical_distribution = counts / num_samples
assert np.allclose(
empirical_distribution, probability, atol=0.05
), f"Empirical distribution {empirical_distribution} does not match expected probability {probability}"
def test_probability_edge_sampling():
"""
Test the probability parameter for edge sampling.
Ensures edges are sampled according to the given probability distribution.
"""
space = Graph(node_space=Discrete(3), edge_space=Discrete(3))
space.seed(42)
# Define a probability distribution for edges
probability = np.array([0.5, 0.3, 0.2], dtype=np.float64)
num_samples = 1000
# Collect samples with the given probability
samples = [
space.sample(probability=(None, (probability,)), num_edges=1).edges[0]
for _ in range(num_samples)
]
# Check the empirical distribution of the samples
counts = np.bincount(samples, minlength=3)
empirical_distribution = counts / num_samples
assert np.allclose(
empirical_distribution, probability, atol=0.05
), f"Empirical distribution {empirical_distribution} does not match expected probability {probability}"
def test_probability_node_and_edge_sampling():
"""
Test the probability parameter for both node and edge sampling.
Ensures nodes and edges are sampled correctly according to their respective probability distributions.
"""
space = Graph(node_space=Discrete(3), edge_space=Discrete(3))
space.seed(42)
# Define probability distributions for nodes and edges
node_probability = np.array([0.6, 0.3, 0.1], dtype=np.float64)
edge_probability = np.array([0.4, 0.4, 0.2], dtype=np.float64)
num_samples = 1000
# Collect samples with the given probabilities
node_samples = []
edge_samples = []
for _ in range(num_samples):
sample = space.sample(
probability=((node_probability,), (edge_probability,)),
num_nodes=1,
num_edges=1,
)
node_samples.append(sample.nodes[0])
edge_samples.append(sample.edges[0])
# Check the empirical distributions of the samples
node_counts = np.bincount(node_samples, minlength=3)
edge_counts = np.bincount(edge_samples, minlength=3)
node_empirical_distribution = node_counts / num_samples
edge_empirical_distribution = edge_counts / num_samples
assert np.allclose(
node_empirical_distribution, node_probability, atol=0.05
), f"Node empirical distribution {node_empirical_distribution} does not match expected probability {node_probability}"
assert np.allclose(
edge_empirical_distribution, edge_probability, atol=0.05
), f"Edge empirical distribution {edge_empirical_distribution} does not match expected probability {edge_probability}"

View File

@@ -17,3 +17,18 @@ def test_sample():
space = MultiBinary(np.array([2, 3]))
sample = space.sample(mask=np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int8))
assert np.all(sample == [[0, 0, 0], [1, 1, 1]]), sample
def test_sample_probabilities():
# Test sampling with probabilities
space = MultiBinary(4)
probabilities = np.array([0, 1, 0.5, 0.25], dtype=np.float64)
samples = [space.sample(probability=probabilities) for _ in range(10000)]
assert all(sample in space for sample in samples)
samples = np.array(samples)
# Check empirical probabilities
for i in range(4):
counts = np.sum(samples[:, i]) / len(samples)
np.testing.assert_allclose(counts, probabilities[i], atol=0.05)

View File

@@ -196,3 +196,61 @@ def test_space_legacy_pickling():
new_legacy_space.__setstate__(legacy_state)
assert new_legacy_space == legacy_space
assert np.all(new_legacy_space.start == np.array([0, 0, 0]))
def test_multidiscrete_sample_edge_cases():
# Test edge case where one dimension has size 1
space = MultiDiscrete([5, 1, 3])
samples = [space.sample() for _ in range(1000)]
samples = np.array(samples)
# The second dimension should always be 0 (only one valid value)
assert np.all(samples[:, 1] == 0)
def test_multidiscrete_sample():
# Test sampling without a mask
space = MultiDiscrete([5, 2, 3])
samples = [space.sample() for _ in range(1000)]
samples = np.array(samples)
# Check that the samples fall within the bounds
assert np.all(samples[:, 0] < 5)
assert np.all(samples[:, 1] < 2)
assert np.all(samples[:, 2] < 3)
def test_multidiscrete_sample_with_mask():
# Test sampling with a mask
space = MultiDiscrete([2, 3, 4])
mask = (
np.array([1, 0], dtype=np.int8),
np.array([1, 1, 0], dtype=np.int8),
np.array([1, 0, 1, 0], dtype=np.int8),
)
samples = [space.sample(mask=mask) for _ in range(1000)]
assert all(sample in space for sample in samples)
samples = np.array(samples)
# Check that the samples respect the mask
for i, dim in enumerate(space.nvec):
for j in range(dim):
if mask[i][j] == 0:
assert np.all(samples[:, i] != j)
def test_multidiscrete_sample_probabilities():
# Test sampling with probabilities
space = MultiDiscrete([3, 3])
probabilities = (
np.array([0.1, 0.7, 0.2], dtype=np.float64),
np.array([0.3, 0.3, 0.4], dtype=np.float64),
)
samples = [space.sample(probability=probabilities) for _ in range(10000)]
assert all(sample in space for sample in samples)
samples = np.array(samples)
# Check empirical probabilities
for i in range(2):
counts = np.bincount(samples[:, i], minlength=3) / len(samples)
np.testing.assert_allclose(counts, probabilities[i], atol=0.05)

View File

@@ -65,3 +65,56 @@ def test_bad_oneof_seed():
match="Expected None, int, or tuple of ints, actual type: <class 'float'>",
):
space.seed(0.0)
def test_oneof_sample():
"""Tests the sample method with and without masks or probabilities."""
space = OneOf([Discrete(2), Box(-1, 1, shape=(2,))])
# Unmasked sampling
sample = space.sample()
assert isinstance(sample, tuple)
sample_idx, sample_value = sample
assert sample_idx in [0, 1]
assert sample_value in space.spaces[sample_idx]
# Masked sampling
mask = (np.array([1, 0], dtype=np.int8), None)
sample_idx, sample_value = space.sample(mask=mask)
assert sample_idx in [0, 1]
while sample_idx != 0:
sample_idx, sample_value = space.sample(mask=mask)
if sample_idx == 0:
assert sample_value == 0
# Probability sampling
probability = (np.array([0.8, 0.2], dtype=np.float64), None)
sample_idx, sample_value = space.sample(probability=probability)
assert sample_idx in [0, 1]
def test_invalid_sample_inputs():
"""Tests that invalid inputs to sample raise appropriate errors."""
space = OneOf([Discrete(2), Box(-1, 1, shape=(2,))])
# Providing both mask and probability
with pytest.raises(
ValueError, match="Only one of `mask` or `probability` can be provided."
):
space.sample(mask=(None, None), probability=(0.5, 0.5))
# Invalid mask type
with pytest.raises(AssertionError, match="Expected type of `mask` is tuple"):
space.sample(mask={"low": 0, "high": 1})
# Invalid mask length
with pytest.raises(AssertionError, match="Expected length of `mask` is 2"):
space.sample(mask=(None,))
# Invalid probability length
with pytest.raises(AssertionError, match="Expected length of `probability` is 2"):
space.sample(probability=(0.5,))
# Invalid probability type
with pytest.raises(AssertionError, match="Expected type of `probability` is tuple"):
space.sample(probability=[0.5, 0.5])

View File

@@ -34,7 +34,7 @@ def test_sample():
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the length mask to be greater than or equal to zero, actual value: -1"
"Expects the length mask of `mask` to be greater than or equal to zero, actual value: -1"
),
):
space.sample(mask=(-1, None))
@@ -51,7 +51,7 @@ def test_sample():
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the shape of the length mask to be 1-dimensional, actual shape: (2, 2)"
"Expects the shape of the length mask of `mask` to be 1-dimensional, actual shape: (2, 2)"
),
):
space.sample(mask=(np.array([[2, 2], [2, 2]]), None))
@@ -59,7 +59,7 @@ def test_sample():
with pytest.raises(
AssertionError,
match=re.escape(
"Expects all values in the length_mask to be greater than or equal to zero, actual values: [ 1 2 -1]"
"Expects all values in the length_mask of `mask` to be greater than or equal to zero, actual values: [ 1 2 -1]"
),
):
space.sample(mask=(np.array([1, 2, -1]), None))
@@ -68,7 +68,63 @@ def test_sample():
with pytest.raises(
TypeError,
match=re.escape(
"Expects the type of length_mask to an integer or a np.ndarray, actual type: <class 'str'>"
"Expects the type of length_mask of `mask` to be an integer or a np.ndarray, actual type: <class 'str'>"
),
):
space.sample(mask=("abc", None))
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the shape of the length mask of `probability` to be 1-dimensional, actual shape: (2, 2)"
),
):
space.sample(probability=(np.array([[2, 2], [2, 2]]), None))
with pytest.raises(
AssertionError,
match=re.escape(
"Expects all values in the length_mask of `probability` to be greater than or equal to zero, actual values: [ 1 2 -1]"
),
):
space.sample(probability=(np.array([1, 2, -1]), None))
# Test with an invalid length
with pytest.raises(
TypeError,
match=re.escape(
"Expects the type of length_mask of `probability` to be an integer or a np.ndarray, actual type: <class 'str'>"
),
):
space.sample(probability=("abc", None))
def test_sample_with_mask():
"""Tests sampling with mask"""
space = gym.spaces.Sequence(gym.spaces.Discrete(2))
sample = space.sample(mask=(np.array([20]), np.array([0, 1], dtype=np.int8)))
sample = np.array(sample)
assert np.all(sample[:] == 1)
assert np.all(value in space for value in sample)
assert len(sample) == 20
def test_sample_with_probability():
"""Tests sampling with probability mask"""
space = gym.spaces.Sequence(gym.spaces.Discrete(2))
sample = space.sample(
probability=(np.array([20]), np.array([0, 1], dtype=np.float64))
)
sample = np.array(sample)
assert np.all(sample[:] == 1)
assert np.all(value in space for value in sample)
assert len(sample) == 20
space = gym.spaces.Sequence(gym.spaces.Discrete(3))
probability = (np.array([1000]), np.array([0, 0.2, 0.8], dtype=np.float64))
sample = space.sample(probability=probability)
sample = np.array(sample)
assert np.all(np.isin(sample[:], [1, 2]))
assert np.all(value in space for value in sample)
counts = np.bincount(sample[:], minlength=3) / len(sample)
np.testing.assert_allclose(counts, probability[1], atol=0.05)

View File

@@ -21,7 +21,7 @@ def test_sample_mask():
with pytest.raises(
ValueError,
match=re.escape(
"Trying to sample with a minimum length > 0 (1) but the character mask is all zero meaning that no character could be sampled."
"Trying to sample with a minimum length > 0 (actual minimum length=1) but the character mask is all zero meaning that no character could be sampled."
),
):
space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8)))
@@ -33,9 +33,47 @@ def test_sample_mask():
assert sample in space
assert sample == ""
sample = space.sample(mask=(0, None))
assert sample in space
assert sample == ""
# Test the sample characters
space = Text(max_length=5, charset="abcd")
sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8)))
assert sample in space
assert sample == "bbb"
def test_sample_probability():
space = Text(min_length=1, max_length=5)
# Test the sample length
sample = space.sample(probability=(3, None))
assert sample in space
assert len(sample) == 3
sample = space.sample(probability=None)
assert sample in space
assert 1 <= len(sample) <= 5
with pytest.raises(
AssertionError,
match=re.escape(
"Expects the sum of the probability mask to be 1, actual sum: 0.0"
),
):
space.sample(
probability=(3, np.zeros(len(space.character_set), dtype=np.float64))
)
# Test the sample characters
space = Text(max_length=5, charset="abcd")
sample = space.sample(probability=(3, np.array([0, 1, 0, 0], dtype=np.float64)))
assert sample in space
assert sample == "bbb"
sample = space.sample(probability=(2, np.array([0.5, 0.5, 0, 0], dtype=np.float64)))
assert sample in space
assert sample in ["aa", "bb", "ab", "ba"]

View File

@@ -105,3 +105,68 @@ def test_bad_seed():
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
):
space.seed(0.0)
def test_oneof_sample():
"""Tests the sample method with and without masks or probabilities."""
space = gym.spaces.Tuple([Discrete(2), Box(-1, 1, shape=(2,))])
# Unmasked sampling
sample = space.sample()
assert isinstance(sample, tuple)
assert len(sample) == 2
assert space.spaces[0].contains(sample[0])
assert space.spaces[1].contains(sample[1])
# Masked sampling
mask = (np.array([1, 0], dtype=np.int8), None)
sample = space.sample(mask=mask)
assert space.spaces[0].contains(sample[0])
assert space.spaces[1].contains(sample[1])
assert sample[0] == 0
# Probability sampling
probability = (np.array([0.8, 0.2], dtype=np.float64), None)
samples_discrete = np.array(
[space.sample(probability=probability)[0] for _ in range(1000)]
)
counts = np.bincount(samples_discrete, minlength=2) / len(samples_discrete)
np.testing.assert_allclose(counts, probability[0], atol=0.05)
def test_invalid_sample_inputs():
"""Tests that invalid inputs to sample raise appropriate errors."""
space = gym.spaces.Tuple([Discrete(2), Box(-1, 1, shape=(2,))])
# Providing both mask and probability
with pytest.raises(
ValueError, match="Only one of `mask` or `probability` can be provided."
):
space.sample(mask=(None, None), probability=(0.5, 0.5))
# Invalid mask type
with pytest.raises(
AssertionError,
match="Expected type of `mask` to be tuple, actual type: <class 'dict'>",
):
space.sample(mask={"low": 0, "high": 1})
# Invalid mask length
with pytest.raises(
AssertionError, match="Expected length of `mask` to be 2, actual length: 1"
):
space.sample(mask=(None,))
# Invalid probability length
with pytest.raises(
AssertionError,
match="Expected length of `probability` to be 2, actual length: 1",
):
space.sample(probability=(0.5,))
# Invalid probability type
with pytest.raises(
AssertionError,
match="Expected type of `probability` to be tuple, actual type: <class 'list'>",
):
space.sample(probability=[0.5, 0.5])