mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-07 08:21:48 +00:00
124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
![]() |
"""Implementation of a space that represents textual strings."""
|
||
|
from typing import Any, FrozenSet, List, Optional, Set, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from gym.spaces.space import Space
|
||
|
from gym.utils import seeding
|
||
|
|
||
|
alphanumeric: FrozenSet[str] = frozenset(
|
||
|
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||
|
)
|
||
|
|
||
|
|
||
|
class Text(Space[str]):
|
||
|
r"""A space representing a string comprised of characters from a given charset.
|
||
|
|
||
|
Example::
|
||
|
>>> # {"", "B5", "hello", ...}
|
||
|
>>> Text(5)
|
||
|
>>> # {"0", "42", "0123456789", ...}
|
||
|
>>> import string
|
||
|
>>> Text(min_length = 1,
|
||
|
... max_length = 10,
|
||
|
... charset = string.digits)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
max_length: int,
|
||
|
*,
|
||
|
min_length: int = 0,
|
||
|
charset: Union[Set[str], str] = alphanumeric,
|
||
|
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
|
||
|
):
|
||
|
r"""Constructor of :class:`Text` space.
|
||
|
|
||
|
Both bounds for text length are inclusive.
|
||
|
|
||
|
Args:
|
||
|
min_length (int): Minimum text length (in characters).
|
||
|
max_length (int): Maximum text length (in characters).
|
||
|
charset (Union[set, SupportsIndex]): Character set, defaults to the lower and upper english alphabet plus latin digits.
|
||
|
seed: The seed for sampling from the space.
|
||
|
"""
|
||
|
assert np.issubdtype(
|
||
|
type(min_length), np.integer
|
||
|
), f"Expects the min_length to be an integer, actual type: {type(min_length)}"
|
||
|
assert np.issubdtype(
|
||
|
type(max_length), np.integer
|
||
|
), f"Expects the max_length to be an integer, actual type: {type(max_length)}"
|
||
|
assert (
|
||
|
0 <= min_length
|
||
|
), f"Minimum text length must be non-negative, actual value: {min_length}"
|
||
|
assert (
|
||
|
min_length <= max_length
|
||
|
), f"The min_length must be less than or equal to the max_length, min_length: {min_length}, max_length: {max_length}"
|
||
|
|
||
|
self.min_length: int = int(min_length)
|
||
|
self.max_length: int = int(max_length)
|
||
|
self.charset: FrozenSet[str] = frozenset(charset)
|
||
|
self._charlist: List[str] = list(charset)
|
||
|
self._charset_str: str = "".join(sorted(self._charlist))
|
||
|
|
||
|
# As the shape is dynamic (between min_length and max_length) then None
|
||
|
super().__init__(dtype=str, seed=seed)
|
||
|
|
||
|
def sample(
|
||
|
self, mask: Optional[Tuple[Optional[int], Optional[np.ndarray]]] = None
|
||
|
) -> str:
|
||
|
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
||
|
|
||
|
Args:
|
||
|
mask: An optional tuples of length and mask for the text.
|
||
|
The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected.
|
||
|
For the mask, we expect a numpy array of length of the charset passed with dtype == np.int8
|
||
|
|
||
|
Returns:
|
||
|
A sampled string from the space
|
||
|
"""
|
||
|
if mask is not None:
|
||
|
length, charlist_mask = mask
|
||
|
if length is not None:
|
||
|
assert self.min_length <= length <= self.max_length
|
||
|
|
||
|
if charlist_mask is not None:
|
||
|
assert isinstance(charlist_mask, np.ndarray)
|
||
|
assert charlist_mask.dtype is np.int8
|
||
|
assert charlist_mask.shape == (len(self._charlist),)
|
||
|
else:
|
||
|
length, charlist_mask = None, None
|
||
|
|
||
|
if length is None:
|
||
|
length = self.np_random.randint(self.min_length, self.max_length + 1)
|
||
|
|
||
|
if charlist_mask is None:
|
||
|
string = self.np_random.choice(self._charlist, size=length)
|
||
|
else:
|
||
|
masked_charlist = self._charlist[np.where(mask)[0]]
|
||
|
string = self.np_random.choice(masked_charlist, size=length)
|
||
|
|
||
|
return "".join(string)
|
||
|
|
||
|
def contains(self, x: Any) -> bool:
|
||
|
"""Return boolean specifying if x is a valid member of this space."""
|
||
|
if isinstance(x, str):
|
||
|
if self.min_length <= len(x) <= self.max_length:
|
||
|
return all(c in self.charset for c in x)
|
||
|
return False
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
"""Gives a string representation of this space."""
|
||
|
return (
|
||
|
f"Text({self.min_length}, {self.max_length}, charset={self._charset_str})"
|
||
|
)
|
||
|
|
||
|
def __eq__(self, other) -> bool:
|
||
|
"""Check whether ``other`` is equivalent to this instance."""
|
||
|
return (
|
||
|
isinstance(other, Text)
|
||
|
and self.min_length == other.min_length
|
||
|
and self.max_length == other.max_length
|
||
|
and self.charset == other.charset
|
||
|
)
|