2022-05-20 14:49:30 +01:00
""" Base class for vectorized environments. """
2022-07-10 02:18:06 +05:30
from typing import Any , List , Optional , Tuple , Union
2021-12-08 22:14:15 +01:00
2022-05-24 16:36:35 +02:00
import numpy as np
2019-06-21 17:29:44 -04:00
import gym
Update the flake8 pre-commit ignores (#2778)
* Remove additional ignores from flake8
* Remove all unused imports
* Remove all unused imports
* Update flake8 and pyupgrade
* F841, removed unused variables
* E731, removed lambda assignment to variables
* Remove E731, F403, F405, F524
* Remove E722, bare exceptions
* Remove E712, compare variable == True or == False to is True or is False
* Remove E402, module level import not at top of file
* Added --pre-file-ignores
* Add --per-file-ignores removing E741, E302 and E704
* Add E741, do not use variables named ‘l’, ‘O’, or ‘I’ to ignore issues in classic control
* Fixed issues for pytest==6.2
* Remove unnecessary # noqa
* Edit comment with the removal of E302
* Added warnings and declared module, attr for pyright type hinting
* Remove unused import
* Removed flake8 E302
* Updated flake8 from 3.9.2 to 4.0.1
* Remove unused variable
2022-04-26 16:18:37 +01:00
from gym . logger import deprecation
2019-06-21 17:29:44 -04:00
from gym . vector . utils . spaces import batch_space
2021-07-29 02:26:34 +02:00
__all__ = [ " VectorEnv " ]
2019-06-21 17:29:44 -04:00
class VectorEnv ( gym . Env ) :
2022-05-20 14:49:30 +01:00
""" Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel.
2019-10-25 23:17:29 +02:00
2022-05-20 14:49:30 +01:00
This is not the same as 1 environment that has multiple subcomponents , but it is many copies of the same base env .
2021-07-29 02:26:34 +02:00
2022-05-20 14:49:30 +01:00
Each observation returned from vectorized environment is a batch of observations for each parallel environment .
And : meth : ` step ` is also expected to receive a batch of actions for each parallel environment .
2021-07-29 02:26:34 +02:00
2022-05-20 14:49:30 +01:00
Notes :
2022-04-21 11:15:16 -04:00
All parallel environments should share the identical observation and action spaces .
2021-07-29 02:26:34 +02:00
In other words , a vector of multiple different environments is not supported .
2019-06-21 17:29:44 -04:00
"""
2021-07-29 02:26:34 +02:00
2022-05-20 14:49:30 +01:00
def __init__ (
2022-07-10 02:18:06 +05:30
self ,
num_envs : int ,
observation_space : gym . Space ,
action_space : gym . Space ,
new_step_api : bool = False ,
2022-05-20 14:49:30 +01:00
) :
""" Base class for vectorized environments.
Args :
num_envs : Number of environments in the vectorized environment .
observation_space : Observation space of a single environment .
action_space : Action space of a single environment .
2022-07-17 21:50:40 +01:00
new_step_api ( bool ) : Whether the vector environment ' s step method outputs two boolean arrays (new API) or one boolean array (old API)
2022-05-20 14:49:30 +01:00
"""
2019-06-21 17:29:44 -04:00
self . num_envs = num_envs
2021-08-18 16:36:40 -04:00
self . is_vector_env = True
2019-06-21 17:29:44 -04:00
self . observation_space = batch_space ( observation_space , n = num_envs )
2021-12-08 21:31:41 -05:00
self . action_space = batch_space ( action_space , n = num_envs )
2019-06-21 17:29:44 -04:00
self . closed = False
self . viewer = None
# The observation and action spaces of a single environment are
# kept in separate properties
self . single_observation_space = observation_space
self . single_action_space = action_space
2022-07-10 02:18:06 +05:30
self . new_step_api = new_step_api
if not self . new_step_api :
deprecation (
2022-07-17 21:50:40 +01:00
" Initializing vector env in old step API which returns one bool array instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
2022-07-10 02:18:06 +05:30
)
2022-01-19 23:28:59 +01:00
def reset_async (
self ,
2022-05-25 15:28:19 +01:00
seed : Optional [ Union [ int , List [ int ] ] ] = None ,
2022-02-06 17:28:27 -06:00
return_info : bool = False ,
2022-01-19 23:28:59 +01:00
options : Optional [ dict ] = None ,
) :
2022-05-20 14:49:30 +01:00
""" Reset the sub-environments asynchronously.
2022-05-25 14:46:41 +01:00
This method will return ` ` None ` ` . A call to : meth : ` reset_async ` should be followed
by a call to : meth : ` reset_wait ` to retrieve the results .
Args :
seed : The reset seed
return_info : If to return info
options : Reset options
2022-05-20 14:49:30 +01:00
"""
2019-06-21 17:29:44 -04:00
pass
2022-01-19 23:28:59 +01:00
def reset_wait (
self ,
2022-05-25 15:28:19 +01:00
seed : Optional [ Union [ int , List [ int ] ] ] = None ,
2022-02-06 17:28:27 -06:00
return_info : bool = False ,
2022-01-19 23:28:59 +01:00
options : Optional [ dict ] = None ,
) :
2022-05-20 14:49:30 +01:00
""" Retrieves the results of a :meth:`reset_async` call.
A call to this method must always be preceded by a call to : meth : ` reset_async ` .
2022-05-25 14:46:41 +01:00
Args :
seed : The reset seed
return_info : If to return info
options : Reset options
Returns :
The results from : meth : ` reset_async `
Raises :
NotImplementedError : VectorEnv does not implement function
2022-05-20 14:49:30 +01:00
"""
2022-07-04 18:19:25 +01:00
raise NotImplementedError ( " VectorEnv does not implement function " )
2019-06-21 17:29:44 -04:00
2022-01-19 23:28:59 +01:00
def reset (
self ,
* ,
2022-05-25 15:28:19 +01:00
seed : Optional [ Union [ int , List [ int ] ] ] = None ,
2022-02-06 17:28:27 -06:00
return_info : bool = False ,
2022-01-19 23:28:59 +01:00
options : Optional [ dict ] = None ,
) :
2022-05-20 14:49:30 +01:00
""" Reset all parallel environments and return a batch of initial observations.
Args :
seed : The environment reset seeds
return_info : If to return the info
options : If to return the options
2021-07-29 02:26:34 +02:00
2022-05-20 14:49:30 +01:00
Returns :
2019-10-09 15:08:10 -07:00
A batch of observations from the vectorized environment .
"""
2022-02-06 17:28:27 -06:00
self . reset_async ( seed = seed , return_info = return_info , options = options )
return self . reset_wait ( seed = seed , return_info = return_info , options = options )
2019-06-21 17:29:44 -04:00
def step_async ( self , actions ) :
2022-05-20 14:49:30 +01:00
""" Asynchronously performs steps in the sub-environments.
The results can be retrieved via a call to : meth : ` step_wait ` .
2022-05-25 14:46:41 +01:00
Args :
actions : The actions to take asynchronously
2022-05-20 14:49:30 +01:00
"""
2019-06-21 17:29:44 -04:00
def step_wait ( self , * * kwargs ) :
2022-05-20 14:49:30 +01:00
""" Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to : meth : ` step_async ` .
2022-05-25 14:46:41 +01:00
Args :
* * kwargs : Additional keywords for vector implementation
Returns :
The results from the : meth : ` step_async ` call
2022-05-20 14:49:30 +01:00
"""
2019-06-21 17:29:44 -04:00
def step ( self , actions ) :
2022-05-20 14:49:30 +01:00
""" Take an action for each parallel environment.
2019-10-09 15:08:10 -07:00
2022-05-20 14:49:30 +01:00
Args :
actions : element of : attr : ` action_space ` Batch of actions .
2019-10-09 15:08:10 -07:00
2022-05-20 14:49:30 +01:00
Returns :
2022-07-17 21:50:40 +01:00
Batch of ( observations , rewards , terminated , truncated , infos ) or ( observations , rewards , dones , infos )
2019-10-09 15:08:10 -07:00
"""
2019-06-21 17:29:44 -04:00
self . step_async ( actions )
return self . step_wait ( )
2022-01-29 12:32:35 -05:00
def call_async ( self , name , * args , * * kwargs ) :
2022-05-20 14:49:30 +01:00
""" Calls a method name for each parallel environment asynchronously. """
2022-01-29 12:32:35 -05:00
2022-07-10 02:18:06 +05:30
def call_wait ( self , * * kwargs ) - > List [ Any ] : # type: ignore
2022-05-20 14:49:30 +01:00
""" After calling a method in :meth:`call_async`, this function collects the results. """
2022-01-29 12:32:35 -05:00
2022-05-25 15:28:19 +01:00
def call ( self , name : str , * args , * * kwargs ) - > List [ Any ] :
2022-04-21 11:15:16 -04:00
""" Call a method, or get a property, from each parallel environment.
2022-01-29 12:32:35 -05:00
2022-05-20 14:49:30 +01:00
Args :
name ( str ) : Name of the method or property to call .
* args : Arguments to apply to the method call .
* * kwargs : Keyword arguments to apply to the method call .
2022-01-29 12:32:35 -05:00
2022-05-20 14:49:30 +01:00
Returns :
List of the results of the individual calls to the method or property for each environment .
2022-01-29 12:32:35 -05:00
"""
self . call_async ( name , * args , * * kwargs )
return self . call_wait ( )
2022-05-20 14:49:30 +01:00
def get_attr ( self , name : str ) :
2022-04-21 11:15:16 -04:00
""" Get a property from each parallel environment.
2022-01-29 12:32:35 -05:00
2022-05-20 14:49:30 +01:00
Args :
name ( str ) : Name of the property to be get from each individual environment .
Returns :
The property with name
2022-01-29 12:32:35 -05:00
"""
return self . call ( name )
2022-05-20 14:49:30 +01:00
def set_attr ( self , name : str , values : Union [ list , tuple , object ] ) :
""" Set a property in each sub-environment.
2022-01-29 12:32:35 -05:00
2022-05-20 14:49:30 +01:00
Args :
name ( str ) : Name of the property to be set in each individual environment .
values ( list , tuple , or object ) : Values of the property to be set to . If ` values ` is a list or
tuple , then it corresponds to the values for each individual environment , otherwise a single value
is set for all environments .
2022-01-29 12:32:35 -05:00
"""
2019-10-26 00:18:54 +02:00
def close_extras ( self , * * kwargs ) :
2022-05-20 14:49:30 +01:00
""" Clean up the extra resources e.g. beyond what ' s in this base class. """
2021-12-08 19:55:09 -05:00
pass
2019-10-26 00:18:54 +02:00
def close ( self , * * kwargs ) :
2022-05-20 14:49:30 +01:00
""" Close all parallel environments and release resources.
2021-07-29 02:26:34 +02:00
2019-10-26 00:18:54 +02:00
It also closes all the existing image viewers , then calls : meth : ` close_extras ` and set
2021-07-29 02:26:34 +02:00
: attr : ` closed ` as ` ` True ` ` .
2022-05-20 14:49:30 +01:00
Warnings :
2019-10-26 00:18:54 +02:00
This function itself does not close the environments , it should be handled
2021-07-29 02:26:34 +02:00
in : meth : ` close_extras ` . This is generic for both synchronous and asynchronous
vectorized environments .
2022-05-20 14:49:30 +01:00
Notes :
2021-07-29 02:26:34 +02:00
This will be automatically called when garbage collected or program exited .
2022-05-25 14:46:41 +01:00
Args :
* * kwargs : Keyword arguments passed to : meth : ` close_extras `
2019-10-26 00:18:54 +02:00
"""
if self . closed :
return
if self . viewer is not None :
self . viewer . close ( )
self . close_extras ( * * kwargs )
self . closed = True
2021-12-08 22:14:15 +01:00
def seed ( self , seed = None ) :
2022-04-21 11:15:16 -04:00
""" Set the random seed in all parallel environments.
2021-11-14 08:59:04 -05:00
2022-05-20 14:49:30 +01:00
Args :
seed : Random seed for each parallel environment . If ` ` seed ` ` is a list of
length ` ` num_envs ` ` , then the items of the list are chosen as random
seeds . If ` ` seed ` ` is an int , then each parallel environment uses the random
seed ` ` seed + n ` ` , where ` ` n ` ` is the index of the parallel environment
( between ` ` 0 ` ` and ` ` num_envs - 1 ` ` ) .
2019-10-09 15:08:10 -07:00
"""
2021-12-08 22:14:15 +01:00
deprecation (
" Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
" Please use `env.reset(seed=seed) instead in VectorEnvs. "
)
2019-10-09 15:08:10 -07:00
2022-05-24 16:36:35 +02:00
def _add_info ( self , infos : dict , info : dict , env_num : int ) - > dict :
""" Add env info to the info dictionary of the vectorized environment.
Given the ` info ` of a single environment add it to the ` infos ` dictionary
which represents all the infos of the vectorized environment .
Every ` key ` of ` info ` is paired with a boolean mask ` _key ` representing
whether or not the i - indexed environment has this ` info ` .
Args :
infos ( dict ) : the infos of the vectorized environment
info ( dict ) : the info coming from the single environment
env_num ( int ) : the index of the single environment
Returns :
infos ( dict ) : the ( updated ) infos of the vectorized environment
"""
for k in info . keys ( ) :
if k not in infos :
info_array , array_mask = self . _init_info_arrays ( type ( info [ k ] ) )
else :
info_array , array_mask = infos [ k ] , infos [ f " _ { k } " ]
info_array [ env_num ] , array_mask [ env_num ] = info [ k ] , True
infos [ k ] , infos [ f " _ { k } " ] = info_array , array_mask
return infos
2022-07-10 02:18:06 +05:30
def _init_info_arrays ( self , dtype : type ) - > Tuple [ np . ndarray , np . ndarray ] :
2022-05-24 16:36:35 +02:00
""" Initialize the info array.
Initialize the info array . If the dtype is numeric
the info array will have the same dtype , otherwise
will be an array of ` None ` . Also , a boolean array
of the same length is returned . It will be used for
assessing which environment has info data .
Args :
dtype ( type ) : data type of the info coming from the env .
Returns :
array ( np . ndarray ) : the initialized info array .
array_mask ( np . ndarray ) : the initialized boolean array .
"""
if dtype in [ int , float , bool ] or issubclass ( dtype , np . number ) :
array = np . zeros ( self . num_envs , dtype = dtype )
else :
array = np . zeros ( self . num_envs , dtype = object )
array [ : ] = None
array_mask = np . zeros ( self . num_envs , dtype = bool )
return array , array_mask
2019-06-21 17:29:44 -04:00
def __del__ ( self ) :
2022-05-20 14:49:30 +01:00
""" Closes the vector environment. """
2021-07-29 02:26:34 +02:00
if not getattr ( self , " closed " , True ) :
2021-12-08 19:55:09 -05:00
self . close ( )
2019-10-26 00:38:52 +02:00
2022-05-25 14:46:41 +01:00
def __repr__ ( self ) - > str :
""" Returns a string representation of the vector environment.
Returns :
A string containing the class name , number of environments and environment spec id
"""
2019-10-26 00:38:52 +02:00
if self . spec is None :
2021-11-14 14:51:32 +01:00
return f " { self . __class__ . __name__ } ( { self . num_envs } ) "
2019-10-26 00:38:52 +02:00
else :
2021-11-14 14:51:32 +01:00
return f " { self . __class__ . __name__ } ( { self . spec . id } , { self . num_envs } ) "
2019-11-01 22:29:39 +01:00
class VectorEnvWrapper ( VectorEnv ) :
2022-05-20 14:49:30 +01:00
""" Wraps the vectorized environment to allow a modular transformation.
2021-07-29 02:26:34 +02:00
2019-11-01 22:29:39 +01:00
This class is the base class for all wrappers for vectorized environments . The subclass
could override some methods to change the behavior of the original vectorized environment
2021-07-29 02:26:34 +02:00
without touching the original code .
2022-05-20 14:49:30 +01:00
Notes :
2019-11-01 22:29:39 +01:00
Don ' t forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
2021-07-29 02:26:34 +02:00
2022-05-20 14:49:30 +01:00
def __init__ ( self , env : VectorEnv ) :
2019-11-01 22:29:39 +01:00
assert isinstance ( env , VectorEnv )
self . env = env
2020-08-14 14:20:56 -07:00
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
2021-12-08 22:14:15 +01:00
def reset_async ( self , * * kwargs ) :
return self . env . reset_async ( * * kwargs )
2020-08-14 14:20:56 -07:00
2021-12-08 22:14:15 +01:00
def reset_wait ( self , * * kwargs ) :
return self . env . reset_wait ( * * kwargs )
2020-08-14 14:20:56 -07:00
def step_async ( self , actions ) :
return self . env . step_async ( actions )
def step_wait ( self ) :
return self . env . step_wait ( )
def close ( self , * * kwargs ) :
return self . env . close ( * * kwargs )
def close_extras ( self , * * kwargs ) :
return self . env . close_extras ( * * kwargs )
2021-12-08 22:14:15 +01:00
def seed ( self , seed = None ) :
return self . env . seed ( seed )
2020-08-14 14:20:56 -07:00
2022-05-06 20:19:46 +05:30
def call ( self , name , * args , * * kwargs ) :
return self . env . call ( name , * args , * * kwargs )
def set_attr ( self , name , values ) :
return self . env . set_attr ( name , values )
2020-08-14 14:20:56 -07:00
# implicitly forward all other methods and attributes to self.env
2019-11-01 22:29:39 +01:00
def __getattr__ ( self , name ) :
2021-07-29 02:26:34 +02:00
if name . startswith ( " _ " ) :
2021-11-14 14:51:32 +01:00
raise AttributeError ( f " attempted to get missing private attribute ' { name } ' " )
2019-11-01 22:29:39 +01:00
return getattr ( self . env , name )
@property
def unwrapped ( self ) :
return self . env . unwrapped
def __repr__ ( self ) :
2021-11-14 14:51:32 +01:00
return f " < { self . __class__ . __name__ } , { self . env } > "
2022-05-06 20:19:46 +05:30
def __del__ ( self ) :
self . env . __del__ ( )