2022-12-10 22:04:14 +00:00
""" Utility functions for the wrappers. """
2023-02-23 16:07:58 +00:00
from functools import singledispatch
2022-12-10 22:04:14 +00:00
import numpy as np
2023-02-23 16:07:58 +00:00
from gymnasium import Space
from gymnasium . error import CustomSpaceError
from gymnasium . spaces import (
Box ,
Dict ,
Discrete ,
Graph ,
GraphInstance ,
MultiBinary ,
MultiDiscrete ,
2024-03-11 13:30:50 +01:00
OneOf ,
2023-02-23 16:07:58 +00:00
Sequence ,
Text ,
Tuple ,
)
from gymnasium . spaces . space import T_cov
2022-12-10 22:04:14 +00:00
2023-06-21 17:04:11 +01:00
__all__ = [ " RunningMeanStd " , " update_mean_var_count_from_moments " , " create_zero_array " ]
2022-12-10 22:04:14 +00:00
class RunningMeanStd :
""" Tracks the mean, variance and count of values. """
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
2023-11-07 13:27:25 +00:00
def __init__ ( self , epsilon = 1e-4 , shape = ( ) , dtype = np . float64 ) :
2022-12-10 22:04:14 +00:00
""" Tracks the mean, variance and count of values. """
2023-11-07 13:27:25 +00:00
self . mean = np . zeros ( shape , dtype = dtype )
self . var = np . ones ( shape , dtype = dtype )
2022-12-10 22:04:14 +00:00
self . count = epsilon
def update ( self , x ) :
""" Updates the mean, var and count from a batch of samples. """
batch_mean = np . mean ( x , axis = 0 )
batch_var = np . var ( x , axis = 0 )
batch_count = x . shape [ 0 ]
self . update_from_moments ( batch_mean , batch_var , batch_count )
def update_from_moments ( self , batch_mean , batch_var , batch_count ) :
""" Updates from batch mean, variance and count moments. """
self . mean , self . var , self . count = update_mean_var_count_from_moments (
self . mean , self . var , self . count , batch_mean , batch_var , batch_count
)
def update_mean_var_count_from_moments (
mean , var , count , batch_mean , batch_var , batch_count
) :
""" Updates the mean, var and count using the previous mean, var, count and batch values. """
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np . square ( delta ) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean , new_var , new_count
2023-02-23 16:07:58 +00:00
@singledispatch
def create_zero_array ( space : Space [ T_cov ] ) - > T_cov :
""" Creates a zero-based array of a space, this is similar to ``create_empty_array`` except all arrays are valid samples from the space.
As some ` ` Box ` ` cases have ` ` high ` ` or ` ` low ` ` that don ' t contain zero then the ``create_empty_array`` would in case
create arrays which is not contained in the space .
Args :
space : The space to create a zero array for
Returns :
Valid sample from the space that is as close to zero as possible
"""
if isinstance ( space , Space ) :
raise CustomSpaceError (
f " Space of type ` { type ( space ) } ` doesn ' t have an registered `create_zero_array` function. Register ` { type ( space ) } ` for `create_zero_array` to support it. "
)
else :
raise TypeError (
f " The space provided to `create_zero_array` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
)
@create_zero_array.register ( Box )
def _create_box_zero_array ( space : Box ) :
zero_array = np . zeros ( space . shape , dtype = space . dtype )
zero_array = np . where ( space . low > 0 , space . low , zero_array )
zero_array = np . where ( space . high < 0 , space . high , zero_array )
return zero_array
@create_zero_array.register ( Discrete )
def _create_discrete_zero_array ( space : Discrete ) :
return space . start
@create_zero_array.register ( MultiDiscrete )
2023-06-16 16:36:42 +02:00
def _create_multidiscrete_zero_array ( space : MultiDiscrete ) :
return np . array ( space . start , copy = True , dtype = space . dtype )
2023-02-23 16:07:58 +00:00
@create_zero_array.register ( MultiBinary )
2023-06-16 16:36:42 +02:00
def _create_array_zero_array ( space : MultiBinary ) :
2023-02-23 16:07:58 +00:00
return np . zeros ( space . shape , dtype = space . dtype )
@create_zero_array.register ( Tuple )
def _create_tuple_zero_array ( space : Tuple ) :
return tuple ( create_zero_array ( subspace ) for subspace in space . spaces )
@create_zero_array.register ( Dict )
def _create_dict_zero_array ( space : Dict ) :
2024-03-22 11:19:41 +00:00
return { key : create_zero_array ( subspace ) for key , subspace in space . spaces . items ( ) }
2023-02-23 16:07:58 +00:00
@create_zero_array.register ( Sequence )
def _create_sequence_zero_array ( space : Sequence ) :
if space . stack :
return create_zero_array ( space . stacked_feature_space )
else :
return tuple ( )
@create_zero_array.register ( Text )
def _create_text_zero_array ( space : Text ) :
return " " . join ( space . characters [ 0 ] for _ in range ( space . min_length ) )
@create_zero_array.register ( Graph )
def _create_graph_zero_array ( space : Graph ) :
nodes = np . expand_dims ( create_zero_array ( space . node_space ) , axis = 0 )
if space . edge_space is None :
return GraphInstance ( nodes = nodes , edges = None , edge_links = None )
else :
edges = np . expand_dims ( create_zero_array ( space . edge_space ) , axis = 0 )
edge_links = np . zeros ( ( 1 , 2 ) , dtype = np . int64 )
return GraphInstance ( nodes = nodes , edges = edges , edge_links = edge_links )
2024-03-11 13:30:50 +01:00
@create_zero_array.register ( OneOf )
def _create_one_of_zero_array ( space : OneOf ) :
return 0 , create_zero_array ( space . spaces [ 0 ] )