2022-05-20 14:49:30 +01:00
""" Utility functions for vector environments to share memory between processes. """
2023-11-07 13:27:25 +00:00
from __future__ import annotations
2019-06-28 19:35:19 -04:00
import multiprocessing as mp
2019-06-21 17:29:44 -04:00
from collections import OrderedDict
2022-03-31 12:50:38 -07:00
from ctypes import c_bool
from functools import singledispatch
2023-11-07 13:27:25 +00:00
from typing import Any
2022-03-31 12:50:38 -07:00
import numpy as np
2019-06-21 17:29:44 -04:00
2022-09-08 10:10:07 +01:00
from gymnasium . error import CustomSpaceError
2022-09-08 10:11:31 +01:00
from gymnasium . spaces import (
Box ,
Dict ,
Discrete ,
2023-11-07 13:27:25 +00:00
Graph ,
2022-09-08 10:11:31 +01:00
MultiBinary ,
MultiDiscrete ,
2023-11-07 13:27:25 +00:00
Sequence ,
2022-09-08 10:11:31 +01:00
Space ,
2023-11-07 13:27:25 +00:00
Text ,
2022-09-08 10:11:31 +01:00
Tuple ,
2023-11-07 13:27:25 +00:00
flatten ,
2022-09-08 10:11:31 +01:00
)
2019-06-21 17:29:44 -04:00
2022-12-04 22:24:02 +08:00
2021-07-29 02:26:34 +02:00
__all__ = [ " create_shared_memory " , " read_from_shared_memory " , " write_to_shared_memory " ]
2019-06-21 17:29:44 -04:00
2022-01-21 11:28:34 -05:00
@singledispatch
2022-05-20 14:49:30 +01:00
def create_shared_memory (
2023-11-07 13:27:25 +00:00
space : Space [ Any ] , n : int = 1 , ctx = mp
) - > dict [ str , Any ] | tuple [ Any , . . . ] | mp . Array :
2022-05-20 14:49:30 +01:00
""" Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment .
Args :
space : Observation space of a single environment in the vectorized environment .
n : Number of environments in the vectorized environment ( i . e . the number of processes ) .
ctx : The multiprocess module
Returns :
shared_memory for the shared object across processes .
2022-05-25 14:46:41 +01:00
Raises :
2022-09-08 10:10:07 +01:00
CustomSpaceError : Space is not a valid : class : ` gymnasium . Space ` instance
2019-06-21 17:29:44 -04:00
"""
2023-11-07 13:27:25 +00:00
if isinstance ( space , Space ) :
raise CustomSpaceError (
f " Space of type ` { type ( space ) } ` doesn ' t have an registered `create_shared_memory` function. Register ` { type ( space ) } ` for `create_shared_memory` to support it. "
)
else :
raise TypeError (
f " The space provided to `create_shared_memory` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
)
2022-01-21 11:28:34 -05:00
@create_shared_memory.register ( Box )
@create_shared_memory.register ( Discrete )
@create_shared_memory.register ( MultiDiscrete )
@create_shared_memory.register ( MultiBinary )
2023-11-07 13:27:25 +00:00
def _create_base_shared_memory (
space : Box | Discrete | MultiDiscrete | MultiBinary , n : int = 1 , ctx = mp
) :
assert space . dtype is not None
2019-06-21 17:29:44 -04:00
dtype = space . dtype . char
2021-07-29 02:26:34 +02:00
if dtype in " ? " :
2019-06-21 17:29:44 -04:00
dtype = c_bool
2019-06-28 19:35:19 -04:00
return ctx . Array ( dtype , n * int ( np . prod ( space . shape ) ) )
2019-06-21 17:29:44 -04:00
2021-07-29 02:26:34 +02:00
2022-01-21 11:28:34 -05:00
@create_shared_memory.register ( Tuple )
2023-11-07 13:27:25 +00:00
def _create_tuple_shared_memory ( space : Tuple , n : int = 1 , ctx = mp ) :
2021-07-29 15:39:42 -04:00
return tuple (
create_shared_memory ( subspace , n = n , ctx = ctx ) for subspace in space . spaces
)
2021-07-29 02:26:34 +02:00
2019-06-21 17:29:44 -04:00
2022-01-21 11:28:34 -05:00
@create_shared_memory.register ( Dict )
2023-11-07 13:27:25 +00:00
def _create_dict_shared_memory ( space : Dict , n : int = 1 , ctx = mp ) :
2021-07-29 15:39:42 -04:00
return OrderedDict (
[
( key , create_shared_memory ( subspace , n = n , ctx = ctx ) )
for ( key , subspace ) in space . spaces . items ( )
]
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
@create_shared_memory.register ( Text )
def _create_text_shared_memory ( space : Text , n : int = 1 , ctx = mp ) :
return ctx . Array ( np . dtype ( np . int32 ) . char , n * space . max_length )
@create_shared_memory.register ( Graph )
@create_shared_memory.register ( Sequence )
def _create_dynamic_shared_memory ( space : Graph | Sequence , n : int = 1 , ctx = mp ) :
raise TypeError (
f " As { space } has a dynamic shape then it is not possible to make a static shared memory. "
)
2022-01-21 11:28:34 -05:00
@singledispatch
2022-05-20 14:49:30 +01:00
def read_from_shared_memory (
2023-11-07 13:27:25 +00:00
space : Space , shared_memory : dict | tuple | mp . Array , n : int = 1
) - > dict [ str , Any ] | tuple [ Any , . . . ] | np . ndarray :
2019-06-21 17:29:44 -04:00
""" Read the batch of observations from shared memory as a numpy array.
2022-05-20 14:49:30 +01:00
. . notes : :
The numpy array objects returned by ` read_from_shared_memory ` shares the
memory of ` shared_memory ` . Any changes to ` shared_memory ` are forwarded
to ` observations ` , and vice - versa . To avoid any side - effect , use ` np . copy ` .
2019-06-21 17:29:44 -04:00
2022-05-20 14:49:30 +01:00
Args :
space : Observation space of a single environment in the vectorized environment .
shared_memory : Shared object across processes . This contains the observations from the vectorized environment .
This object is created with ` create_shared_memory ` .
n : Number of environments in the vectorized environment ( i . e . the number of processes ) .
2019-06-21 17:29:44 -04:00
2022-05-20 14:49:30 +01:00
Returns :
2019-06-21 17:29:44 -04:00
Batch of observations as a ( possibly nested ) numpy array .
2022-05-25 14:46:41 +01:00
Raises :
2022-09-08 10:10:07 +01:00
CustomSpaceError : Space is not a valid : class : ` gymnasium . Space ` instance
2019-06-21 17:29:44 -04:00
"""
2023-11-07 13:27:25 +00:00
if isinstance ( space , Space ) :
raise CustomSpaceError (
f " Space of type ` { type ( space ) } ` doesn ' t have an registered `read_from_shared_memory` function. Register ` { type ( space ) } ` for `read_from_shared_memory` to support it. "
)
else :
raise TypeError (
f " The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
)
2022-01-21 11:28:34 -05:00
@read_from_shared_memory.register ( Box )
@read_from_shared_memory.register ( Discrete )
@read_from_shared_memory.register ( MultiDiscrete )
@read_from_shared_memory.register ( MultiBinary )
2023-11-07 13:27:25 +00:00
def _read_base_from_shared_memory (
space : Box | Discrete | MultiDiscrete | MultiBinary , shared_memory , n : int = 1
) :
2021-07-29 15:39:42 -04:00
return np . frombuffer ( shared_memory . get_obj ( ) , dtype = space . dtype ) . reshape (
( n , ) + space . shape
)
2021-07-29 02:26:34 +02:00
2019-06-21 17:29:44 -04:00
2022-01-21 11:28:34 -05:00
@read_from_shared_memory.register ( Tuple )
2023-11-07 13:27:25 +00:00
def _read_tuple_from_shared_memory ( space : Tuple , shared_memory , n : int = 1 ) :
2021-07-29 15:39:42 -04:00
return tuple (
2022-01-21 11:28:34 -05:00
read_from_shared_memory ( subspace , memory , n = n )
2021-07-29 15:39:42 -04:00
for ( memory , subspace ) in zip ( shared_memory , space . spaces )
)
2021-07-29 02:26:34 +02:00
2019-06-21 17:29:44 -04:00
2022-01-21 11:28:34 -05:00
@read_from_shared_memory.register ( Dict )
2023-11-07 13:27:25 +00:00
def _read_dict_from_shared_memory ( space : Dict , shared_memory , n : int = 1 ) :
2021-07-29 02:26:34 +02:00
return OrderedDict (
2021-07-29 15:39:42 -04:00
[
2022-01-21 11:28:34 -05:00
( key , read_from_shared_memory ( subspace , shared_memory [ key ] , n = n ) )
2021-07-29 15:39:42 -04:00
for ( key , subspace ) in space . spaces . items ( )
]
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
@read_from_shared_memory.register ( Text )
def _read_text_from_shared_memory ( space : Text , shared_memory , n : int = 1 ) - > tuple [ str ] :
data = np . frombuffer ( shared_memory . get_obj ( ) , dtype = np . int32 ) . reshape (
( n , space . max_length )
)
return tuple (
" " . join (
[
space . character_list [ val ]
for val in values
if val < len ( space . character_set )
]
)
for values in data
)
2022-01-21 11:28:34 -05:00
@singledispatch
2022-05-20 14:49:30 +01:00
def write_to_shared_memory (
space : Space ,
index : int ,
value : np . ndarray ,
2023-11-07 13:27:25 +00:00
shared_memory : dict [ str , Any ] | tuple [ Any , . . . ] | mp . Array ,
2022-05-20 14:49:30 +01:00
) :
2019-06-21 17:29:44 -04:00
""" Write the observation of a single environment into shared memory.
2022-05-20 14:49:30 +01:00
Args :
space : Observation space of a single environment in the vectorized environment .
index : Index of the environment ( must be in ` [ 0 , num_envs ) ` ) .
value : Observation of the single environment to write to shared memory .
2022-05-25 14:46:41 +01:00
shared_memory : Shared object across processes . This contains the observations from the vectorized environment .
This object is created with ` create_shared_memory ` .
Raises :
2022-09-08 10:10:07 +01:00
CustomSpaceError : Space is not a valid : class : ` gymnasium . Space ` instance
2019-06-21 17:29:44 -04:00
"""
2023-11-07 13:27:25 +00:00
if isinstance ( space , Space ) :
raise CustomSpaceError (
f " Space of type ` { type ( space ) } ` doesn ' t have an registered `write_to_shared_memory` function. Register ` { type ( space ) } ` for `write_to_shared_memory` to support it. "
)
else :
raise TypeError (
f " The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: { type ( space ) } , { space } "
)
2022-01-21 11:28:34 -05:00
@write_to_shared_memory.register ( Box )
@write_to_shared_memory.register ( Discrete )
@write_to_shared_memory.register ( MultiDiscrete )
@write_to_shared_memory.register ( MultiBinary )
2023-11-07 13:27:25 +00:00
def _write_base_to_shared_memory (
space : Box | Discrete | MultiDiscrete | MultiBinary ,
index : int ,
value ,
shared_memory ,
) :
2019-06-21 17:29:44 -04:00
size = int ( np . prod ( space . shape ) )
2019-06-21 21:56:51 -04:00
destination = np . frombuffer ( shared_memory . get_obj ( ) , dtype = space . dtype )
2021-07-29 02:26:34 +02:00
np . copyto (
destination [ index * size : ( index + 1 ) * size ] ,
np . asarray ( value , dtype = space . dtype ) . flatten ( ) ,
)
2019-06-21 17:29:44 -04:00
2022-01-21 11:28:34 -05:00
@write_to_shared_memory.register ( Tuple )
2023-11-07 13:27:25 +00:00
def _write_tuple_to_shared_memory (
space : Tuple , index : int , values : tuple [ Any , . . . ] , shared_memory
) :
2019-06-21 17:29:44 -04:00
for value , memory , subspace in zip ( values , shared_memory , space . spaces ) :
2022-01-21 11:28:34 -05:00
write_to_shared_memory ( subspace , index , value , memory )
2019-06-21 17:29:44 -04:00
2021-07-29 02:26:34 +02:00
2022-01-21 11:28:34 -05:00
@write_to_shared_memory.register ( Dict )
2023-11-07 13:27:25 +00:00
def _write_dict_to_shared_memory (
space : Dict , index : int , values : dict [ str , Any ] , shared_memory
) :
2019-06-23 13:14:33 -04:00
for key , subspace in space . spaces . items ( ) :
2022-01-21 11:28:34 -05:00
write_to_shared_memory ( subspace , index , values [ key ] , shared_memory [ key ] )
2023-11-07 13:27:25 +00:00
@write_to_shared_memory.register ( Text )
def _write_text_to_shared_memory ( space : Text , index : int , values : str , shared_memory ) :
size = space . max_length
destination = np . frombuffer ( shared_memory . get_obj ( ) , dtype = np . int32 )
np . copyto (
destination [ index * size : ( index + 1 ) * size ] ,
flatten ( space , values ) ,
)