Files
Gymnasium/gym/spaces/text.py

123 lines
4.6 KiB
Python
Raw Normal View History

"""Implementation of a space that represents textual strings."""
from typing import Any, FrozenSet, List, Optional, Set, Tuple, Union
import numpy as np
from gym.spaces.space import Space
alphanumeric: FrozenSet[str] = frozenset(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
class Text(Space[str]):
r"""A space representing a string comprised of characters from a given charset.
Example::
>>> # {"", "B5", "hello", ...}
>>> Text(5)
>>> # {"0", "42", "0123456789", ...}
>>> import string
>>> Text(min_length = 1,
... max_length = 10,
... charset = string.digits)
"""
def __init__(
self,
max_length: int,
*,
min_length: int = 0,
charset: Union[Set[str], str] = alphanumeric,
seed: Optional[Union[int, np.random.Generator]] = None,
):
r"""Constructor of :class:`Text` space.
Both bounds for text length are inclusive.
Args:
min_length (int): Minimum text length (in characters).
max_length (int): Maximum text length (in characters).
charset (Union[set, SupportsIndex]): Character set, defaults to the lower and upper english alphabet plus latin digits.
seed: The seed for sampling from the space.
"""
assert np.issubdtype(
type(min_length), np.integer
), f"Expects the min_length to be an integer, actual type: {type(min_length)}"
assert np.issubdtype(
type(max_length), np.integer
), f"Expects the max_length to be an integer, actual type: {type(max_length)}"
assert (
0 <= min_length
), f"Minimum text length must be non-negative, actual value: {min_length}"
assert (
min_length <= max_length
), f"The min_length must be less than or equal to the max_length, min_length: {min_length}, max_length: {max_length}"
self.min_length: int = int(min_length)
self.max_length: int = int(max_length)
self.charset: FrozenSet[str] = frozenset(charset)
self._charlist: List[str] = list(charset)
self._charset_str: str = "".join(sorted(self._charlist))
# As the shape is dynamic (between min_length and max_length) then None
super().__init__(dtype=str, seed=seed)
def sample(
self, mask: Optional[Tuple[Optional[int], Optional[np.ndarray]]] = None
) -> str:
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
Args:
mask: An optional tuples of length and mask for the text.
The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected.
For the mask, we expect a numpy array of length of the charset passed with dtype == np.int8
Returns:
A sampled string from the space
"""
if mask is not None:
length, charlist_mask = mask
if length is not None:
assert self.min_length <= length <= self.max_length
if charlist_mask is not None:
assert isinstance(charlist_mask, np.ndarray)
assert charlist_mask.dtype is np.int8
assert charlist_mask.shape == (len(self._charlist),)
else:
length, charlist_mask = None, None
if length is None:
length = self.np_random.integers(self.min_length, self.max_length + 1)
if charlist_mask is None:
string = self.np_random.choice(self._charlist, size=length)
else:
masked_charlist = self._charlist[np.where(mask)[0]]
string = self.np_random.choice(masked_charlist, size=length)
return "".join(string)
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
if isinstance(x, str):
if self.min_length <= len(x) <= self.max_length:
return all(c in self.charset for c in x)
return False
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return (
f"Text({self.min_length}, {self.max_length}, charset={self._charset_str})"
)
def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return (
isinstance(other, Text)
and self.min_length == other.min_length
and self.max_length == other.max_length
and self.charset == other.charset
)