2022-07-11 16:39:04 +01:00
""" Implementation of a space that represents textual strings. """
2022-11-15 14:09:22 +00:00
from __future__ import annotations
from typing import Any
2022-07-11 16:39:04 +01:00
import numpy as np
2022-11-15 14:09:22 +00:00
import numpy . typing as npt
2022-07-11 16:39:04 +01:00
2022-09-08 10:10:07 +01:00
from gymnasium . spaces . space import Space
2022-07-11 16:39:04 +01:00
2022-12-04 22:24:02 +08:00
2022-11-15 14:09:22 +00:00
alphanumeric : frozenset [ str ] = frozenset (
2022-07-11 16:39:04 +01:00
" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "
)
class Text ( Space [ str ] ) :
r """ A space representing a string comprised of characters from a given charset.
2023-01-23 11:30:00 +01:00
Example :
>> > from gymnasium . spaces import Text
2022-07-11 16:39:04 +01:00
>> > # {"", "B5", "hello", ...}
>> > Text ( 5 )
2023-01-20 14:28:09 +01:00
Text ( 1 , 5 , characters = 0123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz )
2022-07-11 16:39:04 +01:00
>> > # {"0", "42", "0123456789", ...}
>> > import string
>> > Text ( min_length = 1 ,
. . . max_length = 10 ,
. . . charset = string . digits )
2023-01-20 14:28:09 +01:00
Text ( 1 , 10 , characters = 0123456789 )
2022-07-11 16:39:04 +01:00
"""
def __init__ (
self ,
max_length : int ,
* ,
2022-09-03 22:56:29 +01:00
min_length : int = 1 ,
2022-11-15 14:09:22 +00:00
charset : set [ str ] | str = alphanumeric ,
seed : int | np . random . Generator | None = None ,
2022-07-11 16:39:04 +01:00
) :
r """ Constructor of :class:`Text` space.
Both bounds for text length are inclusive .
Args :
2022-09-03 22:56:29 +01:00
min_length ( int ) : Minimum text length ( in characters ) . Defaults to 1 to prevent empty strings .
2022-07-11 16:39:04 +01:00
max_length ( int ) : Maximum text length ( in characters ) .
2022-09-03 22:56:29 +01:00
charset ( Union [ set ] , str ) : Character set , defaults to the lower and upper english alphabet plus latin digits .
2022-07-11 16:39:04 +01:00
seed : The seed for sampling from the space .
"""
assert np . issubdtype (
type ( min_length ) , np . integer
) , f " Expects the min_length to be an integer, actual type: { type ( min_length ) } "
assert np . issubdtype (
type ( max_length ) , np . integer
) , f " Expects the max_length to be an integer, actual type: { type ( max_length ) } "
assert (
0 < = min_length
) , f " Minimum text length must be non-negative, actual value: { min_length } "
assert (
min_length < = max_length
) , f " The min_length must be less than or equal to the max_length, min_length: { min_length } , max_length: { max_length } "
self . min_length : int = int ( min_length )
self . max_length : int = int ( max_length )
2022-09-03 22:56:29 +01:00
2022-11-15 14:09:22 +00:00
self . _char_set : frozenset [ str ] = frozenset ( charset )
self . _char_list : tuple [ str , . . . ] = tuple ( charset )
self . _char_index : dict [ str , np . int32 ] = {
2022-09-03 22:56:29 +01:00
val : np . int32 ( i ) for i , val in enumerate ( tuple ( charset ) )
}
self . _char_str : str = " " . join ( sorted ( tuple ( charset ) ) )
2022-07-11 16:39:04 +01:00
# As the shape is dynamic (between min_length and max_length) then None
super ( ) . __init__ ( dtype = str , seed = seed )
def sample (
2022-11-15 14:09:22 +00:00
self ,
mask : None | ( tuple [ int | None , npt . NDArray [ np . int8 ] | None ] ) = None ,
2022-07-11 16:39:04 +01:00
) - > str :
""" Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
Args :
mask : An optional tuples of length and mask for the text .
The length is expected to be between the ` min_length ` and ` max_length ` otherwise a random integer between ` min_length ` and ` max_length ` is selected .
2022-09-03 22:56:29 +01:00
For the mask , we expect a numpy array of length of the charset passed with ` dtype == np . int8 ` .
If the charlist mask is all zero then an empty string is returned no matter the ` min_length `
2022-07-11 16:39:04 +01:00
Returns :
A sampled string from the space
"""
if mask is not None :
2022-09-03 22:56:29 +01:00
assert isinstance (
mask , tuple
) , f " Expects the mask type to be a tuple, actual type: { type ( mask ) } "
assert (
len ( mask ) == 2
) , f " Expects the mask length to be two, actual length: { len ( mask ) } "
2022-07-11 16:39:04 +01:00
length , charlist_mask = mask
2022-09-03 22:56:29 +01:00
2022-07-11 16:39:04 +01:00
if length is not None :
2022-09-03 22:56:29 +01:00
assert np . issubdtype (
type ( length ) , np . integer
) , f " Expects the Text sample length to be an integer, actual type: { type ( length ) } "
assert (
self . min_length < = length < = self . max_length
) , f " Expects the Text sample length be between { self . min_length } and { self . max_length } , actual length: { length } "
2022-07-11 16:39:04 +01:00
if charlist_mask is not None :
2022-09-03 22:56:29 +01:00
assert isinstance (
charlist_mask , np . ndarray
) , f " Expects the Text sample mask to be an np.ndarray, actual type: { type ( charlist_mask ) } "
assert (
charlist_mask . dtype == np . int8
) , f " Expects the Text sample mask to be an np.ndarray, actual dtype: { charlist_mask . dtype } "
assert charlist_mask . shape == (
len ( self . character_set ) ,
) , f " expects the Text sample mask to be { ( len ( self . character_set ) , ) } , actual shape: { charlist_mask . shape } "
assert np . all (
np . logical_or ( charlist_mask == 0 , charlist_mask == 1 )
) , f " Expects all masks values to 0 or 1, actual values: { charlist_mask } "
2022-07-11 16:39:04 +01:00
else :
length , charlist_mask = None , None
if length is None :
2022-08-22 09:20:28 -04:00
length = self . np_random . integers ( self . min_length , self . max_length + 1 )
2022-07-11 16:39:04 +01:00
if charlist_mask is None :
2022-09-03 22:56:29 +01:00
string = self . np_random . choice ( self . character_list , size = length )
2022-07-11 16:39:04 +01:00
else :
2022-09-03 22:56:29 +01:00
valid_mask = charlist_mask == 1
valid_indexes = np . where ( valid_mask ) [ 0 ]
if len ( valid_indexes ) == 0 :
if self . min_length == 0 :
string = " "
else :
# Otherwise the string will not be contained in the space
raise ValueError (
f " Trying to sample with a minimum length > 0 ( { self . min_length } ) but the character mask is all zero meaning that no character could be sampled. "
)
else :
string = " " . join (
self . character_list [ index ]
for index in self . np_random . choice ( valid_indexes , size = length )
)
2022-07-11 16:39:04 +01:00
return " " . join ( string )
def contains ( self , x : Any ) - > bool :
""" Return boolean specifying if x is a valid member of this space. """
if isinstance ( x , str ) :
if self . min_length < = len ( x ) < = self . max_length :
2022-09-03 22:56:29 +01:00
return all ( c in self . character_set for c in x )
2022-07-11 16:39:04 +01:00
return False
def __repr__ ( self ) - > str :
""" Gives a string representation of this space. """
return (
2022-09-03 22:56:29 +01:00
f " Text( { self . min_length } , { self . max_length } , characters= { self . characters } ) "
2022-07-11 16:39:04 +01:00
)
2022-11-15 14:09:22 +00:00
def __eq__ ( self , other : Any ) - > bool :
2022-07-11 16:39:04 +01:00
""" Check whether ``other`` is equivalent to this instance. """
return (
isinstance ( other , Text )
and self . min_length == other . min_length
and self . max_length == other . max_length
2022-09-03 22:56:29 +01:00
and self . character_set == other . character_set
2022-07-11 16:39:04 +01:00
)
2022-09-03 22:56:29 +01:00
@property
2022-11-15 14:09:22 +00:00
def character_set ( self ) - > frozenset [ str ] :
2022-09-03 22:56:29 +01:00
""" Returns the character set for the space. """
return self . _char_set
@property
2022-11-15 14:09:22 +00:00
def character_list ( self ) - > tuple [ str , . . . ] :
2022-09-03 22:56:29 +01:00
""" Returns a tuple of characters in the space. """
return self . _char_list
def character_index ( self , char : str ) - > np . int32 :
""" Returns a unique index for each character in the space ' s character set. """
return self . _char_index [ char ]
@property
def characters ( self ) - > str :
""" Returns a string with all Text characters. """
return self . _char_str
@property
def is_np_flattenable ( self ) - > bool :
""" The flattened version is an integer array for each character, padded to the max character length. """
return True