2022-07-11 16:39:04 +01:00
""" Implementation of a space that represents textual strings. """
from typing import Any , FrozenSet , List , Optional , Set , Tuple , Union
import numpy as np
from gym . spaces . space import Space
alphanumeric : FrozenSet [ str ] = frozenset (
" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "
)
class Text ( Space [ str ] ) :
r """ A space representing a string comprised of characters from a given charset.
Example : :
>> > # {"", "B5", "hello", ...}
>> > Text ( 5 )
>> > # {"0", "42", "0123456789", ...}
>> > import string
>> > Text ( min_length = 1 ,
. . . max_length = 10 ,
. . . charset = string . digits )
"""
def __init__ (
self ,
max_length : int ,
* ,
min_length : int = 0 ,
charset : Union [ Set [ str ] , str ] = alphanumeric ,
2022-08-22 09:20:28 -04:00
seed : Optional [ Union [ int , np . random . Generator ] ] = None ,
2022-07-11 16:39:04 +01:00
) :
r """ Constructor of :class:`Text` space.
Both bounds for text length are inclusive .
Args :
min_length ( int ) : Minimum text length ( in characters ) .
max_length ( int ) : Maximum text length ( in characters ) .
charset ( Union [ set , SupportsIndex ] ) : Character set , defaults to the lower and upper english alphabet plus latin digits .
seed : The seed for sampling from the space .
"""
assert np . issubdtype (
type ( min_length ) , np . integer
) , f " Expects the min_length to be an integer, actual type: { type ( min_length ) } "
assert np . issubdtype (
type ( max_length ) , np . integer
) , f " Expects the max_length to be an integer, actual type: { type ( max_length ) } "
assert (
0 < = min_length
) , f " Minimum text length must be non-negative, actual value: { min_length } "
assert (
min_length < = max_length
) , f " The min_length must be less than or equal to the max_length, min_length: { min_length } , max_length: { max_length } "
self . min_length : int = int ( min_length )
self . max_length : int = int ( max_length )
self . charset : FrozenSet [ str ] = frozenset ( charset )
self . _charlist : List [ str ] = list ( charset )
self . _charset_str : str = " " . join ( sorted ( self . _charlist ) )
# As the shape is dynamic (between min_length and max_length) then None
super ( ) . __init__ ( dtype = str , seed = seed )
def sample (
self , mask : Optional [ Tuple [ Optional [ int ] , Optional [ np . ndarray ] ] ] = None
) - > str :
""" Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
Args :
mask : An optional tuples of length and mask for the text .
The length is expected to be between the ` min_length ` and ` max_length ` otherwise a random integer between ` min_length ` and ` max_length ` is selected .
For the mask , we expect a numpy array of length of the charset passed with dtype == np . int8
Returns :
A sampled string from the space
"""
if mask is not None :
length , charlist_mask = mask
if length is not None :
assert self . min_length < = length < = self . max_length
if charlist_mask is not None :
assert isinstance ( charlist_mask , np . ndarray )
assert charlist_mask . dtype is np . int8
assert charlist_mask . shape == ( len ( self . _charlist ) , )
else :
length , charlist_mask = None , None
if length is None :
2022-08-22 09:20:28 -04:00
length = self . np_random . integers ( self . min_length , self . max_length + 1 )
2022-07-11 16:39:04 +01:00
if charlist_mask is None :
string = self . np_random . choice ( self . _charlist , size = length )
else :
masked_charlist = self . _charlist [ np . where ( mask ) [ 0 ] ]
string = self . np_random . choice ( masked_charlist , size = length )
return " " . join ( string )
def contains ( self , x : Any ) - > bool :
""" Return boolean specifying if x is a valid member of this space. """
if isinstance ( x , str ) :
if self . min_length < = len ( x ) < = self . max_length :
return all ( c in self . charset for c in x )
return False
def __repr__ ( self ) - > str :
""" Gives a string representation of this space. """
return (
f " Text( { self . min_length } , { self . max_length } , charset= { self . _charset_str } ) "
)
def __eq__ ( self , other ) - > bool :
""" Check whether ``other`` is equivalent to this instance. """
return (
isinstance ( other , Text )
and self . min_length == other . min_length
and self . max_length == other . max_length
and self . charset == other . charset
)