mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
704 lines
26 KiB
Python
704 lines
26 KiB
Python
import contextlib
|
|
import copy
|
|
import difflib
|
|
import importlib
|
|
import importlib.util
|
|
import re
|
|
import sys
|
|
import warnings
|
|
from dataclasses import dataclass, field
|
|
from typing import (
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
SupportsFloat,
|
|
Tuple,
|
|
Union,
|
|
overload,
|
|
)
|
|
|
|
import numpy as np
|
|
|
|
from gymnasium.wrappers import (
|
|
AutoResetWrapper,
|
|
HumanRendering,
|
|
OrderEnforcing,
|
|
RenderCollection,
|
|
TimeLimit,
|
|
)
|
|
from gymnasium.wrappers.compatibility import EnvCompatibility
|
|
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
|
|
|
if sys.version_info < (3, 10):
|
|
import importlib_metadata as metadata # type: ignore
|
|
else:
|
|
import importlib.metadata as metadata
|
|
|
|
if sys.version_info >= (3, 8):
|
|
from typing import Literal
|
|
else:
|
|
from typing_extensions import Literal
|
|
|
|
from gymnasium import Env, error, logger
|
|
|
|
ENV_ID_RE = re.compile(
|
|
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
|
|
)
|
|
|
|
|
|
def load(name: str) -> callable:
|
|
"""Loads an environment with name and returns an environment creation function
|
|
|
|
Args:
|
|
name: The environment name
|
|
|
|
Returns:
|
|
Calls the environment constructor
|
|
"""
|
|
mod_name, attr_name = name.split(":")
|
|
mod = importlib.import_module(mod_name)
|
|
fn = getattr(mod, attr_name)
|
|
return fn
|
|
|
|
|
|
def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
|
"""Parse environment ID string format.
|
|
|
|
This format is true today, but it's *not* an official spec.
|
|
[namespace/](env-name)-v(version) env-name is group 1, version is group 2
|
|
|
|
2016-10-31: We're experimentally expanding the environment ID format
|
|
to include an optional namespace.
|
|
|
|
Args:
|
|
id: The environment id to parse
|
|
|
|
Returns:
|
|
A tuple of environment namespace, environment name and version number
|
|
|
|
Raises:
|
|
Error: If the environment id does not a valid environment regex
|
|
"""
|
|
match = ENV_ID_RE.fullmatch(id)
|
|
if not match:
|
|
raise error.Error(
|
|
f"Malformed environment ID: {id}."
|
|
f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
|
)
|
|
namespace, name, version = match.group("namespace", "name", "version")
|
|
if version is not None:
|
|
version = int(version)
|
|
|
|
return namespace, name, version
|
|
|
|
|
|
def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
|
|
"""Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
|
|
|
|
Args:
|
|
ns: The environment namespace
|
|
name: The environment name
|
|
version: The environment version
|
|
|
|
Returns:
|
|
The environment id
|
|
"""
|
|
|
|
full_name = name
|
|
if version is not None:
|
|
full_name += f"-v{version}"
|
|
if ns is not None:
|
|
full_name = ns + "/" + full_name
|
|
return full_name
|
|
|
|
|
|
@dataclass
|
|
class EnvSpec:
|
|
"""A specification for creating environments with `gymnasium.make`.
|
|
|
|
* id: The string used to create the environment with `gymnasium.make`
|
|
* entry_point: The location of the environment to create from
|
|
* reward_threshold: The reward threshold for completing the environment.
|
|
* nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
|
|
* max_episode_steps: The max number of steps that the environment can take before truncation
|
|
* order_enforce: If to enforce the order of `reset` before `step` and `render` functions
|
|
* autoreset: If to automatically reset the environment on episode end
|
|
* disable_env_checker: If to disable the environment checker wrapper in `gymnasium.make`, by default False (runs the environment checker)
|
|
* kwargs: Additional keyword arguments passed to the environments through `gymnasium.make`
|
|
"""
|
|
|
|
id: str
|
|
entry_point: Union[Callable, str]
|
|
|
|
# Environment attributes
|
|
reward_threshold: Optional[float] = field(default=None)
|
|
nondeterministic: bool = field(default=False)
|
|
|
|
# Wrappers
|
|
max_episode_steps: Optional[int] = field(default=None)
|
|
order_enforce: bool = field(default=True)
|
|
autoreset: bool = field(default=False)
|
|
disable_env_checker: bool = field(default=False)
|
|
apply_api_compatibility: bool = field(default=False)
|
|
|
|
# Environment arguments
|
|
kwargs: dict = field(default_factory=dict)
|
|
|
|
# post-init attributes
|
|
namespace: Optional[str] = field(init=False)
|
|
name: str = field(init=False)
|
|
version: Optional[int] = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
# Initialize namespace, name, version
|
|
self.namespace, self.name, self.version = parse_env_id(self.id)
|
|
|
|
def make(self, **kwargs) -> Env:
|
|
# For compatibility purposes
|
|
return make(self, **kwargs)
|
|
|
|
|
|
def _check_namespace_exists(ns: Optional[str]):
|
|
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
|
|
if ns is None:
|
|
return
|
|
namespaces = {
|
|
spec_.namespace for spec_ in registry.values() if spec_.namespace is not None
|
|
}
|
|
if ns in namespaces:
|
|
return
|
|
|
|
suggestion = (
|
|
difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
|
|
)
|
|
suggestion_msg = (
|
|
f"Did you mean: `{suggestion[0]}`?"
|
|
if suggestion
|
|
else f"Have you installed the proper package for {ns}?"
|
|
)
|
|
|
|
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
|
|
|
|
|
|
def _check_name_exists(ns: Optional[str], name: str):
|
|
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
|
|
_check_namespace_exists(ns)
|
|
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
|
|
|
|
if name in names:
|
|
return
|
|
|
|
suggestion = difflib.get_close_matches(name, names, n=1)
|
|
namespace_msg = f" in namespace {ns}" if ns else ""
|
|
suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else ""
|
|
|
|
raise error.NameNotFound(
|
|
f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}"
|
|
)
|
|
|
|
|
|
def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
|
|
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
|
|
This is a complete test whether an environment identifier is valid, and will provide the best available hints.
|
|
|
|
Args:
|
|
ns: The environment namespace
|
|
name: The environment space
|
|
version: The environment version
|
|
|
|
Raises:
|
|
DeprecatedEnv: The environment doesn't exist but a default version does
|
|
VersionNotFound: The ``version`` used doesn't exist
|
|
DeprecatedEnv: Environment version is deprecated
|
|
"""
|
|
if get_env_id(ns, name, version) in registry:
|
|
return
|
|
|
|
_check_name_exists(ns, name)
|
|
if version is None:
|
|
return
|
|
|
|
message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
|
|
|
|
env_specs = [
|
|
spec_
|
|
for spec_ in registry.values()
|
|
if spec_.namespace == ns and spec_.name == name
|
|
]
|
|
env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1))
|
|
|
|
default_spec = [spec_ for spec_ in env_specs if spec_.version is None]
|
|
|
|
if default_spec:
|
|
message += f" It provides the default version {default_spec[0].id}`."
|
|
if len(env_specs) == 1:
|
|
raise error.DeprecatedEnv(message)
|
|
|
|
# Process possible versioned environments
|
|
|
|
versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None]
|
|
|
|
latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore
|
|
if latest_spec is not None and version > latest_spec.version:
|
|
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs)
|
|
message += f" It provides versioned environments: [ {version_list_msg} ]."
|
|
|
|
raise error.VersionNotFound(message)
|
|
|
|
if latest_spec is not None and version < latest_spec.version:
|
|
raise error.DeprecatedEnv(
|
|
f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. "
|
|
f"Please use `{latest_spec.id}` instead."
|
|
)
|
|
|
|
|
|
def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
|
|
version: List[int] = [
|
|
spec_.version
|
|
for spec_ in registry.values()
|
|
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
|
]
|
|
return max(version, default=None)
|
|
|
|
|
|
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
|
# Load third-party environments
|
|
for plugin in metadata.entry_points(group=entry_point):
|
|
# Python 3.8 doesn't support plugin.module, plugin.attr
|
|
# So we'll have to try and parse this ourselves
|
|
module, attr = None, None
|
|
try:
|
|
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
|
|
except AttributeError:
|
|
if ":" in plugin.value:
|
|
module, attr = plugin.value.split(":", maxsplit=1)
|
|
else:
|
|
module, attr = plugin.value, None
|
|
except Exception as e:
|
|
warnings.warn(
|
|
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
|
|
)
|
|
module, attr = None, None
|
|
finally:
|
|
if attr is None:
|
|
raise error.Error(
|
|
f"Gymnasium environment plugin `{module}` must specify a function to execute, not a root module"
|
|
)
|
|
|
|
context = namespace(plugin.name)
|
|
if plugin.name.startswith("__") and plugin.name.endswith("__"):
|
|
# `__internal__` is an artifact of the plugin system when
|
|
# the root namespace had an allow-list. The allow-list is now
|
|
# removed and plugins can register environments in the root
|
|
# namespace with the `__root__` magic key.
|
|
if plugin.name == "__root__" or plugin.name == "__internal__":
|
|
context = contextlib.nullcontext()
|
|
else:
|
|
logger.warn(
|
|
f"The environment namespace magic key `{plugin.name}` is unsupported. "
|
|
"To register an environment at the root namespace you should specify the `__root__` namespace."
|
|
)
|
|
|
|
with context:
|
|
fn = plugin.load()
|
|
try:
|
|
fn()
|
|
except Exception as e:
|
|
logger.warn(str(e))
|
|
|
|
|
|
# fmt: off
|
|
@overload
|
|
def make(id: str, **kwargs) -> Env: ...
|
|
@overload
|
|
def make(id: EnvSpec, **kwargs) -> Env: ...
|
|
|
|
|
|
# Classic control
|
|
# ----------------------------------------
|
|
@overload
|
|
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
@overload
|
|
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
@overload
|
|
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
|
@overload
|
|
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
|
@overload
|
|
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
|
|
|
|
# Box2d
|
|
# ----------------------------------------
|
|
@overload
|
|
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
@overload
|
|
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
|
@overload
|
|
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ...
|
|
|
|
|
|
# Toy Text
|
|
# ----------------------------------------
|
|
@overload
|
|
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
@overload
|
|
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
@overload
|
|
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
@overload
|
|
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ...
|
|
|
|
|
|
# Mujoco
|
|
# ----------------------------------------
|
|
@overload
|
|
def make(id: Literal[
|
|
"Reacher-v2", "Reacher-v4",
|
|
"Pusher-v2", "Pusher-v4",
|
|
"InvertedPendulum-v2", "InvertedPendulum-v4",
|
|
"InvertedDoublePendulum-v2", "InvertedDoublePendulum-v4",
|
|
"HalfCheetah-v2", "HalfCheetah-v3", "HalfCheetah-v4",
|
|
"Hopper-v2", "Hopper-v3", "Hopper-v4",
|
|
"Swimmer-v2", "Swimmer-v3", "Swimmer-v4",
|
|
"Walker2d-v2", "Walker2d-v3", "Walker2d-v4",
|
|
"Ant-v2", "Ant-v3", "Ant-v4",
|
|
"HumanoidStandup-v2", "HumanoidStandup-v4",
|
|
"Humanoid-v2", "Humanoid-v3", "Humanoid-v4",
|
|
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
|
|
# fmt: on
|
|
|
|
|
|
# Global registry of environments. Meant to be accessed through `register` and `make`
|
|
registry: Dict[str, EnvSpec] = {}
|
|
current_namespace: Optional[str] = None
|
|
|
|
|
|
def _check_spec_register(spec: EnvSpec):
|
|
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
|
|
global registry
|
|
latest_versioned_spec = max(
|
|
(
|
|
spec_
|
|
for spec_ in registry.values()
|
|
if spec_.namespace == spec.namespace
|
|
and spec_.name == spec.name
|
|
and spec_.version is not None
|
|
),
|
|
key=lambda spec_: int(spec_.version), # type: ignore
|
|
default=None,
|
|
)
|
|
|
|
unversioned_spec = next(
|
|
(
|
|
spec_
|
|
for spec_ in registry.values()
|
|
if spec_.namespace == spec.namespace
|
|
and spec_.name == spec.name
|
|
and spec_.version is None
|
|
),
|
|
None,
|
|
)
|
|
|
|
if unversioned_spec is not None and spec.version is not None:
|
|
raise error.RegistrationError(
|
|
"Can't register the versioned environment "
|
|
f"`{spec.id}` when the unversioned environment "
|
|
f"`{unversioned_spec.id}` of the same name already exists."
|
|
)
|
|
elif latest_versioned_spec is not None and spec.version is None:
|
|
raise error.RegistrationError(
|
|
"Can't register the unversioned environment "
|
|
f"`{spec.id}` when the versioned environment "
|
|
f"`{latest_versioned_spec.id}` of the same name "
|
|
f"already exists. Note: the default behavior is "
|
|
f"that `gymnasium.make` with the unversioned environment "
|
|
f"will return the latest versioned environment"
|
|
)
|
|
|
|
|
|
# Public API
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def namespace(ns: str):
|
|
global current_namespace
|
|
old_namespace = current_namespace
|
|
current_namespace = ns
|
|
yield
|
|
current_namespace = old_namespace
|
|
|
|
|
|
def register(
|
|
id: str,
|
|
entry_point: Union[Callable, str],
|
|
reward_threshold: Optional[float] = None,
|
|
nondeterministic: bool = False,
|
|
max_episode_steps: Optional[int] = None,
|
|
order_enforce: bool = True,
|
|
autoreset: bool = False,
|
|
disable_env_checker: bool = False,
|
|
apply_api_compatibility: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""Register an environment with gymnasium.
|
|
|
|
The `id` parameter corresponds to the name of the environment, with the syntax as follows:
|
|
`(namespace)/(env_name)-v(version)` where `namespace` is optional.
|
|
|
|
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
|
|
|
|
Args:
|
|
id: The environment id
|
|
entry_point: The entry point for creating the environment
|
|
reward_threshold: The reward threshold considered to have learnt an environment
|
|
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions)
|
|
max_episode_steps: The maximum number of episodes steps before truncation. Used by the Time Limit wrapper.
|
|
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order
|
|
autoreset: If to add the autoreset wrapper such that reset does not need to be called.
|
|
disable_env_checker: If to disable the environment checker for the environment. Recommended to False.
|
|
apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper.
|
|
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
|
|
"""
|
|
global registry, current_namespace
|
|
ns, name, version = parse_env_id(id)
|
|
|
|
if current_namespace is not None:
|
|
if (
|
|
kwargs.get("namespace") is not None
|
|
and kwargs.get("namespace") != current_namespace
|
|
):
|
|
logger.warn(
|
|
f"Custom namespace `{kwargs.get('namespace')}` is being overridden by namespace `{current_namespace}`. "
|
|
f"If you are developing a plugin you shouldn't specify a namespace in `register` calls. "
|
|
"The namespace is specified through the entry point package metadata."
|
|
)
|
|
ns_id = current_namespace
|
|
else:
|
|
ns_id = ns
|
|
|
|
full_id = get_env_id(ns_id, name, version)
|
|
|
|
new_spec = EnvSpec(
|
|
id=full_id,
|
|
entry_point=entry_point,
|
|
reward_threshold=reward_threshold,
|
|
nondeterministic=nondeterministic,
|
|
max_episode_steps=max_episode_steps,
|
|
order_enforce=order_enforce,
|
|
autoreset=autoreset,
|
|
disable_env_checker=disable_env_checker,
|
|
apply_api_compatibility=apply_api_compatibility,
|
|
**kwargs,
|
|
)
|
|
_check_spec_register(new_spec)
|
|
if new_spec.id in registry:
|
|
logger.warn(f"Overriding environment {new_spec.id} already in registry.")
|
|
registry[new_spec.id] = new_spec
|
|
|
|
|
|
def make(
|
|
id: Union[str, EnvSpec],
|
|
max_episode_steps: Optional[int] = None,
|
|
autoreset: bool = False,
|
|
apply_api_compatibility: Optional[bool] = None,
|
|
disable_env_checker: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Env:
|
|
"""Create an environment according to the given ID.
|
|
|
|
To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids.
|
|
|
|
Args:
|
|
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
|
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
|
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
|
apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
|
|
converts the environment step from a done bool to return termination and truncation bools.
|
|
By default, the argument is None to which the environment specification `apply_api_compatibility` is used
|
|
which defaults to False. Otherwise, the value of `apply_api_compatibility` is used.
|
|
If `True`, the wrapper is applied otherwise, the wrapper is not applied.
|
|
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
|
|
(which is by default False, running the environment checker),
|
|
otherwise will run according to this parameter (`True` = not run, `False` = run)
|
|
kwargs: Additional arguments to pass to the environment constructor.
|
|
|
|
Returns:
|
|
An instance of the environment.
|
|
|
|
Raises:
|
|
Error: If the ``id`` doesn't exist then an error is raised
|
|
"""
|
|
if isinstance(id, EnvSpec):
|
|
spec_ = id
|
|
else:
|
|
module, id = (None, id) if ":" not in id else id.split(":")
|
|
if module is not None:
|
|
try:
|
|
importlib.import_module(module)
|
|
except ModuleNotFoundError as e:
|
|
raise ModuleNotFoundError(
|
|
f"{e}. Environment registration via importing a module failed. "
|
|
f"Check whether '{module}' contains env registration and can be imported."
|
|
)
|
|
spec_ = registry.get(id)
|
|
|
|
ns, name, version = parse_env_id(id)
|
|
latest_version = find_highest_version(ns, name)
|
|
if (
|
|
version is not None
|
|
and latest_version is not None
|
|
and latest_version > version
|
|
):
|
|
logger.warn(
|
|
f"The environment {id} is out of date. You should consider "
|
|
f"upgrading to version `v{latest_version}`."
|
|
)
|
|
if version is None and latest_version is not None:
|
|
version = latest_version
|
|
new_env_id = get_env_id(ns, name, version)
|
|
spec_ = registry.get(new_env_id)
|
|
logger.warn(
|
|
f"Using the latest versioned environment `{new_env_id}` "
|
|
f"instead of the unversioned environment `{id}`."
|
|
)
|
|
|
|
if spec_ is None:
|
|
_check_version_exists(ns, name, version)
|
|
raise error.Error(f"No registered env with id: {id}")
|
|
|
|
_kwargs = spec_.kwargs.copy()
|
|
_kwargs.update(kwargs)
|
|
|
|
if spec_.entry_point is None:
|
|
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
|
|
elif callable(spec_.entry_point):
|
|
env_creator = spec_.entry_point
|
|
else:
|
|
# Assume it's a string
|
|
env_creator = load(spec_.entry_point)
|
|
|
|
mode = _kwargs.get("render_mode")
|
|
apply_human_rendering = False
|
|
apply_render_collection = False
|
|
|
|
# If we have access to metadata we check that "render_mode" is valid and see if the HumanRendering wrapper needs to be applied
|
|
if mode is not None and hasattr(env_creator, "metadata"):
|
|
assert isinstance(
|
|
env_creator.metadata, dict
|
|
), f"Expect the environment creator ({env_creator}) metadata to be dict, actual type: {type(env_creator.metadata)}"
|
|
|
|
if "render_modes" in env_creator.metadata:
|
|
render_modes = env_creator.metadata["render_modes"]
|
|
if not isinstance(render_modes, Sequence):
|
|
logger.warn(
|
|
f"Expects the environment metadata render_modes to be a Sequence (tuple or list), actual type: {type(render_modes)}"
|
|
)
|
|
|
|
# Apply the `HumanRendering` wrapper, if the mode=="human" but "human" not in render_modes
|
|
if (
|
|
mode == "human"
|
|
and "human" not in render_modes
|
|
and ("rgb_array" in render_modes or "rgb_array_list" in render_modes)
|
|
):
|
|
logger.warn(
|
|
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
|
"The HumanRendering wrapper is being applied to your environment."
|
|
)
|
|
apply_human_rendering = True
|
|
if "rgb_array" in render_modes:
|
|
_kwargs["render_mode"] = "rgb_array"
|
|
else:
|
|
_kwargs["render_mode"] = "rgb_array_list"
|
|
elif (
|
|
mode not in render_modes
|
|
and mode.endswith("_list")
|
|
and mode[: -len("_list")] in render_modes
|
|
):
|
|
_kwargs["render_mode"] = mode[: -len("_list")]
|
|
apply_render_collection = True
|
|
elif mode not in render_modes:
|
|
logger.warn(
|
|
f"The environment is being initialised with mode ({mode}) that is not in the possible render_modes ({render_modes})."
|
|
)
|
|
else:
|
|
logger.warn(
|
|
f"The environment creator metadata doesn't include `render_modes`, contains: {list(env_creator.metadata.keys())}"
|
|
)
|
|
|
|
if apply_api_compatibility is True or (
|
|
apply_api_compatibility is None and spec_.apply_api_compatibility is True
|
|
):
|
|
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
|
render_mode = _kwargs.pop("render_mode", None)
|
|
else:
|
|
render_mode = None
|
|
|
|
try:
|
|
env = env_creator(**_kwargs)
|
|
except TypeError as e:
|
|
if (
|
|
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
|
and apply_human_rendering
|
|
):
|
|
raise error.Error(
|
|
f"You passed render_mode='human' although {id} doesn't implement human-rendering natively. "
|
|
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
|
|
"rendering API, which is not supported by the HumanRendering wrapper."
|
|
)
|
|
else:
|
|
raise e
|
|
|
|
# Copies the environment creation specification and kwargs to add to the environment specification details
|
|
spec_ = copy.deepcopy(spec_)
|
|
spec_.kwargs = _kwargs
|
|
env.unwrapped.spec = spec_
|
|
|
|
# Add step API wrapper
|
|
if apply_api_compatibility is True or (
|
|
apply_api_compatibility is None and spec_.apply_api_compatibility is True
|
|
):
|
|
env = EnvCompatibility(env, render_mode)
|
|
|
|
# Run the environment checker as the lowest level wrapper
|
|
if disable_env_checker is False or (
|
|
disable_env_checker is None and spec_.disable_env_checker is False
|
|
):
|
|
env = PassiveEnvChecker(env)
|
|
|
|
# Add the order enforcing wrapper
|
|
if spec_.order_enforce:
|
|
env = OrderEnforcing(env)
|
|
|
|
# Add the time limit wrapper
|
|
if max_episode_steps is not None:
|
|
env = TimeLimit(env, max_episode_steps)
|
|
elif spec_.max_episode_steps is not None:
|
|
env = TimeLimit(env, spec_.max_episode_steps)
|
|
|
|
# Add the autoreset wrapper
|
|
if autoreset:
|
|
env = AutoResetWrapper(env)
|
|
|
|
# Add human rendering wrapper
|
|
if apply_human_rendering:
|
|
env = HumanRendering(env)
|
|
elif apply_render_collection:
|
|
env = RenderCollection(env)
|
|
|
|
return env
|
|
|
|
|
|
def spec(env_id: str) -> EnvSpec:
|
|
"""Retrieve the spec for the given environment from the global registry."""
|
|
spec_ = registry.get(env_id)
|
|
if spec_ is None:
|
|
ns, name, version = parse_env_id(env_id)
|
|
_check_version_exists(ns, name, version)
|
|
raise error.Error(f"No registered env with id: {env_id}")
|
|
else:
|
|
assert isinstance(spec_, EnvSpec)
|
|
return spec_
|