2022-07-11 16:39:04 +01:00
""" Implementation of a space that represents textual strings. """
2024-06-10 17:07:47 +01:00
2022-11-15 14:09:22 +00:00
from __future__ import annotations
from typing import Any
2022-07-11 16:39:04 +01:00
import numpy as np
2023-02-13 18:18:40 +01:00
from numpy . typing import NDArray
2022-07-11 16:39:04 +01:00
2022-09-08 10:10:07 +01:00
from gymnasium . spaces . space import Space
2022-07-11 16:39:04 +01:00
2022-12-04 22:24:02 +08:00
2022-11-15 14:09:22 +00:00
alphanumeric : frozenset [ str ] = frozenset (
2022-07-11 16:39:04 +01:00
" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "
)
class Text ( Space [ str ] ) :
r """ A space representing a string comprised of characters from a given charset.
2023-01-23 11:30:00 +01:00
Example :
>> > from gymnasium . spaces import Text
2022-07-11 16:39:04 +01:00
>> > # {"", "B5", "hello", ...}
>> > Text ( 5 )
2023-05-08 10:52:32 +01:00
Text ( 1 , 5 , charset = 0123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz )
2022-07-11 16:39:04 +01:00
>> > # {"0", "42", "0123456789", ...}
>> > import string
>> > Text ( min_length = 1 ,
. . . max_length = 10 ,
. . . charset = string . digits )
2023-05-08 10:52:32 +01:00
Text ( 1 , 10 , charset = 0123456789 )
2022-07-11 16:39:04 +01:00
"""
def __init__ (
self ,
max_length : int ,
* ,
2022-09-03 22:56:29 +01:00
min_length : int = 1 ,
2023-02-13 18:18:40 +01:00
charset : frozenset [ str ] | str = alphanumeric ,
2022-11-15 14:09:22 +00:00
seed : int | np . random . Generator | None = None ,
2022-07-11 16:39:04 +01:00
) :
r """ Constructor of :class:`Text` space.
Both bounds for text length are inclusive .
Args :
2022-09-03 22:56:29 +01:00
min_length ( int ) : Minimum text length ( in characters ) . Defaults to 1 to prevent empty strings .
2022-07-11 16:39:04 +01:00
max_length ( int ) : Maximum text length ( in characters ) .
2022-09-03 22:56:29 +01:00
charset ( Union [ set ] , str ) : Character set , defaults to the lower and upper english alphabet plus latin digits .
2022-07-11 16:39:04 +01:00
seed : The seed for sampling from the space .
"""
assert np . issubdtype (
type ( min_length ) , np . integer
) , f " Expects the min_length to be an integer, actual type: { type ( min_length ) } "
assert np . issubdtype (
type ( max_length ) , np . integer
) , f " Expects the max_length to be an integer, actual type: { type ( max_length ) } "
assert (
0 < = min_length
) , f " Minimum text length must be non-negative, actual value: { min_length } "
assert (
min_length < = max_length
) , f " The min_length must be less than or equal to the max_length, min_length: { min_length } , max_length: { max_length } "
self . min_length : int = int ( min_length )
self . max_length : int = int ( max_length )
2022-09-03 22:56:29 +01:00
2022-11-15 14:09:22 +00:00
self . _char_set : frozenset [ str ] = frozenset ( charset )
self . _char_list : tuple [ str , . . . ] = tuple ( charset )
self . _char_index : dict [ str , np . int32 ] = {
2022-09-03 22:56:29 +01:00
val : np . int32 ( i ) for i , val in enumerate ( tuple ( charset ) )
}
self . _char_str : str = " " . join ( sorted ( tuple ( charset ) ) )
2022-07-11 16:39:04 +01:00
# As the shape is dynamic (between min_length and max_length) then None
super ( ) . __init__ ( dtype = str , seed = seed )
def sample (
2022-11-15 14:09:22 +00:00
self ,
2023-02-13 18:18:40 +01:00
mask : None | ( tuple [ int | None , NDArray [ np . int8 ] | None ] ) = None ,
2025-02-21 13:39:23 +00:00
probability : None | ( tuple [ int | None , NDArray [ np . float64 ] | None ] ) = None ,
2022-07-11 16:39:04 +01:00
) - > str :
2023-11-07 13:27:25 +00:00
""" Generates a single random sample from this space with by default a random length between ``min_length`` and ``max_length`` and sampled from the ``charset``.
2022-07-11 16:39:04 +01:00
Args :
mask : An optional tuples of length and mask for the text .
2025-02-21 13:39:23 +00:00
The length is expected to be between the ` ` min_length ` ` and ` ` max_length ` ` .
Otherwise , a random integer between ` ` min_length ` ` and ` ` max_length ` ` is selected .
2023-11-07 13:27:25 +00:00
For the mask , we expect a numpy array of length of the charset passed with ` ` dtype == np . int8 ` ` .
If the charlist mask is all zero then an empty string is returned no matter the ` ` min_length ` `
2025-02-21 13:39:23 +00:00
probability : An optional tuples of length and probability mask for the text .
The length is expected to be between the ` ` min_length ` ` and ` ` max_length ` ` .
Otherwise , a random integer between ` ` min_length ` ` and ` ` max_length ` ` is selected .
For the probability mask , we expect a numpy array of length of the charset passed with ` ` dtype == np . float64 ` ` .
The sum of the probability mask should be 1 , otherwise an exception is raised .
2022-07-11 16:39:04 +01:00
Returns :
A sampled string from the space
"""
2025-02-21 13:39:23 +00:00
if mask is not None and probability is not None :
raise ValueError (
f " Only one of `mask` or `probability` can be provided, actual values: mask= { mask } , probability= { probability } "
)
elif mask is not None :
length , charlist_mask = self . _validate_mask ( mask , np . int8 , " mask " )
2022-07-11 16:39:04 +01:00
if charlist_mask is not None :
2022-09-03 22:56:29 +01:00
assert np . all (
np . logical_or ( charlist_mask == 0 , charlist_mask == 1 )
2025-02-21 13:39:23 +00:00
) , f " Expects all mask values to 0 or 1, actual values: { charlist_mask } "
# normalise the mask to use as a probability
if np . sum ( charlist_mask ) > 0 :
charlist_mask = charlist_mask / np . sum ( charlist_mask )
elif probability is not None :
length , charlist_mask = self . _validate_mask (
probability , np . float64 , " probability "
)
if charlist_mask is not None :
assert np . all (
np . logical_and ( charlist_mask > = 0 , charlist_mask < = 1 )
) , f " Expects all probability mask values to be within 0 and 1, actual values: { charlist_mask } "
assert np . isclose (
np . sum ( charlist_mask ) , 1
) , f " Expects the sum of the probability mask to be 1, actual sum: { np . sum ( charlist_mask ) } "
2022-07-11 16:39:04 +01:00
else :
2025-02-21 13:39:23 +00:00
length = charlist_mask = None
2022-07-11 16:39:04 +01:00
if length is None :
2022-08-22 09:20:28 -04:00
length = self . np_random . integers ( self . min_length , self . max_length + 1 )
2025-02-21 13:39:23 +00:00
if charlist_mask is None : # uniform sampling
charlist_mask = np . ones ( len ( self . character_set ) ) / len ( self . character_set )
2022-07-11 16:39:04 +01:00
2025-02-21 13:39:23 +00:00
if np . all ( charlist_mask == 0 ) :
if self . min_length == 0 :
return " "
2022-09-03 22:56:29 +01:00
else :
2025-02-21 13:39:23 +00:00
# Otherwise the string will not be contained in the space
raise ValueError (
f " Trying to sample with a minimum length > 0 (actual minimum length= { self . min_length } ) but the character mask is all zero meaning that no character could be sampled. "
2022-09-03 22:56:29 +01:00
)
2022-07-11 16:39:04 +01:00
2025-02-21 13:39:23 +00:00
string = self . np_random . choice (
self . character_list , size = length , p = charlist_mask
)
2022-07-11 16:39:04 +01:00
return " " . join ( string )
2025-02-21 13:39:23 +00:00
def _validate_mask (
self ,
mask : tuple [ int | None , NDArray [ np . int8 ] | NDArray [ np . float64 ] | None ] ,
expected_dtype : np . dtype ,
mask_type : str ,
) - > tuple [ int | None , NDArray [ np . int8 ] | NDArray [ np . float64 ] | None ] :
assert isinstance (
mask , tuple
) , f " Expects the ` { mask_type } ` type to be a tuple, actual type: { type ( mask ) } "
assert (
len ( mask ) == 2
) , f " Expects the ` { mask_type } ` length to be two, actual length: { len ( mask ) } "
length , charlist_mask = mask
if length is not None :
assert np . issubdtype (
type ( length ) , np . integer
) , f " Expects the Text sample length to be an integer, actual type: { type ( length ) } "
assert (
self . min_length < = length < = self . max_length
) , f " Expects the Text sample length be between { self . min_length } and { self . max_length } , actual length: { length } "
if charlist_mask is not None :
assert isinstance (
charlist_mask , np . ndarray
) , f " Expects the Text sample ` { mask_type } ` to be an np.ndarray, actual type: { type ( charlist_mask ) } "
assert (
charlist_mask . dtype == expected_dtype
) , f " Expects the Text sample ` { mask_type } ` to be type { expected_dtype } , actual dtype: { charlist_mask . dtype } "
assert charlist_mask . shape == (
len ( self . character_set ) ,
) , f " expects the Text sample ` { mask_type } ` to be { ( len ( self . character_set ) , ) } , actual shape: { charlist_mask . shape } "
return length , charlist_mask
2022-07-11 16:39:04 +01:00
def contains ( self , x : Any ) - > bool :
""" Return boolean specifying if x is a valid member of this space. """
if isinstance ( x , str ) :
if self . min_length < = len ( x ) < = self . max_length :
2022-09-03 22:56:29 +01:00
return all ( c in self . character_set for c in x )
2022-07-11 16:39:04 +01:00
return False
def __repr__ ( self ) - > str :
""" Gives a string representation of this space. """
2023-05-08 10:52:32 +01:00
return f " Text( { self . min_length } , { self . max_length } , charset= { self . characters } ) "
2022-07-11 16:39:04 +01:00
2022-11-15 14:09:22 +00:00
def __eq__ ( self , other : Any ) - > bool :
2022-07-11 16:39:04 +01:00
""" Check whether ``other`` is equivalent to this instance. """
return (
isinstance ( other , Text )
and self . min_length == other . min_length
and self . max_length == other . max_length
2022-09-03 22:56:29 +01:00
and self . character_set == other . character_set
2022-07-11 16:39:04 +01:00
)
2022-09-03 22:56:29 +01:00
@property
2022-11-15 14:09:22 +00:00
def character_set ( self ) - > frozenset [ str ] :
2022-09-03 22:56:29 +01:00
""" Returns the character set for the space. """
return self . _char_set
@property
2022-11-15 14:09:22 +00:00
def character_list ( self ) - > tuple [ str , . . . ] :
2022-09-03 22:56:29 +01:00
""" Returns a tuple of characters in the space. """
return self . _char_list
def character_index ( self , char : str ) - > np . int32 :
""" Returns a unique index for each character in the space ' s character set. """
return self . _char_index [ char ]
@property
def characters ( self ) - > str :
""" Returns a string with all Text characters. """
return self . _char_str
@property
def is_np_flattenable ( self ) - > bool :
""" The flattened version is an integer array for each character, padded to the max character length. """
return True