2022-12-10 22:04:14 +00:00
""" Functions for registering environments within gymnasium using public functions ``make``, ``register`` and ``spec``. """
from __future__ import annotations
2022-03-31 12:50:38 -07:00
import contextlib
2020-04-24 23:49:41 +02:00
import copy
2023-02-24 11:34:20 +00:00
import dataclasses
2022-01-19 13:50:25 -05:00
import difflib
2019-03-08 14:50:32 -08:00
import importlib
2021-12-22 18:51:33 -05:00
import importlib . util
2023-02-24 11:34:20 +00:00
import json
2022-03-31 12:50:38 -07:00
import re
import sys
2023-01-11 01:12:12 +00:00
import traceback
2022-11-16 12:59:42 +00:00
from collections import defaultdict
2022-04-21 20:41:15 +02:00
from dataclasses import dataclass , field
2023-02-12 07:49:37 -05:00
from typing import Any , Callable , Iterable , Sequence
2022-04-21 20:41:15 +02:00
2023-06-21 17:04:11 +01:00
import gymnasium as gym
2023-02-12 07:49:37 -05:00
from gymnasium import Env , Wrapper , error , logger
2022-04-21 20:41:15 +02:00
2022-12-04 22:24:02 +08:00
2021-12-22 13:54:20 -05:00
if sys . version_info < ( 3 , 10 ) :
2021-12-22 19:12:57 +01:00
import importlib_metadata as metadata # type: ignore
2021-09-14 20:14:05 -06:00
else :
import importlib . metadata as metadata
2023-07-03 17:28:18 +03:00
from typing import Protocol
2022-04-21 20:41:15 +02:00
2022-12-04 22:24:02 +08:00
2022-05-25 15:28:19 +01:00
ENV_ID_RE = re . compile (
2022-01-19 13:50:25 -05:00
r " ^(?:(?P<namespace>[ \ w:-]+) \ /)?(?:(?P<name>[ \ w:.-]+?))(?:-v(?P<version> \ d+))?$ "
)
2021-09-14 20:14:05 -06:00
2019-03-08 14:50:32 -08:00
2023-02-05 00:05:59 +00:00
__all__ = [
" registry " ,
" current_namespace " ,
2023-02-24 11:34:20 +00:00
" EnvSpec " ,
" WrapperSpec " ,
# Functions
2023-02-05 00:05:59 +00:00
" register " ,
" make " ,
2023-02-24 11:34:20 +00:00
" make_vec " ,
2023-02-05 00:05:59 +00:00
" spec " ,
" pprint_registry " ,
]
2022-06-06 16:21:45 +01:00
2023-02-05 00:05:59 +00:00
class EnvCreator ( Protocol ) :
""" Function type expected for an environment. """
def __call__ ( self , * * kwargs : Any ) - > Env :
. . .
2023-02-12 07:49:37 -05:00
class VectorEnvCreator ( Protocol ) :
""" Function type expected for an environment. """
2023-06-21 17:04:11 +01:00
def __call__ ( self , * * kwargs : Any ) - > gym . experimental . vector . VectorEnv :
2023-02-12 07:49:37 -05:00
. . .
2023-02-24 11:34:20 +00:00
@dataclass
class WrapperSpec :
""" A specification for recording wrapper configs.
* name : The name of the wrapper .
* entry_point : The location of the wrapper to create from .
* kwargs : Additional keyword arguments passed to the wrapper . If the wrapper doesn ' t inherit from EzPickle then this is ``None``
"""
name : str
entry_point : str
kwargs : dict [ str , Any ] | None
2023-02-05 00:05:59 +00:00
@dataclass
class EnvSpec :
""" A specification for creating environments with :meth:`gymnasium.make`.
* * * id * * : The string used to create the environment with : meth : ` gymnasium . make `
* * * entry_point * * : A string for the environment location , ` ` ( import path ) : ( environment name ) ` ` or a function that creates the environment .
* * * reward_threshold * * : The reward threshold for completing the environment .
* * * nondeterministic * * : If the observation of an environment cannot be repeated with the same initial state , random number generator state and actions .
* * * max_episode_steps * * : The max number of steps that the environment can take before truncation
* * * order_enforce * * : If to enforce the order of : meth : ` gymnasium . Env . reset ` before : meth : ` gymnasium . Env . step ` and : meth : ` gymnasium . Env . render ` functions
* * * autoreset * * : If to automatically reset the environment on episode end
* * * disable_env_checker * * : If to disable the environment checker wrapper in : meth : ` gymnasium . make ` , by default False ( runs the environment checker )
* * * kwargs * * : Additional keyword arguments passed to the environment during initialisation
2023-03-08 14:07:09 +00:00
* * * additional_wrappers * * : A tuple of additional wrappers applied to the environment ( WrapperSpec )
2023-02-12 07:49:37 -05:00
* * * vector_entry_point * * : The location of the vectorized environment to create from
2022-06-06 16:21:45 +01:00
"""
2019-03-08 14:50:32 -08:00
2023-02-05 00:05:59 +00:00
id : str
2023-02-12 07:49:37 -05:00
entry_point : EnvCreator | str | None = field ( default = None )
2023-02-05 00:05:59 +00:00
# Environment attributes
reward_threshold : float | None = field ( default = None )
nondeterministic : bool = field ( default = False )
# Wrappers
max_episode_steps : int | None = field ( default = None )
order_enforce : bool = field ( default = True )
autoreset : bool = field ( default = False )
disable_env_checker : bool = field ( default = False )
apply_api_compatibility : bool = field ( default = False )
2023-02-24 11:34:20 +00:00
# Environment arguments
kwargs : dict = field ( default_factory = dict )
2023-02-05 00:05:59 +00:00
# post-init attributes
namespace : str | None = field ( init = False )
name : str = field ( init = False )
version : int | None = field ( init = False )
2016-04-27 08:00:58 -07:00
2023-02-24 11:34:20 +00:00
# applied wrappers
2023-03-08 14:07:09 +00:00
additional_wrappers : tuple [ WrapperSpec , . . . ] = field ( default_factory = tuple )
2022-01-19 13:50:25 -05:00
2023-02-24 11:34:20 +00:00
# Vectorized environment entry point
vector_entry_point : VectorEnvCreator | str | None = field ( default = None )
2023-02-12 07:49:37 -05:00
2023-02-05 00:05:59 +00:00
def __post_init__ ( self ) :
2023-02-24 11:34:20 +00:00
""" Calls after the spec is created to extract the namespace, name and version from the environment id. """
2023-02-05 00:05:59 +00:00
self . namespace , self . name , self . version = parse_env_id ( self . id )
2022-01-19 13:50:25 -05:00
2023-02-05 00:05:59 +00:00
def make ( self , * * kwargs : Any ) - > Env :
""" Calls ``make`` using the environment spec and any keyword arguments. """
return make ( self , * * kwargs )
2023-02-24 11:34:20 +00:00
def to_json ( self ) - > str :
""" Converts the environment spec into a json compatible string.
Returns :
A jsonifyied string for the environment spec
"""
env_spec_dict = dataclasses . asdict ( self )
# As the namespace, name and version are initialised after `init` then we remove the attributes
env_spec_dict . pop ( " namespace " )
env_spec_dict . pop ( " name " )
env_spec_dict . pop ( " version " )
# To check that the environment spec can be transformed to a json compatible type
self . _check_can_jsonify ( env_spec_dict )
return json . dumps ( env_spec_dict )
@staticmethod
def _check_can_jsonify ( env_spec : dict [ str , Any ] ) :
""" Warns the user about serialisation failing if the spec contains a callable.
Args :
env_spec : An environment or wrapper specification .
Returns : The specification with lambda functions converted to strings .
"""
spec_name = env_spec [ " name " ] if " name " in env_spec else env_spec [ " id " ]
for key , value in env_spec . items ( ) :
if callable ( value ) :
ValueError (
f " Callable found in { spec_name } for { key } attribute with value= { value } . Currently, Gymnasium does not support serialising callables. "
)
@staticmethod
def from_json ( json_env_spec : str ) - > EnvSpec :
""" Converts a JSON string into a specification stack.
Args :
json_env_spec : A JSON string representing the env specification .
Returns :
An environment spec
"""
parsed_env_spec = json . loads ( json_env_spec )
applied_wrapper_specs : list [ WrapperSpec ] = [ ]
2023-03-08 14:07:09 +00:00
for wrapper_spec_json in parsed_env_spec . pop ( " additional_wrappers " ) :
2023-02-24 11:34:20 +00:00
try :
applied_wrapper_specs . append ( WrapperSpec ( * * wrapper_spec_json ) )
except Exception as e :
raise ValueError (
f " An issue occurred when trying to make { wrapper_spec_json } a WrapperSpec "
) from e
try :
env_spec = EnvSpec ( * * parsed_env_spec )
2023-03-08 14:07:09 +00:00
env_spec . additional_wrappers = tuple ( applied_wrapper_specs )
2023-02-24 11:34:20 +00:00
except Exception as e :
raise ValueError (
f " An issue occurred when trying to make { parsed_env_spec } an EnvSpec "
) from e
return env_spec
def pprint (
self ,
disable_print : bool = False ,
include_entry_points : bool = False ,
print_all : bool = False ,
) - > str | None :
""" Pretty prints the environment spec.
Args :
disable_print : If to disable print and return the output
include_entry_points : If to include the entry_points in the output
print_all : If to print all information , including variables with default values
Returns :
If ` ` disable_print is True ` ` a string otherwise ` ` None ` `
"""
output = f " id= { self . id } "
if print_all or include_entry_points :
output + = f " \n entry_point= { self . entry_point } "
if print_all or self . reward_threshold is not None :
output + = f " \n reward_threshold= { self . reward_threshold } "
if print_all or self . nondeterministic is not False :
output + = f " \n nondeterministic= { self . nondeterministic } "
if print_all or self . max_episode_steps is not None :
output + = f " \n max_episode_steps= { self . max_episode_steps } "
if print_all or self . order_enforce is not True :
output + = f " \n order_enforce= { self . order_enforce } "
if print_all or self . autoreset is not False :
output + = f " \n autoreset= { self . autoreset } "
if print_all or self . disable_env_checker is not False :
output + = f " \n disable_env_checker= { self . disable_env_checker } "
if print_all or self . apply_api_compatibility is not False :
output + = f " \n applied_api_compatibility= { self . apply_api_compatibility } "
2023-03-08 14:07:09 +00:00
if print_all or self . additional_wrappers :
2023-02-24 11:34:20 +00:00
wrapper_output : list [ str ] = [ ]
2023-03-08 14:07:09 +00:00
for wrapper_spec in self . additional_wrappers :
2023-02-24 11:34:20 +00:00
if include_entry_points :
wrapper_output . append (
f " \n \t name= { wrapper_spec . name } , entry_point= { wrapper_spec . entry_point } , kwargs= { wrapper_spec . kwargs } "
)
else :
wrapper_output . append (
f " \n \t name= { wrapper_spec . name } , kwargs= { wrapper_spec . kwargs } "
)
if len ( wrapper_output ) == 0 :
2023-03-08 14:07:09 +00:00
output + = " \n additional_wrappers=[] "
2023-02-24 11:34:20 +00:00
else :
2023-03-08 14:07:09 +00:00
output + = f " \n additional_wrappers=[ { ' , ' . join ( wrapper_output ) } \n ] "
2023-02-24 11:34:20 +00:00
if disable_print :
return output
else :
print ( output )
2023-02-05 00:05:59 +00:00
# Global registry of environments. Meant to be accessed through `register` and `make`
registry : dict [ str , EnvSpec ] = { }
current_namespace : str | None = None
def parse_env_id ( env_id : str ) - > tuple [ str | None , str , int | None ] :
""" Parse environment ID string format - ``[namespace/](env-name)[-v(version)]`` where the namespace and version are optional.
2022-05-25 14:46:41 +01:00
Args :
2023-02-05 00:05:59 +00:00
env_id : The environment id to parse
2022-05-25 14:46:41 +01:00
Returns :
A tuple of environment namespace , environment name and version number
Raises :
2023-02-05 00:05:59 +00:00
Error : If the environment id is not valid environment regex
2022-01-19 13:50:25 -05:00
"""
2023-02-05 00:05:59 +00:00
match = ENV_ID_RE . fullmatch ( env_id )
2022-01-19 13:50:25 -05:00
if not match :
raise error . Error (
2023-02-05 00:05:59 +00:00
f " Malformed environment ID: { env_id } . (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional)) "
2022-01-19 13:50:25 -05:00
)
2023-02-05 00:05:59 +00:00
ns , name , version = match . group ( " namespace " , " name " , " version " )
2022-01-19 13:50:25 -05:00
if version is not None :
version = int ( version )
2023-02-05 00:05:59 +00:00
return ns , name , version
2022-01-19 13:50:25 -05:00
2022-12-10 22:04:14 +00:00
def get_env_id ( ns : str | None , name : str , version : int | None ) - > str :
2022-05-25 14:46:41 +01:00
""" Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
Args :
ns : The environment namespace
name : The environment name
version : The environment version
Returns :
The environment id
"""
2022-04-21 20:41:15 +02:00
full_name = name
if ns is not None :
2023-02-05 00:05:59 +00:00
full_name = f " { ns } / { name } "
if version is not None :
full_name = f " { full_name } -v { version } "
2022-07-11 02:45:24 +01:00
2023-02-05 00:05:59 +00:00
return full_name
2022-07-11 02:45:24 +01:00
2022-04-21 20:41:15 +02:00
2023-02-05 00:05:59 +00:00
def find_highest_version ( ns : str | None , name : str ) - > int | None :
""" Finds the highest registered version of the environment given the namespace and name in the registry.
2022-01-19 13:50:25 -05:00
2023-02-05 00:05:59 +00:00
Args :
ns : The environment namespace
name : The environment name ( id )
2016-04-27 08:00:58 -07:00
2023-02-05 00:05:59 +00:00
Returns :
The highest version of an environment with matching namespace and name , otherwise ` ` None ` ` is returned .
"""
version : list [ int ] = [
env_spec . version
for env_spec in registry . values ( )
if env_spec . namespace == ns
and env_spec . name == name
and env_spec . version is not None
]
return max ( version , default = None )
2016-04-27 08:00:58 -07:00
2022-04-08 09:54:49 -05:00
2022-12-10 22:04:14 +00:00
def _check_namespace_exists ( ns : str | None ) :
2022-04-21 20:41:15 +02:00
""" Check if a namespace exists. If it doesn ' t, print a helpful error message. """
2023-02-05 00:05:59 +00:00
# If the namespace is none, then the namespace does exist
2022-04-21 20:41:15 +02:00
if ns is None :
return
2023-02-05 00:05:59 +00:00
# Check if the namespace exists in one of the registry's specs
namespaces : set [ str ] = {
env_spec . namespace
for env_spec in registry . values ( )
if env_spec . namespace is not None
2022-04-21 20:41:15 +02:00
}
if ns in namespaces :
return
2022-01-13 15:59:55 -05:00
2023-02-05 00:05:59 +00:00
# Otherwise, the namespace doesn't exist and raise a helpful message
2022-04-21 20:41:15 +02:00
suggestion = (
difflib . get_close_matches ( ns , namespaces , n = 1 ) if len ( namespaces ) > 0 else None
)
2023-02-05 00:05:59 +00:00
if suggestion :
suggestion_msg = f " Did you mean: ` { suggestion [ 0 ] } `? "
else :
suggestion_msg = f " Have you installed the proper package for { ns } ? "
2022-04-08 09:54:49 -05:00
2022-04-21 20:41:15 +02:00
raise error . NamespaceNotFound ( f " Namespace { ns } not found. { suggestion_msg } " )
2017-02-01 13:10:59 -08:00
2022-04-08 09:54:49 -05:00
2022-12-10 22:04:14 +00:00
def _check_name_exists ( ns : str | None , name : str ) :
2022-04-21 20:41:15 +02:00
""" Check if an env exists in a namespace. If it doesn ' t, print a helpful error message. """
2023-02-05 00:05:59 +00:00
# First check if the namespace exists
2022-04-21 20:41:15 +02:00
_check_namespace_exists ( ns )
2022-04-08 09:54:49 -05:00
2023-02-05 00:05:59 +00:00
# Then check if the name exists
names : set [ str ] = {
env_spec . name for env_spec in registry . values ( ) if env_spec . namespace == ns
}
2022-04-21 20:41:15 +02:00
if name in names :
return
2022-04-08 09:54:49 -05:00
2023-02-05 00:05:59 +00:00
# Otherwise, raise a helpful error to the user
2022-04-21 20:41:15 +02:00
suggestion = difflib . get_close_matches ( name , names , n = 1 )
namespace_msg = f " in namespace { ns } " if ns else " "
2023-02-05 00:05:59 +00:00
suggestion_msg = f " Did you mean: ` { suggestion [ 0 ] } `? " if suggestion else " "
2021-11-20 10:43:36 -05:00
2022-04-21 20:41:15 +02:00
raise error . NameNotFound (
2023-02-05 00:05:59 +00:00
f " Environment ` { name } ` doesn ' t exist { namespace_msg } . { suggestion_msg } "
2022-04-21 20:41:15 +02:00
)
2021-11-20 10:43:36 -05:00
2022-12-10 22:04:14 +00:00
def _check_version_exists ( ns : str | None , name : str , version : int | None ) :
2022-04-21 20:41:15 +02:00
""" Check if an env version exists in a namespace. If it doesn ' t, print a helpful error message.
2022-12-10 22:04:14 +00:00
2022-05-25 14:46:41 +01:00
This is a complete test whether an environment identifier is valid , and will provide the best available hints .
Args :
ns : The environment namespace
name : The environment space
version : The environment version
Raises :
DeprecatedEnv : The environment doesn ' t exist but a default version does
VersionNotFound : The ` ` version ` ` used doesn ' t exist
DeprecatedEnv : Environment version is deprecated
"""
2022-04-21 20:41:15 +02:00
if get_env_id ( ns , name , version ) in registry :
return
2022-01-19 13:50:25 -05:00
2022-04-21 20:41:15 +02:00
_check_name_exists ( ns , name )
if version is None :
return
2022-01-19 13:50:25 -05:00
2022-04-21 20:41:15 +02:00
message = f " Environment version `v { version } ` for environment ` { get_env_id ( ns , name , None ) } ` doesn ' t exist. "
2022-01-19 13:50:25 -05:00
2022-04-21 20:41:15 +02:00
env_specs = [
2023-02-05 00:05:59 +00:00
env_spec
for env_spec in registry . values ( )
if env_spec . namespace == ns and env_spec . name == name
2022-04-21 20:41:15 +02:00
]
2023-02-05 00:05:59 +00:00
env_specs = sorted ( env_specs , key = lambda env_spec : int ( env_spec . version or - 1 ) )
2022-01-19 13:50:25 -05:00
2023-02-05 00:05:59 +00:00
default_spec = [ env_spec for env_spec in env_specs if env_spec . version is None ]
2022-01-19 13:50:25 -05:00
2022-04-21 20:41:15 +02:00
if default_spec :
2023-02-05 00:05:59 +00:00
message + = f " It provides the default version ` { default_spec [ 0 ] . id } `. "
2022-04-21 20:41:15 +02:00
if len ( env_specs ) == 1 :
2022-01-19 13:50:25 -05:00
raise error . DeprecatedEnv ( message )
2016-04-27 08:00:58 -07:00
2022-04-21 20:41:15 +02:00
# Process possible versioned environments
2021-11-20 10:43:36 -05:00
2023-02-05 00:05:59 +00:00
versioned_specs = [
env_spec for env_spec in env_specs if env_spec . version is not None
]
2017-02-01 13:10:59 -08:00
2023-02-05 00:05:59 +00:00
latest_spec = max ( versioned_specs , key = lambda env_spec : env_spec . version , default = None ) # type: ignore
2022-04-21 20:41:15 +02:00
if latest_spec is not None and version > latest_spec . version :
2023-02-05 00:05:59 +00:00
version_list_msg = " , " . join ( f " `v { env_spec . version } ` " for env_spec in env_specs )
2022-04-21 20:41:15 +02:00
message + = f " It provides versioned environments: [ { version_list_msg } ]. "
2016-04-27 08:00:58 -07:00
2022-04-21 20:41:15 +02:00
raise error . VersionNotFound ( message )
2019-03-08 14:50:32 -08:00
2022-04-21 20:41:15 +02:00
if latest_spec is not None and version < latest_spec . version :
raise error . DeprecatedEnv (
f " Environment version v { version } for ` { get_env_id ( ns , name , None ) } ` is deprecated. "
f " Please use ` { latest_spec . id } ` instead. "
)
2016-04-27 08:00:58 -07:00
2022-01-19 13:50:25 -05:00
2023-02-05 00:05:59 +00:00
def _check_spec_register ( testing_spec : EnvSpec ) :
""" Checks whether the spec is valid to be registered. Helper function for `register`. """
latest_versioned_spec = max (
(
env_spec
for env_spec in registry . values ( )
if env_spec . namespace == testing_spec . namespace
and env_spec . name == testing_spec . name
and env_spec . version is not None
) ,
key = lambda spec_ : int ( spec_ . version ) , # type: ignore
default = None ,
)
unversioned_spec = next (
(
env_spec
for env_spec in registry . values ( )
if env_spec . namespace == testing_spec . namespace
and env_spec . name == testing_spec . name
and env_spec . version is None
) ,
None ,
)
if unversioned_spec is not None and testing_spec . version is not None :
raise error . RegistrationError (
" Can ' t register the versioned environment "
f " ` { testing_spec . id } ` when the unversioned environment "
f " ` { unversioned_spec . id } ` of the same name already exists. "
)
elif latest_versioned_spec is not None and testing_spec . version is None :
raise error . RegistrationError (
f " Can ' t register the unversioned environment ` { testing_spec . id } ` when the versioned environment "
f " ` { latest_versioned_spec . id } ` of the same name already exists. Note: the default behavior is "
" that `gym.make` with the unversioned environment will return the latest versioned environment "
)
def _check_metadata ( testing_metadata : dict [ str , Any ] ) :
""" Check the metadata of an environment. """
if not isinstance ( testing_metadata , dict ) :
raise error . InvalidMetadata (
f " Expect the environment metadata to be dict, actual type: { type ( metadata ) } "
)
render_modes = testing_metadata . get ( " render_modes " )
if render_modes is None :
logger . warn (
f " The environment creator metadata doesn ' t include `render_modes`, contains: { list ( testing_metadata . keys ( ) ) } "
)
elif not isinstance ( render_modes , Iterable ) :
logger . warn (
f " Expects the environment metadata render_modes to be a Iterable, actual type: { type ( render_modes ) } "
)
2023-02-24 11:34:20 +00:00
def _find_spec ( env_id : str ) - > EnvSpec :
# For string id's, load the environment spec from the registry then make the environment spec
assert isinstance ( env_id , str )
# The environment name can include an unloaded module in "module:env_name" style
module , env_name = ( None , env_id ) if " : " not in env_id else env_id . split ( " : " )
2023-02-12 07:49:37 -05:00
if module is not None :
try :
importlib . import_module ( module )
except ModuleNotFoundError as e :
raise ModuleNotFoundError (
f " { e } . Environment registration via importing a module failed. "
f " Check whether ' { module } ' contains env registration and can be imported. "
) from e
# load the env spec from the registry
env_spec = registry . get ( env_name )
# update env spec is not version provided, raise warning if out of date
ns , name , version = parse_env_id ( env_name )
latest_version = find_highest_version ( ns , name )
if version is not None and latest_version is not None and latest_version > version :
2023-03-13 12:10:28 +01:00
logger . deprecation (
2023-02-12 07:49:37 -05:00
f " The environment { env_name } is out of date. You should consider "
f " upgrading to version `v { latest_version } `. "
)
if version is None and latest_version is not None :
version = latest_version
new_env_id = get_env_id ( ns , name , version )
env_spec = registry . get ( new_env_id )
logger . warn (
f " Using the latest versioned environment ` { new_env_id } ` "
f " instead of the unversioned environment ` { env_name } `. "
)
if env_spec is None :
_check_version_exists ( ns , name , version )
raise error . Error ( f " No registered env with id: { env_name } " )
return env_spec
2023-02-24 11:34:20 +00:00
def load_env_creator ( name : str ) - > EnvCreator | VectorEnvCreator :
2023-02-05 00:05:59 +00:00
""" Loads an environment with name of style `` " (import path):(environment name) " `` and returns the environment creation function, normally the environment class type.
Args :
name : The environment name
Returns :
The environment constructor for the given environment name .
"""
mod_name , attr_name = name . split ( " : " )
mod = importlib . import_module ( mod_name )
fn = getattr ( mod , attr_name )
return fn
2022-01-19 13:50:25 -05:00
2021-11-20 10:43:36 -05:00
2023-02-05 00:05:59 +00:00
def load_plugin_envs ( entry_point : str = " gymnasium.envs " ) :
""" Load modules (plugins) using the gymnasium entry points in order to register external module ' s environments on ``import gymnasium``.
2022-12-10 22:04:14 +00:00
Args :
entry_point : The string for the entry point .
"""
2022-04-21 20:41:15 +02:00
# Load third-party environments
for plugin in metadata . entry_points ( group = entry_point ) :
# Python 3.8 doesn't support plugin.module, plugin.attr
# So we'll have to try and parse this ourselves
Update the flake8 pre-commit ignores (#2778)
* Remove additional ignores from flake8
* Remove all unused imports
* Remove all unused imports
* Update flake8 and pyupgrade
* F841, removed unused variables
* E731, removed lambda assignment to variables
* Remove E731, F403, F405, F524
* Remove E722, bare exceptions
* Remove E712, compare variable == True or == False to is True or is False
* Remove E402, module level import not at top of file
* Added --pre-file-ignores
* Add --per-file-ignores removing E741, E302 and E704
* Add E741, do not use variables named ‘l’, ‘O’, or ‘I’ to ignore issues in classic control
* Fixed issues for pytest==6.2
* Remove unnecessary # noqa
* Edit comment with the removal of E302
* Added warnings and declared module, attr for pyright type hinting
* Remove unused import
* Removed flake8 E302
* Updated flake8 from 3.9.2 to 4.0.1
* Remove unused variable
2022-04-26 16:18:37 +01:00
module , attr = None , None
2022-04-21 20:41:15 +02:00
try :
module , attr = plugin . module , plugin . attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
except AttributeError :
if " : " in plugin . value :
module , attr = plugin . value . split ( " : " , maxsplit = 1 )
else :
module , attr = plugin . value , None
Update the flake8 pre-commit ignores (#2778)
* Remove additional ignores from flake8
* Remove all unused imports
* Remove all unused imports
* Update flake8 and pyupgrade
* F841, removed unused variables
* E731, removed lambda assignment to variables
* Remove E731, F403, F405, F524
* Remove E722, bare exceptions
* Remove E712, compare variable == True or == False to is True or is False
* Remove E402, module level import not at top of file
* Added --pre-file-ignores
* Add --per-file-ignores removing E741, E302 and E704
* Add E741, do not use variables named ‘l’, ‘O’, or ‘I’ to ignore issues in classic control
* Fixed issues for pytest==6.2
* Remove unnecessary # noqa
* Edit comment with the removal of E302
* Added warnings and declared module, attr for pyright type hinting
* Remove unused import
* Removed flake8 E302
* Updated flake8 from 3.9.2 to 4.0.1
* Remove unused variable
2022-04-26 16:18:37 +01:00
except Exception as e :
2023-02-05 00:05:59 +00:00
logger . warn (
Update the flake8 pre-commit ignores (#2778)
* Remove additional ignores from flake8
* Remove all unused imports
* Remove all unused imports
* Update flake8 and pyupgrade
* F841, removed unused variables
* E731, removed lambda assignment to variables
* Remove E731, F403, F405, F524
* Remove E722, bare exceptions
* Remove E712, compare variable == True or == False to is True or is False
* Remove E402, module level import not at top of file
* Added --pre-file-ignores
* Add --per-file-ignores removing E741, E302 and E704
* Add E741, do not use variables named ‘l’, ‘O’, or ‘I’ to ignore issues in classic control
* Fixed issues for pytest==6.2
* Remove unnecessary # noqa
* Edit comment with the removal of E302
* Added warnings and declared module, attr for pyright type hinting
* Remove unused import
* Removed flake8 E302
* Updated flake8 from 3.9.2 to 4.0.1
* Remove unused variable
2022-04-26 16:18:37 +01:00
f " While trying to load plugin ` { plugin } ` from { entry_point } , an exception occurred: { e } "
)
2022-04-21 20:41:15 +02:00
module , attr = None , None
2022-01-19 13:50:25 -05:00
finally :
2022-04-21 20:41:15 +02:00
if attr is None :
raise error . Error (
2022-09-08 10:58:14 +01:00
f " Gymnasium environment plugin ` { module } ` must specify a function to execute, not a root module "
2022-04-21 20:41:15 +02:00
)
2016-08-20 16:05:50 -07:00
2022-04-21 20:41:15 +02:00
context = namespace ( plugin . name )
if plugin . name . startswith ( " __ " ) and plugin . name . endswith ( " __ " ) :
2023-02-24 11:34:20 +00:00
# `__internal__` is an artifact of the plugin system when the root namespace had an allow-list.
# The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key.
2022-04-21 20:41:15 +02:00
if plugin . name == " __root__ " or plugin . name == " __internal__ " :
context = contextlib . nullcontext ( )
else :
logger . warn (
f " The environment namespace magic key ` { plugin . name } ` is unsupported. "
2022-09-01 16:02:31 +01:00
" To register an environment at the root namespace you should specify the `__root__` namespace. "
2022-04-21 20:41:15 +02:00
)
2021-07-29 02:26:34 +02:00
2022-04-21 20:41:15 +02:00
with context :
fn = plugin . load ( )
try :
fn ( )
2023-01-11 01:12:12 +00:00
except Exception :
logger . warn ( f " plugin: { plugin . value } raised { traceback . format_exc ( ) } " )
2016-08-20 16:05:50 -07:00
2021-07-29 02:26:34 +02:00
2021-09-16 08:23:32 -06:00
@contextlib.contextmanager
2021-12-22 19:12:57 +01:00
def namespace ( ns : str ) :
2022-12-10 22:04:14 +00:00
""" Context manager for modifying the current namespace. """
2022-04-21 20:41:15 +02:00
global current_namespace
old_namespace = current_namespace
current_namespace = ns
yield
current_namespace = old_namespace
2021-09-14 20:14:05 -06:00
2022-09-01 16:02:31 +01:00
def register (
id : str ,
2023-02-12 07:49:37 -05:00
entry_point : EnvCreator | str | None = None ,
2022-12-10 22:04:14 +00:00
reward_threshold : float | None = None ,
2022-09-01 16:02:31 +01:00
nondeterministic : bool = False ,
2022-12-10 22:04:14 +00:00
max_episode_steps : int | None = None ,
2022-09-01 16:02:31 +01:00
order_enforce : bool = True ,
autoreset : bool = False ,
disable_env_checker : bool = False ,
2022-09-06 17:20:04 +02:00
apply_api_compatibility : bool = False ,
2023-03-08 14:07:09 +00:00
additional_wrappers : tuple [ WrapperSpec , . . . ] = ( ) ,
2023-02-12 07:49:37 -05:00
vector_entry_point : VectorEnvCreator | str | None = None ,
2023-02-05 00:05:59 +00:00
* * kwargs : Any ,
2022-09-01 16:02:31 +01:00
) :
2023-02-05 00:05:59 +00:00
""" Registers an environment in gymnasium with an ``id`` to use with :meth:`gymnasium.make` with the ``entry_point`` being a string or callable for creating the environment.
2022-05-25 14:46:41 +01:00
2023-02-05 00:05:59 +00:00
The ` ` id ` ` parameter corresponds to the name of the environment , with the syntax as follows :
` ` [ namespace / ] ( env_name ) [ - v ( version ) ] ` ` where ` ` namespace ` ` and ` ` - v ( version ) ` ` is optional .
2021-09-16 08:23:32 -06:00
2023-02-05 00:05:59 +00:00
It takes arbitrary keyword arguments , which are passed to the : class : ` EnvSpec ` ` ` kwargs ` ` parameter .
2022-05-25 14:46:41 +01:00
Args :
id : The environment id
2022-09-01 16:02:31 +01:00
entry_point : The entry point for creating the environment
2023-02-05 00:05:59 +00:00
reward_threshold : The reward threshold considered for an agent to have learnt the environment
nondeterministic : If the environment is nondeterministic ( even with knowledge of the initial seed and all actions , the same state cannot be reached )
max_episode_steps : The maximum number of episodes steps before truncation . Used by the : class : ` gymnasium . wrappers . TimeLimit ` wrapper if not ` ` None ` ` .
order_enforce : If to enable the order enforcer wrapper to ensure users run functions in the correct order .
If ` ` True ` ` , then the : class : ` gymnasium . wrappers . OrderEnforcing ` is applied to the environment .
autoreset : If to add the : class : ` gymnasium . wrappers . AutoResetWrapper ` such that on ` ` ( terminated or truncated ) is True ` ` , : meth : ` gymnasium . Env . reset ` is called .
disable_env_checker : If to disable the : class : ` gymnasium . wrappers . PassiveEnvChecker ` to the environment .
apply_api_compatibility : If to apply the : class : ` gymnasium . wrappers . StepAPICompatibility ` wrapper to the environment .
Use if the environment is implemented in the gym v0 .21 environment API .
2023-03-08 14:07:09 +00:00
additional_wrappers : Additional wrappers to apply the environment .
2023-02-12 07:49:37 -05:00
vector_entry_point : The entry point for creating the vector environment
2023-02-05 00:05:59 +00:00
* * kwargs : arbitrary keyword arguments which are passed to the environment constructor on initialisation .
2022-04-21 20:41:15 +02:00
"""
2023-02-12 07:49:37 -05:00
assert (
entry_point is not None or vector_entry_point is not None
) , " Either `entry_point` or `vector_entry_point` (or both) must be provided "
2022-04-21 20:41:15 +02:00
global registry , current_namespace
2022-05-05 15:43:53 +02:00
ns , name , version = parse_env_id ( id )
if current_namespace is not None :
2022-05-30 16:38:20 +02:00
if (
kwargs . get ( " namespace " ) is not None
and kwargs . get ( " namespace " ) != current_namespace
) :
2022-05-05 15:43:53 +02:00
logger . warn (
2022-09-01 16:02:31 +01:00
f " Custom namespace ` { kwargs . get ( ' namespace ' ) } ` is being overridden by namespace ` { current_namespace } `. "
f " If you are developing a plugin you shouldn ' t specify a namespace in `register` calls. "
" The namespace is specified through the entry point package metadata. "
2022-05-05 15:43:53 +02:00
)
ns_id = current_namespace
else :
ns_id = ns
2023-02-05 00:05:59 +00:00
full_env_id = get_env_id ( ns_id , name , version )
2022-05-05 15:43:53 +02:00
2022-09-01 16:02:31 +01:00
new_spec = EnvSpec (
2023-02-05 00:05:59 +00:00
id = full_env_id ,
2022-09-01 16:02:31 +01:00
entry_point = entry_point ,
reward_threshold = reward_threshold ,
nondeterministic = nondeterministic ,
max_episode_steps = max_episode_steps ,
order_enforce = order_enforce ,
autoreset = autoreset ,
disable_env_checker = disable_env_checker ,
2022-09-06 17:20:04 +02:00
apply_api_compatibility = apply_api_compatibility ,
2023-03-08 14:07:09 +00:00
* * kwargs ,
additional_wrappers = additional_wrappers ,
2023-02-12 07:49:37 -05:00
vector_entry_point = vector_entry_point ,
2022-09-01 16:02:31 +01:00
)
_check_spec_register ( new_spec )
2023-02-05 00:05:59 +00:00
2022-09-01 16:02:31 +01:00
if new_spec . id in registry :
logger . warn ( f " Overriding environment { new_spec . id } already in registry. " )
registry [ new_spec . id ] = new_spec
2022-04-21 20:41:15 +02:00
def make (
2022-12-10 22:04:14 +00:00
id : str | EnvSpec ,
max_episode_steps : int | None = None ,
2023-03-08 14:07:09 +00:00
autoreset : bool | None = None ,
2022-12-10 22:04:14 +00:00
apply_api_compatibility : bool | None = None ,
disable_env_checker : bool | None = None ,
2023-02-05 00:05:59 +00:00
* * kwargs : Any ,
2022-04-21 20:41:15 +02:00
) - > Env :
2023-02-05 00:05:59 +00:00
""" Creates an environment previously registered with :meth:`gymnasium.register` or a :class:`EnvSpec`.
2021-09-16 08:23:32 -06:00
2023-02-05 00:05:59 +00:00
To find all available environments use ` ` gymnasium . envs . registry . keys ( ) ` ` for all valid ids .
2022-09-01 16:02:31 +01:00
2022-04-21 20:41:15 +02:00
Args :
2023-02-05 00:05:59 +00:00
id : A string for the environment id or a : class : ` EnvSpec ` . Optionally if using a string , a module to import can be included , e . g . ` ` ' module:Env-v0 ' ` ` .
This is equivalent to importing the module first to register the environment followed by making the environment .
max_episode_steps : Maximum length of an episode , can override the registered : class : ` EnvSpec ` ` ` max_episode_steps ` ` .
The value is used by : class : ` gymnasium . wrappers . TimeLimit ` .
autoreset : Whether to automatically reset the environment after each episode ( : class : ` gymnasium . wrappers . AutoResetWrapper ` ) .
apply_api_compatibility : Whether to wrap the environment with the : class : ` gymnasium . wrappers . StepAPICompatibility ` wrapper that
2022-09-01 16:02:31 +01:00
converts the environment step from a done bool to return termination and truncation bools .
2023-02-05 00:05:59 +00:00
By default , the argument is None in which the : class : ` EnvSpec ` ` ` apply_api_compatibility ` ` is used , otherwise this variable is used in favor .
disable_env_checker : If to add : class : ` gymnasium . wrappers . PassiveEnvChecker ` , ` ` None ` ` will default to the
: class : ` EnvSpec ` ` ` disable_env_checker ` ` value otherwise use this value will be used .
2022-04-21 20:41:15 +02:00
kwargs : Additional arguments to pass to the environment constructor .
2022-05-25 14:46:41 +01:00
2022-04-21 20:41:15 +02:00
Returns :
2023-02-05 00:05:59 +00:00
An instance of the environment with wrappers applied .
2022-05-25 14:46:41 +01:00
Raises :
2023-02-05 00:05:59 +00:00
Error : If the ` ` id ` ` doesn ' t exist in the :attr:`registry`
2022-04-21 20:41:15 +02:00
"""
if isinstance ( id , EnvSpec ) :
2023-03-08 14:07:09 +00:00
env_spec = id
if not hasattr ( env_spec , " additional_wrappers " ) :
logger . warn (
f " The env spec passed to `make` does not have a `additional_wrappers`, set it to an empty tuple. Env_spec= { env_spec } "
2022-06-24 22:25:58 +02:00
)
2023-03-08 14:07:09 +00:00
env_spec . additional_wrappers = ( )
2022-09-06 17:20:04 +02:00
else :
2023-02-24 11:34:20 +00:00
# For string id's, load the environment spec from the registry then make the environment spec
assert isinstance ( id , str )
2022-09-06 17:20:04 +02:00
2023-02-24 11:34:20 +00:00
# The environment name can include an unloaded module in "module:env_name" style
env_spec = _find_spec ( id )
2022-07-11 02:45:24 +01:00
2023-03-08 14:07:09 +00:00
assert isinstance ( env_spec , EnvSpec )
# Update the env spec kwargs with the `make` kwargs
env_spec_kwargs = copy . deepcopy ( env_spec . kwargs )
env_spec_kwargs . update ( kwargs )
# Load the environment creator
if env_spec . entry_point is None :
raise error . Error ( f " { env_spec . id } registered but entry_point is not specified " )
elif callable ( env_spec . entry_point ) :
env_creator = env_spec . entry_point
else :
# Assume it's a string
env_creator = load_env_creator ( env_spec . entry_point )
# Determine if to use the rendering
render_modes : list [ str ] | None = None
if hasattr ( env_creator , " metadata " ) :
_check_metadata ( env_creator . metadata )
render_modes = env_creator . metadata . get ( " render_modes " )
render_mode = env_spec_kwargs . get ( " render_mode " )
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if (
render_mode is not None
and render_modes is not None
and render_mode not in render_modes
) :
displayable_modes = { " rgb_array " , " rgb_array_list " } . intersection ( render_modes )
if render_mode == " human " and len ( displayable_modes ) > 0 :
logger . warn (
" You are trying to use ' human ' rendering for an environment that doesn ' t natively support it. "
" The HumanRendering wrapper is being applied to your environment. "
)
env_spec_kwargs [ " render_mode " ] = displayable_modes . pop ( )
apply_human_rendering = True
elif (
render_mode . endswith ( " _list " )
and render_mode [ : - len ( " _list " ) ] in render_modes
) :
env_spec_kwargs [ " render_mode " ] = render_mode [ : - len ( " _list " ) ]
apply_render_collection = True
else :
logger . warn (
f " The environment is being initialised with render_mode= { render_mode !r} "
f " that is not in the possible render_modes ( { render_modes } ). "
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec . apply_api_compatibility
) :
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = env_spec_kwargs . pop ( " render_mode " , None )
else :
render_mode = None
try :
env = env_creator ( * * env_spec_kwargs )
except TypeError as e :
if (
str ( e ) . find ( " got an unexpected keyword argument ' render_mode ' " ) > = 0
and apply_human_rendering
) :
raise error . Error (
f " You passed render_mode= ' human ' although { env_spec . id } doesn ' t implement human-rendering natively. "
" Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
" rendering API, which is not supported by the HumanRendering wrapper. "
) from e
else :
2023-06-21 15:32:12 +01:00
raise type ( e ) (
f " { e } was raised from the environment creator for { env_spec . id } with kwargs ( { env_spec_kwargs } ) "
)
2023-03-08 14:07:09 +00:00
# Set the minimal env spec for the environment.
env . unwrapped . spec = EnvSpec (
id = env_spec . id ,
entry_point = env_spec . entry_point ,
reward_threshold = env_spec . reward_threshold ,
nondeterministic = env_spec . nondeterministic ,
max_episode_steps = None ,
order_enforce = False ,
autoreset = False ,
disable_env_checker = True ,
apply_api_compatibility = False ,
kwargs = env_spec_kwargs ,
additional_wrappers = ( ) ,
vector_entry_point = env_spec . vector_entry_point ,
)
# Check if pre-wrapped wrappers
assert env . spec is not None
num_prior_wrappers = len ( env . spec . additional_wrappers )
if (
env_spec . additional_wrappers [ : num_prior_wrappers ]
!= env . spec . additional_wrappers
) :
for env_spec_wrapper_spec , recreated_wrapper_spec in zip (
env_spec . additional_wrappers , env . spec . additional_wrappers
) :
raise ValueError (
f " The environment ' s wrapper spec { recreated_wrapper_spec } is different from the saved `EnvSpec` additional wrapper { env_spec_wrapper_spec } "
)
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec . apply_api_compatibility is True
) :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . EnvCompatibility ( env , render_mode )
2023-03-08 14:07:09 +00:00
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec . disable_env_checker is False
) :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . PassiveEnvChecker ( env )
2023-03-08 14:07:09 +00:00
# Add the order enforcing wrapper
if env_spec . order_enforce :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . OrderEnforcing ( env )
2023-03-08 14:07:09 +00:00
# Add the time limit wrapper
if max_episode_steps is not None :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . TimeLimit ( env , max_episode_steps )
2023-03-08 14:07:09 +00:00
elif env_spec . max_episode_steps is not None :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . TimeLimit ( env , env_spec . max_episode_steps )
2023-03-08 14:07:09 +00:00
# Add the auto-reset wrapper
if autoreset is True or ( autoreset is None and env_spec . autoreset is True ) :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . AutoResetWrapper ( env )
2023-03-08 14:07:09 +00:00
for wrapper_spec in env_spec . additional_wrappers [ num_prior_wrappers : ] :
if wrapper_spec . kwargs is None :
raise ValueError (
f " { wrapper_spec . name } wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated. "
)
env = load_env_creator ( wrapper_spec . entry_point ) ( env = env , * * wrapper_spec . kwargs )
# Add human rendering wrapper
if apply_human_rendering :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . HumanRendering ( env )
2023-03-08 14:07:09 +00:00
elif apply_render_collection :
2023-06-21 17:04:11 +01:00
env = gym . wrappers . RenderCollection ( env )
2023-03-08 14:07:09 +00:00
return env
2022-04-21 20:41:15 +02:00
2023-02-12 07:49:37 -05:00
def make_vec (
id : str | EnvSpec ,
num_envs : int = 1 ,
vectorization_mode : str = " async " ,
vector_kwargs : dict [ str , Any ] | None = None ,
wrappers : Sequence [ Callable [ [ Env ] , Wrapper ] ] | None = None ,
* * kwargs ,
2023-06-21 17:04:11 +01:00
) - > gym . experimental . vector . VectorEnv :
2023-02-12 14:17:57 -05:00
""" Create a vector environment according to the given ID.
Note :
This feature is experimental , and is likely to change in future releases .
2023-02-12 07:49:37 -05:00
To find all available environments use ` gymnasium . envs . registry . keys ( ) ` for all valid ids .
Args :
id : Name of the environment . Optionally , a module to import can be included , eg . ' module:Env-v0 '
num_envs : Number of environments to create
vectorization_mode : How to vectorize the environment . Can be either " async " , " sync " or " custom "
vector_kwargs : Additional arguments to pass to the vectorized environment constructor .
wrappers : A sequence of wrapper functions to apply to the environment . Can only be used in " sync " or " async " mode .
* * kwargs : Additional arguments to pass to the environment constructor .
Returns :
An instance of the environment .
Raises :
Error : If the ` ` id ` ` doesn ' t exist then an error is raised
"""
if vector_kwargs is None :
vector_kwargs = { }
if wrappers is None :
wrappers = [ ]
if isinstance ( id , EnvSpec ) :
spec_ = id
else :
spec_ = _find_spec ( id )
_kwargs = spec_ . kwargs . copy ( )
_kwargs . update ( kwargs )
# Check if we have the necessary entry point
if vectorization_mode in ( " sync " , " async " ) :
if spec_ . entry_point is None :
raise error . Error (
f " Cannot create vectorized environment for { id } because it doesn ' t have an entry point defined. "
)
entry_point = spec_ . entry_point
elif vectorization_mode in ( " custom " , ) :
if spec_ . vector_entry_point is None :
raise error . Error (
f " Cannot create vectorized environment for { id } because it doesn ' t have a vector entry point defined. "
)
entry_point = spec_ . vector_entry_point
else :
raise error . Error ( f " Invalid vectorization mode: { vectorization_mode } " )
if callable ( entry_point ) :
env_creator = entry_point
else :
# Assume it's a string
2023-02-24 11:34:20 +00:00
env_creator = load_env_creator ( entry_point )
2023-02-12 07:49:37 -05:00
def _create_env ( ) :
# Env creator for use with sync and async modes
_kwargs_copy = _kwargs . copy ( )
2023-05-23 15:35:49 +01:00
render_mode = _kwargs . get ( " render_mode " , None )
if render_mode is not None :
inner_render_mode = (
render_mode [ : - len ( " _list " ) ]
if render_mode . endswith ( " _list " )
else render_mode
)
_kwargs_copy [ " render_mode " ] = inner_render_mode
2023-02-12 07:49:37 -05:00
_env = env_creator ( * * _kwargs_copy )
_env . spec = spec_
if spec_ . max_episode_steps is not None :
2023-06-21 17:04:11 +01:00
_env = gym . wrappers . TimeLimit ( _env , spec_ . max_episode_steps )
2023-02-12 07:49:37 -05:00
if render_mode is not None and render_mode . endswith ( " _list " ) :
2023-06-21 17:04:11 +01:00
_env = gym . wrappers . RenderCollection ( _env )
2023-02-12 07:49:37 -05:00
for wrapper in wrappers :
_env = wrapper ( _env )
return _env
if vectorization_mode == " sync " :
2023-06-21 17:04:11 +01:00
env = gym . experimental . vector . SyncVectorEnv (
2023-02-12 07:49:37 -05:00
env_fns = [ _create_env for _ in range ( num_envs ) ] ,
* * vector_kwargs ,
)
elif vectorization_mode == " async " :
2023-06-21 17:04:11 +01:00
env = gym . experimental . vector . AsyncVectorEnv (
2023-02-12 07:49:37 -05:00
env_fns = [ _create_env for _ in range ( num_envs ) ] ,
* * vector_kwargs ,
)
elif vectorization_mode == " custom " :
if len ( wrappers ) > 0 :
raise error . Error ( " Cannot use custom vectorization mode with wrappers. " )
2023-02-12 14:17:57 -05:00
vector_kwargs [ " max_episode_steps " ] = spec_ . max_episode_steps
env = env_creator ( num_envs = num_envs , * * vector_kwargs )
2023-02-12 07:49:37 -05:00
else :
raise error . Error ( f " Invalid vectorization mode: { vectorization_mode } " )
# Copies the environment creation specification and kwargs to add to the environment specification details
spec_ = copy . deepcopy ( spec_ )
spec_ . kwargs = _kwargs
env . unwrapped . spec = spec_
return env
2022-04-21 20:41:15 +02:00
def spec ( env_id : str ) - > EnvSpec :
2023-02-05 00:05:59 +00:00
""" Retrieve the :class:`EnvSpec` for the environment id from the :attr:`registry`.
Args :
env_id : The environment id with the expected format of ` ` [ ( namespace ) / ] id [ - v ( version ) ] ` `
Returns :
The environment spec if it exists
Raises :
Error : If the environment id doesn ' t exist
"""
env_spec = registry . get ( env_id )
if env_spec is None :
2022-04-21 20:41:15 +02:00
ns , name , version = parse_env_id ( env_id )
_check_version_exists ( ns , name , version )
raise error . Error ( f " No registered env with id: { env_id } " )
else :
2023-02-05 00:05:59 +00:00
assert isinstance (
env_spec , EnvSpec
) , f " Expected the registry for { env_id } to be an `EnvSpec`, actual type is { type ( env_spec ) } "
return env_spec
2022-11-16 12:59:42 +00:00
def pprint_registry (
2023-02-05 00:05:59 +00:00
print_registry : dict [ str , EnvSpec ] = registry ,
* ,
2022-11-22 23:43:41 +00:00
num_cols : int = 3 ,
2022-12-10 22:04:14 +00:00
exclude_namespaces : list [ str ] | None = None ,
2022-11-17 20:40:19 +00:00
disable_print : bool = False ,
2022-12-10 22:04:14 +00:00
) - > str | None :
2023-02-05 00:05:59 +00:00
""" Pretty prints all environments in the :attr:`registry`.
Note :
All arguments are keyword only
2022-11-17 20:40:19 +00:00
Args :
2023-02-05 00:05:59 +00:00
print_registry : Environment registry to be printed . By default , : attr : ` registry `
2022-11-22 23:43:41 +00:00
num_cols : Number of columns to arrange environments in , for display .
2023-02-05 00:05:59 +00:00
exclude_namespaces : A list of namespaces to be excluded from printing . Helpful if only ALE environments are wanted .
2022-11-17 20:40:19 +00:00
disable_print : Whether to return a string of all the namespaces and environment IDs
2023-02-05 00:05:59 +00:00
or to print the string to console .
2022-11-17 20:40:19 +00:00
"""
2023-02-05 00:05:59 +00:00
# Defaultdict to store environment ids according to namespace.
namespace_envs : dict [ str , list [ str ] ] = defaultdict ( lambda : [ ] )
2022-11-16 12:59:42 +00:00
max_justify = float ( " -inf " )
2023-02-05 00:05:59 +00:00
# Find the namespace associated with each environment spec
for env_spec in print_registry . values ( ) :
ns = env_spec . namespace
if ns is None and isinstance ( env_spec . entry_point , str ) :
# Use regex to obtain namespace from entrypoints.
env_entry_point = re . sub ( r " : \ w+ " , " " , env_spec . entry_point )
split_entry_point = env_entry_point . split ( " . " )
if len ( split_entry_point ) > = 3 :
# If namespace is of the format:
# - gymnasium.envs.mujoco.ant_v4:AntEnv
# - gymnasium.envs.mujoco:HumanoidEnv
ns = split_entry_point [ 2 ]
elif len ( split_entry_point ) > 1 :
2022-11-16 12:59:42 +00:00
# If namespace is of the format - shimmy.atari_env
2023-02-05 00:05:59 +00:00
ns = split_entry_point [ 1 ]
2022-11-16 12:59:42 +00:00
else :
2023-02-05 00:05:59 +00:00
# If namespace cannot be found, default to env name
ns = env_spec . name
namespace_envs [ ns ] . append ( env_spec . id )
max_justify = max ( max_justify , len ( env_spec . name ) )
# Iterate through each namespace and print environment alphabetically
output : list [ str ] = [ ]
for ns , env_ids in namespace_envs . items ( ) :
2022-11-16 12:59:42 +00:00
# Ignore namespaces to exclude.
2023-02-05 00:05:59 +00:00
if exclude_namespaces is not None and ns in exclude_namespaces :
2022-11-16 12:59:42 +00:00
continue
2023-02-05 00:05:59 +00:00
# Print the namespace
namespace_output = f " { ' = ' * 5 } { ns } { ' = ' * 5 } \n "
2022-11-16 12:59:42 +00:00
# Reference: https://stackoverflow.com/a/33464001
2023-02-05 00:05:59 +00:00
for count , env_id in enumerate ( sorted ( env_ids ) , 1 ) :
# Print column with justification.
namespace_output + = env_id . ljust ( max_justify ) + " "
2022-11-16 12:59:42 +00:00
# Once all rows printed, switch to new column.
2023-02-05 00:05:59 +00:00
if count % num_cols == 0 :
namespace_output = namespace_output . rstrip ( " " )
if count != len ( env_ids ) :
namespace_output + = " \n "
output . append ( namespace_output . rstrip ( " " ) )
2022-11-16 12:59:42 +00:00
2022-11-17 20:40:19 +00:00
if disable_print :
2023-02-05 00:05:59 +00:00
return " \n " . join ( output )
2022-11-17 20:40:19 +00:00
else :
2023-02-05 00:05:59 +00:00
print ( " \n " . join ( output ) )