2022-05-20 14:49:30 +01:00
""" An async vector environment. """
2023-11-07 13:27:25 +00:00
from __future__ import annotations
import multiprocessing
2019-06-21 17:29:44 -04:00
import sys
2022-03-31 12:50:38 -07:00
import time
2019-06-21 17:29:44 -04:00
from copy import deepcopy
2022-03-31 12:50:38 -07:00
from enum import Enum
2023-11-07 13:27:25 +00:00
from multiprocessing import Queue
from multiprocessing . connection import Connection
from typing import Any , Callable , Sequence
2022-03-31 12:50:38 -07:00
import numpy as np
2019-06-21 17:29:44 -04:00
2022-09-08 10:10:07 +01:00
from gymnasium import logger
2023-11-07 13:27:25 +00:00
from gymnasium . core import ActType , Env , ObsType , RenderFrame
2022-09-08 10:10:07 +01:00
from gymnasium . error import (
2021-07-29 02:26:34 +02:00
AlreadyPendingCallError ,
ClosedEnvironmentError ,
CustomSpaceError ,
2022-03-31 12:50:38 -07:00
NoAsyncCallError ,
2021-07-29 02:26:34 +02:00
)
2022-09-08 10:10:07 +01:00
from gymnasium . vector . utils import (
2021-07-29 02:26:34 +02:00
CloudpickleWrapper ,
2023-11-07 13:27:25 +00:00
batch_space ,
2021-07-29 02:26:34 +02:00
clear_mpi_env_vars ,
2022-03-31 12:50:38 -07:00
concatenate ,
create_empty_array ,
create_shared_memory ,
iterate ,
read_from_shared_memory ,
write_to_shared_memory ,
2021-07-29 02:26:34 +02:00
)
2023-11-07 13:27:25 +00:00
from gymnasium . vector . vector_env import ArrayType , VectorEnv
2021-07-29 02:26:34 +02:00
2022-12-04 22:24:02 +08:00
2023-11-07 13:27:25 +00:00
__all__ = [ " AsyncVectorEnv " , " AsyncState " ]
2019-06-21 17:29:44 -04:00
class AsyncState ( Enum ) :
2023-11-07 13:27:25 +00:00
""" The AsyncVectorEnv possible states given the different actions. """
2021-07-29 02:26:34 +02:00
DEFAULT = " default "
WAITING_RESET = " reset "
WAITING_STEP = " step "
2022-01-29 12:32:35 -05:00
WAITING_CALL = " call "
2019-06-21 17:29:44 -04:00
class AsyncVectorEnv ( VectorEnv ) :
2022-05-20 14:49:30 +01:00
""" Vectorized environment that runs multiple environments in parallel.
It uses ` ` multiprocessing ` ` processes , and pipes for communication .
2021-11-14 08:59:04 -05:00
2023-01-23 11:30:00 +01:00
Example :
2022-09-16 23:41:27 +01:00
>> > import gymnasium as gym
2023-11-07 13:27:25 +00:00
>> > envs = gym . make_vec ( " Pendulum-v1 " , num_envs = 2 , vectorization_mode = " async " )
>> > envs
AsyncVectorEnv ( Pendulum - v1 , num_envs = 2 )
>> > envs = gym . vector . AsyncVectorEnv ( [
2023-01-20 14:28:09 +01:00
. . . lambda : gym . make ( " Pendulum-v1 " , g = 9.81 ) ,
. . . lambda : gym . make ( " Pendulum-v1 " , g = 1.62 )
2021-11-14 08:59:04 -05:00
. . . ] )
2023-11-07 13:27:25 +00:00
>> > envs
AsyncVectorEnv ( num_envs = 2 )
>> > observations , infos = envs . reset ( seed = 42 )
>> > observations
array ( [ [ - 0.14995256 , 0.9886932 , - 0.12224312 ] ,
[ 0.5760367 , 0.8174238 , - 0.91244936 ] ] , dtype = float32 )
>> > infos
{ }
>> > _ = envs . action_space . seed ( 123 )
>> > observations , rewards , terminations , truncations , infos = envs . step ( envs . action_space . sample ( ) )
>> > observations
array ( [ [ - 0.1851753 , 0.98270553 , 0.714599 ] ,
[ 0.6193494 , 0.7851154 , - 1.0808398 ] ] , dtype = float32 )
>> > rewards
array ( [ - 2.96495728 , - 1.00214607 ] )
>> > terminations
array ( [ False , False ] )
>> > truncations
array ( [ False , False ] )
>> > infos
{ }
2019-06-21 17:29:44 -04:00
"""
2021-07-29 02:26:34 +02:00
def __init__ (
self ,
2023-01-11 20:09:37 +00:00
env_fns : Sequence [ Callable [ [ ] , Env ] ] ,
2022-05-20 14:49:30 +01:00
shared_memory : bool = True ,
copy : bool = True ,
2023-11-07 13:27:25 +00:00
context : str | None = None ,
2022-05-20 14:49:30 +01:00
daemon : bool = True ,
2023-11-07 13:27:25 +00:00
worker : Callable [
[ int , Callable [ [ ] , Env ] , Connection , Connection , bool , Queue ] , None
]
| None = None ,
2021-07-29 02:26:34 +02:00
) :
2022-05-20 14:49:30 +01:00
""" Vectorized environment that runs multiple environments in parallel.
Args :
env_fns : Functions that create the environments .
2022-05-25 14:46:41 +01:00
shared_memory : If ` ` True ` ` , then the observations from the worker processes are communicated back through
shared variables . This can improve the efficiency if the observations are large ( e . g . images ) .
2023-11-07 13:27:25 +00:00
copy : If ` ` True ` ` , then the : meth : ` AsyncVectorEnv . reset ` and : meth : ` AsyncVectorEnv . step ` methods
2022-05-25 14:46:41 +01:00
return a copy of the observations .
2023-11-07 13:27:25 +00:00
context : Context for ` multiprocessing ` . If ` ` None ` ` , then the default context is used .
2022-05-25 14:46:41 +01:00
daemon : If ` ` True ` ` , then subprocesses have ` ` daemon ` ` flag turned on ; that is , they will quit if
the head process quits . However , ` ` daemon = True ` ` prevents subprocesses to spawn children ,
so for some environments you may want to have it set to ` ` False ` ` .
worker : If set , then use that worker in a subprocess instead of a default one .
2022-07-10 02:18:06 +05:30
Can be useful to override some inner vector env logic , for instance , how resets on termination or truncation are handled .
2022-05-20 14:49:30 +01:00
2023-05-26 19:19:21 +09:00
Warnings :
worker is an advanced mode option . It provides a high degree of flexibility and a high chance
2022-05-25 14:46:41 +01:00
to shoot yourself in the foot ; thus , if you are writing your own worker , it is recommended to start
from the code for ` ` _worker ` ` ( or ` ` _worker_shared_memory ` ` ) method , and add changes .
2022-05-20 14:49:30 +01:00
Raises :
2022-05-25 14:46:41 +01:00
RuntimeError : If the observation space of some sub - environment does not match observation_space
( or , by default , the observation space of the first sub - environment ) .
ValueError : If observation_space is a custom space ( i . e . not a default space in Gym ,
2022-09-08 10:10:07 +01:00
such as gymnasium . spaces . Box , gymnasium . spaces . Discrete , or gymnasium . spaces . Dict ) and shared_memory is True .
2022-05-20 14:49:30 +01:00
"""
2019-06-21 17:29:44 -04:00
self . env_fns = env_fns
self . shared_memory = shared_memory
self . copy = copy
2023-11-07 13:27:25 +00:00
self . num_envs = len ( env_fns )
# This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes
# Create a dummy environment to gather the metadata and observation / action space of the environment
2021-08-18 16:36:40 -04:00
dummy_env = env_fns [ 0 ] ( )
2023-11-07 13:27:25 +00:00
# As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env.
2021-08-18 16:36:40 -04:00
self . metadata = dummy_env . metadata
2023-11-07 13:27:25 +00:00
self . render_mode = dummy_env . render_mode
self . single_observation_space = dummy_env . observation_space
self . single_action_space = dummy_env . action_space
self . observation_space = batch_space (
self . single_observation_space , self . num_envs
)
self . action_space = batch_space ( self . single_action_space , self . num_envs )
2019-06-21 17:29:44 -04:00
2021-08-18 16:36:40 -04:00
dummy_env . close ( )
del dummy_env
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
# Generate the multiprocessing context for the observation buffer
ctx = multiprocessing . get_context ( context )
2019-06-21 17:29:44 -04:00
if self . shared_memory :
2020-09-21 22:38:51 +02:00
try :
2021-07-29 15:39:42 -04:00
_obs_buffer = create_shared_memory (
self . single_observation_space , n = self . num_envs , ctx = ctx
)
self . observations = read_from_shared_memory (
2022-01-21 11:28:34 -05:00
self . single_observation_space , _obs_buffer , n = self . num_envs
2021-07-29 15:39:42 -04:00
)
2022-12-10 16:47:18 +02:00
except CustomSpaceError as e :
2021-07-29 02:26:34 +02:00
raise ValueError (
2023-11-07 13:27:25 +00:00
" Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), "
" and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. "
" Set `shared_memory=False` if you use custom observation spaces. "
2022-12-10 16:47:18 +02:00
) from e
2019-06-21 17:29:44 -04:00
else :
_obs_buffer = None
2021-07-29 15:39:42 -04:00
self . observations = create_empty_array (
self . single_observation_space , n = self . num_envs , fn = np . zeros
)
2019-06-21 17:29:44 -04:00
self . parent_pipes , self . processes = [ ] , [ ]
self . error_queue = ctx . Queue ( )
2023-11-07 13:27:25 +00:00
target = worker or _async_worker
2019-06-21 17:29:44 -04:00
with clear_mpi_env_vars ( ) :
for idx , env_fn in enumerate ( self . env_fns ) :
parent_pipe , child_pipe = ctx . Pipe ( )
2021-07-29 02:26:34 +02:00
process = ctx . Process (
target = target ,
2021-11-14 14:51:32 +01:00
name = f " Worker< { type ( self ) . __name__ } >- { idx } " ,
2021-07-29 02:26:34 +02:00
args = (
idx ,
CloudpickleWrapper ( env_fn ) ,
child_pipe ,
parent_pipe ,
_obs_buffer ,
self . error_queue ,
) ,
)
2019-06-21 17:29:44 -04:00
self . parent_pipes . append ( parent_pipe )
self . processes . append ( process )
2019-10-09 15:08:10 -07:00
process . daemon = daemon
2019-06-21 17:29:44 -04:00
process . start ( )
child_pipe . close ( )
self . _state = AsyncState . DEFAULT
2021-12-08 21:31:41 -05:00
self . _check_spaces ( )
2019-06-21 17:29:44 -04:00
2024-02-26 13:00:18 +01:00
@property
def np_random_seed ( self ) - > tuple [ int , . . . ] :
2024-04-17 14:52:41 +01:00
""" Returns a tuple of np_random seeds for all the wrapped envs. """
2024-02-26 13:00:18 +01:00
return self . get_attr ( " np_random_seed " )
@property
def np_random ( self ) - > tuple [ np . random . Generator , . . . ] :
2024-04-17 14:52:41 +01:00
""" Returns the tuple of the numpy random number generators for the wrapped envs. """
2024-02-26 13:00:18 +01:00
return self . get_attr ( " np_random " )
2023-11-07 13:27:25 +00:00
def reset (
self ,
* ,
seed : int | list [ int ] | None = None ,
options : dict [ str , Any ] | None = None ,
) - > tuple [ ObsType , dict [ str , Any ] ] :
""" Resets all sub-environments in parallel and return a batch of concatenated observations and info.
Args :
seed : The environment reset seeds
options : If to return the options
Returns :
A batch of observations and info from the vectorized environment .
"""
self . reset_async ( seed = seed , options = options )
return self . reset_wait ( )
2022-01-19 23:28:59 +01:00
def reset_async (
self ,
2023-11-07 13:27:25 +00:00
seed : int | list [ int ] | None = None ,
options : dict | None = None ,
2022-01-19 23:28:59 +01:00
) :
2022-05-20 14:49:30 +01:00
""" Send calls to the :obj:`reset` methods of the sub-environments.
To get the results of these calls , you may invoke : meth : ` reset_wait ` .
Args :
seed : List of seeds for each environment
options : The reset option
Raises :
ClosedEnvironmentError : If the environment was closed ( if : meth : ` close ` was previously called ) .
AlreadyPendingCallError : If the environment is already waiting for a pending call to another
method ( e . g . : meth : ` step_async ` ) . This can be caused by two consecutive
calls to : meth : ` reset_async ` , with no call to : meth : ` reset_wait ` in between .
2021-11-14 08:59:04 -05:00
"""
2019-06-21 17:29:44 -04:00
self . _assert_is_running ( )
2021-12-08 22:14:15 +01:00
if seed is None :
seed = [ None for _ in range ( self . num_envs ) ]
2023-11-07 13:27:25 +00:00
elif isinstance ( seed , int ) :
2021-12-08 22:14:15 +01:00
seed = [ seed + i for i in range ( self . num_envs ) ]
2024-02-26 13:00:18 +01:00
assert (
len ( seed ) == self . num_envs
) , f " If seeds are passed as a list the length must match num_envs= { self . num_envs } but got length= { len ( seed ) } . "
2021-12-08 22:14:15 +01:00
2019-06-21 17:29:44 -04:00
if self . _state != AsyncState . DEFAULT :
2021-07-29 02:26:34 +02:00
raise AlreadyPendingCallError (
2021-11-14 14:51:32 +01:00
f " Calling `reset_async` while waiting for a pending call to ` { self . _state . value } ` to complete " ,
2023-11-07 13:27:25 +00:00
str ( self . _state . value ) ,
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
for pipe , env_seed in zip ( self . parent_pipes , seed ) :
env_kwargs = { " seed " : env_seed , " options " : options }
pipe . send ( ( " reset " , env_kwargs ) )
2019-06-21 17:29:44 -04:00
self . _state = AsyncState . WAITING_RESET
2022-01-19 23:28:59 +01:00
def reset_wait (
2022-02-06 17:28:27 -06:00
self ,
2023-11-07 13:27:25 +00:00
timeout : int | float | None = None ,
) - > tuple [ ObsType , dict [ str , Any ] ] :
2022-05-20 14:49:30 +01:00
""" Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args :
2023-11-07 13:27:25 +00:00
timeout : Number of seconds before the call to ` ` reset_wait ` ` times out . If ` None ` , the call to ` ` reset_wait ` ` never times out .
2022-05-20 14:49:30 +01:00
Returns :
A tuple of batched observations and list of dictionaries
Raises :
ClosedEnvironmentError : If the environment was closed ( if : meth : ` close ` was previously called ) .
NoAsyncCallError : If : meth : ` reset_wait ` was called without any prior call to : meth : ` reset_async ` .
TimeoutError : If : meth : ` reset_wait ` timed out .
2019-06-21 17:29:44 -04:00
"""
self . _assert_is_running ( )
if self . _state != AsyncState . WAITING_RESET :
2021-07-29 02:26:34 +02:00
raise NoAsyncCallError (
" Calling `reset_wait` without any prior " " call to `reset_async`. " ,
AsyncState . WAITING_RESET . value ,
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
if not self . _poll_pipe_envs ( timeout ) :
2019-06-21 17:29:44 -04:00
self . _state = AsyncState . DEFAULT
2023-11-07 13:27:25 +00:00
raise multiprocessing . TimeoutError (
2022-01-29 12:32:35 -05:00
f " The call to `reset_wait` has timed out after { timeout } second(s). "
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
2019-06-28 17:42:21 -04:00
results , successes = zip ( * [ pipe . recv ( ) for pipe in self . parent_pipes ] )
self . _raise_if_errors ( successes )
2019-06-21 17:29:44 -04:00
2022-08-23 11:09:54 -04:00
infos = { }
results , info_data = zip ( * results )
for i , info in enumerate ( info_data ) :
infos = self . _add_info ( infos , info , i )
2022-02-06 17:28:27 -06:00
2022-08-23 11:09:54 -04:00
if not self . shared_memory :
self . observations = concatenate (
self . single_observation_space , results , self . observations
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
self . _state = AsyncState . DEFAULT
2022-08-23 11:09:54 -04:00
return ( deepcopy ( self . observations ) if self . copy else self . observations ) , infos
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
def step (
self , actions : ActType
) - > tuple [ ObsType , ArrayType , ArrayType , ArrayType , dict [ str , Any ] ] :
""" Take an action for each parallel environment.
Args :
actions : element of : attr : ` action_space ` batch of actions .
Returns :
Batch of ( observations , rewards , terminations , truncations , infos )
"""
self . step_async ( actions )
return self . step_wait ( )
2022-05-20 14:49:30 +01:00
def step_async ( self , actions : np . ndarray ) :
2023-11-07 13:27:25 +00:00
""" Send the calls to :meth:`Env.step` to each sub-environment.
2021-11-14 08:59:04 -05:00
2022-05-20 14:49:30 +01:00
Args :
2023-11-07 13:27:25 +00:00
actions : Batch of actions . element of : attr : ` VectorEnv . action_space `
2022-05-20 14:49:30 +01:00
Raises :
ClosedEnvironmentError : If the environment was closed ( if : meth : ` close ` was previously called ) .
AlreadyPendingCallError : If the environment is already waiting for a pending call to another
method ( e . g . : meth : ` reset_async ` ) . This can be caused by two consecutive
calls to : meth : ` step_async ` , with no call to : meth : ` step_wait ` in
between .
2019-06-21 17:29:44 -04:00
"""
self . _assert_is_running ( )
if self . _state != AsyncState . DEFAULT :
2021-07-29 02:26:34 +02:00
raise AlreadyPendingCallError (
2021-11-14 14:51:32 +01:00
f " Calling `step_async` while waiting for a pending call to ` { self . _state . value } ` to complete. " ,
2023-11-07 13:27:25 +00:00
str ( self . _state . value ) ,
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
iter_actions = iterate ( self . action_space , actions )
for pipe , action in zip ( self . parent_pipes , iter_actions ) :
2021-07-29 02:26:34 +02:00
pipe . send ( ( " step " , action ) )
2019-06-21 17:29:44 -04:00
self . _state = AsyncState . WAITING_STEP
2022-05-20 14:49:30 +01:00
def step_wait (
2023-11-07 13:27:25 +00:00
self , timeout : int | float | None = None
) - > tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray , dict ] :
2021-11-14 08:59:04 -05:00
""" Wait for the calls to :obj:`step` in each sub-environment to finish.
2022-05-20 14:49:30 +01:00
Args :
timeout : Number of seconds before the call to : meth : ` step_wait ` times out . If ` ` None ` ` , the call to : meth : ` step_wait ` never times out .
2019-06-21 17:29:44 -04:00
2022-05-20 14:49:30 +01:00
Returns :
2022-08-30 19:41:59 +05:30
The batched environment step information , ( obs , reward , terminated , truncated , info )
2019-06-21 17:29:44 -04:00
2022-05-20 14:49:30 +01:00
Raises :
ClosedEnvironmentError : If the environment was closed ( if : meth : ` close ` was previously called ) .
NoAsyncCallError : If : meth : ` step_wait ` was called without any prior call to : meth : ` step_async ` .
TimeoutError : If : meth : ` step_wait ` timed out .
2019-06-21 17:29:44 -04:00
"""
self . _assert_is_running ( )
if self . _state != AsyncState . WAITING_STEP :
2021-07-29 02:26:34 +02:00
raise NoAsyncCallError (
" Calling `step_wait` without any prior call " " to `step_async`. " ,
AsyncState . WAITING_STEP . value ,
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
if not self . _poll_pipe_envs ( timeout ) :
2019-06-21 17:29:44 -04:00
self . _state = AsyncState . DEFAULT
2023-11-07 13:27:25 +00:00
raise multiprocessing . TimeoutError (
2022-01-29 12:32:35 -05:00
f " The call to `step_wait` has timed out after { timeout } second(s). "
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
observations , rewards , terminations , truncations , infos = [ ] , [ ] , [ ] , [ ] , { }
2022-05-24 16:36:35 +02:00
successes = [ ]
2023-11-07 13:27:25 +00:00
for env_idx , pipe in enumerate ( self . parent_pipes ) :
env_step_return , success = pipe . recv ( )
2022-05-24 16:36:35 +02:00
successes . append ( success )
2022-12-04 06:35:12 -08:00
if success :
2023-11-07 13:27:25 +00:00
observations . append ( env_step_return [ 0 ] )
rewards . append ( env_step_return [ 1 ] )
terminations . append ( env_step_return [ 2 ] )
truncations . append ( env_step_return [ 3 ] )
infos = self . _add_info ( infos , env_step_return [ 4 ] , env_idx )
2022-05-24 16:36:35 +02:00
2019-06-28 17:42:21 -04:00
self . _raise_if_errors ( successes )
2019-06-21 17:29:44 -04:00
if not self . shared_memory :
2021-07-29 15:39:42 -04:00
self . observations = concatenate (
2022-01-21 11:28:34 -05:00
self . single_observation_space ,
2023-11-07 13:27:25 +00:00
observations ,
2022-01-21 11:28:34 -05:00
self . observations ,
2021-07-29 15:39:42 -04:00
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
self . _state = AsyncState . DEFAULT
2022-08-30 19:41:59 +05:30
return (
deepcopy ( self . observations ) if self . copy else self . observations ,
2023-11-07 13:27:25 +00:00
np . array ( rewards , dtype = np . float64 ) ,
np . array ( terminations , dtype = np . bool_ ) ,
np . array ( truncations , dtype = np . bool_ ) ,
2022-08-30 19:41:59 +05:30
infos ,
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
def call ( self , name : str , * args : Any , * * kwargs : Any ) - > tuple [ Any , . . . ] :
""" Call a method from each parallel environment with args and kwargs.
Args :
name ( str ) : Name of the method or property to call .
* args : Position arguments to apply to the method call .
* * kwargs : Keyword arguments to apply to the method call .
Returns :
List of the results of the individual calls to the method or property for each environment .
"""
self . call_async ( name , * args , * * kwargs )
return self . call_wait ( )
def render ( self ) - > tuple [ RenderFrame , . . . ] | None :
""" Returns a list of rendered frames from the environments. """
return self . call ( " render " )
2022-05-20 14:49:30 +01:00
def call_async ( self , name : str , * args , * * kwargs ) :
""" Calls the method with name asynchronously and apply args and kwargs to the method.
2022-01-29 12:32:35 -05:00
2022-05-20 14:49:30 +01:00
Args :
name : Name of the method or property to call .
* args : Arguments to apply to the method call .
* * kwargs : Keyword arguments to apply to the method call .
2022-05-25 14:46:41 +01:00
Raises :
ClosedEnvironmentError : If the environment was closed ( if : meth : ` close ` was previously called ) .
AlreadyPendingCallError : Calling ` call_async ` while waiting for a pending call to complete
2022-01-29 12:32:35 -05:00
"""
self . _assert_is_running ( )
if self . _state != AsyncState . DEFAULT :
raise AlreadyPendingCallError (
2023-11-07 13:27:25 +00:00
f " Calling `call_async` while waiting for a pending call to ` { self . _state . value } ` to complete. " ,
str ( self . _state . value ) ,
2022-01-29 12:32:35 -05:00
)
for pipe in self . parent_pipes :
pipe . send ( ( " _call " , ( name , args , kwargs ) ) )
self . _state = AsyncState . WAITING_CALL
2023-11-07 13:27:25 +00:00
def call_wait ( self , timeout : int | float | None = None ) - > tuple [ Any , . . . ] :
2022-05-20 14:49:30 +01:00
""" Calls all parent pipes and waits for the results.
Args :
2023-11-07 13:27:25 +00:00
timeout : Number of seconds before the call to : meth : ` step_wait ` times out .
If ` ` None ` ` ( default ) , the call to : meth : ` step_wait ` never times out .
2022-05-20 14:49:30 +01:00
Returns :
List of the results of the individual calls to the method or property for each environment .
2022-05-25 14:46:41 +01:00
Raises :
2023-11-07 13:27:25 +00:00
NoAsyncCallError : Calling : meth : ` call_wait ` without any prior call to : meth : ` call_async ` .
TimeoutError : The call to : meth : ` call_wait ` has timed out after timeout second ( s ) .
2022-01-29 12:32:35 -05:00
"""
self . _assert_is_running ( )
if self . _state != AsyncState . WAITING_CALL :
raise NoAsyncCallError (
" Calling `call_wait` without any prior call to `call_async`. " ,
AsyncState . WAITING_CALL . value ,
)
2023-11-07 13:27:25 +00:00
if not self . _poll_pipe_envs ( timeout ) :
2022-01-29 12:32:35 -05:00
self . _state = AsyncState . DEFAULT
2023-11-07 13:27:25 +00:00
raise multiprocessing . TimeoutError (
2022-01-29 12:32:35 -05:00
f " The call to `call_wait` has timed out after { timeout } second(s). "
)
results , successes = zip ( * [ pipe . recv ( ) for pipe in self . parent_pipes ] )
self . _raise_if_errors ( successes )
self . _state = AsyncState . DEFAULT
return results
2024-02-26 13:00:18 +01:00
def get_attr ( self , name : str ) - > tuple [ Any , . . . ] :
2023-11-07 13:27:25 +00:00
""" Get a property from each parallel environment.
Args :
name ( str ) : Name of the property to be get from each individual environment .
Returns :
The property with name
"""
return self . call ( name )
def set_attr ( self , name : str , values : list [ Any ] | tuple [ Any ] | object ) :
2022-05-20 14:49:30 +01:00
""" Sets an attribute of the sub-environments.
Args :
name : Name of the property to be set in each individual environment .
values : Values of the property to be set to . If ` ` values ` ` is a list or
tuple , then it corresponds to the values for each individual
environment , otherwise a single value is set for all environments .
2022-05-25 14:46:41 +01:00
Raises :
ValueError : Values must be a list or tuple with length equal to the number of environments .
2023-11-07 13:27:25 +00:00
AlreadyPendingCallError : Calling : meth : ` set_attr ` while waiting for a pending call to complete .
2022-01-29 12:32:35 -05:00
"""
self . _assert_is_running ( )
if not isinstance ( values , ( list , tuple ) ) :
values = [ values for _ in range ( self . num_envs ) ]
if len ( values ) != self . num_envs :
raise ValueError (
2023-11-07 13:27:25 +00:00
" Values must be a list or tuple with length equal to the number of environments. "
f " Got ` { len ( values ) } ` values for { self . num_envs } environments. "
2022-01-29 12:32:35 -05:00
)
if self . _state != AsyncState . DEFAULT :
raise AlreadyPendingCallError (
2023-11-07 13:27:25 +00:00
f " Calling `set_attr` while waiting for a pending call to ` { self . _state . value } ` to complete. " ,
str ( self . _state . value ) ,
2022-01-29 12:32:35 -05:00
)
for pipe , value in zip ( self . parent_pipes , values ) :
pipe . send ( ( " _setattr " , ( name , value ) ) )
_ , successes = zip ( * [ pipe . recv ( ) for pipe in self . parent_pipes ] )
self . _raise_if_errors ( successes )
2023-11-07 13:27:25 +00:00
def close_extras ( self , timeout : int | float | None = None , terminate : bool = False ) :
2022-05-20 14:49:30 +01:00
""" Close the environments & clean up the extra resources (processes and pipes).
Args :
timeout : Number of seconds before the call to : meth : ` close ` times out . If ` ` None ` ` ,
the call to : meth : ` close ` never times out . If the call to : meth : ` close `
times out , then all processes are terminated .
terminate : If ` ` True ` ` , then the : meth : ` close ` operation is forced and all processes are terminated .
Raises :
TimeoutError : If : meth : ` close ` timed out .
2019-06-21 17:29:44 -04:00
"""
timeout = 0 if terminate else timeout
try :
if self . _state != AsyncState . DEFAULT :
2021-07-29 02:26:34 +02:00
logger . warn (
2021-11-14 14:51:32 +01:00
f " Calling `close` while waiting for a pending call to ` { self . _state . value } ` to complete. "
2021-07-29 02:26:34 +02:00
)
2021-11-14 14:51:32 +01:00
function = getattr ( self , f " { self . _state . value } _wait " )
2019-06-21 17:29:44 -04:00
function ( timeout )
2023-11-07 13:27:25 +00:00
except multiprocessing . TimeoutError :
2019-06-21 17:29:44 -04:00
terminate = True
if terminate :
for process in self . processes :
if process . is_alive ( ) :
process . terminate ( )
else :
for pipe in self . parent_pipes :
2019-06-28 18:23:25 -04:00
if ( pipe is not None ) and ( not pipe . closed ) :
2021-07-29 02:26:34 +02:00
pipe . send ( ( " close " , None ) )
2019-06-21 17:29:44 -04:00
for pipe in self . parent_pipes :
2019-06-28 18:23:25 -04:00
if ( pipe is not None ) and ( not pipe . closed ) :
2019-06-21 17:29:44 -04:00
pipe . recv ( )
for pipe in self . parent_pipes :
2019-06-28 18:23:25 -04:00
if pipe is not None :
pipe . close ( )
2019-06-21 17:29:44 -04:00
for process in self . processes :
process . join ( )
2023-11-07 13:27:25 +00:00
def _poll_pipe_envs ( self , timeout : int | None = None ) :
2019-06-21 17:29:44 -04:00
self . _assert_is_running ( )
2023-11-07 13:27:25 +00:00
2019-06-23 13:51:18 -04:00
if timeout is None :
return True
2023-11-07 13:27:25 +00:00
2021-09-12 02:03:54 +09:00
end_time = time . perf_counter ( ) + timeout
2019-06-21 17:29:44 -04:00
for pipe in self . parent_pipes :
2021-09-12 02:03:54 +09:00
delta = max ( end_time - time . perf_counter ( ) , 0 )
2023-11-07 13:27:25 +00:00
2019-06-23 15:36:59 -04:00
if pipe is None :
return False
2019-06-21 17:29:44 -04:00
if pipe . closed or ( not pipe . poll ( delta ) ) :
2019-06-23 13:51:18 -04:00
return False
return True
2019-06-21 17:29:44 -04:00
2021-12-08 21:31:41 -05:00
def _check_spaces ( self ) :
2019-06-21 17:29:44 -04:00
self . _assert_is_running ( )
2021-12-08 21:31:41 -05:00
spaces = ( self . single_observation_space , self . single_action_space )
2023-11-07 13:27:25 +00:00
2019-06-21 17:29:44 -04:00
for pipe in self . parent_pipes :
2021-12-08 21:31:41 -05:00
pipe . send ( ( " _check_spaces " , spaces ) )
2023-11-07 13:27:25 +00:00
2021-12-08 21:31:41 -05:00
results , successes = zip ( * [ pipe . recv ( ) for pipe in self . parent_pipes ] )
2019-06-28 17:42:21 -04:00
self . _raise_if_errors ( successes )
2021-12-08 21:31:41 -05:00
same_observation_spaces , same_action_spaces = zip ( * results )
2023-11-07 13:27:25 +00:00
2021-12-08 21:31:41 -05:00
if not all ( same_observation_spaces ) :
raise RuntimeError (
2023-11-07 13:27:25 +00:00
f " Some environments have an observation space different from ` { self . single_observation_space } `. "
" In order to batch observations, the observation spaces from all environments must be equal. "
2021-12-08 21:31:41 -05:00
)
if not all ( same_action_spaces ) :
2021-07-29 02:26:34 +02:00
raise RuntimeError (
2023-11-07 13:27:25 +00:00
f " Some environments have an action space different from ` { self . single_action_space } `. "
" In order to batch actions, the action spaces from all environments must be equal. "
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
def _assert_is_running ( self ) :
if self . closed :
2021-07-29 02:26:34 +02:00
raise ClosedEnvironmentError (
2021-11-14 14:51:32 +01:00
f " Trying to operate on ` { type ( self ) . __name__ } `, after a call to `close()`. "
2021-07-29 02:26:34 +02:00
)
2019-06-21 17:29:44 -04:00
2024-04-16 18:33:33 +09:00
def _raise_if_errors ( self , successes : list [ bool ] | tuple [ bool ] ) :
2019-06-28 17:42:21 -04:00
if all ( successes ) :
return
num_errors = self . num_envs - sum ( successes )
assert num_errors > 0
2022-07-04 18:19:25 +01:00
for i in range ( num_errors ) :
2019-06-28 17:42:21 -04:00
index , exctype , value = self . error_queue . get ( )
2023-11-07 13:27:25 +00:00
2021-07-29 15:39:42 -04:00
logger . error (
2021-11-14 14:51:32 +01:00
f " Received the following error from Worker- { index } : { exctype . __name__ } : { value } "
2021-07-29 15:39:42 -04:00
)
2021-11-14 14:51:32 +01:00
logger . error ( f " Shutting down Worker- { index } . " )
2023-11-07 13:27:25 +00:00
2019-06-28 17:42:21 -04:00
self . parent_pipes [ index ] . close ( )
self . parent_pipes [ index ] = None
2022-07-04 18:19:25 +01:00
if i == num_errors - 1 :
logger . error ( " Raising the last exception back to the main process. " )
raise exctype ( value )
2019-06-21 17:29:44 -04:00
2021-12-08 19:55:09 -05:00
def __del__ ( self ) :
2022-05-20 14:49:30 +01:00
""" On deleting the object, checks that the vector environment is closed. """
2022-03-14 14:27:03 +00:00
if not getattr ( self , " closed " , True ) and hasattr ( self , " _state " ) :
2021-12-08 19:55:09 -05:00
self . close ( terminate = True )
2019-06-21 17:29:44 -04:00
2023-11-07 13:27:25 +00:00
def _async_worker (
index : int ,
env_fn : callable ,
pipe : Connection ,
parent_pipe : Connection ,
shared_memory : bool ,
error_queue : Queue ,
) :
2019-06-21 17:29:44 -04:00
env = env_fn ( )
2023-11-07 13:27:25 +00:00
observation_space = env . observation_space
action_space = env . action_space
2023-12-03 19:50:18 +01:00
autoreset = False
2023-11-07 13:27:25 +00:00
2019-06-21 17:29:44 -04:00
parent_pipe . close ( )
2023-11-07 13:27:25 +00:00
2019-06-21 17:29:44 -04:00
try :
while True :
command , data = pipe . recv ( )
2023-11-07 13:27:25 +00:00
2021-07-29 02:26:34 +02:00
if command == " reset " :
2022-08-23 11:09:54 -04:00
observation , info = env . reset ( * * data )
2023-11-07 13:27:25 +00:00
if shared_memory :
write_to_shared_memory (
observation_space , index , observation , shared_memory
)
observation = None
2023-12-03 19:50:18 +01:00
autoreset = False
2022-08-23 11:09:54 -04:00
pipe . send ( ( ( observation , info ) , True ) )
2021-07-29 02:26:34 +02:00
elif command == " step " :
2023-12-03 19:50:18 +01:00
if autoreset :
2022-08-23 11:09:54 -04:00
observation , info = env . reset ( )
2023-12-03 19:50:18 +01:00
reward , terminated , truncated = 0 , False , False
else :
(
observation ,
reward ,
terminated ,
truncated ,
info ,
) = env . step ( data )
autoreset = terminated or truncated
2023-11-07 13:27:25 +00:00
if shared_memory :
write_to_shared_memory (
observation_space , index , observation , shared_memory
)
observation = None
2022-07-10 02:18:06 +05:30
pipe . send ( ( ( observation , reward , terminated , truncated , info ) , True ) )
2021-07-29 02:26:34 +02:00
elif command == " close " :
2019-06-28 17:42:21 -04:00
pipe . send ( ( None , True ) )
2019-06-21 17:29:44 -04:00
break
2022-01-29 12:32:35 -05:00
elif command == " _call " :
name , args , kwargs = data
2023-11-07 13:27:25 +00:00
if name in [ " reset " , " step " , " close " , " set_wrapper_attr " ] :
2022-01-29 12:32:35 -05:00
raise ValueError (
2023-11-07 13:27:25 +00:00
f " Trying to call function ` { name } ` with `call`, use ` { name } ` directly instead. "
2022-01-29 12:32:35 -05:00
)
2023-11-07 13:27:25 +00:00
attr = env . get_wrapper_attr ( name )
if callable ( attr ) :
pipe . send ( ( attr ( * args , * * kwargs ) , True ) )
2022-01-29 12:32:35 -05:00
else :
2023-11-07 13:27:25 +00:00
pipe . send ( ( attr , True ) )
2022-01-29 12:32:35 -05:00
elif command == " _setattr " :
name , value = data
2023-11-07 13:27:25 +00:00
env . set_wrapper_attr ( name , value )
2022-01-29 12:32:35 -05:00
pipe . send ( ( None , True ) )
2021-12-08 21:31:41 -05:00
elif command == " _check_spaces " :
pipe . send (
(
2023-11-07 13:27:25 +00:00
( data [ 0 ] == observation_space , data [ 1 ] == action_space ) ,
2021-12-08 21:31:41 -05:00
True ,
)
)
2019-06-21 17:29:44 -04:00
else :
2021-07-29 02:26:34 +02:00
raise RuntimeError (
2023-11-07 13:27:25 +00:00
f " Received unknown command ` { command } `. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]. "
2021-07-29 02:26:34 +02:00
)
2019-06-21 22:14:29 -04:00
except ( KeyboardInterrupt , Exception ) :
2019-06-21 17:29:44 -04:00
error_queue . put ( ( index , ) + sys . exc_info ( ) [ : 2 ] )
2019-06-28 17:42:21 -04:00
pipe . send ( ( None , False ) )
2019-06-21 17:29:44 -04:00
finally :
env . close ( )