2022-12-05 19:14:56 +00:00
""" A collection of observation wrappers using a lambda function.
* ` ` LambdaObservation ` ` - Transforms the observation with a function
* ` ` FilterObservation ` ` - Filters a ` ` Tuple ` ` or ` ` Dict ` ` to only include certain keys
* ` ` FlattenObservation ` ` - Flattens the observations
* ` ` GrayscaleObservation ` ` - Converts a RGB observation to a grayscale observation
* ` ` ResizeObservation ` ` - Resizes an array - based observation ( normally a RGB observation )
* ` ` ReshapeObservation ` ` - Reshapes an array - based observation
* ` ` RescaleObservation ` ` - Rescales an observation to between a minimum and maximum value
* ` ` DtypeObservation ` ` - Convert a observation dtype
"""
2022-12-03 19:45:39 +00:00
from __future__ import annotations
2022-11-20 00:57:10 +01:00
2022-12-03 19:45:39 +00:00
from typing import Any , Callable , Sequence
2022-12-05 19:14:56 +00:00
from typing_extensions import Final
2022-12-03 19:45:39 +00:00
import jumpy as jp
import numpy as np
2022-11-20 00:57:10 +01:00
import gymnasium as gym
2022-12-03 19:45:39 +00:00
from gymnasium import spaces
2022-11-20 00:57:10 +01:00
from gymnasium . core import ObsType
2022-12-03 19:45:39 +00:00
from gymnasium . error import DependencyNotInstalled
2022-12-05 19:14:56 +00:00
from gymnasium . spaces import Box , utils
2022-11-20 00:57:10 +01:00
2022-11-29 23:37:53 +00:00
class LambdaObservationV0 ( gym . ObservationWrapper ) :
2022-12-03 19:45:39 +00:00
""" Transforms an observation via a function provided to the wrapper.
The function : attr : ` func ` will be applied to all observations .
If the observations from : attr : ` func ` are outside the bounds of the ` env ` spaces , provide a : attr : ` observation_space ` .
Example :
>> > import gymnasium as gym
>> > import numpy as np
>> > env = gym . make ( ' CartPole-v1 ' )
>> > env = LambdaObservationV0 ( env , lambda obs : obs + 0.1 * np . random . random ( obs . shape ) )
>> > env . reset ( )
array ( [ - 0.08319338 , 0.04635121 , - 0.07394746 , 0.20877492 ] )
"""
2022-11-20 00:57:10 +01:00
def __init__ (
self ,
env : gym . Env ,
2022-12-03 19:45:39 +00:00
func : Callable [ [ ObsType ] , Any ] ,
observation_space : gym . Space | None ,
2022-11-20 00:57:10 +01:00
) :
""" Constructor for the lambda observation wrapper.
Args :
env : The environment to wrap
2022-12-03 19:45:39 +00:00
func : A function that will transform an observation . If this transformed observation is outside the observation space of ` env . observation_space ` then provide an ` observation_space ` .
observation_space : The observation spaces of the wrapper , if None , then it is assumed the same as ` env . observation_space ` .
2022-11-20 00:57:10 +01:00
"""
super ( ) . __init__ ( env )
2022-12-03 19:45:39 +00:00
if observation_space is not None :
self . observation_space = observation_space
2022-11-20 00:57:10 +01:00
self . func = func
def observation ( self , observation : ObsType ) - > Any :
""" Apply function to the observation. """
return self . func ( observation )
2022-12-03 19:45:39 +00:00
class FilterObservationV0 ( LambdaObservationV0 ) :
""" Filter Dict observation space by the keys.
Example :
>> > import gymnasium as gym
>> > env = gym . wrappers . TransformObservation (
. . . gym . make ( ' CartPole-v1 ' ) , lambda obs : { ' obs ' : obs , ' time ' : 0 }
. . . )
>> > env . observation_space = gym . spaces . Dict ( obs = env . observation_space , time = gym . spaces . Discrete ( 1 ) )
>> > env . reset ( )
{ ' obs ' : array ( [ - 0.00067088 , - 0.01860439 , 0.04772898 , - 0.01911527 ] , dtype = float32 ) , ' time ' : 0 }
>> > env = FilterObservationV0 ( env , filter_keys = [ ' time ' ] )
>> > env . reset ( )
{ ' obs ' : array ( [ 0.04560107 , 0.04466959 , - 0.0328232 , - 0.02367178 ] , dtype = float32 ) }
>> > env . step ( 0 )
( { ' obs ' : array ( [ 0.04649447 , - 0.14996664 , - 0.03329664 , 0.25847703 ] , dtype = float32 ) } , 1.0 , False , { } )
"""
2022-12-05 19:14:56 +00:00
def __init__ ( self , env : gym . Env , filter_keys : Sequence [ str | int ] ) :
2022-12-03 19:45:39 +00:00
""" Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys. """
2022-12-05 19:14:56 +00:00
assert isinstance ( filter_keys , Sequence )
# Filters for dictionary space
if isinstance ( env . observation_space , spaces . Dict ) :
assert all ( isinstance ( key , str ) for key in filter_keys )
if any (
key not in env . observation_space . spaces . keys ( ) for key in filter_keys
) :
missing_keys = [
key
for key in filter_keys
if key not in env . observation_space . spaces . keys ( )
]
raise ValueError (
" All the `filter_keys` must be included in the observation space. \n "
f " Filter keys: { filter_keys } \n "
f " Observation keys: { list ( env . observation_space . spaces . keys ( ) ) } \n "
f " Missing keys: { missing_keys } "
)
new_observation_space = spaces . Dict (
{ key : env . observation_space [ key ] for key in filter_keys }
)
if len ( new_observation_space ) == 0 :
raise ValueError (
" The observation space is empty due to filtering all keys. "
)
super ( ) . __init__ (
env ,
lambda obs : { key : obs [ key ] for key in filter_keys } ,
new_observation_space ,
2022-12-03 19:45:39 +00:00
)
2022-12-05 19:14:56 +00:00
# Filter for tuple observation
elif isinstance ( env . observation_space , spaces . Tuple ) :
assert all ( isinstance ( key , int ) for key in filter_keys )
assert len ( set ( filter_keys ) ) == len (
filter_keys
) , f " Duplicate keys exist, filter_keys: { filter_keys } "
if any (
0 < key and key > = len ( env . observation_space ) for key in filter_keys
) :
missing_index = [
key
for key in filter_keys
if 0 < key and key > = len ( env . observation_space )
]
raise ValueError (
" All the `filter_keys` must be included in the length of the observation space. \n "
f " Filter keys: { filter_keys } , length of observation: { len ( env . observation_space ) } , "
f " missing indexes: { missing_index } "
)
new_observation_spaces = spaces . Tuple (
env . observation_space [ key ] for key in filter_keys
)
if len ( new_observation_spaces ) == 0 :
raise ValueError (
" The observation space is empty due to filtering all keys. "
)
2022-12-03 19:45:39 +00:00
2022-12-05 19:14:56 +00:00
super ( ) . __init__ (
env ,
lambda obs : tuple ( obs [ key ] for key in filter_keys ) ,
new_observation_spaces ,
)
else :
2022-12-03 19:45:39 +00:00
raise ValueError (
2022-12-05 19:14:56 +00:00
f " FilterObservation wrapper is only usable with ``Dict`` and ``Tuple`` observations, actual type: { type ( env . observation_space ) } "
2022-12-03 19:45:39 +00:00
)
2022-12-05 19:14:56 +00:00
self . filter_keys : Final [ Sequence [ str | int ] ] = filter_keys
2022-12-03 19:45:39 +00:00
class FlattenObservationV0 ( LambdaObservationV0 ) :
""" Observation wrapper that flattens the observation.
Example :
>> > import gymnasium as gym
>> > env = gym . make ( ' CarRacing-v1 ' )
>> > env . observation_space . shape
( 96 , 96 , 3 )
>> > env = FlattenObservationV0 ( env )
>> > env . observation_space . shape
( 27648 , )
>> > obs , info = env . reset ( )
>> > obs . shape
( 27648 , )
"""
def __init__ ( self , env : gym . Env ) :
""" Constructor for any environment ' s observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``. """
super ( ) . __init__ (
2022-12-05 19:14:56 +00:00
env ,
lambda obs : utils . flatten ( env . observation_space , obs ) ,
utils . flatten_space ( env . observation_space ) ,
2022-12-03 19:45:39 +00:00
)
class GrayscaleObservationV0 ( LambdaObservationV0 ) :
""" Observation wrapper that converts an RGB image to grayscale.
The : attr : ` keep_dim ` will keep the channel dimension
Example :
>> > import gymnasium as gym
>> > env = gym . make ( " CarRacing-v1 " )
>> > env . observation_space . shape
( 96 , 96 , 3 )
>> > grayscale_env = GrayscaleObservationV0 ( env )
>> > grayscale_env . observation_space . shape
( 96 , 96 )
>> > grayscale_env = GrayscaleObservationV0 ( env , keep_dim = True )
>> > grayscale_env . observation_space . shape
( 96 , 96 , 1 )
"""
def __init__ ( self , env : gym . Env , keep_dim : bool = False ) :
""" Constructor for an RGB image based environments to make the image grayscale. """
assert isinstance ( env . observation_space , spaces . Box )
assert (
len ( env . observation_space . shape ) == 3
and env . observation_space . shape [ - 1 ] == 3
)
assert (
np . all ( env . observation_space . low == 0 )
and np . all ( env . observation_space . high == 255 )
and env . observation_space . dtype == np . uint8
)
2022-12-05 19:14:56 +00:00
self . keep_dim : Final [ bool ] = keep_dim
2022-12-03 19:45:39 +00:00
if keep_dim :
new_observation_space = spaces . Box (
low = 0 ,
high = 255 ,
shape = env . observation_space . shape [ : 2 ] + ( 1 , ) ,
dtype = np . uint8 ,
)
super ( ) . __init__ (
env ,
lambda obs : jp . expand_dims (
jp . sum (
jp . multiply ( obs , jp . array ( [ 0.2125 , 0.7154 , 0.0721 ] ) ) , axis = - 1
2022-12-05 19:14:56 +00:00
) . astype ( np . uint8 ) ,
axis = - 1 ,
2022-12-03 19:45:39 +00:00
) ,
new_observation_space ,
)
else :
new_observation_space = spaces . Box (
low = 0 , high = 255 , shape = env . observation_space . shape [ : 2 ] , dtype = np . uint8
)
super ( ) . __init__ (
env ,
lambda obs : jp . sum (
jp . multiply ( obs , jp . array ( [ 0.2125 , 0.7154 , 0.0721 ] ) ) , axis = - 1
2022-12-05 19:14:56 +00:00
) . astype ( np . uint8 ) ,
2022-12-03 19:45:39 +00:00
new_observation_space ,
)
class ResizeObservationV0 ( LambdaObservationV0 ) :
""" Observation wrapper for resize image observations using opencv.
Example :
>> > import gymnasium as gym
>> > env = gym . make ( " CarRacing-v1 " )
>> > resized_env = ResizeObservationV0 ( env , ( 32 , 32 ) )
>> > resized_env . observation_space . shape
( 32 , 32 , 3 )
"""
def __init__ ( self , env : gym . Env , shape : tuple [ int , . . . ] ) :
""" Constructor that requires an image environment observation space with a shape. """
assert isinstance ( env . observation_space , spaces . Box )
assert len ( env . observation_space . shape ) in [ 2 , 3 ]
assert np . all ( env . observation_space . low == 0 ) and np . all (
env . observation_space . high == 255
)
assert env . observation_space . dtype == np . uint8
assert isinstance ( shape , tuple )
assert all ( np . issubdtype ( type ( elem ) , np . integer ) for elem in shape )
assert all ( x > 0 for x in shape )
try :
import cv2
except ImportError :
raise DependencyNotInstalled (
" opencv is not install, run `pip install gymnasium[other]` "
)
2022-12-05 19:14:56 +00:00
self . shape : Final [ tuple [ int , . . . ] ] = tuple ( shape )
2022-12-03 19:45:39 +00:00
new_observation_space = spaces . Box (
low = 0 , high = 255 , shape = self . shape + env . observation_space . shape [ 2 : ]
)
super ( ) . __init__ (
env ,
lambda obs : cv2 . resize ( obs , self . shape , interpolation = cv2 . INTER_AREA ) ,
new_observation_space ,
)
class ReshapeObservationV0 ( LambdaObservationV0 ) :
""" Observation wrapper for reshaping the observation. """
def __init__ ( self , env : gym . Env , shape : int | tuple [ int , . . . ] ) :
""" Constructor for env with Box observation space that has a shape product equal to the new shape product. """
assert isinstance ( env . observation_space , spaces . Box )
assert np . product ( shape ) == np . product ( env . observation_space . shape )
assert isinstance ( shape , tuple )
assert all ( np . issubdtype ( type ( elem ) , np . integer ) for elem in shape )
2022-12-05 19:14:56 +00:00
assert all ( x > 0 or x == - 1 for x in shape )
2022-12-03 19:45:39 +00:00
new_observation_space = spaces . Box (
low = np . reshape ( np . ravel ( env . observation_space . low ) , shape ) ,
high = np . reshape ( np . ravel ( env . observation_space . high ) , shape ) ,
shape = shape ,
dtype = env . observation_space . dtype ,
)
2022-12-05 19:14:56 +00:00
self . shape = shape
super ( ) . __init__ ( env , lambda obs : jp . reshape ( obs , shape ) , new_observation_space )
2022-12-03 19:45:39 +00:00
class RescaleObservationV0 ( LambdaObservationV0 ) :
""" Observation wrapper for rescaling the observations between a minimum and maximum value. """
def __init__ (
self ,
env : gym . Env ,
2022-12-05 19:14:56 +00:00
min_obs : np . floating | np . integer | np . ndarray ,
max_obs : np . floating | np . integer | np . ndarray ,
2022-12-03 19:45:39 +00:00
) :
""" Constructor that requires the env observation spaces to be a :class:`Box`. """
assert isinstance ( env . observation_space , spaces . Box )
2022-12-05 19:14:56 +00:00
assert not np . any ( env . observation_space . low == np . inf ) and not np . any (
env . observation_space . high == np . inf
)
2022-12-03 19:45:39 +00:00
if not isinstance ( min_obs , np . ndarray ) :
assert np . issubdtype ( type ( min_obs ) , np . integer ) or np . issubdtype (
type ( max_obs ) , np . floating
)
min_obs = np . full ( env . observation_space . shape , min_obs )
2022-12-05 19:14:56 +00:00
assert (
min_obs . shape == env . observation_space . shape
) , f " { min_obs . shape } , { env . observation_space . shape } , { min_obs } , { env . observation_space . low } "
2022-12-03 19:45:39 +00:00
assert not np . any ( min_obs == np . inf )
if not isinstance ( max_obs , np . ndarray ) :
assert np . issubdtype ( type ( max_obs ) , np . integer ) or np . issubdtype (
type ( max_obs ) , np . floating
)
max_obs = np . full ( env . observation_space . shape , max_obs )
assert max_obs . shape == env . observation_space . shape
assert not np . any ( max_obs == np . inf )
2022-12-05 19:14:56 +00:00
self . min_obs = min_obs
self . max_obs = max_obs
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = ( max_obs - min_obs ) / (
env . observation_space . high - env . observation_space . low
)
intercept = gradient * - env . observation_space . low + min_obs
2022-12-03 19:45:39 +00:00
super ( ) . __init__ (
env ,
2022-12-05 19:14:56 +00:00
lambda obs : gradient * obs + intercept ,
Box (
low = min_obs ,
high = max_obs ,
shape = env . observation_space . shape ,
dtype = env . observation_space . dtype ,
) ,
2022-12-03 19:45:39 +00:00
)
class DtypeObservationV0 ( LambdaObservationV0 ) :
""" Observation wrapper for transforming the dtype of an observation. """
2022-12-05 19:14:56 +00:00
def __init__ ( self , env : gym . Env , dtype : Any ) :
2022-12-03 19:45:39 +00:00
""" Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces. """
assert isinstance (
env . observation_space ,
( spaces . Box , spaces . Discrete , spaces . MultiDiscrete , spaces . MultiBinary ) ,
)
2022-12-05 19:14:56 +00:00
self . dtype = dtype
2022-12-03 19:45:39 +00:00
if isinstance ( env . observation_space , spaces . Box ) :
new_observation_space = spaces . Box (
low = env . observation_space . low ,
high = env . observation_space . high ,
shape = env . observation_space . shape ,
2022-12-05 19:14:56 +00:00
dtype = self . dtype ,
2022-12-03 19:45:39 +00:00
)
elif isinstance ( env . observation_space , spaces . Discrete ) :
new_observation_space = spaces . Box (
low = env . observation_space . start ,
high = env . observation_space . start + env . observation_space . n ,
shape = ( ) ,
2022-12-05 19:14:56 +00:00
dtype = self . dtype ,
2022-12-03 19:45:39 +00:00
)
elif isinstance ( env . observation_space , spaces . MultiDiscrete ) :
new_observation_space = spaces . MultiDiscrete (
2022-12-05 19:14:56 +00:00
env . observation_space . nvec , dtype = dtype
2022-12-03 19:45:39 +00:00
)
elif isinstance ( env . observation_space , spaces . MultiBinary ) :
new_observation_space = spaces . Box (
2022-12-05 19:14:56 +00:00
low = 0 ,
high = 1 ,
shape = env . observation_space . shape ,
dtype = self . dtype ,
2022-12-03 19:45:39 +00:00
)
else :
2022-12-05 19:14:56 +00:00
raise TypeError (
" DtypeObservation is only compatible with value / array-based observations. "
)
2022-12-03 19:45:39 +00:00
super ( ) . __init__ ( env , lambda obs : dtype ( obs ) , new_observation_space )