mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Reorder functions and refactor registration.py
(#289)
Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
@@ -2,36 +2,39 @@
|
|||||||
title: Registry
|
title: Registry
|
||||||
---
|
---
|
||||||
|
|
||||||
# Registry
|
# Register and Make
|
||||||
|
|
||||||
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers.
|
```{eval-rst}
|
||||||
Environments can also be created through python imports.
|
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`.
|
||||||
|
```
|
||||||
## Make
|
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autofunction:: gymnasium.make
|
.. autofunction:: gymnasium.make
|
||||||
```
|
|
||||||
|
|
||||||
## Register
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autofunction:: gymnasium.register
|
.. autofunction:: gymnasium.register
|
||||||
```
|
|
||||||
|
|
||||||
## All registered environments
|
|
||||||
|
|
||||||
To find all the registered Gymnasium environments, use the `gymnasium.pprint_registry()`.
|
|
||||||
This will not include environments registered only in OpenAI Gym however can be loaded by `gymnasium.make`.
|
|
||||||
|
|
||||||
## Spec
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autofunction:: gymnasium.spec
|
.. autofunction:: gymnasium.spec
|
||||||
```
|
|
||||||
|
|
||||||
## Pretty print registry
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autofunction:: gymnasium.pprint_registry
|
.. autofunction:: gymnasium.pprint_registry
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Core variables
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.envs.registration.EnvSpec
|
||||||
|
.. attribute:: gymnasium.envs.registration.registry
|
||||||
|
|
||||||
|
The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments.
|
||||||
|
|
||||||
|
.. attribute:: gymnasium.envs.registration.current_namespace
|
||||||
|
|
||||||
|
The current namespace when creating or registering environments. This is by default ``None`` by with :meth:`namespace` this can be modified to automatically set the environment id namespace.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional functions
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autofunction:: gymnasium.envs.registration.get_env_id
|
||||||
|
.. autofunction:: gymnasium.envs.registration.parse_env_id
|
||||||
|
.. autofunction:: gymnasium.envs.registration.find_highest_version
|
||||||
|
.. autofunction:: gymnasium.envs.registration.namespace
|
||||||
|
.. autofunction:: gymnasium.envs.registration.load_env
|
||||||
|
.. autofunction:: gymnasium.envs.registration.load_plugin_envs
|
||||||
|
```
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from gymnasium.envs.registration import (
|
from gymnasium.envs.registration import (
|
||||||
load_env_plugins,
|
load_plugin_envs,
|
||||||
make,
|
make,
|
||||||
pprint_registry,
|
pprint_registry,
|
||||||
register,
|
register,
|
||||||
@@ -363,4 +363,4 @@ register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error)
|
|||||||
|
|
||||||
|
|
||||||
# Hook to load plugins from entry points
|
# Hook to load plugins from entry points
|
||||||
load_env_plugins()
|
load_plugin_envs()
|
||||||
|
@@ -9,13 +9,11 @@ import importlib.util
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Iterable, Sequence, SupportsFloat, overload
|
from typing import Any, Iterable
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
from gymnasium import Env, error, logger
|
||||||
from gymnasium.wrappers import (
|
from gymnasium.wrappers import (
|
||||||
AutoResetWrapper,
|
AutoResetWrapper,
|
||||||
HumanRendering,
|
HumanRendering,
|
||||||
@@ -32,12 +30,10 @@ if sys.version_info < (3, 10):
|
|||||||
else:
|
else:
|
||||||
import importlib.metadata as metadata
|
import importlib.metadata as metadata
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
if sys.version_info < (3, 8):
|
||||||
from typing import Literal
|
from typing_extensions import Protocol
|
||||||
else:
|
else:
|
||||||
from typing_extensions import Literal
|
from typing import Protocol
|
||||||
|
|
||||||
from gymnasium import Env, error, logger
|
|
||||||
|
|
||||||
|
|
||||||
ENV_ID_RE = re.compile(
|
ENV_ID_RE = re.compile(
|
||||||
@@ -45,50 +41,99 @@ ENV_ID_RE = re.compile(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(name: str) -> Callable:
|
__all__ = [
|
||||||
"""Loads an environment with name and returns an environment creation function.
|
"EnvSpec",
|
||||||
|
"registry",
|
||||||
|
"current_namespace",
|
||||||
|
"register",
|
||||||
|
"make",
|
||||||
|
"spec",
|
||||||
|
"pprint_registry",
|
||||||
|
]
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The environment name
|
|
||||||
|
|
||||||
Returns:
|
class EnvCreator(Protocol):
|
||||||
Calls the environment constructor
|
"""Function type expected for an environment."""
|
||||||
|
|
||||||
|
def __call__(self, **kwargs: Any) -> Env:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnvSpec:
|
||||||
|
"""A specification for creating environments with :meth:`gymnasium.make`.
|
||||||
|
|
||||||
|
* **id**: The string used to create the environment with :meth:`gymnasium.make`
|
||||||
|
* **entry_point**: A string for the environment location, ``(import path):(environment name)`` or a function that creates the environment.
|
||||||
|
* **reward_threshold**: The reward threshold for completing the environment.
|
||||||
|
* **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
|
||||||
|
* **max_episode_steps**: The max number of steps that the environment can take before truncation
|
||||||
|
* **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions
|
||||||
|
* **autoreset**: If to automatically reset the environment on episode end
|
||||||
|
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
||||||
|
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
||||||
"""
|
"""
|
||||||
mod_name, attr_name = name.split(":")
|
|
||||||
mod = importlib.import_module(mod_name)
|
id: str
|
||||||
fn = getattr(mod, attr_name)
|
entry_point: EnvCreator | str
|
||||||
return fn
|
|
||||||
|
# Environment attributes
|
||||||
|
reward_threshold: float | None = field(default=None)
|
||||||
|
nondeterministic: bool = field(default=False)
|
||||||
|
|
||||||
|
# Wrappers
|
||||||
|
max_episode_steps: int | None = field(default=None)
|
||||||
|
order_enforce: bool = field(default=True)
|
||||||
|
autoreset: bool = field(default=False)
|
||||||
|
disable_env_checker: bool = field(default=False)
|
||||||
|
apply_api_compatibility: bool = field(default=False)
|
||||||
|
|
||||||
|
# post-init attributes
|
||||||
|
namespace: str | None = field(init=False)
|
||||||
|
name: str = field(init=False)
|
||||||
|
version: int | None = field(init=False)
|
||||||
|
|
||||||
|
# Environment arguments
|
||||||
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Calls after the spec is created to extract the namespace, name and version from the id."""
|
||||||
|
# Initialize namespace, name, version
|
||||||
|
self.namespace, self.name, self.version = parse_env_id(self.id)
|
||||||
|
|
||||||
|
def make(self, **kwargs: Any) -> Env:
|
||||||
|
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
||||||
|
# For compatibility purposes
|
||||||
|
return make(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def parse_env_id(id: str) -> tuple[str | None, str, int | None]:
|
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||||
"""Parse environment ID string format.
|
registry: dict[str, EnvSpec] = {}
|
||||||
|
current_namespace: str | None = None
|
||||||
|
|
||||||
This format is true today, but it's *not* an official spec.
|
|
||||||
[namespace/](env-name)-v(version) env-name is group 1, version is group 2
|
|
||||||
|
|
||||||
2016-10-31: We're experimentally expanding the environment ID format
|
def parse_env_id(env_id: str) -> tuple[str | None, str, int | None]:
|
||||||
to include an optional namespace.
|
"""Parse environment ID string format - ``[namespace/](env-name)[-v(version)]`` where the namespace and version are optional.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id: The environment id to parse
|
env_id: The environment id to parse
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of environment namespace, environment name and version number
|
A tuple of environment namespace, environment name and version number
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Error: If the environment id does not a valid environment regex
|
Error: If the environment id is not valid environment regex
|
||||||
"""
|
"""
|
||||||
match = ENV_ID_RE.fullmatch(id)
|
match = ENV_ID_RE.fullmatch(env_id)
|
||||||
if not match:
|
if not match:
|
||||||
raise error.Error(
|
raise error.Error(
|
||||||
f"Malformed environment ID: {id}."
|
f"Malformed environment ID: {env_id}. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
||||||
f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
|
||||||
)
|
)
|
||||||
namespace, name, version = match.group("namespace", "name", "version")
|
ns, name, version = match.group("namespace", "name", "version")
|
||||||
if version is not None:
|
if version is not None:
|
||||||
version = int(version)
|
version = int(version)
|
||||||
|
|
||||||
return namespace, name, version
|
return ns, name, version
|
||||||
|
|
||||||
|
|
||||||
def get_env_id(ns: str | None, name: str, version: int | None) -> str:
|
def get_env_id(ns: str | None, name: str, version: int | None) -> str:
|
||||||
@@ -103,97 +148,80 @@ def get_env_id(ns: str | None, name: str, version: int | None) -> str:
|
|||||||
The environment id
|
The environment id
|
||||||
"""
|
"""
|
||||||
full_name = name
|
full_name = name
|
||||||
if version is not None:
|
|
||||||
full_name += f"-v{version}"
|
|
||||||
if ns is not None:
|
if ns is not None:
|
||||||
full_name = ns + "/" + full_name
|
full_name = f"{ns}/{name}"
|
||||||
|
if version is not None:
|
||||||
|
full_name = f"{full_name}-v{version}"
|
||||||
|
|
||||||
return full_name
|
return full_name
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def find_highest_version(ns: str | None, name: str) -> int | None:
|
||||||
class EnvSpec:
|
"""Finds the highest registered version of the environment given the namespace and name in the registry.
|
||||||
"""A specification for creating environments with `gym.make`.
|
|
||||||
|
|
||||||
* id: The string used to create the environment with `gym.make`
|
Args:
|
||||||
* entry_point: The location of the environment to create from
|
ns: The environment namespace
|
||||||
* reward_threshold: The reward threshold for completing the environment.
|
name: The environment name (id)
|
||||||
* nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
|
|
||||||
* max_episode_steps: The max number of steps that the environment can take before truncation
|
Returns:
|
||||||
* order_enforce: If to enforce the order of `reset` before `step` and `render` functions
|
The highest version of an environment with matching namespace and name, otherwise ``None`` is returned.
|
||||||
* autoreset: If to automatically reset the environment on episode end
|
|
||||||
* disable_env_checker: If to disable the environment checker wrapper in `gym.make`, by default False (runs the environment checker)
|
|
||||||
* kwargs: Additional keyword arguments passed to the environments through `gym.make`
|
|
||||||
"""
|
"""
|
||||||
|
version: list[int] = [
|
||||||
id: str
|
env_spec.version
|
||||||
entry_point: Callable | str
|
for env_spec in registry.values()
|
||||||
|
if env_spec.namespace == ns
|
||||||
# Environment attributes
|
and env_spec.name == name
|
||||||
reward_threshold: float | None = field(default=None)
|
and env_spec.version is not None
|
||||||
nondeterministic: bool = field(default=False)
|
]
|
||||||
|
return max(version, default=None)
|
||||||
# Wrappers
|
|
||||||
max_episode_steps: int | None = field(default=None)
|
|
||||||
order_enforce: bool = field(default=True)
|
|
||||||
autoreset: bool = field(default=False)
|
|
||||||
disable_env_checker: bool = field(default=False)
|
|
||||||
apply_api_compatibility: bool = field(default=False)
|
|
||||||
|
|
||||||
# Environment arguments
|
|
||||||
kwargs: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
# post-init attributes
|
|
||||||
namespace: str | None = field(init=False)
|
|
||||||
name: str = field(init=False)
|
|
||||||
version: int | None = field(init=False)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Calls after the spec is created to extract the namespace, name and version from the id."""
|
|
||||||
# Initialize namespace, name, version
|
|
||||||
self.namespace, self.name, self.version = parse_env_id(self.id)
|
|
||||||
|
|
||||||
def make(self, **kwargs: Any) -> Env:
|
|
||||||
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
|
||||||
# For compatibility purposes
|
|
||||||
return make(self, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_namespace_exists(ns: str | None):
|
def _check_namespace_exists(ns: str | None):
|
||||||
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
|
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
|
||||||
|
# If the namespace is none, then the namespace does exist
|
||||||
if ns is None:
|
if ns is None:
|
||||||
return
|
return
|
||||||
namespaces = {
|
|
||||||
spec_.namespace for spec_ in registry.values() if spec_.namespace is not None
|
# Check if the namespace exists in one of the registry's specs
|
||||||
|
namespaces: set[str] = {
|
||||||
|
env_spec.namespace
|
||||||
|
for env_spec in registry.values()
|
||||||
|
if env_spec.namespace is not None
|
||||||
}
|
}
|
||||||
if ns in namespaces:
|
if ns in namespaces:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Otherwise, the namespace doesn't exist and raise a helpful message
|
||||||
suggestion = (
|
suggestion = (
|
||||||
difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
|
difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
|
||||||
)
|
)
|
||||||
suggestion_msg = (
|
if suggestion:
|
||||||
f"Did you mean: `{suggestion[0]}`?"
|
suggestion_msg = f"Did you mean: `{suggestion[0]}`?"
|
||||||
if suggestion
|
else:
|
||||||
else f"Have you installed the proper package for {ns}?"
|
suggestion_msg = f"Have you installed the proper package for {ns}?"
|
||||||
)
|
|
||||||
|
|
||||||
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
|
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
|
||||||
|
|
||||||
|
|
||||||
def _check_name_exists(ns: str | None, name: str):
|
def _check_name_exists(ns: str | None, name: str):
|
||||||
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
|
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
|
||||||
|
# First check if the namespace exists
|
||||||
_check_namespace_exists(ns)
|
_check_namespace_exists(ns)
|
||||||
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
|
|
||||||
|
|
||||||
|
# Then check if the name exists
|
||||||
|
names: set[str] = {
|
||||||
|
env_spec.name for env_spec in registry.values() if env_spec.namespace == ns
|
||||||
|
}
|
||||||
if name in names:
|
if name in names:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Otherwise, raise a helpful error to the user
|
||||||
suggestion = difflib.get_close_matches(name, names, n=1)
|
suggestion = difflib.get_close_matches(name, names, n=1)
|
||||||
namespace_msg = f" in namespace {ns}" if ns else ""
|
namespace_msg = f" in namespace {ns}" if ns else ""
|
||||||
suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else ""
|
suggestion_msg = f" Did you mean: `{suggestion[0]}`?" if suggestion else ""
|
||||||
|
|
||||||
raise error.NameNotFound(
|
raise error.NameNotFound(
|
||||||
f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}"
|
f"Environment `{name}` doesn't exist{namespace_msg}.{suggestion_msg}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -222,26 +250,28 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
|
|||||||
message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
|
message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
|
||||||
|
|
||||||
env_specs = [
|
env_specs = [
|
||||||
spec_
|
env_spec
|
||||||
for spec_ in registry.values()
|
for env_spec in registry.values()
|
||||||
if spec_.namespace == ns and spec_.name == name
|
if env_spec.namespace == ns and env_spec.name == name
|
||||||
]
|
]
|
||||||
env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1))
|
env_specs = sorted(env_specs, key=lambda env_spec: int(env_spec.version or -1))
|
||||||
|
|
||||||
default_spec = [spec_ for spec_ in env_specs if spec_.version is None]
|
default_spec = [env_spec for env_spec in env_specs if env_spec.version is None]
|
||||||
|
|
||||||
if default_spec:
|
if default_spec:
|
||||||
message += f" It provides the default version {default_spec[0].id}`."
|
message += f" It provides the default version `{default_spec[0].id}`."
|
||||||
if len(env_specs) == 1:
|
if len(env_specs) == 1:
|
||||||
raise error.DeprecatedEnv(message)
|
raise error.DeprecatedEnv(message)
|
||||||
|
|
||||||
# Process possible versioned environments
|
# Process possible versioned environments
|
||||||
|
|
||||||
versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None]
|
versioned_specs = [
|
||||||
|
env_spec for env_spec in env_specs if env_spec.version is not None
|
||||||
|
]
|
||||||
|
|
||||||
latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore
|
latest_spec = max(versioned_specs, key=lambda env_spec: env_spec.version, default=None) # type: ignore
|
||||||
if latest_spec is not None and version > latest_spec.version:
|
if latest_spec is not None and version > latest_spec.version:
|
||||||
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs)
|
version_list_msg = ", ".join(f"`v{env_spec.version}`" for env_spec in env_specs)
|
||||||
message += f" It provides versioned environments: [ {version_list_msg} ]."
|
message += f" It provides versioned environments: [ {version_list_msg} ]."
|
||||||
|
|
||||||
raise error.VersionNotFound(message)
|
raise error.VersionNotFound(message)
|
||||||
@@ -253,18 +283,80 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_highest_version(ns: str | None, name: str) -> int | None:
|
def _check_spec_register(testing_spec: EnvSpec):
|
||||||
"""Finds the highest registered version of the environment in the registry."""
|
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
|
||||||
version: list[int] = [
|
latest_versioned_spec = max(
|
||||||
spec_.version
|
(
|
||||||
for spec_ in registry.values()
|
env_spec
|
||||||
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
for env_spec in registry.values()
|
||||||
]
|
if env_spec.namespace == testing_spec.namespace
|
||||||
return max(version, default=None)
|
and env_spec.name == testing_spec.name
|
||||||
|
and env_spec.version is not None
|
||||||
|
),
|
||||||
|
key=lambda spec_: int(spec_.version), # type: ignore
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
unversioned_spec = next(
|
||||||
|
(
|
||||||
|
env_spec
|
||||||
|
for env_spec in registry.values()
|
||||||
|
if env_spec.namespace == testing_spec.namespace
|
||||||
|
and env_spec.name == testing_spec.name
|
||||||
|
and env_spec.version is None
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if unversioned_spec is not None and testing_spec.version is not None:
|
||||||
|
raise error.RegistrationError(
|
||||||
|
"Can't register the versioned environment "
|
||||||
|
f"`{testing_spec.id}` when the unversioned environment "
|
||||||
|
f"`{unversioned_spec.id}` of the same name already exists."
|
||||||
|
)
|
||||||
|
elif latest_versioned_spec is not None and testing_spec.version is None:
|
||||||
|
raise error.RegistrationError(
|
||||||
|
f"Can't register the unversioned environment `{testing_spec.id}` when the versioned environment "
|
||||||
|
f"`{latest_versioned_spec.id}` of the same name already exists. Note: the default behavior is "
|
||||||
|
"that `gym.make` with the unversioned environment will return the latest versioned environment"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
def _check_metadata(testing_metadata: dict[str, Any]):
|
||||||
"""Load modules (plugins) using the gymnasium entry points == to `entry_points`.
|
"""Check the metadata of an environment."""
|
||||||
|
if not isinstance(testing_metadata, dict):
|
||||||
|
raise error.InvalidMetadata(
|
||||||
|
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
render_modes = testing_metadata.get("render_modes")
|
||||||
|
if render_modes is None:
|
||||||
|
logger.warn(
|
||||||
|
f"The environment creator metadata doesn't include `render_modes`, contains: {list(testing_metadata.keys())}"
|
||||||
|
)
|
||||||
|
elif not isinstance(render_modes, Iterable):
|
||||||
|
logger.warn(
|
||||||
|
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_env(name: str) -> EnvCreator:
|
||||||
|
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The environment name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The environment constructor for the given environment name.
|
||||||
|
"""
|
||||||
|
mod_name, attr_name = name.split(":")
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
fn = getattr(mod, attr_name)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
||||||
|
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entry_point: The string for the entry point.
|
entry_point: The string for the entry point.
|
||||||
@@ -282,7 +374,7 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
|||||||
else:
|
else:
|
||||||
module, attr = plugin.value, None
|
module, attr = plugin.value, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(
|
logger.warn(
|
||||||
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
|
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
|
||||||
)
|
)
|
||||||
module, attr = None, None
|
module, attr = None, None
|
||||||
@@ -314,136 +406,6 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
|||||||
logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")
|
logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
@overload
|
|
||||||
def make(id: str, **kwargs) -> Env: ...
|
|
||||||
@overload
|
|
||||||
def make(id: EnvSpec, **kwargs) -> Env: ...
|
|
||||||
|
|
||||||
|
|
||||||
# Classic control
|
|
||||||
# ----------------------------------------
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
|
|
||||||
|
|
||||||
# Box2d
|
|
||||||
# ----------------------------------------
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
|
||||||
|
|
||||||
|
|
||||||
# Toy Text
|
|
||||||
# ----------------------------------------
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
@overload
|
|
||||||
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
|
||||||
|
|
||||||
|
|
||||||
# Mujoco
|
|
||||||
# ----------------------------------------
|
|
||||||
@overload
|
|
||||||
def make(id: Literal[
|
|
||||||
"Reacher-v2", "Reacher-v4",
|
|
||||||
"Pusher-v2", "Pusher-v4",
|
|
||||||
"InvertedPendulum-v2", "InvertedPendulum-v4",
|
|
||||||
"InvertedDoublePendulum-v2", "InvertedDoublePendulum-v4",
|
|
||||||
"HalfCheetah-v2", "HalfCheetah-v3", "HalfCheetah-v4",
|
|
||||||
"Hopper-v2", "Hopper-v3", "Hopper-v4",
|
|
||||||
"Swimmer-v2", "Swimmer-v3", "Swimmer-v4",
|
|
||||||
"Walker2d-v2", "Walker2d-v3", "Walker2d-v4",
|
|
||||||
"Ant-v2", "Ant-v3", "Ant-v4",
|
|
||||||
"HumanoidStandup-v2", "HumanoidStandup-v4",
|
|
||||||
"Humanoid-v2", "Humanoid-v3", "Humanoid-v4",
|
|
||||||
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
# Global registry of environments. Meant to be accessed through `register` and `make`
|
|
||||||
registry: dict[str, EnvSpec] = {}
|
|
||||||
current_namespace: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _check_spec_register(spec: EnvSpec):
|
|
||||||
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
|
|
||||||
global registry
|
|
||||||
latest_versioned_spec = max(
|
|
||||||
(
|
|
||||||
spec_
|
|
||||||
for spec_ in registry.values()
|
|
||||||
if spec_.namespace == spec.namespace
|
|
||||||
and spec_.name == spec.name
|
|
||||||
and spec_.version is not None
|
|
||||||
),
|
|
||||||
key=lambda spec_: int(spec_.version), # type: ignore
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
unversioned_spec = next(
|
|
||||||
(
|
|
||||||
spec_
|
|
||||||
for spec_ in registry.values()
|
|
||||||
if spec_.namespace == spec.namespace
|
|
||||||
and spec_.name == spec.name
|
|
||||||
and spec_.version is None
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if unversioned_spec is not None and spec.version is not None:
|
|
||||||
raise error.RegistrationError(
|
|
||||||
"Can't register the versioned environment "
|
|
||||||
f"`{spec.id}` when the unversioned environment "
|
|
||||||
f"`{unversioned_spec.id}` of the same name already exists."
|
|
||||||
)
|
|
||||||
elif latest_versioned_spec is not None and spec.version is None:
|
|
||||||
raise error.RegistrationError(
|
|
||||||
"Can't register the unversioned environment "
|
|
||||||
f"`{spec.id}` when the versioned environment "
|
|
||||||
f"`{latest_versioned_spec.id}` of the same name "
|
|
||||||
f"already exists. Note: the default behavior is "
|
|
||||||
f"that `gym.make` with the unversioned environment "
|
|
||||||
f"will return the latest versioned environment"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_metadata(metadata_: dict):
|
|
||||||
if not isinstance(metadata_, dict):
|
|
||||||
raise error.InvalidMetadata(
|
|
||||||
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
render_modes = metadata_.get("render_modes")
|
|
||||||
if render_modes is None:
|
|
||||||
logger.warn(
|
|
||||||
f"The environment creator metadata doesn't include `render_modes`, contains: {list(metadata_.keys())}"
|
|
||||||
)
|
|
||||||
elif not isinstance(render_modes, Iterable):
|
|
||||||
logger.warn(
|
|
||||||
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Public API
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def namespace(ns: str):
|
def namespace(ns: str):
|
||||||
"""Context manager for modifying the current namespace."""
|
"""Context manager for modifying the current namespace."""
|
||||||
@@ -456,7 +418,7 @@ def namespace(ns: str):
|
|||||||
|
|
||||||
def register(
|
def register(
|
||||||
id: str,
|
id: str,
|
||||||
entry_point: Callable | str,
|
entry_point: EnvCreator | str,
|
||||||
reward_threshold: float | None = None,
|
reward_threshold: float | None = None,
|
||||||
nondeterministic: bool = False,
|
nondeterministic: bool = False,
|
||||||
max_episode_steps: int | None = None,
|
max_episode_steps: int | None = None,
|
||||||
@@ -464,26 +426,28 @@ def register(
|
|||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
disable_env_checker: bool = False,
|
disable_env_checker: bool = False,
|
||||||
apply_api_compatibility: bool = False,
|
apply_api_compatibility: bool = False,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Register an environment with gymnasium.
|
"""Registers an environment in gymnasium with an ``id`` to use with :meth:`gymnasium.make` with the ``entry_point`` being a string or callable for creating the environment.
|
||||||
|
|
||||||
The `id` parameter corresponds to the name of the environment, with the syntax as follows:
|
The ``id`` parameter corresponds to the name of the environment, with the syntax as follows:
|
||||||
`(namespace)/(env_name)-v(version)` where `namespace` is optional.
|
``[namespace/](env_name)[-v(version)]`` where ``namespace`` and ``-v(version)`` is optional.
|
||||||
|
|
||||||
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
|
It takes arbitrary keyword arguments, which are passed to the :class:`EnvSpec` ``kwargs`` parameter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id: The environment id
|
id: The environment id
|
||||||
entry_point: The entry point for creating the environment
|
entry_point: The entry point for creating the environment
|
||||||
reward_threshold: The reward threshold considered to have learnt an environment
|
reward_threshold: The reward threshold considered for an agent to have learnt the environment
|
||||||
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions)
|
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions, the same state cannot be reached)
|
||||||
max_episode_steps: The maximum number of episodes steps before truncation. Used by the Time Limit wrapper.
|
max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``.
|
||||||
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order
|
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order.
|
||||||
autoreset: If to add the autoreset wrapper such that reset does not need to be called.
|
If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment.
|
||||||
disable_env_checker: If to disable the environment checker for the environment. Recommended to False.
|
autoreset: If to add the :class:`gymnasium.wrappers.AutoResetWrapper` such that on ``(terminated or truncated) is True``, :meth:`gymnasium.Env.reset` is called.
|
||||||
apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper.
|
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
|
||||||
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
|
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
|
||||||
|
Use if the environment is implemented in the gym v0.21 environment API.
|
||||||
|
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
|
||||||
"""
|
"""
|
||||||
global registry, current_namespace
|
global registry, current_namespace
|
||||||
ns, name, version = parse_env_id(id)
|
ns, name, version = parse_env_id(id)
|
||||||
@@ -502,10 +466,10 @@ def register(
|
|||||||
else:
|
else:
|
||||||
ns_id = ns
|
ns_id = ns
|
||||||
|
|
||||||
full_id = get_env_id(ns_id, name, version)
|
full_env_id = get_env_id(ns_id, name, version)
|
||||||
|
|
||||||
new_spec = EnvSpec(
|
new_spec = EnvSpec(
|
||||||
id=full_id,
|
id=full_env_id,
|
||||||
entry_point=entry_point,
|
entry_point=entry_point,
|
||||||
reward_threshold=reward_threshold,
|
reward_threshold=reward_threshold,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
@@ -517,6 +481,7 @@ def register(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
_check_spec_register(new_spec)
|
_check_spec_register(new_spec)
|
||||||
|
|
||||||
if new_spec.id in registry:
|
if new_spec.id in registry:
|
||||||
logger.warn(f"Overriding environment {new_spec.id} already in registry.")
|
logger.warn(f"Overriding environment {new_spec.id} already in registry.")
|
||||||
registry[new_spec.id] = new_spec
|
registry[new_spec.id] = new_spec
|
||||||
@@ -528,36 +493,36 @@ def make(
|
|||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
apply_api_compatibility: bool | None = None,
|
apply_api_compatibility: bool | None = None,
|
||||||
disable_env_checker: bool | None = None,
|
disable_env_checker: bool | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> Env:
|
) -> Env:
|
||||||
"""Create an environment according to the given ID.
|
"""Creates an environment previously registered with :meth:`gymnasium.register` or a :class:`EnvSpec`.
|
||||||
|
|
||||||
To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids.
|
To find all available environments use ``gymnasium.envs.registry.keys()`` for all valid ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``.
|
||||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
This is equivalent to importing the module first to register the environment followed by making the environment.
|
||||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``.
|
||||||
apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
|
The value is used by :class:`gymnasium.wrappers.TimeLimit`.
|
||||||
|
autoreset: Whether to automatically reset the environment after each episode (:class:`gymnasium.wrappers.AutoResetWrapper`).
|
||||||
|
apply_api_compatibility: Whether to wrap the environment with the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper that
|
||||||
converts the environment step from a done bool to return termination and truncation bools.
|
converts the environment step from a done bool to return termination and truncation bools.
|
||||||
By default, the argument is None to which the environment specification `apply_api_compatibility` is used
|
By default, the argument is None in which the :class:`EnvSpec` ``apply_api_compatibility`` is used, otherwise this variable is used in favor.
|
||||||
which defaults to False. Otherwise, the value of `apply_api_compatibility` is used.
|
disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the
|
||||||
If `True`, the wrapper is applied otherwise, the wrapper is not applied.
|
:class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used.
|
||||||
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
|
|
||||||
(which is by default False, running the environment checker),
|
|
||||||
otherwise will run according to this parameter (`True` = not run, `False` = run)
|
|
||||||
kwargs: Additional arguments to pass to the environment constructor.
|
kwargs: Additional arguments to pass to the environment constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the environment.
|
An instance of the environment with wrappers applied.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Error: If the ``id`` doesn't exist then an error is raised
|
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
||||||
"""
|
"""
|
||||||
if isinstance(id, EnvSpec):
|
if isinstance(id, EnvSpec):
|
||||||
spec_ = id
|
env_spec = id
|
||||||
else:
|
else:
|
||||||
module, id = (None, id) if ":" not in id else id.split(":")
|
# The environment name can include an unloaded module in "module:env_name" style
|
||||||
|
module, env_name = (None, id) if ":" not in id else id.split(":")
|
||||||
if module is not None:
|
if module is not None:
|
||||||
try:
|
try:
|
||||||
importlib.import_module(module)
|
importlib.import_module(module)
|
||||||
@@ -566,9 +531,13 @@ def make(
|
|||||||
f"{e}. Environment registration via importing a module failed. "
|
f"{e}. Environment registration via importing a module failed. "
|
||||||
f"Check whether '{module}' contains env registration and can be imported."
|
f"Check whether '{module}' contains env registration and can be imported."
|
||||||
) from e
|
) from e
|
||||||
spec_ = registry.get(id)
|
|
||||||
|
|
||||||
ns, name, version = parse_env_id(id)
|
# load the env spec from the registry
|
||||||
|
env_spec = registry.get(env_name)
|
||||||
|
|
||||||
|
# update env spec is not version provided, raise warning if out of date
|
||||||
|
ns, name, version = parse_env_id(env_name)
|
||||||
|
|
||||||
latest_version = find_highest_version(ns, name)
|
latest_version = find_highest_version(ns, name)
|
||||||
if (
|
if (
|
||||||
version is not None
|
version is not None
|
||||||
@@ -576,38 +545,44 @@ def make(
|
|||||||
and latest_version > version
|
and latest_version > version
|
||||||
):
|
):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"The environment {id} is out of date. You should consider "
|
f"The environment {env_name} is out of date. You should consider "
|
||||||
f"upgrading to version `v{latest_version}`."
|
f"upgrading to version `v{latest_version}`."
|
||||||
)
|
)
|
||||||
if version is None and latest_version is not None:
|
if version is None and latest_version is not None:
|
||||||
version = latest_version
|
version = latest_version
|
||||||
new_env_id = get_env_id(ns, name, version)
|
new_env_id = get_env_id(ns, name, version)
|
||||||
spec_ = registry.get(new_env_id)
|
env_spec = registry.get(new_env_id)
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"Using the latest versioned environment `{new_env_id}` "
|
f"Using the latest versioned environment `{new_env_id}` "
|
||||||
f"instead of the unversioned environment `{id}`."
|
f"instead of the unversioned environment `{env_name}`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if spec_ is None:
|
if env_spec is None:
|
||||||
_check_version_exists(ns, name, version)
|
_check_version_exists(ns, name, version)
|
||||||
raise error.Error(f"No registered env with id: {id}")
|
raise error.Error(f"No registered env with id: {env_name}")
|
||||||
|
|
||||||
_kwargs = spec_.kwargs.copy()
|
assert isinstance(
|
||||||
_kwargs.update(kwargs)
|
env_spec, EnvSpec
|
||||||
|
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}"
|
||||||
|
# Extract the spec kwargs and append the make kwargs
|
||||||
|
spec_kwargs = env_spec.kwargs.copy()
|
||||||
|
spec_kwargs.update(kwargs)
|
||||||
|
|
||||||
if spec_.entry_point is None:
|
# Load the environment creator
|
||||||
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
|
if env_spec.entry_point is None:
|
||||||
elif callable(spec_.entry_point):
|
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
||||||
env_creator = spec_.entry_point
|
elif callable(env_spec.entry_point):
|
||||||
|
env_creator = env_spec.entry_point
|
||||||
else:
|
else:
|
||||||
# Assume it's a string
|
# Assume it's a string
|
||||||
env_creator = load(spec_.entry_point)
|
env_creator = load_env(env_spec.entry_point)
|
||||||
|
|
||||||
render_modes = None
|
# Determine if to use the rendering
|
||||||
|
render_modes: list[str] | None = None
|
||||||
if hasattr(env_creator, "metadata"):
|
if hasattr(env_creator, "metadata"):
|
||||||
_check_metadata(env_creator.metadata)
|
_check_metadata(env_creator.metadata)
|
||||||
render_modes = env_creator.metadata.get("render_modes")
|
render_modes = env_creator.metadata.get("render_modes")
|
||||||
mode = _kwargs.get("render_mode")
|
mode = spec_kwargs.get("render_mode")
|
||||||
apply_human_rendering = False
|
apply_human_rendering = False
|
||||||
apply_render_collection = False
|
apply_render_collection = False
|
||||||
|
|
||||||
@@ -619,10 +594,10 @@ def make(
|
|||||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
||||||
"The HumanRendering wrapper is being applied to your environment."
|
"The HumanRendering wrapper is being applied to your environment."
|
||||||
)
|
)
|
||||||
_kwargs["render_mode"] = displayable_modes.pop()
|
spec_kwargs["render_mode"] = displayable_modes.pop()
|
||||||
apply_human_rendering = True
|
apply_human_rendering = True
|
||||||
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
||||||
_kwargs["render_mode"] = mode[: -len("_list")]
|
spec_kwargs["render_mode"] = mode[: -len("_list")]
|
||||||
apply_render_collection = True
|
apply_render_collection = True
|
||||||
else:
|
else:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
@@ -630,16 +605,16 @@ def make(
|
|||||||
f"that is not in the possible render_modes ({render_modes})."
|
f"that is not in the possible render_modes ({render_modes})."
|
||||||
)
|
)
|
||||||
|
|
||||||
if apply_api_compatibility is True or (
|
if apply_api_compatibility or (
|
||||||
apply_api_compatibility is None and spec_.apply_api_compatibility is True
|
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
||||||
):
|
):
|
||||||
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
||||||
render_mode = _kwargs.pop("render_mode", None)
|
render_mode = spec_kwargs.pop("render_mode", None)
|
||||||
else:
|
else:
|
||||||
render_mode = None
|
render_mode = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
env = env_creator(**_kwargs)
|
env = env_creator(**spec_kwargs)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
if (
|
if (
|
||||||
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
||||||
@@ -654,31 +629,31 @@ def make(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||||
spec_ = copy.deepcopy(spec_)
|
env_spec = copy.deepcopy(env_spec)
|
||||||
spec_.kwargs = _kwargs
|
env_spec.kwargs = spec_kwargs
|
||||||
env.unwrapped.spec = spec_
|
env.unwrapped.spec = env_spec
|
||||||
|
|
||||||
# Add step API wrapper
|
# Add step API wrapper
|
||||||
if apply_api_compatibility is True or (
|
if apply_api_compatibility is True or (
|
||||||
apply_api_compatibility is None and spec_.apply_api_compatibility is True
|
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
||||||
):
|
):
|
||||||
env = EnvCompatibility(env, render_mode)
|
env = EnvCompatibility(env, render_mode)
|
||||||
|
|
||||||
# Run the environment checker as the lowest level wrapper
|
# Run the environment checker as the lowest level wrapper
|
||||||
if disable_env_checker is False or (
|
if disable_env_checker is False or (
|
||||||
disable_env_checker is None and spec_.disable_env_checker is False
|
disable_env_checker is None and env_spec.disable_env_checker is False
|
||||||
):
|
):
|
||||||
env = PassiveEnvChecker(env)
|
env = PassiveEnvChecker(env)
|
||||||
|
|
||||||
# Add the order enforcing wrapper
|
# Add the order enforcing wrapper
|
||||||
if spec_.order_enforce:
|
if env_spec.order_enforce:
|
||||||
env = OrderEnforcing(env)
|
env = OrderEnforcing(env)
|
||||||
|
|
||||||
# Add the time limit wrapper
|
# Add the time limit wrapper
|
||||||
if max_episode_steps is not None:
|
if max_episode_steps is not None:
|
||||||
env = TimeLimit(env, max_episode_steps)
|
env = TimeLimit(env, max_episode_steps)
|
||||||
elif spec_.max_episode_steps is not None:
|
elif env_spec.max_episode_steps is not None:
|
||||||
env = TimeLimit(env, spec_.max_episode_steps)
|
env = TimeLimit(env, env_spec.max_episode_steps)
|
||||||
|
|
||||||
# Add the autoreset wrapper
|
# Add the autoreset wrapper
|
||||||
if autoreset:
|
if autoreset:
|
||||||
@@ -694,74 +669,101 @@ def make(
|
|||||||
|
|
||||||
|
|
||||||
def spec(env_id: str) -> EnvSpec:
|
def spec(env_id: str) -> EnvSpec:
|
||||||
"""Retrieve the spec for the given environment from the global registry."""
|
"""Retrieve the :class:`EnvSpec` for the environment id from the :attr:`registry`.
|
||||||
spec_ = registry.get(env_id)
|
|
||||||
if spec_ is None:
|
Args:
|
||||||
|
env_id: The environment id with the expected format of ``[(namespace)/]id[-v(version)]``
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The environment spec if it exists
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Error: If the environment id doesn't exist
|
||||||
|
"""
|
||||||
|
env_spec = registry.get(env_id)
|
||||||
|
if env_spec is None:
|
||||||
ns, name, version = parse_env_id(env_id)
|
ns, name, version = parse_env_id(env_id)
|
||||||
_check_version_exists(ns, name, version)
|
_check_version_exists(ns, name, version)
|
||||||
raise error.Error(f"No registered env with id: {env_id}")
|
raise error.Error(f"No registered env with id: {env_id}")
|
||||||
else:
|
else:
|
||||||
assert isinstance(spec_, EnvSpec)
|
assert isinstance(
|
||||||
return spec_
|
env_spec, EnvSpec
|
||||||
|
), f"Expected the registry for {env_id} to be an `EnvSpec`, actual type is {type(env_spec)}"
|
||||||
|
return env_spec
|
||||||
|
|
||||||
|
|
||||||
def pprint_registry(
|
def pprint_registry(
|
||||||
_registry: dict = registry,
|
print_registry: dict[str, EnvSpec] = registry,
|
||||||
|
*,
|
||||||
num_cols: int = 3,
|
num_cols: int = 3,
|
||||||
exclude_namespaces: list[str] | None = None,
|
exclude_namespaces: list[str] | None = None,
|
||||||
disable_print: bool = False,
|
disable_print: bool = False,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Pretty print the environments in the registry.
|
"""Pretty prints all environments in the :attr:`registry`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
All arguments are keyword only
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_registry: Environment registry to be printed.
|
print_registry: Environment registry to be printed. By default, :attr:`registry`
|
||||||
num_cols: Number of columns to arrange environments in, for display.
|
num_cols: Number of columns to arrange environments in, for display.
|
||||||
exclude_namespaces: Exclude any namespaces from being printed.
|
exclude_namespaces: A list of namespaces to be excluded from printing. Helpful if only ALE environments are wanted.
|
||||||
disable_print: Whether to return a string of all the namespaces and environment IDs
|
disable_print: Whether to return a string of all the namespaces and environment IDs
|
||||||
instead of printing it to console.
|
or to print the string to console.
|
||||||
"""
|
"""
|
||||||
# Defaultdict to store environment names according to namespace.
|
# Defaultdict to store environment ids according to namespace.
|
||||||
namespace_envs = defaultdict(lambda: [])
|
namespace_envs: dict[str, list[str]] = defaultdict(lambda: [])
|
||||||
max_justify = float("-inf")
|
max_justify = float("-inf")
|
||||||
for env in _registry.values():
|
|
||||||
namespace, _, _ = parse_env_id(env.id)
|
|
||||||
if namespace is None:
|
|
||||||
# Since namespace is currently none, use regex to obtain namespace from entrypoints.
|
|
||||||
env_entry_point = re.sub(r":\w+", "", env.entry_point)
|
|
||||||
e_ep_split = env_entry_point.split(".")
|
|
||||||
if len(e_ep_split) >= 3:
|
|
||||||
# If namespace is of the format - gymnasium.envs.mujoco.ant_v4:AntEnv
|
|
||||||
# or gymnasium.envs.mujoco:HumanoidEnv
|
|
||||||
idx = 2
|
|
||||||
namespace = e_ep_split[idx]
|
|
||||||
elif len(e_ep_split) > 1:
|
|
||||||
# If namespace is of the format - shimmy.atari_env
|
|
||||||
idx = 1
|
|
||||||
namespace = e_ep_split[idx]
|
|
||||||
else:
|
|
||||||
# If namespace cannot be found, default to env id.
|
|
||||||
namespace = env.id
|
|
||||||
namespace_envs[namespace].append(env.id)
|
|
||||||
max_justify = max(max_justify, len(env.id))
|
|
||||||
|
|
||||||
# Iterate through each namespace and print environment alphabetically.
|
# Find the namespace associated with each environment spec
|
||||||
return_str = ""
|
for env_spec in print_registry.values():
|
||||||
for namespace, envs in namespace_envs.items():
|
ns = env_spec.namespace
|
||||||
|
|
||||||
|
if ns is None and isinstance(env_spec.entry_point, str):
|
||||||
|
# Use regex to obtain namespace from entrypoints.
|
||||||
|
env_entry_point = re.sub(r":\w+", "", env_spec.entry_point)
|
||||||
|
split_entry_point = env_entry_point.split(".")
|
||||||
|
|
||||||
|
if len(split_entry_point) >= 3:
|
||||||
|
# If namespace is of the format:
|
||||||
|
# - gymnasium.envs.mujoco.ant_v4:AntEnv
|
||||||
|
# - gymnasium.envs.mujoco:HumanoidEnv
|
||||||
|
ns = split_entry_point[2]
|
||||||
|
elif len(split_entry_point) > 1:
|
||||||
|
# If namespace is of the format - shimmy.atari_env
|
||||||
|
ns = split_entry_point[1]
|
||||||
|
else:
|
||||||
|
# If namespace cannot be found, default to env name
|
||||||
|
ns = env_spec.name
|
||||||
|
|
||||||
|
namespace_envs[ns].append(env_spec.id)
|
||||||
|
max_justify = max(max_justify, len(env_spec.name))
|
||||||
|
|
||||||
|
# Iterate through each namespace and print environment alphabetically
|
||||||
|
output: list[str] = []
|
||||||
|
for ns, env_ids in namespace_envs.items():
|
||||||
# Ignore namespaces to exclude.
|
# Ignore namespaces to exclude.
|
||||||
if exclude_namespaces is not None and namespace in exclude_namespaces:
|
if exclude_namespaces is not None and ns in exclude_namespaces:
|
||||||
continue
|
continue
|
||||||
return_str += f"{'=' * 5} {namespace} {'=' * 5}\n" # Print namespace.
|
|
||||||
|
# Print the namespace
|
||||||
|
namespace_output = f"{'=' * 5} {ns} {'=' * 5}\n"
|
||||||
|
|
||||||
# Reference: https://stackoverflow.com/a/33464001
|
# Reference: https://stackoverflow.com/a/33464001
|
||||||
for count, item in enumerate(sorted(envs), 1):
|
for count, env_id in enumerate(sorted(env_ids), 1):
|
||||||
return_str += (
|
# Print column with justification.
|
||||||
item.ljust(max_justify) + " "
|
namespace_output += env_id.ljust(max_justify) + " "
|
||||||
) # Print column with justification.
|
|
||||||
# Once all rows printed, switch to new column.
|
# Once all rows printed, switch to new column.
|
||||||
if count % num_cols == 0 or count == len(envs):
|
if count % num_cols == 0:
|
||||||
return_str = return_str.rstrip(" ") + "\n"
|
namespace_output = namespace_output.rstrip(" ")
|
||||||
return_str += "\n"
|
|
||||||
|
if count != len(env_ids):
|
||||||
|
namespace_output += "\n"
|
||||||
|
|
||||||
|
output.append(namespace_output.rstrip(" "))
|
||||||
|
|
||||||
if disable_print:
|
if disable_print:
|
||||||
return return_str
|
return "\n".join(output)
|
||||||
else:
|
else:
|
||||||
print(return_str, end="")
|
print("\n".join(output))
|
||||||
|
0
tests/envs/registration/__init__.py
Normal file
0
tests/envs/registration/__init__.py
Normal file
@@ -1,14 +1,14 @@
|
|||||||
"""Tests that gym.make works as expected."""
|
"""Tests that gym.make works as expected."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.classic_control import cartpole
|
from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
from gymnasium.wrappers import (
|
from gymnasium.wrappers import (
|
||||||
AutoResetWrapper,
|
AutoResetWrapper,
|
||||||
HumanRendering,
|
HumanRendering,
|
||||||
@@ -16,9 +16,8 @@ from gymnasium.wrappers import (
|
|||||||
TimeLimit,
|
TimeLimit,
|
||||||
)
|
)
|
||||||
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
||||||
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
|
from tests.envs.registration.utils_envs import ArgumentEnv
|
||||||
from tests.envs.utils import all_testing_env_specs
|
from tests.envs.utils import all_testing_env_specs
|
||||||
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
|
||||||
from tests.testing_env import GenericTestEnv, old_step_func
|
from tests.testing_env import GenericTestEnv, old_step_func
|
||||||
from tests.wrappers.utils import has_wrapper
|
from tests.wrappers.utils import has_wrapper
|
||||||
|
|
||||||
@@ -30,16 +29,11 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def register_make_testing_envs():
|
def register_testing_envs():
|
||||||
"""Registers testing envs for `gym.make`"""
|
"""Registers testing envs for `gym.make`"""
|
||||||
gym.register(
|
|
||||||
"RegisterDuringMakeEnv-v0",
|
|
||||||
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
|
|
||||||
)
|
|
||||||
|
|
||||||
gym.register(
|
gym.register(
|
||||||
id="test.ArgumentEnv-v0",
|
id="test.ArgumentEnv-v0",
|
||||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
||||||
kwargs={
|
kwargs={
|
||||||
"arg1": "arg1",
|
"arg1": "arg1",
|
||||||
"arg2": "arg2",
|
"arg2": "arg2",
|
||||||
@@ -48,26 +42,25 @@ def register_make_testing_envs():
|
|||||||
|
|
||||||
gym.register(
|
gym.register(
|
||||||
id="test/NoHuman-v0",
|
id="test/NoHuman-v0",
|
||||||
entry_point="tests.envs.utils_envs:NoHuman",
|
entry_point="tests.envs.registration.utils_envs:NoHuman",
|
||||||
)
|
)
|
||||||
gym.register(
|
gym.register(
|
||||||
id="test/NoHumanOldAPI-v0",
|
id="test/NoHumanOldAPI-v0",
|
||||||
entry_point="tests.envs.utils_envs:NoHumanOldAPI",
|
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
|
||||||
)
|
)
|
||||||
|
|
||||||
gym.register(
|
gym.register(
|
||||||
id="test/NoHumanNoRGB-v0",
|
id="test/NoHumanNoRGB-v0",
|
||||||
entry_point="tests.envs.utils_envs:NoHumanNoRGB",
|
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
|
||||||
)
|
)
|
||||||
|
|
||||||
gym.register(
|
gym.register(
|
||||||
id="test/NoRenderModesMetadata-v0",
|
id="test/NoRenderModesMetadata-v0",
|
||||||
entry_point="tests.envs.utils_envs:NoRenderModesMetadata",
|
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
del gym.envs.registration.registry["RegisterDuringMakeEnv-v0"]
|
|
||||||
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
|
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
|
||||||
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
|
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
|
||||||
del gym.envs.registration.registry["test/NoHuman-v0"]
|
del gym.envs.registration.registry["test/NoHuman-v0"]
|
||||||
@@ -76,14 +69,16 @@ def register_make_testing_envs():
|
|||||||
|
|
||||||
|
|
||||||
def test_make():
|
def test_make():
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
"""Test basic `gym.make`."""
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
assert env.spec is not None
|
assert env.spec is not None
|
||||||
assert env.spec.id == "CartPole-v1"
|
assert env.spec.id == "CartPole-v1"
|
||||||
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
|
assert isinstance(env.unwrapped, CartPoleEnv)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def test_make_deprecated():
|
def test_make_deprecated():
|
||||||
|
"""Test make with a deprecated environment (i.e., doesn't exist)."""
|
||||||
with warnings.catch_warnings(record=True):
|
with warnings.catch_warnings(record=True):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
gym.error.Error,
|
gym.error.Error,
|
||||||
@@ -91,21 +86,20 @@ def test_make_deprecated():
|
|||||||
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
|
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
gym.make("Humanoid-v0", disable_env_checker=True)
|
gym.make("Humanoid-v0")
|
||||||
|
|
||||||
|
|
||||||
def test_make_max_episode_steps(register_make_testing_envs):
|
def test_make_max_episode_steps(register_testing_envs):
|
||||||
# Default, uses the spec's
|
# Default, uses the spec's
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
env = gym.make("CartPole-v1")
|
||||||
assert has_wrapper(env, TimeLimit)
|
assert has_wrapper(env, TimeLimit)
|
||||||
assert env.spec is not None
|
assert env.spec is not None
|
||||||
assert (
|
assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps
|
||||||
env.spec.max_episode_steps == gym.envs.registry["CartPole-v1"].max_episode_steps
|
|
||||||
)
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
# Custom max episode steps
|
# Custom max episode steps
|
||||||
env = gym.make("CartPole-v1", max_episode_steps=100, disable_env_checker=True)
|
assert gym.spec("CartPole-v1").max_episode_steps != 100
|
||||||
|
env = gym.make("CartPole-v1", max_episode_steps=100)
|
||||||
assert has_wrapper(env, TimeLimit)
|
assert has_wrapper(env, TimeLimit)
|
||||||
assert env.spec is not None
|
assert env.spec is not None
|
||||||
assert env.spec.max_episode_steps == 100
|
assert env.spec.max_episode_steps == 100
|
||||||
@@ -113,20 +107,20 @@ def test_make_max_episode_steps(register_make_testing_envs):
|
|||||||
|
|
||||||
# Env spec has no max episode steps
|
# Env spec has no max episode steps
|
||||||
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
|
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
|
||||||
env = gym.make(
|
env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None)
|
||||||
"test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None, disable_env_checker=True
|
assert env.spec is not None
|
||||||
)
|
assert env.spec.max_episode_steps is None
|
||||||
assert has_wrapper(env, TimeLimit) is False
|
assert has_wrapper(env, TimeLimit) is False
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def test_gym_make_autoreset():
|
def test_make_autoreset():
|
||||||
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
|
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
env = gym.make("CartPole-v1")
|
||||||
assert has_wrapper(env, AutoResetWrapper) is False
|
assert has_wrapper(env, AutoResetWrapper) is False
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", autoreset=False, disable_env_checker=True)
|
env = gym.make("CartPole-v1", autoreset=False)
|
||||||
assert has_wrapper(env, AutoResetWrapper) is False
|
assert has_wrapper(env, AutoResetWrapper) is False
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
@@ -135,43 +129,49 @@ def test_gym_make_autoreset():
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def test_make_disable_env_checker():
|
@pytest.mark.parametrize(
|
||||||
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`."""
|
"registration_disabled, make_disabled, if_disabled",
|
||||||
spec = deepcopy(gym.spec("CartPole-v1"))
|
[
|
||||||
|
[False, False, False],
|
||||||
|
[False, True, True],
|
||||||
|
[True, False, False],
|
||||||
|
[True, True, True],
|
||||||
|
[False, None, False],
|
||||||
|
[True, None, True],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_make_disable_env_checker(
|
||||||
|
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
|
||||||
|
):
|
||||||
|
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
|
||||||
|
|
||||||
# Test with spec disable env checker
|
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
|
||||||
spec.disable_env_checker = False
|
"""
|
||||||
env = gym.make(spec)
|
gym.register(
|
||||||
assert has_wrapper(env, PassiveEnvChecker)
|
"testing-env-v0",
|
||||||
|
lambda: GenericTestEnv(),
|
||||||
|
disable_env_checker=registration_disabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test when the registered EnvSpec.disable_env_checker = False
|
||||||
|
env = gym.make("testing-env-v0", disable_env_checker=make_disabled)
|
||||||
|
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
# Test with overwritten spec using make disable env checker
|
del gym.registry["testing-env-v0"]
|
||||||
assert spec.disable_env_checker is False
|
|
||||||
env = gym.make(spec, disable_env_checker=True)
|
|
||||||
assert has_wrapper(env, PassiveEnvChecker) is False
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
# Test with spec enabled disable env checker
|
|
||||||
spec.disable_env_checker = True
|
|
||||||
env = gym.make(spec)
|
|
||||||
assert has_wrapper(env, PassiveEnvChecker) is False
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
# Test with overwritten spec using make disable env checker
|
|
||||||
assert spec.disable_env_checker is True
|
|
||||||
env = gym.make(spec, disable_env_checker=False)
|
|
||||||
assert has_wrapper(env, PassiveEnvChecker)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_apply_api_compatibility():
|
def test_make_apply_api_compatibility():
|
||||||
|
"""Test the API compatibility wrapper."""
|
||||||
gym.register(
|
gym.register(
|
||||||
"testing-old-env",
|
"testing-old-env",
|
||||||
lambda: GenericTestEnv(step_func=old_step_func),
|
lambda: GenericTestEnv(step_func=old_step_func),
|
||||||
apply_api_compatibility=True,
|
apply_api_compatibility=True,
|
||||||
max_episode_steps=3,
|
max_episode_steps=3,
|
||||||
)
|
)
|
||||||
|
# Apply the environment compatibility and check it works as intended
|
||||||
env = gym.make("testing-old-env")
|
env = gym.make("testing-old-env")
|
||||||
|
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
assert len(env.step(env.action_space.sample())) == 5
|
assert len(env.step(env.action_space.sample())) == 5
|
||||||
@@ -179,11 +179,19 @@ def test_apply_api_compatibility():
|
|||||||
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
||||||
assert termination is False and truncation is True
|
assert termination is False and truncation is True
|
||||||
|
|
||||||
|
# Turn off the spec api compatibility
|
||||||
gym.spec("testing-old-env").apply_api_compatibility = False
|
gym.spec("testing-old-env").apply_api_compatibility = False
|
||||||
env = gym.make("testing-old-env")
|
env = gym.make("testing-old-env")
|
||||||
# Cannot run reset and step as will not work
|
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
|
||||||
|
env.reset()
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)")
|
||||||
|
):
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
|
||||||
|
# Apply the environment compatibility and check it works as intended
|
||||||
env = gym.make("testing-old-env", apply_api_compatibility=True)
|
env = gym.make("testing-old-env", apply_api_compatibility=True)
|
||||||
|
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
assert len(env.step(env.action_space.sample())) == 5
|
assert len(env.step(env.action_space.sample())) == 5
|
||||||
@@ -191,57 +199,63 @@ def test_apply_api_compatibility():
|
|||||||
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
||||||
assert termination is False and truncation is True
|
assert termination is False and truncation is True
|
||||||
|
|
||||||
gym.envs.registry.pop("testing-old-env")
|
del gym.registry["testing-old-env"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
|
||||||
)
|
|
||||||
def test_passive_checker_wrapper_warnings(spec):
|
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
||||||
env = gym.make(spec) # disable_env_checker=False
|
|
||||||
env.reset()
|
|
||||||
env.step(env.action_space.sample())
|
|
||||||
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
for warning in caught_warnings:
|
|
||||||
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
|
|
||||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_make_order_enforcing():
|
def test_make_order_enforcing():
|
||||||
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
|
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
|
||||||
assert all(spec.order_enforce is True for spec in all_testing_env_specs)
|
assert all(spec.order_enforce is True for spec in all_testing_env_specs)
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
env = gym.make("CartPole-v1")
|
||||||
assert has_wrapper(env, OrderEnforcing)
|
assert has_wrapper(env, OrderEnforcing)
|
||||||
# We can assume that there all other specs will also have the order enforcing
|
# We can assume that there all other specs will also have the order enforcing
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
gym.register(
|
gym.register(
|
||||||
id="test.OrderlessArgumentEnv-v0",
|
id="test.OrderlessArgumentEnv-v0",
|
||||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
||||||
order_enforce=False,
|
order_enforce=False,
|
||||||
kwargs={"arg1": None, "arg2": None, "arg3": None},
|
kwargs={"arg1": None, "arg2": None, "arg3": None},
|
||||||
)
|
)
|
||||||
|
|
||||||
env = gym.make("test.OrderlessArgumentEnv-v0", disable_env_checker=True)
|
env = gym.make("test.OrderlessArgumentEnv-v0")
|
||||||
assert has_wrapper(env, OrderEnforcing) is False
|
assert has_wrapper(env, OrderEnforcing) is False
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
# There is no `make(..., order_enforcing=...)` so we don't test that
|
||||||
|
|
||||||
def test_make_render_mode(register_make_testing_envs):
|
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
def test_make_render_mode():
|
||||||
|
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
|
||||||
|
env = gym.make("CartPole-v1", render_mode=None)
|
||||||
assert env.render_mode is None
|
assert env.render_mode is None
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
assert "rgb_array" in env.metadata["render_modes"]
|
||||||
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
assert env.render_mode == "rgb_array"
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
assert "no-render-mode" not in env.metadata["render_modes"]
|
||||||
|
# cartpole is special that it doesn't check the render_mode passed at initialisation
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning,
|
||||||
|
match=re.escape(
|
||||||
|
"\x1b[33mWARN: The environment is being initialised with render_mode='no-render-mode' that is not in the possible render_modes (['human', 'rgb_array']).\x1b[0m"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
env = gym.make("CartPole-v1", render_mode="no-render-mode")
|
||||||
|
assert env.render_mode == "no-render-mode"
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_render_collection():
|
||||||
# Make sure that render_mode is applied correctly
|
# Make sure that render_mode is applied correctly
|
||||||
env = gym.make(
|
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
|
||||||
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
|
assert has_wrapper(env, gym.wrappers.RenderCollection)
|
||||||
)
|
|
||||||
assert env.render_mode == "rgb_array_list"
|
assert env.render_mode == "rgb_array_list"
|
||||||
|
assert env.unwrapped.render_mode == "rgb_array"
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
renders = env.render()
|
renders = env.render()
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -250,24 +264,10 @@ def test_make_render_mode(register_make_testing_envs):
|
|||||||
assert isinstance(renders[0], np.ndarray)
|
assert isinstance(renders[0], np.ndarray)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", render_mode=None, disable_env_checker=True)
|
|
||||||
assert env.render_mode is None
|
|
||||||
valid_render_modes = env.metadata["render_modes"]
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
assert len(valid_render_modes) > 0
|
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
||||||
env = gym.make(
|
|
||||||
"CartPole-v1", render_mode=valid_render_modes[0], disable_env_checker=True
|
|
||||||
)
|
|
||||||
assert env.render_mode == valid_render_modes[0]
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
for warning in caught_warnings:
|
|
||||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
|
||||||
|
|
||||||
|
def test_make_human_rendering(register_testing_envs):
|
||||||
# Make sure that native rendering is used when possible
|
# Make sure that native rendering is used when possible
|
||||||
env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True)
|
env = gym.make("CartPole-v1", render_mode="human")
|
||||||
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
|
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
|
||||||
assert env.render_mode == "human"
|
assert env.render_mode == "human"
|
||||||
env.close()
|
env.close()
|
||||||
@@ -278,10 +278,8 @@ def test_make_render_mode(register_make_testing_envs):
|
|||||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. The HumanRendering wrapper is being applied to your environment."
|
"You are trying to use 'human' rendering for an environment that doesn't natively support it. The HumanRendering wrapper is being applied to your environment."
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Make sure that `HumanRendering` is applied here
|
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
|
||||||
env = gym.make(
|
env = gym.make("test/NoHuman-v0", render_mode="human")
|
||||||
"test/NoHuman-v0", render_mode="human", disable_env_checker=True
|
|
||||||
) # This environment doesn't use native rendering
|
|
||||||
assert has_wrapper(env, HumanRendering)
|
assert has_wrapper(env, HumanRendering)
|
||||||
assert env.render_mode == "human"
|
assert env.render_mode == "human"
|
||||||
env.close()
|
env.close()
|
||||||
@@ -292,7 +290,6 @@ def test_make_render_mode(register_make_testing_envs):
|
|||||||
gym.make(
|
gym.make(
|
||||||
"test/NoHumanOldAPI-v0",
|
"test/NoHumanOldAPI-v0",
|
||||||
render_mode="rgb_array_list",
|
render_mode="rgb_array_list",
|
||||||
disable_env_checker=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure that an additional error is thrown a user tries to use the wrapper on an environment with old API
|
# Make sure that an additional error is thrown a user tries to use the wrapper on an environment with old API
|
||||||
@@ -303,9 +300,7 @@ def test_make_render_mode(register_make_testing_envs):
|
|||||||
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively."
|
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively."
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
gym.make(
|
gym.make("test/NoHumanOldAPI-v0", render_mode="human")
|
||||||
"test/NoHumanOldAPI-v0", render_mode="human", disable_env_checker=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
|
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
|
||||||
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
|
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
|
||||||
@@ -326,15 +321,20 @@ def test_make_render_mode(register_make_testing_envs):
|
|||||||
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array")
|
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array")
|
||||||
|
|
||||||
|
|
||||||
def test_make_kwargs(register_make_testing_envs):
|
def test_make_kwargs(register_testing_envs):
|
||||||
env = gym.make(
|
env = gym.make(
|
||||||
"test.ArgumentEnv-v0",
|
"test.ArgumentEnv-v0",
|
||||||
arg2="override_arg2",
|
arg2="override_arg2",
|
||||||
arg3="override_arg3",
|
arg3="override_arg3",
|
||||||
disable_env_checker=True,
|
|
||||||
)
|
)
|
||||||
assert env.spec is not None
|
assert env.spec is not None
|
||||||
assert env.spec.id == "test.ArgumentEnv-v0"
|
assert env.spec.id == "test.ArgumentEnv-v0"
|
||||||
|
assert env.spec.kwargs == {
|
||||||
|
"arg1": "arg1",
|
||||||
|
"arg2": "override_arg2",
|
||||||
|
"arg3": "override_arg3",
|
||||||
|
}
|
||||||
|
|
||||||
assert isinstance(env.unwrapped, ArgumentEnv)
|
assert isinstance(env.unwrapped, ArgumentEnv)
|
||||||
assert env.arg1 == "arg1"
|
assert env.arg1 == "arg1"
|
||||||
assert env.arg2 == "override_arg2"
|
assert env.arg2 == "override_arg2"
|
||||||
@@ -342,11 +342,16 @@ def test_make_kwargs(register_make_testing_envs):
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def test_import_module_during_make(register_make_testing_envs):
|
def test_import_module_during_make():
|
||||||
# Test custom environment which is registered at make
|
# Test custom environment which is registered at make
|
||||||
|
assert "RegisterDuringMake-v0" not in gym.registry
|
||||||
env = gym.make(
|
env = gym.make(
|
||||||
"tests.envs.utils:RegisterDuringMakeEnv-v0",
|
"tests.envs.registration.utils_unregistered_env:RegisterDuringMake-v0"
|
||||||
disable_env_checker=True,
|
|
||||||
)
|
)
|
||||||
|
assert "RegisterDuringMake-v0" in gym.registry
|
||||||
|
from tests.envs.registration.utils_unregistered_env import RegisterDuringMakeEnv
|
||||||
|
|
||||||
assert isinstance(env.unwrapped, RegisterDuringMakeEnv)
|
assert isinstance(env.unwrapped, RegisterDuringMakeEnv)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
del gym.registry["RegisterDuringMake-v0"]
|
112
tests/envs/registration/test_pprint_registry.py
Normal file
112
tests/envs/registration/test_pprint_registry.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
|
# To ignore the trailing whitespaces, will need flake to ignore this file.
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
EXAMPLE_ENTRY_POINT = "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pprint_default_registry():
|
||||||
|
out = gym.pprint_registry(disable_print=True)
|
||||||
|
assert isinstance(out, str) and len(out) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_pprint_example_registry():
|
||||||
|
"""Testing a registry different from default."""
|
||||||
|
example_registry: dict[str, EnvSpec] = {
|
||||||
|
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
|
||||||
|
}
|
||||||
|
|
||||||
|
out = gym.pprint_registry(example_registry, disable_print=True)
|
||||||
|
correct_out = """===== classic_control =====
|
||||||
|
CartPole-v0 CartPole-v1 CartPole-v2
|
||||||
|
CartPole-v3"""
|
||||||
|
assert out == correct_out
|
||||||
|
|
||||||
|
|
||||||
|
def test_pprint_namespace():
|
||||||
|
example_registry: dict[str, EnvSpec] = {
|
||||||
|
"CartPole-v0": EnvSpec(
|
||||||
|
"CartPole-v0", "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
|
||||||
|
),
|
||||||
|
"CartPole-v1": EnvSpec(
|
||||||
|
"CartPole-v1", "gymnasium.envs.classic_control:CartPoleEnv"
|
||||||
|
),
|
||||||
|
"CartPole-v2": EnvSpec("CartPole-v2", "gymnasium.cartpole:CartPoleEnv"),
|
||||||
|
"CartPole-v3": EnvSpec("CartPole-v3", lambda: CartPoleEnv()),
|
||||||
|
"ExampleNamespace/CartPole-v2": EnvSpec(
|
||||||
|
"ExampleNamespace/CartPole-v2", "gymnasium.envs.classic_control:CartPoleEnv"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
out = gym.pprint_registry(example_registry, disable_print=True)
|
||||||
|
correct_out = """===== classic_control =====
|
||||||
|
CartPole-v0 CartPole-v1
|
||||||
|
===== cartpole =====
|
||||||
|
CartPole-v2
|
||||||
|
===== None =====
|
||||||
|
CartPole-v3
|
||||||
|
===== ExampleNamespace =====
|
||||||
|
ExampleNamespace/CartPole-v2"""
|
||||||
|
assert out == correct_out
|
||||||
|
|
||||||
|
|
||||||
|
def test_pprint_n_columns():
|
||||||
|
example_registry = {
|
||||||
|
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
|
||||||
|
}
|
||||||
|
|
||||||
|
out = gym.pprint_registry(example_registry, num_cols=2, disable_print=True)
|
||||||
|
correct_out = """===== classic_control =====
|
||||||
|
CartPole-v0 CartPole-v1
|
||||||
|
CartPole-v2 CartPole-v3"""
|
||||||
|
assert out == correct_out
|
||||||
|
|
||||||
|
out = gym.pprint_registry(example_registry, num_cols=5, disable_print=True)
|
||||||
|
correct_out = """===== classic_control =====
|
||||||
|
CartPole-v0 CartPole-v1 CartPole-v2 CartPole-v3"""
|
||||||
|
assert out == correct_out
|
||||||
|
|
||||||
|
|
||||||
|
def test_pprint_exclude_namespace():
|
||||||
|
example_registry: dict[str, EnvSpec] = {
|
||||||
|
"Test/CartPole-v0": EnvSpec("Test/CartPole-v0", EXAMPLE_ENTRY_POINT),
|
||||||
|
"Test/CartPole-v1": EnvSpec("Test/CartPole-v1", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
|
||||||
|
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
|
||||||
|
}
|
||||||
|
|
||||||
|
out = gym.pprint_registry(
|
||||||
|
example_registry, exclude_namespaces=["Test"], disable_print=True
|
||||||
|
)
|
||||||
|
correct_out = """===== classic_control =====
|
||||||
|
CartPole-v2 CartPole-v3"""
|
||||||
|
assert out == correct_out
|
||||||
|
|
||||||
|
out = gym.pprint_registry(
|
||||||
|
example_registry, exclude_namespaces=["classic_control"], disable_print=True
|
||||||
|
)
|
||||||
|
correct_out = """===== Test =====
|
||||||
|
Test/CartPole-v0 Test/CartPole-v1"""
|
||||||
|
assert out == correct_out
|
||||||
|
|
||||||
|
example_registry["Example/CartPole-v4"] = EnvSpec(
|
||||||
|
"Example/CartPole-v4", EXAMPLE_ENTRY_POINT
|
||||||
|
)
|
||||||
|
out = gym.pprint_registry(
|
||||||
|
example_registry, exclude_namespaces=["Test", "Example"], disable_print=True
|
||||||
|
)
|
||||||
|
correct_out = """===== classic_control =====
|
||||||
|
CartPole-v2 CartPole-v3"""
|
||||||
|
assert out == correct_out
|
@@ -18,7 +18,7 @@ def register_registration_testing_envs():
|
|||||||
env_id = f"{namespace}/{versioned_name}-v{version}"
|
env_id = f"{namespace}/{versioned_name}-v{version}"
|
||||||
gym.register(
|
gym.register(
|
||||||
id=env_id,
|
id=env_id,
|
||||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
||||||
kwargs={
|
kwargs={
|
||||||
"arg1": "arg1",
|
"arg1": "arg1",
|
||||||
"arg2": "arg2",
|
"arg2": "arg2",
|
||||||
@@ -111,7 +111,7 @@ def test_env_suggestions(
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
|
gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
|
||||||
):
|
):
|
||||||
gym.make(env_id_input, disable_env_checker=True)
|
gym.make(env_id_input)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -136,13 +136,13 @@ def test_env_version_suggestions(
|
|||||||
gym.error.DeprecatedEnv,
|
gym.error.DeprecatedEnv,
|
||||||
match="It provides the default version", # env name,
|
match="It provides the default version", # env name,
|
||||||
):
|
):
|
||||||
gym.make(env_id_input, disable_env_checker=True)
|
gym.make(env_id_input)
|
||||||
else:
|
else:
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
gym.error.UnregisteredEnv,
|
gym.error.UnregisteredEnv,
|
||||||
match=f"It provides versioned environments: \\[ {suggested_versions} \\]",
|
match=f"It provides versioned environments: \\[ {suggested_versions} \\]",
|
||||||
):
|
):
|
||||||
gym.make(env_id_input, disable_env_checker=True)
|
gym.make(env_id_input)
|
||||||
|
|
||||||
|
|
||||||
def test_register_versioned_unversioned():
|
def test_register_versioned_unversioned():
|
||||||
@@ -185,9 +185,7 @@ def test_make_latest_versioned_env(register_registration_testing_envs):
|
|||||||
"Using the latest versioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv-v5` instead of the unversioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv`."
|
"Using the latest versioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv-v5` instead of the unversioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv`."
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
env = gym.make(
|
env = gym.make("MyAwesomeNamespace/MyAwesomeVersionedEnv")
|
||||||
"MyAwesomeNamespace/MyAwesomeVersionedEnv", disable_env_checker=True
|
|
||||||
)
|
|
||||||
assert env.spec is not None
|
assert env.spec is not None
|
||||||
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"
|
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"
|
||||||
|
|
95
tests/envs/registration/test_spec.py
Normal file
95
tests/envs/registration/test_spec.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""Tests that `gym.spec` works as expected."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec():
|
||||||
|
spec = gym.spec("CartPole-v1")
|
||||||
|
assert spec.id == "CartPole-v1"
|
||||||
|
assert spec is gym.envs.registry["CartPole-v1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_missing_lookup():
|
||||||
|
gym.register(id="TestEnv-v0", entry_point="no-entry-point")
|
||||||
|
gym.register(id="TestEnv-v15", entry_point="no-entry-point")
|
||||||
|
gym.register(id="TestEnv-v9", entry_point="no-entry-point")
|
||||||
|
gym.register(id="OtherEnv-v100", entry_point="no-entry-point")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.DeprecatedEnv,
|
||||||
|
match=re.escape(
|
||||||
|
"Environment version v1 for `TestEnv` is deprecated. Please use `TestEnv-v15` instead."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.spec("TestEnv-v1")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.UnregisteredEnv,
|
||||||
|
match=re.escape(
|
||||||
|
"Environment version `v1000` for environment `TestEnv` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.spec("TestEnv-v1000")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.UnregisteredEnv,
|
||||||
|
match=re.escape("Environment `UnknownEnv` doesn't exist."),
|
||||||
|
):
|
||||||
|
gym.spec("UnknownEnv-v1")
|
||||||
|
|
||||||
|
del gym.registry["TestEnv-v0"]
|
||||||
|
del gym.registry["TestEnv-v15"]
|
||||||
|
del gym.registry["TestEnv-v9"]
|
||||||
|
del gym.registry["OtherEnv-v100"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_malformed_lookup():
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.Error,
|
||||||
|
match=re.escape(
|
||||||
|
"Malformed environment ID: “Breakout-v0”. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.spec("“Breakout-v0”")
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_versioned_lookups():
|
||||||
|
gym.register("test/TestEnv-v5", "no-entry-point")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.VersionNotFound,
|
||||||
|
match=re.escape(
|
||||||
|
"Environment version `v9` for environment `test/TestEnv` doesn't exist. It provides versioned environments: [ `v5` ]."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.spec("test/TestEnv-v9")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.DeprecatedEnv,
|
||||||
|
match=re.escape(
|
||||||
|
"Environment version v4 for `test/TestEnv` is deprecated. Please use `test/TestEnv-v5` instead."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.spec("test/TestEnv-v4")
|
||||||
|
|
||||||
|
assert gym.spec("test/TestEnv-v5") is not None
|
||||||
|
del gym.registry["test/TestEnv-v5"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_default_lookups():
|
||||||
|
gym.register("test/TestEnv", "no-entry-point")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.DeprecatedEnv,
|
||||||
|
match=re.escape(
|
||||||
|
"Environment version `v0` for environment `test/TestEnv` doesn't exist. It provides the default version `test/TestEnv`."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.spec("test/TestEnv-v0")
|
||||||
|
|
||||||
|
assert gym.spec("test/TestEnv") is not None
|
||||||
|
del gym.registry["test/TestEnv"]
|
@@ -1,14 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class RegisterDuringMakeEnv(gym.Env):
|
|
||||||
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.action_space = gym.spaces.Discrete(1)
|
|
||||||
self.observation_space = gym.spaces.Discrete(1)
|
|
||||||
|
|
||||||
|
|
||||||
class ArgumentEnv(gym.Env):
|
class ArgumentEnv(gym.Env):
|
||||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
@@ -23,9 +17,12 @@ class ArgumentEnv(gym.Env):
|
|||||||
class NoHuman(gym.Env):
|
class NoHuman(gym.Env):
|
||||||
"""Environment that does not have human-rendering."""
|
"""Environment that does not have human-rendering."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
|
||||||
def __init__(self, render_mode=None):
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 4}
|
||||||
|
|
||||||
|
def __init__(self, render_mode: list[str] = None):
|
||||||
assert render_mode in self.metadata["render_modes"]
|
assert render_mode in self.metadata["render_modes"]
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
|
|
||||||
@@ -33,6 +30,9 @@ class NoHuman(gym.Env):
|
|||||||
class NoHumanOldAPI(gym.Env):
|
class NoHumanOldAPI(gym.Env):
|
||||||
"""Environment that does not have human-rendering."""
|
"""Environment that does not have human-rendering."""
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
|
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -42,6 +42,9 @@ class NoHumanOldAPI(gym.Env):
|
|||||||
class NoHumanNoRGB(gym.Env):
|
class NoHumanNoRGB(gym.Env):
|
||||||
"""Environment that has neither human- nor rgb-rendering"""
|
"""Environment that has neither human- nor rgb-rendering"""
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
|
||||||
metadata = {"render_modes": ["ascii"], "render_fps": 4}
|
metadata = {"render_modes": ["ascii"], "render_fps": 4}
|
||||||
|
|
||||||
def __init__(self, render_mode=None):
|
def __init__(self, render_mode=None):
|
||||||
@@ -52,6 +55,9 @@ class NoHumanNoRGB(gym.Env):
|
|||||||
class NoRenderModesMetadata(gym.Env):
|
class NoRenderModesMetadata(gym.Env):
|
||||||
"""An environment that has rendering but has not updated the metadata."""
|
"""An environment that has rendering but has not updated the metadata."""
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||||
|
|
||||||
# metadata: dict[str, Any] = {"render_modes": []}
|
# metadata: dict[str, Any] = {"render_modes": []}
|
||||||
|
|
||||||
def __init__(self, render_mode):
|
def __init__(self, render_mode):
|
13
tests/envs/registration/utils_unregistered_env.py
Normal file
13
tests/envs/registration/utils_unregistered_env.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""This utility file contains an environment that is registered upon loading the file."""
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterDuringMakeEnv(gym.Env):
|
||||||
|
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.action_space = gym.spaces.Discrete(1)
|
||||||
|
self.observation_space = gym.spaces.Discrete(1)
|
||||||
|
|
||||||
|
|
||||||
|
gym.register(id="RegisterDuringMake-v0", entry_point=RegisterDuringMakeEnv)
|
@@ -39,7 +39,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
|
|||||||
all_testing_env_specs,
|
all_testing_env_specs,
|
||||||
ids=[spec.id for spec in all_testing_env_specs],
|
ids=[spec.id for spec in all_testing_env_specs],
|
||||||
)
|
)
|
||||||
def test_envs_pass_env_checker(spec):
|
def test_all_env_api(spec):
|
||||||
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
env = spec.make(disable_env_checker=True).unwrapped
|
env = spec.make(disable_env_checker=True).unwrapped
|
||||||
@@ -52,6 +52,22 @@ def test_envs_pass_env_checker(spec):
|
|||||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
||||||
|
)
|
||||||
|
def test_all_env_passive_env_checker(spec):
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
env = gym.make(spec.id)
|
||||||
|
env.reset()
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
for warning in caught_warnings:
|
||||||
|
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
|
||||||
|
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||||
|
|
||||||
|
|
||||||
# Note that this precludes running this test in multiple threads.
|
# Note that this precludes running this test in multiple threads.
|
||||||
# However, we probably already can't do multithreading due to some environments.
|
# However, we probably already can't do multithreading due to some environments.
|
||||||
SEED = 0
|
SEED = 0
|
||||||
|
@@ -1,27 +0,0 @@
|
|||||||
import gymnasium as gym
|
|
||||||
from gymnasium.envs.registration import EnvSpec
|
|
||||||
|
|
||||||
|
|
||||||
# To ignore the trailing whitespaces, will need flake to ignore this file.
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
reduced_registry = {
|
|
||||||
env_id: env_spec
|
|
||||||
for env_id, env_spec in gym.registry.items()
|
|
||||||
if env_spec.entry_point != "shimmy.atari_env:AtariEnv"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_pprint_custom_registry():
|
|
||||||
"""Testing a registry different from default."""
|
|
||||||
a = {
|
|
||||||
"CartPole-v0": gym.envs.registry["CartPole-v0"],
|
|
||||||
"CartPole-v1": gym.envs.registry["CartPole-v1"],
|
|
||||||
}
|
|
||||||
out = gym.pprint_registry(a, disable_print=True)
|
|
||||||
|
|
||||||
correct_out = """===== classic_control =====
|
|
||||||
CartPole-v0 CartPole-v1
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert out == correct_out
|
|
@@ -1,93 +0,0 @@
|
|||||||
"""Tests that gym.spec works as expected."""
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec():
|
|
||||||
spec = gym.spec("CartPole-v1")
|
|
||||||
assert spec.id == "CartPole-v1"
|
|
||||||
assert spec is gym.envs.registry["CartPole-v1"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_kwargs():
|
|
||||||
map_name_value = "8x8"
|
|
||||||
env = gym.make("FrozenLake-v1", map_name=map_name_value)
|
|
||||||
assert env.spec is not None
|
|
||||||
assert env.spec.kwargs["map_name"] == map_name_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_missing_lookup():
|
|
||||||
gym.register(id="Test1-v0", entry_point="no-entry-point")
|
|
||||||
gym.register(id="Test1-v15", entry_point="no-entry-point")
|
|
||||||
gym.register(id="Test1-v9", entry_point="no-entry-point")
|
|
||||||
gym.register(id="Other1-v100", entry_point="no-entry-point")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
gym.error.DeprecatedEnv,
|
|
||||||
match=re.escape(
|
|
||||||
"Environment version v1 for `Test1` is deprecated. Please use `Test1-v15` instead."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
gym.spec("Test1-v1")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
gym.error.UnregisteredEnv,
|
|
||||||
match=re.escape(
|
|
||||||
"Environment version `v1000` for environment `Test1` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
gym.spec("Test1-v1000")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
gym.error.UnregisteredEnv,
|
|
||||||
match=re.escape("Environment Unknown1 doesn't exist. "),
|
|
||||||
):
|
|
||||||
gym.spec("Unknown1-v1")
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_malformed_lookup():
|
|
||||||
with pytest.raises(
|
|
||||||
gym.error.Error,
|
|
||||||
match=f'^{re.escape("Malformed environment ID: “Breakout-v0”.(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))")}$',
|
|
||||||
):
|
|
||||||
gym.spec("“Breakout-v0”")
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_versioned_lookups():
|
|
||||||
gym.register("test/Test2-v5", "no-entry-point")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
gym.error.VersionNotFound,
|
|
||||||
match=re.escape(
|
|
||||||
"Environment version `v9` for environment `test/Test2` doesn't exist. It provides versioned environments: [ `v5` ]."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
gym.spec("test/Test2-v9")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
gym.error.DeprecatedEnv,
|
|
||||||
match=re.escape(
|
|
||||||
"Environment version v4 for `test/Test2` is deprecated. Please use `test/Test2-v5` instead."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
gym.spec("test/Test2-v4")
|
|
||||||
|
|
||||||
assert gym.spec("test/Test2-v5") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_default_lookups():
|
|
||||||
gym.register("test/Test3", "no-entry-point")
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
gym.error.DeprecatedEnv,
|
|
||||||
match=re.escape(
|
|
||||||
"Environment version `v0` for environment `test/Test3` doesn't exist. It provides the default version test/Test3`."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
gym.spec("test/Test3-v0")
|
|
||||||
|
|
||||||
assert gym.spec("test/Test3") is not None
|
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user