mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Reorder functions and refactor registration.py
(#289)
Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
@@ -2,36 +2,39 @@
|
||||
title: Registry
|
||||
---
|
||||
|
||||
# Registry
|
||||
# Register and Make
|
||||
|
||||
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers.
|
||||
Environments can also be created through python imports.
|
||||
|
||||
## Make
|
||||
```{eval-rst}
|
||||
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`.
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.make
|
||||
```
|
||||
|
||||
## Register
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.register
|
||||
```
|
||||
|
||||
## All registered environments
|
||||
|
||||
To find all the registered Gymnasium environments, use the `gymnasium.pprint_registry()`.
|
||||
This will not include environments registered only in OpenAI Gym however can be loaded by `gymnasium.make`.
|
||||
|
||||
## Spec
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.spec
|
||||
```
|
||||
|
||||
## Pretty print registry
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.pprint_registry
|
||||
```
|
||||
|
||||
## Core variables
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.envs.registration.EnvSpec
|
||||
.. attribute:: gymnasium.envs.registration.registry
|
||||
|
||||
The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments.
|
||||
|
||||
.. attribute:: gymnasium.envs.registration.current_namespace
|
||||
|
||||
The current namespace when creating or registering environments. This is by default ``None`` by with :meth:`namespace` this can be modified to automatically set the environment id namespace.
|
||||
```
|
||||
|
||||
## Additional functions
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.envs.registration.get_env_id
|
||||
.. autofunction:: gymnasium.envs.registration.parse_env_id
|
||||
.. autofunction:: gymnasium.envs.registration.find_highest_version
|
||||
.. autofunction:: gymnasium.envs.registration.namespace
|
||||
.. autofunction:: gymnasium.envs.registration.load_env
|
||||
.. autofunction:: gymnasium.envs.registration.load_plugin_envs
|
||||
```
|
||||
|
@@ -2,7 +2,7 @@
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.envs.registration import (
|
||||
load_env_plugins,
|
||||
load_plugin_envs,
|
||||
make,
|
||||
pprint_registry,
|
||||
register,
|
||||
@@ -363,4 +363,4 @@ register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error)
|
||||
|
||||
|
||||
# Hook to load plugins from entry points
|
||||
load_env_plugins()
|
||||
load_plugin_envs()
|
||||
|
@@ -9,13 +9,11 @@ import importlib.util
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Iterable, Sequence, SupportsFloat, overload
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, Iterable
|
||||
|
||||
from gymnasium import Env, error, logger
|
||||
from gymnasium.wrappers import (
|
||||
AutoResetWrapper,
|
||||
HumanRendering,
|
||||
@@ -32,12 +30,10 @@ if sys.version_info < (3, 10):
|
||||
else:
|
||||
import importlib.metadata as metadata
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Protocol
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from gymnasium import Env, error, logger
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
ENV_ID_RE = re.compile(
|
||||
@@ -45,50 +41,99 @@ ENV_ID_RE = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def load(name: str) -> Callable:
|
||||
"""Loads an environment with name and returns an environment creation function.
|
||||
__all__ = [
|
||||
"EnvSpec",
|
||||
"registry",
|
||||
"current_namespace",
|
||||
"register",
|
||||
"make",
|
||||
"spec",
|
||||
"pprint_registry",
|
||||
]
|
||||
|
||||
Args:
|
||||
name: The environment name
|
||||
|
||||
Returns:
|
||||
Calls the environment constructor
|
||||
class EnvCreator(Protocol):
|
||||
"""Function type expected for an environment."""
|
||||
|
||||
def __call__(self, **kwargs: Any) -> Env:
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvSpec:
|
||||
"""A specification for creating environments with :meth:`gymnasium.make`.
|
||||
|
||||
* **id**: The string used to create the environment with :meth:`gymnasium.make`
|
||||
* **entry_point**: A string for the environment location, ``(import path):(environment name)`` or a function that creates the environment.
|
||||
* **reward_threshold**: The reward threshold for completing the environment.
|
||||
* **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
|
||||
* **max_episode_steps**: The max number of steps that the environment can take before truncation
|
||||
* **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions
|
||||
* **autoreset**: If to automatically reset the environment on episode end
|
||||
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
||||
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
||||
"""
|
||||
mod_name, attr_name = name.split(":")
|
||||
mod = importlib.import_module(mod_name)
|
||||
fn = getattr(mod, attr_name)
|
||||
return fn
|
||||
|
||||
id: str
|
||||
entry_point: EnvCreator | str
|
||||
|
||||
# Environment attributes
|
||||
reward_threshold: float | None = field(default=None)
|
||||
nondeterministic: bool = field(default=False)
|
||||
|
||||
# Wrappers
|
||||
max_episode_steps: int | None = field(default=None)
|
||||
order_enforce: bool = field(default=True)
|
||||
autoreset: bool = field(default=False)
|
||||
disable_env_checker: bool = field(default=False)
|
||||
apply_api_compatibility: bool = field(default=False)
|
||||
|
||||
# post-init attributes
|
||||
namespace: str | None = field(init=False)
|
||||
name: str = field(init=False)
|
||||
version: int | None = field(init=False)
|
||||
|
||||
# Environment arguments
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calls after the spec is created to extract the namespace, name and version from the id."""
|
||||
# Initialize namespace, name, version
|
||||
self.namespace, self.name, self.version = parse_env_id(self.id)
|
||||
|
||||
def make(self, **kwargs: Any) -> Env:
|
||||
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
||||
# For compatibility purposes
|
||||
return make(self, **kwargs)
|
||||
|
||||
|
||||
def parse_env_id(id: str) -> tuple[str | None, str, int | None]:
|
||||
"""Parse environment ID string format.
|
||||
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||
registry: dict[str, EnvSpec] = {}
|
||||
current_namespace: str | None = None
|
||||
|
||||
This format is true today, but it's *not* an official spec.
|
||||
[namespace/](env-name)-v(version) env-name is group 1, version is group 2
|
||||
|
||||
2016-10-31: We're experimentally expanding the environment ID format
|
||||
to include an optional namespace.
|
||||
def parse_env_id(env_id: str) -> tuple[str | None, str, int | None]:
|
||||
"""Parse environment ID string format - ``[namespace/](env-name)[-v(version)]`` where the namespace and version are optional.
|
||||
|
||||
Args:
|
||||
id: The environment id to parse
|
||||
env_id: The environment id to parse
|
||||
|
||||
Returns:
|
||||
A tuple of environment namespace, environment name and version number
|
||||
|
||||
Raises:
|
||||
Error: If the environment id does not a valid environment regex
|
||||
Error: If the environment id is not valid environment regex
|
||||
"""
|
||||
match = ENV_ID_RE.fullmatch(id)
|
||||
match = ENV_ID_RE.fullmatch(env_id)
|
||||
if not match:
|
||||
raise error.Error(
|
||||
f"Malformed environment ID: {id}."
|
||||
f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
||||
f"Malformed environment ID: {env_id}. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
||||
)
|
||||
namespace, name, version = match.group("namespace", "name", "version")
|
||||
ns, name, version = match.group("namespace", "name", "version")
|
||||
if version is not None:
|
||||
version = int(version)
|
||||
|
||||
return namespace, name, version
|
||||
return ns, name, version
|
||||
|
||||
|
||||
def get_env_id(ns: str | None, name: str, version: int | None) -> str:
|
||||
@@ -103,97 +148,80 @@ def get_env_id(ns: str | None, name: str, version: int | None) -> str:
|
||||
The environment id
|
||||
"""
|
||||
full_name = name
|
||||
if version is not None:
|
||||
full_name += f"-v{version}"
|
||||
if ns is not None:
|
||||
full_name = ns + "/" + full_name
|
||||
full_name = f"{ns}/{name}"
|
||||
if version is not None:
|
||||
full_name = f"{full_name}-v{version}"
|
||||
|
||||
return full_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvSpec:
|
||||
"""A specification for creating environments with `gym.make`.
|
||||
def find_highest_version(ns: str | None, name: str) -> int | None:
|
||||
"""Finds the highest registered version of the environment given the namespace and name in the registry.
|
||||
|
||||
* id: The string used to create the environment with `gym.make`
|
||||
* entry_point: The location of the environment to create from
|
||||
* reward_threshold: The reward threshold for completing the environment.
|
||||
* nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
|
||||
* max_episode_steps: The max number of steps that the environment can take before truncation
|
||||
* order_enforce: If to enforce the order of `reset` before `step` and `render` functions
|
||||
* autoreset: If to automatically reset the environment on episode end
|
||||
* disable_env_checker: If to disable the environment checker wrapper in `gym.make`, by default False (runs the environment checker)
|
||||
* kwargs: Additional keyword arguments passed to the environments through `gym.make`
|
||||
Args:
|
||||
ns: The environment namespace
|
||||
name: The environment name (id)
|
||||
|
||||
Returns:
|
||||
The highest version of an environment with matching namespace and name, otherwise ``None`` is returned.
|
||||
"""
|
||||
|
||||
id: str
|
||||
entry_point: Callable | str
|
||||
|
||||
# Environment attributes
|
||||
reward_threshold: float | None = field(default=None)
|
||||
nondeterministic: bool = field(default=False)
|
||||
|
||||
# Wrappers
|
||||
max_episode_steps: int | None = field(default=None)
|
||||
order_enforce: bool = field(default=True)
|
||||
autoreset: bool = field(default=False)
|
||||
disable_env_checker: bool = field(default=False)
|
||||
apply_api_compatibility: bool = field(default=False)
|
||||
|
||||
# Environment arguments
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# post-init attributes
|
||||
namespace: str | None = field(init=False)
|
||||
name: str = field(init=False)
|
||||
version: int | None = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calls after the spec is created to extract the namespace, name and version from the id."""
|
||||
# Initialize namespace, name, version
|
||||
self.namespace, self.name, self.version = parse_env_id(self.id)
|
||||
|
||||
def make(self, **kwargs: Any) -> Env:
|
||||
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
||||
# For compatibility purposes
|
||||
return make(self, **kwargs)
|
||||
version: list[int] = [
|
||||
env_spec.version
|
||||
for env_spec in registry.values()
|
||||
if env_spec.namespace == ns
|
||||
and env_spec.name == name
|
||||
and env_spec.version is not None
|
||||
]
|
||||
return max(version, default=None)
|
||||
|
||||
|
||||
def _check_namespace_exists(ns: str | None):
|
||||
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
|
||||
# If the namespace is none, then the namespace does exist
|
||||
if ns is None:
|
||||
return
|
||||
namespaces = {
|
||||
spec_.namespace for spec_ in registry.values() if spec_.namespace is not None
|
||||
|
||||
# Check if the namespace exists in one of the registry's specs
|
||||
namespaces: set[str] = {
|
||||
env_spec.namespace
|
||||
for env_spec in registry.values()
|
||||
if env_spec.namespace is not None
|
||||
}
|
||||
if ns in namespaces:
|
||||
return
|
||||
|
||||
# Otherwise, the namespace doesn't exist and raise a helpful message
|
||||
suggestion = (
|
||||
difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
|
||||
)
|
||||
suggestion_msg = (
|
||||
f"Did you mean: `{suggestion[0]}`?"
|
||||
if suggestion
|
||||
else f"Have you installed the proper package for {ns}?"
|
||||
)
|
||||
if suggestion:
|
||||
suggestion_msg = f"Did you mean: `{suggestion[0]}`?"
|
||||
else:
|
||||
suggestion_msg = f"Have you installed the proper package for {ns}?"
|
||||
|
||||
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
|
||||
|
||||
|
||||
def _check_name_exists(ns: str | None, name: str):
|
||||
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
|
||||
# First check if the namespace exists
|
||||
_check_namespace_exists(ns)
|
||||
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
|
||||
|
||||
# Then check if the name exists
|
||||
names: set[str] = {
|
||||
env_spec.name for env_spec in registry.values() if env_spec.namespace == ns
|
||||
}
|
||||
if name in names:
|
||||
return
|
||||
|
||||
# Otherwise, raise a helpful error to the user
|
||||
suggestion = difflib.get_close_matches(name, names, n=1)
|
||||
namespace_msg = f" in namespace {ns}" if ns else ""
|
||||
suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else ""
|
||||
suggestion_msg = f" Did you mean: `{suggestion[0]}`?" if suggestion else ""
|
||||
|
||||
raise error.NameNotFound(
|
||||
f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}"
|
||||
f"Environment `{name}` doesn't exist{namespace_msg}.{suggestion_msg}"
|
||||
)
|
||||
|
||||
|
||||
@@ -222,26 +250,28 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
|
||||
message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
|
||||
|
||||
env_specs = [
|
||||
spec_
|
||||
for spec_ in registry.values()
|
||||
if spec_.namespace == ns and spec_.name == name
|
||||
env_spec
|
||||
for env_spec in registry.values()
|
||||
if env_spec.namespace == ns and env_spec.name == name
|
||||
]
|
||||
env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1))
|
||||
env_specs = sorted(env_specs, key=lambda env_spec: int(env_spec.version or -1))
|
||||
|
||||
default_spec = [spec_ for spec_ in env_specs if spec_.version is None]
|
||||
default_spec = [env_spec for env_spec in env_specs if env_spec.version is None]
|
||||
|
||||
if default_spec:
|
||||
message += f" It provides the default version {default_spec[0].id}`."
|
||||
message += f" It provides the default version `{default_spec[0].id}`."
|
||||
if len(env_specs) == 1:
|
||||
raise error.DeprecatedEnv(message)
|
||||
|
||||
# Process possible versioned environments
|
||||
|
||||
versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None]
|
||||
versioned_specs = [
|
||||
env_spec for env_spec in env_specs if env_spec.version is not None
|
||||
]
|
||||
|
||||
latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore
|
||||
latest_spec = max(versioned_specs, key=lambda env_spec: env_spec.version, default=None) # type: ignore
|
||||
if latest_spec is not None and version > latest_spec.version:
|
||||
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs)
|
||||
version_list_msg = ", ".join(f"`v{env_spec.version}`" for env_spec in env_specs)
|
||||
message += f" It provides versioned environments: [ {version_list_msg} ]."
|
||||
|
||||
raise error.VersionNotFound(message)
|
||||
@@ -253,18 +283,80 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
|
||||
)
|
||||
|
||||
|
||||
def find_highest_version(ns: str | None, name: str) -> int | None:
|
||||
"""Finds the highest registered version of the environment in the registry."""
|
||||
version: list[int] = [
|
||||
spec_.version
|
||||
for spec_ in registry.values()
|
||||
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
||||
]
|
||||
return max(version, default=None)
|
||||
def _check_spec_register(testing_spec: EnvSpec):
|
||||
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
|
||||
latest_versioned_spec = max(
|
||||
(
|
||||
env_spec
|
||||
for env_spec in registry.values()
|
||||
if env_spec.namespace == testing_spec.namespace
|
||||
and env_spec.name == testing_spec.name
|
||||
and env_spec.version is not None
|
||||
),
|
||||
key=lambda spec_: int(spec_.version), # type: ignore
|
||||
default=None,
|
||||
)
|
||||
|
||||
unversioned_spec = next(
|
||||
(
|
||||
env_spec
|
||||
for env_spec in registry.values()
|
||||
if env_spec.namespace == testing_spec.namespace
|
||||
and env_spec.name == testing_spec.name
|
||||
and env_spec.version is None
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if unversioned_spec is not None and testing_spec.version is not None:
|
||||
raise error.RegistrationError(
|
||||
"Can't register the versioned environment "
|
||||
f"`{testing_spec.id}` when the unversioned environment "
|
||||
f"`{unversioned_spec.id}` of the same name already exists."
|
||||
)
|
||||
elif latest_versioned_spec is not None and testing_spec.version is None:
|
||||
raise error.RegistrationError(
|
||||
f"Can't register the unversioned environment `{testing_spec.id}` when the versioned environment "
|
||||
f"`{latest_versioned_spec.id}` of the same name already exists. Note: the default behavior is "
|
||||
"that `gym.make` with the unversioned environment will return the latest versioned environment"
|
||||
)
|
||||
|
||||
|
||||
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
||||
"""Load modules (plugins) using the gymnasium entry points == to `entry_points`.
|
||||
def _check_metadata(testing_metadata: dict[str, Any]):
|
||||
"""Check the metadata of an environment."""
|
||||
if not isinstance(testing_metadata, dict):
|
||||
raise error.InvalidMetadata(
|
||||
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
|
||||
)
|
||||
|
||||
render_modes = testing_metadata.get("render_modes")
|
||||
if render_modes is None:
|
||||
logger.warn(
|
||||
f"The environment creator metadata doesn't include `render_modes`, contains: {list(testing_metadata.keys())}"
|
||||
)
|
||||
elif not isinstance(render_modes, Iterable):
|
||||
logger.warn(
|
||||
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
|
||||
)
|
||||
|
||||
|
||||
def load_env(name: str) -> EnvCreator:
|
||||
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
|
||||
|
||||
Args:
|
||||
name: The environment name
|
||||
|
||||
Returns:
|
||||
The environment constructor for the given environment name.
|
||||
"""
|
||||
mod_name, attr_name = name.split(":")
|
||||
mod = importlib.import_module(mod_name)
|
||||
fn = getattr(mod, attr_name)
|
||||
return fn
|
||||
|
||||
|
||||
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
||||
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
||||
|
||||
Args:
|
||||
entry_point: The string for the entry point.
|
||||
@@ -282,7 +374,7 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
||||
else:
|
||||
module, attr = plugin.value, None
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
logger.warn(
|
||||
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
|
||||
)
|
||||
module, attr = None, None
|
||||
@@ -314,136 +406,6 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
|
||||
logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")
|
||||
|
||||
|
||||
# fmt: off
|
||||
@overload
|
||||
def make(id: str, **kwargs) -> Env: ...
|
||||
@overload
|
||||
def make(id: EnvSpec, **kwargs) -> Env: ...
|
||||
|
||||
|
||||
# Classic control
|
||||
# ----------------------------------------
|
||||
@overload
|
||||
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
@overload
|
||||
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
@overload
|
||||
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
|
||||
|
||||
# Box2d
|
||||
# ----------------------------------------
|
||||
@overload
|
||||
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
@overload
|
||||
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
|
||||
|
||||
# Toy Text
|
||||
# ----------------------------------------
|
||||
@overload
|
||||
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
|
||||
|
||||
# Mujoco
|
||||
# ----------------------------------------
|
||||
@overload
|
||||
def make(id: Literal[
|
||||
"Reacher-v2", "Reacher-v4",
|
||||
"Pusher-v2", "Pusher-v4",
|
||||
"InvertedPendulum-v2", "InvertedPendulum-v4",
|
||||
"InvertedDoublePendulum-v2", "InvertedDoublePendulum-v4",
|
||||
"HalfCheetah-v2", "HalfCheetah-v3", "HalfCheetah-v4",
|
||||
"Hopper-v2", "Hopper-v3", "Hopper-v4",
|
||||
"Swimmer-v2", "Swimmer-v3", "Swimmer-v4",
|
||||
"Walker2d-v2", "Walker2d-v3", "Walker2d-v4",
|
||||
"Ant-v2", "Ant-v3", "Ant-v4",
|
||||
"HumanoidStandup-v2", "HumanoidStandup-v4",
|
||||
"Humanoid-v2", "Humanoid-v3", "Humanoid-v4",
|
||||
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
|
||||
# fmt: on
|
||||
|
||||
|
||||
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||
registry: dict[str, EnvSpec] = {}
|
||||
current_namespace: str | None = None
|
||||
|
||||
|
||||
def _check_spec_register(spec: EnvSpec):
|
||||
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
|
||||
global registry
|
||||
latest_versioned_spec = max(
|
||||
(
|
||||
spec_
|
||||
for spec_ in registry.values()
|
||||
if spec_.namespace == spec.namespace
|
||||
and spec_.name == spec.name
|
||||
and spec_.version is not None
|
||||
),
|
||||
key=lambda spec_: int(spec_.version), # type: ignore
|
||||
default=None,
|
||||
)
|
||||
|
||||
unversioned_spec = next(
|
||||
(
|
||||
spec_
|
||||
for spec_ in registry.values()
|
||||
if spec_.namespace == spec.namespace
|
||||
and spec_.name == spec.name
|
||||
and spec_.version is None
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if unversioned_spec is not None and spec.version is not None:
|
||||
raise error.RegistrationError(
|
||||
"Can't register the versioned environment "
|
||||
f"`{spec.id}` when the unversioned environment "
|
||||
f"`{unversioned_spec.id}` of the same name already exists."
|
||||
)
|
||||
elif latest_versioned_spec is not None and spec.version is None:
|
||||
raise error.RegistrationError(
|
||||
"Can't register the unversioned environment "
|
||||
f"`{spec.id}` when the versioned environment "
|
||||
f"`{latest_versioned_spec.id}` of the same name "
|
||||
f"already exists. Note: the default behavior is "
|
||||
f"that `gym.make` with the unversioned environment "
|
||||
f"will return the latest versioned environment"
|
||||
)
|
||||
|
||||
|
||||
def _check_metadata(metadata_: dict):
|
||||
if not isinstance(metadata_, dict):
|
||||
raise error.InvalidMetadata(
|
||||
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
|
||||
)
|
||||
|
||||
render_modes = metadata_.get("render_modes")
|
||||
if render_modes is None:
|
||||
logger.warn(
|
||||
f"The environment creator metadata doesn't include `render_modes`, contains: {list(metadata_.keys())}"
|
||||
)
|
||||
elif not isinstance(render_modes, Iterable):
|
||||
logger.warn(
|
||||
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
|
||||
)
|
||||
|
||||
|
||||
# Public API
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def namespace(ns: str):
|
||||
"""Context manager for modifying the current namespace."""
|
||||
@@ -456,7 +418,7 @@ def namespace(ns: str):
|
||||
|
||||
def register(
|
||||
id: str,
|
||||
entry_point: Callable | str,
|
||||
entry_point: EnvCreator | str,
|
||||
reward_threshold: float | None = None,
|
||||
nondeterministic: bool = False,
|
||||
max_episode_steps: int | None = None,
|
||||
@@ -464,26 +426,28 @@ def register(
|
||||
autoreset: bool = False,
|
||||
disable_env_checker: bool = False,
|
||||
apply_api_compatibility: bool = False,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Register an environment with gymnasium.
|
||||
"""Registers an environment in gymnasium with an ``id`` to use with :meth:`gymnasium.make` with the ``entry_point`` being a string or callable for creating the environment.
|
||||
|
||||
The `id` parameter corresponds to the name of the environment, with the syntax as follows:
|
||||
`(namespace)/(env_name)-v(version)` where `namespace` is optional.
|
||||
The ``id`` parameter corresponds to the name of the environment, with the syntax as follows:
|
||||
``[namespace/](env_name)[-v(version)]`` where ``namespace`` and ``-v(version)`` is optional.
|
||||
|
||||
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
|
||||
It takes arbitrary keyword arguments, which are passed to the :class:`EnvSpec` ``kwargs`` parameter.
|
||||
|
||||
Args:
|
||||
id: The environment id
|
||||
entry_point: The entry point for creating the environment
|
||||
reward_threshold: The reward threshold considered to have learnt an environment
|
||||
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions)
|
||||
max_episode_steps: The maximum number of episodes steps before truncation. Used by the Time Limit wrapper.
|
||||
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order
|
||||
autoreset: If to add the autoreset wrapper such that reset does not need to be called.
|
||||
disable_env_checker: If to disable the environment checker for the environment. Recommended to False.
|
||||
apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper.
|
||||
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
|
||||
reward_threshold: The reward threshold considered for an agent to have learnt the environment
|
||||
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions, the same state cannot be reached)
|
||||
max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``.
|
||||
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order.
|
||||
If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment.
|
||||
autoreset: If to add the :class:`gymnasium.wrappers.AutoResetWrapper` such that on ``(terminated or truncated) is True``, :meth:`gymnasium.Env.reset` is called.
|
||||
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
|
||||
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
|
||||
Use if the environment is implemented in the gym v0.21 environment API.
|
||||
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
|
||||
"""
|
||||
global registry, current_namespace
|
||||
ns, name, version = parse_env_id(id)
|
||||
@@ -502,10 +466,10 @@ def register(
|
||||
else:
|
||||
ns_id = ns
|
||||
|
||||
full_id = get_env_id(ns_id, name, version)
|
||||
full_env_id = get_env_id(ns_id, name, version)
|
||||
|
||||
new_spec = EnvSpec(
|
||||
id=full_id,
|
||||
id=full_env_id,
|
||||
entry_point=entry_point,
|
||||
reward_threshold=reward_threshold,
|
||||
nondeterministic=nondeterministic,
|
||||
@@ -517,6 +481,7 @@ def register(
|
||||
**kwargs,
|
||||
)
|
||||
_check_spec_register(new_spec)
|
||||
|
||||
if new_spec.id in registry:
|
||||
logger.warn(f"Overriding environment {new_spec.id} already in registry.")
|
||||
registry[new_spec.id] = new_spec
|
||||
@@ -528,36 +493,36 @@ def make(
|
||||
autoreset: bool = False,
|
||||
apply_api_compatibility: bool | None = None,
|
||||
disable_env_checker: bool | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Env:
|
||||
"""Create an environment according to the given ID.
|
||||
"""Creates an environment previously registered with :meth:`gymnasium.register` or a :class:`EnvSpec`.
|
||||
|
||||
To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids.
|
||||
To find all available environments use ``gymnasium.envs.registry.keys()`` for all valid ids.
|
||||
|
||||
Args:
|
||||
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||
apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
|
||||
id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``.
|
||||
This is equivalent to importing the module first to register the environment followed by making the environment.
|
||||
max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``.
|
||||
The value is used by :class:`gymnasium.wrappers.TimeLimit`.
|
||||
autoreset: Whether to automatically reset the environment after each episode (:class:`gymnasium.wrappers.AutoResetWrapper`).
|
||||
apply_api_compatibility: Whether to wrap the environment with the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper that
|
||||
converts the environment step from a done bool to return termination and truncation bools.
|
||||
By default, the argument is None to which the environment specification `apply_api_compatibility` is used
|
||||
which defaults to False. Otherwise, the value of `apply_api_compatibility` is used.
|
||||
If `True`, the wrapper is applied otherwise, the wrapper is not applied.
|
||||
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
|
||||
(which is by default False, running the environment checker),
|
||||
otherwise will run according to this parameter (`True` = not run, `False` = run)
|
||||
By default, the argument is None in which the :class:`EnvSpec` ``apply_api_compatibility`` is used, otherwise this variable is used in favor.
|
||||
disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the
|
||||
:class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used.
|
||||
kwargs: Additional arguments to pass to the environment constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the environment.
|
||||
An instance of the environment with wrappers applied.
|
||||
|
||||
Raises:
|
||||
Error: If the ``id`` doesn't exist then an error is raised
|
||||
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
||||
"""
|
||||
if isinstance(id, EnvSpec):
|
||||
spec_ = id
|
||||
env_spec = id
|
||||
else:
|
||||
module, id = (None, id) if ":" not in id else id.split(":")
|
||||
# The environment name can include an unloaded module in "module:env_name" style
|
||||
module, env_name = (None, id) if ":" not in id else id.split(":")
|
||||
if module is not None:
|
||||
try:
|
||||
importlib.import_module(module)
|
||||
@@ -566,9 +531,13 @@ def make(
|
||||
f"{e}. Environment registration via importing a module failed. "
|
||||
f"Check whether '{module}' contains env registration and can be imported."
|
||||
) from e
|
||||
spec_ = registry.get(id)
|
||||
|
||||
ns, name, version = parse_env_id(id)
|
||||
# load the env spec from the registry
|
||||
env_spec = registry.get(env_name)
|
||||
|
||||
# update env spec is not version provided, raise warning if out of date
|
||||
ns, name, version = parse_env_id(env_name)
|
||||
|
||||
latest_version = find_highest_version(ns, name)
|
||||
if (
|
||||
version is not None
|
||||
@@ -576,38 +545,44 @@ def make(
|
||||
and latest_version > version
|
||||
):
|
||||
logger.warn(
|
||||
f"The environment {id} is out of date. You should consider "
|
||||
f"The environment {env_name} is out of date. You should consider "
|
||||
f"upgrading to version `v{latest_version}`."
|
||||
)
|
||||
if version is None and latest_version is not None:
|
||||
version = latest_version
|
||||
new_env_id = get_env_id(ns, name, version)
|
||||
spec_ = registry.get(new_env_id)
|
||||
env_spec = registry.get(new_env_id)
|
||||
logger.warn(
|
||||
f"Using the latest versioned environment `{new_env_id}` "
|
||||
f"instead of the unversioned environment `{id}`."
|
||||
f"instead of the unversioned environment `{env_name}`."
|
||||
)
|
||||
|
||||
if spec_ is None:
|
||||
if env_spec is None:
|
||||
_check_version_exists(ns, name, version)
|
||||
raise error.Error(f"No registered env with id: {id}")
|
||||
raise error.Error(f"No registered env with id: {env_name}")
|
||||
|
||||
_kwargs = spec_.kwargs.copy()
|
||||
_kwargs.update(kwargs)
|
||||
assert isinstance(
|
||||
env_spec, EnvSpec
|
||||
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}"
|
||||
# Extract the spec kwargs and append the make kwargs
|
||||
spec_kwargs = env_spec.kwargs.copy()
|
||||
spec_kwargs.update(kwargs)
|
||||
|
||||
if spec_.entry_point is None:
|
||||
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
|
||||
elif callable(spec_.entry_point):
|
||||
env_creator = spec_.entry_point
|
||||
# Load the environment creator
|
||||
if env_spec.entry_point is None:
|
||||
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
||||
elif callable(env_spec.entry_point):
|
||||
env_creator = env_spec.entry_point
|
||||
else:
|
||||
# Assume it's a string
|
||||
env_creator = load(spec_.entry_point)
|
||||
env_creator = load_env(env_spec.entry_point)
|
||||
|
||||
render_modes = None
|
||||
# Determine if to use the rendering
|
||||
render_modes: list[str] | None = None
|
||||
if hasattr(env_creator, "metadata"):
|
||||
_check_metadata(env_creator.metadata)
|
||||
render_modes = env_creator.metadata.get("render_modes")
|
||||
mode = _kwargs.get("render_mode")
|
||||
mode = spec_kwargs.get("render_mode")
|
||||
apply_human_rendering = False
|
||||
apply_render_collection = False
|
||||
|
||||
@@ -619,10 +594,10 @@ def make(
|
||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
||||
"The HumanRendering wrapper is being applied to your environment."
|
||||
)
|
||||
_kwargs["render_mode"] = displayable_modes.pop()
|
||||
spec_kwargs["render_mode"] = displayable_modes.pop()
|
||||
apply_human_rendering = True
|
||||
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
||||
_kwargs["render_mode"] = mode[: -len("_list")]
|
||||
spec_kwargs["render_mode"] = mode[: -len("_list")]
|
||||
apply_render_collection = True
|
||||
else:
|
||||
logger.warn(
|
||||
@@ -630,16 +605,16 @@ def make(
|
||||
f"that is not in the possible render_modes ({render_modes})."
|
||||
)
|
||||
|
||||
if apply_api_compatibility is True or (
|
||||
apply_api_compatibility is None and spec_.apply_api_compatibility is True
|
||||
if apply_api_compatibility or (
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
||||
):
|
||||
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
||||
render_mode = _kwargs.pop("render_mode", None)
|
||||
render_mode = spec_kwargs.pop("render_mode", None)
|
||||
else:
|
||||
render_mode = None
|
||||
|
||||
try:
|
||||
env = env_creator(**_kwargs)
|
||||
env = env_creator(**spec_kwargs)
|
||||
except TypeError as e:
|
||||
if (
|
||||
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
||||
@@ -654,31 +629,31 @@ def make(
|
||||
raise e
|
||||
|
||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||
spec_ = copy.deepcopy(spec_)
|
||||
spec_.kwargs = _kwargs
|
||||
env.unwrapped.spec = spec_
|
||||
env_spec = copy.deepcopy(env_spec)
|
||||
env_spec.kwargs = spec_kwargs
|
||||
env.unwrapped.spec = env_spec
|
||||
|
||||
# Add step API wrapper
|
||||
if apply_api_compatibility is True or (
|
||||
apply_api_compatibility is None and spec_.apply_api_compatibility is True
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
||||
):
|
||||
env = EnvCompatibility(env, render_mode)
|
||||
|
||||
# Run the environment checker as the lowest level wrapper
|
||||
if disable_env_checker is False or (
|
||||
disable_env_checker is None and spec_.disable_env_checker is False
|
||||
disable_env_checker is None and env_spec.disable_env_checker is False
|
||||
):
|
||||
env = PassiveEnvChecker(env)
|
||||
|
||||
# Add the order enforcing wrapper
|
||||
if spec_.order_enforce:
|
||||
if env_spec.order_enforce:
|
||||
env = OrderEnforcing(env)
|
||||
|
||||
# Add the time limit wrapper
|
||||
if max_episode_steps is not None:
|
||||
env = TimeLimit(env, max_episode_steps)
|
||||
elif spec_.max_episode_steps is not None:
|
||||
env = TimeLimit(env, spec_.max_episode_steps)
|
||||
elif env_spec.max_episode_steps is not None:
|
||||
env = TimeLimit(env, env_spec.max_episode_steps)
|
||||
|
||||
# Add the autoreset wrapper
|
||||
if autoreset:
|
||||
@@ -694,74 +669,101 @@ def make(
|
||||
|
||||
|
||||
def spec(env_id: str) -> EnvSpec:
|
||||
"""Retrieve the spec for the given environment from the global registry."""
|
||||
spec_ = registry.get(env_id)
|
||||
if spec_ is None:
|
||||
"""Retrieve the :class:`EnvSpec` for the environment id from the :attr:`registry`.
|
||||
|
||||
Args:
|
||||
env_id: The environment id with the expected format of ``[(namespace)/]id[-v(version)]``
|
||||
|
||||
Returns:
|
||||
The environment spec if it exists
|
||||
|
||||
Raises:
|
||||
Error: If the environment id doesn't exist
|
||||
"""
|
||||
env_spec = registry.get(env_id)
|
||||
if env_spec is None:
|
||||
ns, name, version = parse_env_id(env_id)
|
||||
_check_version_exists(ns, name, version)
|
||||
raise error.Error(f"No registered env with id: {env_id}")
|
||||
else:
|
||||
assert isinstance(spec_, EnvSpec)
|
||||
return spec_
|
||||
assert isinstance(
|
||||
env_spec, EnvSpec
|
||||
), f"Expected the registry for {env_id} to be an `EnvSpec`, actual type is {type(env_spec)}"
|
||||
return env_spec
|
||||
|
||||
|
||||
def pprint_registry(
|
||||
_registry: dict = registry,
|
||||
print_registry: dict[str, EnvSpec] = registry,
|
||||
*,
|
||||
num_cols: int = 3,
|
||||
exclude_namespaces: list[str] | None = None,
|
||||
disable_print: bool = False,
|
||||
) -> str | None:
|
||||
"""Pretty print the environments in the registry.
|
||||
"""Pretty prints all environments in the :attr:`registry`.
|
||||
|
||||
Note:
|
||||
All arguments are keyword only
|
||||
|
||||
Args:
|
||||
_registry: Environment registry to be printed.
|
||||
print_registry: Environment registry to be printed. By default, :attr:`registry`
|
||||
num_cols: Number of columns to arrange environments in, for display.
|
||||
exclude_namespaces: Exclude any namespaces from being printed.
|
||||
exclude_namespaces: A list of namespaces to be excluded from printing. Helpful if only ALE environments are wanted.
|
||||
disable_print: Whether to return a string of all the namespaces and environment IDs
|
||||
instead of printing it to console.
|
||||
or to print the string to console.
|
||||
"""
|
||||
# Defaultdict to store environment names according to namespace.
|
||||
namespace_envs = defaultdict(lambda: [])
|
||||
# Defaultdict to store environment ids according to namespace.
|
||||
namespace_envs: dict[str, list[str]] = defaultdict(lambda: [])
|
||||
max_justify = float("-inf")
|
||||
for env in _registry.values():
|
||||
namespace, _, _ = parse_env_id(env.id)
|
||||
if namespace is None:
|
||||
# Since namespace is currently none, use regex to obtain namespace from entrypoints.
|
||||
env_entry_point = re.sub(r":\w+", "", env.entry_point)
|
||||
e_ep_split = env_entry_point.split(".")
|
||||
if len(e_ep_split) >= 3:
|
||||
# If namespace is of the format - gymnasium.envs.mujoco.ant_v4:AntEnv
|
||||
# or gymnasium.envs.mujoco:HumanoidEnv
|
||||
idx = 2
|
||||
namespace = e_ep_split[idx]
|
||||
elif len(e_ep_split) > 1:
|
||||
# If namespace is of the format - shimmy.atari_env
|
||||
idx = 1
|
||||
namespace = e_ep_split[idx]
|
||||
else:
|
||||
# If namespace cannot be found, default to env id.
|
||||
namespace = env.id
|
||||
namespace_envs[namespace].append(env.id)
|
||||
max_justify = max(max_justify, len(env.id))
|
||||
|
||||
# Iterate through each namespace and print environment alphabetically.
|
||||
return_str = ""
|
||||
for namespace, envs in namespace_envs.items():
|
||||
# Find the namespace associated with each environment spec
|
||||
for env_spec in print_registry.values():
|
||||
ns = env_spec.namespace
|
||||
|
||||
if ns is None and isinstance(env_spec.entry_point, str):
|
||||
# Use regex to obtain namespace from entrypoints.
|
||||
env_entry_point = re.sub(r":\w+", "", env_spec.entry_point)
|
||||
split_entry_point = env_entry_point.split(".")
|
||||
|
||||
if len(split_entry_point) >= 3:
|
||||
# If namespace is of the format:
|
||||
# - gymnasium.envs.mujoco.ant_v4:AntEnv
|
||||
# - gymnasium.envs.mujoco:HumanoidEnv
|
||||
ns = split_entry_point[2]
|
||||
elif len(split_entry_point) > 1:
|
||||
# If namespace is of the format - shimmy.atari_env
|
||||
ns = split_entry_point[1]
|
||||
else:
|
||||
# If namespace cannot be found, default to env name
|
||||
ns = env_spec.name
|
||||
|
||||
namespace_envs[ns].append(env_spec.id)
|
||||
max_justify = max(max_justify, len(env_spec.name))
|
||||
|
||||
# Iterate through each namespace and print environment alphabetically
|
||||
output: list[str] = []
|
||||
for ns, env_ids in namespace_envs.items():
|
||||
# Ignore namespaces to exclude.
|
||||
if exclude_namespaces is not None and namespace in exclude_namespaces:
|
||||
if exclude_namespaces is not None and ns in exclude_namespaces:
|
||||
continue
|
||||
return_str += f"{'=' * 5} {namespace} {'=' * 5}\n" # Print namespace.
|
||||
|
||||
# Print the namespace
|
||||
namespace_output = f"{'=' * 5} {ns} {'=' * 5}\n"
|
||||
|
||||
# Reference: https://stackoverflow.com/a/33464001
|
||||
for count, item in enumerate(sorted(envs), 1):
|
||||
return_str += (
|
||||
item.ljust(max_justify) + " "
|
||||
) # Print column with justification.
|
||||
for count, env_id in enumerate(sorted(env_ids), 1):
|
||||
# Print column with justification.
|
||||
namespace_output += env_id.ljust(max_justify) + " "
|
||||
|
||||
# Once all rows printed, switch to new column.
|
||||
if count % num_cols == 0 or count == len(envs):
|
||||
return_str = return_str.rstrip(" ") + "\n"
|
||||
return_str += "\n"
|
||||
if count % num_cols == 0:
|
||||
namespace_output = namespace_output.rstrip(" ")
|
||||
|
||||
if count != len(env_ids):
|
||||
namespace_output += "\n"
|
||||
|
||||
output.append(namespace_output.rstrip(" "))
|
||||
|
||||
if disable_print:
|
||||
return return_str
|
||||
return "\n".join(output)
|
||||
else:
|
||||
print(return_str, end="")
|
||||
print("\n".join(output))
|
||||
|
0
tests/envs/registration/__init__.py
Normal file
0
tests/envs/registration/__init__.py
Normal file
@@ -1,14 +1,14 @@
|
||||
"""Tests that gym.make works as expected."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.classic_control import cartpole
|
||||
from gymnasium.envs.classic_control import CartPoleEnv
|
||||
from gymnasium.wrappers import (
|
||||
AutoResetWrapper,
|
||||
HumanRendering,
|
||||
@@ -16,9 +16,8 @@ from gymnasium.wrappers import (
|
||||
TimeLimit,
|
||||
)
|
||||
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
||||
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
|
||||
from tests.envs.registration.utils_envs import ArgumentEnv
|
||||
from tests.envs.utils import all_testing_env_specs
|
||||
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
||||
from tests.testing_env import GenericTestEnv, old_step_func
|
||||
from tests.wrappers.utils import has_wrapper
|
||||
|
||||
@@ -30,16 +29,11 @@ except ImportError:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def register_make_testing_envs():
|
||||
def register_testing_envs():
|
||||
"""Registers testing envs for `gym.make`"""
|
||||
gym.register(
|
||||
"RegisterDuringMakeEnv-v0",
|
||||
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
|
||||
)
|
||||
|
||||
gym.register(
|
||||
id="test.ArgumentEnv-v0",
|
||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
||||
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
||||
kwargs={
|
||||
"arg1": "arg1",
|
||||
"arg2": "arg2",
|
||||
@@ -48,26 +42,25 @@ def register_make_testing_envs():
|
||||
|
||||
gym.register(
|
||||
id="test/NoHuman-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHuman",
|
||||
entry_point="tests.envs.registration.utils_envs:NoHuman",
|
||||
)
|
||||
gym.register(
|
||||
id="test/NoHumanOldAPI-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHumanOldAPI",
|
||||
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
|
||||
)
|
||||
|
||||
gym.register(
|
||||
id="test/NoHumanNoRGB-v0",
|
||||
entry_point="tests.envs.utils_envs:NoHumanNoRGB",
|
||||
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
|
||||
)
|
||||
|
||||
gym.register(
|
||||
id="test/NoRenderModesMetadata-v0",
|
||||
entry_point="tests.envs.utils_envs:NoRenderModesMetadata",
|
||||
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
del gym.envs.registration.registry["RegisterDuringMakeEnv-v0"]
|
||||
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
|
||||
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
|
||||
del gym.envs.registration.registry["test/NoHuman-v0"]
|
||||
@@ -76,14 +69,16 @@ def register_make_testing_envs():
|
||||
|
||||
|
||||
def test_make():
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
"""Test basic `gym.make`."""
|
||||
env = gym.make("CartPole-v1")
|
||||
assert env.spec is not None
|
||||
assert env.spec.id == "CartPole-v1"
|
||||
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
|
||||
assert isinstance(env.unwrapped, CartPoleEnv)
|
||||
env.close()
|
||||
|
||||
|
||||
def test_make_deprecated():
|
||||
"""Test make with a deprecated environment (i.e., doesn't exist)."""
|
||||
with warnings.catch_warnings(record=True):
|
||||
with pytest.raises(
|
||||
gym.error.Error,
|
||||
@@ -91,21 +86,20 @@ def test_make_deprecated():
|
||||
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
|
||||
),
|
||||
):
|
||||
gym.make("Humanoid-v0", disable_env_checker=True)
|
||||
gym.make("Humanoid-v0")
|
||||
|
||||
|
||||
def test_make_max_episode_steps(register_make_testing_envs):
|
||||
def test_make_max_episode_steps(register_testing_envs):
|
||||
# Default, uses the spec's
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
env = gym.make("CartPole-v1")
|
||||
assert has_wrapper(env, TimeLimit)
|
||||
assert env.spec is not None
|
||||
assert (
|
||||
env.spec.max_episode_steps == gym.envs.registry["CartPole-v1"].max_episode_steps
|
||||
)
|
||||
assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps
|
||||
env.close()
|
||||
|
||||
# Custom max episode steps
|
||||
env = gym.make("CartPole-v1", max_episode_steps=100, disable_env_checker=True)
|
||||
assert gym.spec("CartPole-v1").max_episode_steps != 100
|
||||
env = gym.make("CartPole-v1", max_episode_steps=100)
|
||||
assert has_wrapper(env, TimeLimit)
|
||||
assert env.spec is not None
|
||||
assert env.spec.max_episode_steps == 100
|
||||
@@ -113,20 +107,20 @@ def test_make_max_episode_steps(register_make_testing_envs):
|
||||
|
||||
# Env spec has no max episode steps
|
||||
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
|
||||
env = gym.make(
|
||||
"test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None, disable_env_checker=True
|
||||
)
|
||||
env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None)
|
||||
assert env.spec is not None
|
||||
assert env.spec.max_episode_steps is None
|
||||
assert has_wrapper(env, TimeLimit) is False
|
||||
env.close()
|
||||
|
||||
|
||||
def test_gym_make_autoreset():
|
||||
def test_make_autoreset():
|
||||
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
env = gym.make("CartPole-v1")
|
||||
assert has_wrapper(env, AutoResetWrapper) is False
|
||||
env.close()
|
||||
|
||||
env = gym.make("CartPole-v1", autoreset=False, disable_env_checker=True)
|
||||
env = gym.make("CartPole-v1", autoreset=False)
|
||||
assert has_wrapper(env, AutoResetWrapper) is False
|
||||
env.close()
|
||||
|
||||
@@ -135,43 +129,49 @@ def test_gym_make_autoreset():
|
||||
env.close()
|
||||
|
||||
|
||||
def test_make_disable_env_checker():
|
||||
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`."""
|
||||
spec = deepcopy(gym.spec("CartPole-v1"))
|
||||
@pytest.mark.parametrize(
|
||||
"registration_disabled, make_disabled, if_disabled",
|
||||
[
|
||||
[False, False, False],
|
||||
[False, True, True],
|
||||
[True, False, False],
|
||||
[True, True, True],
|
||||
[False, None, False],
|
||||
[True, None, True],
|
||||
],
|
||||
)
|
||||
def test_make_disable_env_checker(
|
||||
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
|
||||
):
|
||||
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
|
||||
|
||||
# Test with spec disable env checker
|
||||
spec.disable_env_checker = False
|
||||
env = gym.make(spec)
|
||||
assert has_wrapper(env, PassiveEnvChecker)
|
||||
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
|
||||
"""
|
||||
gym.register(
|
||||
"testing-env-v0",
|
||||
lambda: GenericTestEnv(),
|
||||
disable_env_checker=registration_disabled,
|
||||
)
|
||||
|
||||
# Test when the registered EnvSpec.disable_env_checker = False
|
||||
env = gym.make("testing-env-v0", disable_env_checker=make_disabled)
|
||||
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
|
||||
env.close()
|
||||
|
||||
# Test with overwritten spec using make disable env checker
|
||||
assert spec.disable_env_checker is False
|
||||
env = gym.make(spec, disable_env_checker=True)
|
||||
assert has_wrapper(env, PassiveEnvChecker) is False
|
||||
env.close()
|
||||
|
||||
# Test with spec enabled disable env checker
|
||||
spec.disable_env_checker = True
|
||||
env = gym.make(spec)
|
||||
assert has_wrapper(env, PassiveEnvChecker) is False
|
||||
env.close()
|
||||
|
||||
# Test with overwritten spec using make disable env checker
|
||||
assert spec.disable_env_checker is True
|
||||
env = gym.make(spec, disable_env_checker=False)
|
||||
assert has_wrapper(env, PassiveEnvChecker)
|
||||
env.close()
|
||||
del gym.registry["testing-env-v0"]
|
||||
|
||||
|
||||
def test_apply_api_compatibility():
|
||||
def test_make_apply_api_compatibility():
|
||||
"""Test the API compatibility wrapper."""
|
||||
gym.register(
|
||||
"testing-old-env",
|
||||
lambda: GenericTestEnv(step_func=old_step_func),
|
||||
apply_api_compatibility=True,
|
||||
max_episode_steps=3,
|
||||
)
|
||||
# Apply the environment compatibility and check it works as intended
|
||||
env = gym.make("testing-old-env")
|
||||
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
||||
|
||||
env.reset()
|
||||
assert len(env.step(env.action_space.sample())) == 5
|
||||
@@ -179,11 +179,19 @@ def test_apply_api_compatibility():
|
||||
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
||||
assert termination is False and truncation is True
|
||||
|
||||
# Turn off the spec api compatibility
|
||||
gym.spec("testing-old-env").apply_api_compatibility = False
|
||||
env = gym.make("testing-old-env")
|
||||
# Cannot run reset and step as will not work
|
||||
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
|
||||
env.reset()
|
||||
with pytest.raises(
|
||||
ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)")
|
||||
):
|
||||
env.step(env.action_space.sample())
|
||||
|
||||
# Apply the environment compatibility and check it works as intended
|
||||
env = gym.make("testing-old-env", apply_api_compatibility=True)
|
||||
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
||||
|
||||
env.reset()
|
||||
assert len(env.step(env.action_space.sample())) == 5
|
||||
@@ -191,57 +199,63 @@ def test_apply_api_compatibility():
|
||||
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
||||
assert termination is False and truncation is True
|
||||
|
||||
gym.envs.registry.pop("testing-old-env")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
||||
)
|
||||
def test_passive_checker_wrapper_warnings(spec):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env = gym.make(spec) # disable_env_checker=False
|
||||
env.reset()
|
||||
env.step(env.action_space.sample())
|
||||
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
|
||||
|
||||
env.close()
|
||||
|
||||
for warning in caught_warnings:
|
||||
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
|
||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||
del gym.registry["testing-old-env"]
|
||||
|
||||
|
||||
def test_make_order_enforcing():
|
||||
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
|
||||
assert all(spec.order_enforce is True for spec in all_testing_env_specs)
|
||||
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
env = gym.make("CartPole-v1")
|
||||
assert has_wrapper(env, OrderEnforcing)
|
||||
# We can assume that there all other specs will also have the order enforcing
|
||||
env.close()
|
||||
|
||||
gym.register(
|
||||
id="test.OrderlessArgumentEnv-v0",
|
||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
||||
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
||||
order_enforce=False,
|
||||
kwargs={"arg1": None, "arg2": None, "arg3": None},
|
||||
)
|
||||
|
||||
env = gym.make("test.OrderlessArgumentEnv-v0", disable_env_checker=True)
|
||||
env = gym.make("test.OrderlessArgumentEnv-v0")
|
||||
assert has_wrapper(env, OrderEnforcing) is False
|
||||
env.close()
|
||||
|
||||
# There is no `make(..., order_enforcing=...)` so we don't test that
|
||||
|
||||
def test_make_render_mode(register_make_testing_envs):
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
|
||||
def test_make_render_mode():
|
||||
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
|
||||
env = gym.make("CartPole-v1", render_mode=None)
|
||||
assert env.render_mode is None
|
||||
env.close()
|
||||
|
||||
assert "rgb_array" in env.metadata["render_modes"]
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
assert env.render_mode == "rgb_array"
|
||||
env.close()
|
||||
|
||||
assert "no-render-mode" not in env.metadata["render_modes"]
|
||||
# cartpole is special that it doesn't check the render_mode passed at initialisation
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match=re.escape(
|
||||
"\x1b[33mWARN: The environment is being initialised with render_mode='no-render-mode' that is not in the possible render_modes (['human', 'rgb_array']).\x1b[0m"
|
||||
),
|
||||
):
|
||||
env = gym.make("CartPole-v1", render_mode="no-render-mode")
|
||||
assert env.render_mode == "no-render-mode"
|
||||
env.close()
|
||||
|
||||
|
||||
def test_make_render_collection():
|
||||
# Make sure that render_mode is applied correctly
|
||||
env = gym.make(
|
||||
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
|
||||
)
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
|
||||
assert has_wrapper(env, gym.wrappers.RenderCollection)
|
||||
assert env.render_mode == "rgb_array_list"
|
||||
assert env.unwrapped.render_mode == "rgb_array"
|
||||
|
||||
env.reset()
|
||||
renders = env.render()
|
||||
assert isinstance(
|
||||
@@ -250,24 +264,10 @@ def test_make_render_mode(register_make_testing_envs):
|
||||
assert isinstance(renders[0], np.ndarray)
|
||||
env.close()
|
||||
|
||||
env = gym.make("CartPole-v1", render_mode=None, disable_env_checker=True)
|
||||
assert env.render_mode is None
|
||||
valid_render_modes = env.metadata["render_modes"]
|
||||
env.close()
|
||||
|
||||
assert len(valid_render_modes) > 0
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env = gym.make(
|
||||
"CartPole-v1", render_mode=valid_render_modes[0], disable_env_checker=True
|
||||
)
|
||||
assert env.render_mode == valid_render_modes[0]
|
||||
env.close()
|
||||
|
||||
for warning in caught_warnings:
|
||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||
|
||||
def test_make_human_rendering(register_testing_envs):
|
||||
# Make sure that native rendering is used when possible
|
||||
env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True)
|
||||
env = gym.make("CartPole-v1", render_mode="human")
|
||||
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
|
||||
assert env.render_mode == "human"
|
||||
env.close()
|
||||
@@ -278,10 +278,8 @@ def test_make_render_mode(register_make_testing_envs):
|
||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. The HumanRendering wrapper is being applied to your environment."
|
||||
),
|
||||
):
|
||||
# Make sure that `HumanRendering` is applied here
|
||||
env = gym.make(
|
||||
"test/NoHuman-v0", render_mode="human", disable_env_checker=True
|
||||
) # This environment doesn't use native rendering
|
||||
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
|
||||
env = gym.make("test/NoHuman-v0", render_mode="human")
|
||||
assert has_wrapper(env, HumanRendering)
|
||||
assert env.render_mode == "human"
|
||||
env.close()
|
||||
@@ -292,7 +290,6 @@ def test_make_render_mode(register_make_testing_envs):
|
||||
gym.make(
|
||||
"test/NoHumanOldAPI-v0",
|
||||
render_mode="rgb_array_list",
|
||||
disable_env_checker=True,
|
||||
)
|
||||
|
||||
# Make sure that an additional error is thrown a user tries to use the wrapper on an environment with old API
|
||||
@@ -303,9 +300,7 @@ def test_make_render_mode(register_make_testing_envs):
|
||||
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively."
|
||||
),
|
||||
):
|
||||
gym.make(
|
||||
"test/NoHumanOldAPI-v0", render_mode="human", disable_env_checker=True
|
||||
)
|
||||
gym.make("test/NoHumanOldAPI-v0", render_mode="human")
|
||||
|
||||
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
|
||||
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
|
||||
@@ -326,15 +321,20 @@ def test_make_render_mode(register_make_testing_envs):
|
||||
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array")
|
||||
|
||||
|
||||
def test_make_kwargs(register_make_testing_envs):
|
||||
def test_make_kwargs(register_testing_envs):
|
||||
env = gym.make(
|
||||
"test.ArgumentEnv-v0",
|
||||
arg2="override_arg2",
|
||||
arg3="override_arg3",
|
||||
disable_env_checker=True,
|
||||
)
|
||||
assert env.spec is not None
|
||||
assert env.spec.id == "test.ArgumentEnv-v0"
|
||||
assert env.spec.kwargs == {
|
||||
"arg1": "arg1",
|
||||
"arg2": "override_arg2",
|
||||
"arg3": "override_arg3",
|
||||
}
|
||||
|
||||
assert isinstance(env.unwrapped, ArgumentEnv)
|
||||
assert env.arg1 == "arg1"
|
||||
assert env.arg2 == "override_arg2"
|
||||
@@ -342,11 +342,16 @@ def test_make_kwargs(register_make_testing_envs):
|
||||
env.close()
|
||||
|
||||
|
||||
def test_import_module_during_make(register_make_testing_envs):
|
||||
def test_import_module_during_make():
|
||||
# Test custom environment which is registered at make
|
||||
assert "RegisterDuringMake-v0" not in gym.registry
|
||||
env = gym.make(
|
||||
"tests.envs.utils:RegisterDuringMakeEnv-v0",
|
||||
disable_env_checker=True,
|
||||
"tests.envs.registration.utils_unregistered_env:RegisterDuringMake-v0"
|
||||
)
|
||||
assert "RegisterDuringMake-v0" in gym.registry
|
||||
from tests.envs.registration.utils_unregistered_env import RegisterDuringMakeEnv
|
||||
|
||||
assert isinstance(env.unwrapped, RegisterDuringMakeEnv)
|
||||
env.close()
|
||||
|
||||
del gym.registry["RegisterDuringMake-v0"]
|
112
tests/envs/registration/test_pprint_registry.py
Normal file
112
tests/envs/registration/test_pprint_registry.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.classic_control import CartPoleEnv
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
|
||||
# To ignore the trailing whitespaces, will need flake to ignore this file.
|
||||
# flake8: noqa
|
||||
|
||||
EXAMPLE_ENTRY_POINT = "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
|
||||
|
||||
|
||||
def test_pprint_default_registry():
|
||||
out = gym.pprint_registry(disable_print=True)
|
||||
assert isinstance(out, str) and len(out) > 0
|
||||
|
||||
|
||||
def test_pprint_example_registry():
|
||||
"""Testing a registry different from default."""
|
||||
example_registry: dict[str, EnvSpec] = {
|
||||
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
|
||||
}
|
||||
|
||||
out = gym.pprint_registry(example_registry, disable_print=True)
|
||||
correct_out = """===== classic_control =====
|
||||
CartPole-v0 CartPole-v1 CartPole-v2
|
||||
CartPole-v3"""
|
||||
assert out == correct_out
|
||||
|
||||
|
||||
def test_pprint_namespace():
|
||||
example_registry: dict[str, EnvSpec] = {
|
||||
"CartPole-v0": EnvSpec(
|
||||
"CartPole-v0", "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
|
||||
),
|
||||
"CartPole-v1": EnvSpec(
|
||||
"CartPole-v1", "gymnasium.envs.classic_control:CartPoleEnv"
|
||||
),
|
||||
"CartPole-v2": EnvSpec("CartPole-v2", "gymnasium.cartpole:CartPoleEnv"),
|
||||
"CartPole-v3": EnvSpec("CartPole-v3", lambda: CartPoleEnv()),
|
||||
"ExampleNamespace/CartPole-v2": EnvSpec(
|
||||
"ExampleNamespace/CartPole-v2", "gymnasium.envs.classic_control:CartPoleEnv"
|
||||
),
|
||||
}
|
||||
|
||||
out = gym.pprint_registry(example_registry, disable_print=True)
|
||||
correct_out = """===== classic_control =====
|
||||
CartPole-v0 CartPole-v1
|
||||
===== cartpole =====
|
||||
CartPole-v2
|
||||
===== None =====
|
||||
CartPole-v3
|
||||
===== ExampleNamespace =====
|
||||
ExampleNamespace/CartPole-v2"""
|
||||
assert out == correct_out
|
||||
|
||||
|
||||
def test_pprint_n_columns():
|
||||
example_registry = {
|
||||
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
|
||||
}
|
||||
|
||||
out = gym.pprint_registry(example_registry, num_cols=2, disable_print=True)
|
||||
correct_out = """===== classic_control =====
|
||||
CartPole-v0 CartPole-v1
|
||||
CartPole-v2 CartPole-v3"""
|
||||
assert out == correct_out
|
||||
|
||||
out = gym.pprint_registry(example_registry, num_cols=5, disable_print=True)
|
||||
correct_out = """===== classic_control =====
|
||||
CartPole-v0 CartPole-v1 CartPole-v2 CartPole-v3"""
|
||||
assert out == correct_out
|
||||
|
||||
|
||||
def test_pprint_exclude_namespace():
|
||||
example_registry: dict[str, EnvSpec] = {
|
||||
"Test/CartPole-v0": EnvSpec("Test/CartPole-v0", EXAMPLE_ENTRY_POINT),
|
||||
"Test/CartPole-v1": EnvSpec("Test/CartPole-v1", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
|
||||
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
|
||||
}
|
||||
|
||||
out = gym.pprint_registry(
|
||||
example_registry, exclude_namespaces=["Test"], disable_print=True
|
||||
)
|
||||
correct_out = """===== classic_control =====
|
||||
CartPole-v2 CartPole-v3"""
|
||||
assert out == correct_out
|
||||
|
||||
out = gym.pprint_registry(
|
||||
example_registry, exclude_namespaces=["classic_control"], disable_print=True
|
||||
)
|
||||
correct_out = """===== Test =====
|
||||
Test/CartPole-v0 Test/CartPole-v1"""
|
||||
assert out == correct_out
|
||||
|
||||
example_registry["Example/CartPole-v4"] = EnvSpec(
|
||||
"Example/CartPole-v4", EXAMPLE_ENTRY_POINT
|
||||
)
|
||||
out = gym.pprint_registry(
|
||||
example_registry, exclude_namespaces=["Test", "Example"], disable_print=True
|
||||
)
|
||||
correct_out = """===== classic_control =====
|
||||
CartPole-v2 CartPole-v3"""
|
||||
assert out == correct_out
|
@@ -18,7 +18,7 @@ def register_registration_testing_envs():
|
||||
env_id = f"{namespace}/{versioned_name}-v{version}"
|
||||
gym.register(
|
||||
id=env_id,
|
||||
entry_point="tests.envs.utils_envs:ArgumentEnv",
|
||||
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
||||
kwargs={
|
||||
"arg1": "arg1",
|
||||
"arg2": "arg2",
|
||||
@@ -111,7 +111,7 @@ def test_env_suggestions(
|
||||
with pytest.raises(
|
||||
gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
|
||||
):
|
||||
gym.make(env_id_input, disable_env_checker=True)
|
||||
gym.make(env_id_input)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -136,13 +136,13 @@ def test_env_version_suggestions(
|
||||
gym.error.DeprecatedEnv,
|
||||
match="It provides the default version", # env name,
|
||||
):
|
||||
gym.make(env_id_input, disable_env_checker=True)
|
||||
gym.make(env_id_input)
|
||||
else:
|
||||
with pytest.raises(
|
||||
gym.error.UnregisteredEnv,
|
||||
match=f"It provides versioned environments: \\[ {suggested_versions} \\]",
|
||||
):
|
||||
gym.make(env_id_input, disable_env_checker=True)
|
||||
gym.make(env_id_input)
|
||||
|
||||
|
||||
def test_register_versioned_unversioned():
|
||||
@@ -185,9 +185,7 @@ def test_make_latest_versioned_env(register_registration_testing_envs):
|
||||
"Using the latest versioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv-v5` instead of the unversioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv`."
|
||||
),
|
||||
):
|
||||
env = gym.make(
|
||||
"MyAwesomeNamespace/MyAwesomeVersionedEnv", disable_env_checker=True
|
||||
)
|
||||
env = gym.make("MyAwesomeNamespace/MyAwesomeVersionedEnv")
|
||||
assert env.spec is not None
|
||||
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"
|
||||
|
95
tests/envs/registration/test_spec.py
Normal file
95
tests/envs/registration/test_spec.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Tests that `gym.spec` works as expected."""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
def test_spec():
|
||||
spec = gym.spec("CartPole-v1")
|
||||
assert spec.id == "CartPole-v1"
|
||||
assert spec is gym.envs.registry["CartPole-v1"]
|
||||
|
||||
|
||||
def test_spec_missing_lookup():
|
||||
gym.register(id="TestEnv-v0", entry_point="no-entry-point")
|
||||
gym.register(id="TestEnv-v15", entry_point="no-entry-point")
|
||||
gym.register(id="TestEnv-v9", entry_point="no-entry-point")
|
||||
gym.register(id="OtherEnv-v100", entry_point="no-entry-point")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.DeprecatedEnv,
|
||||
match=re.escape(
|
||||
"Environment version v1 for `TestEnv` is deprecated. Please use `TestEnv-v15` instead."
|
||||
),
|
||||
):
|
||||
gym.spec("TestEnv-v1")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.UnregisteredEnv,
|
||||
match=re.escape(
|
||||
"Environment version `v1000` for environment `TestEnv` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
|
||||
),
|
||||
):
|
||||
gym.spec("TestEnv-v1000")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.UnregisteredEnv,
|
||||
match=re.escape("Environment `UnknownEnv` doesn't exist."),
|
||||
):
|
||||
gym.spec("UnknownEnv-v1")
|
||||
|
||||
del gym.registry["TestEnv-v0"]
|
||||
del gym.registry["TestEnv-v15"]
|
||||
del gym.registry["TestEnv-v9"]
|
||||
del gym.registry["OtherEnv-v100"]
|
||||
|
||||
|
||||
def test_spec_malformed_lookup():
|
||||
with pytest.raises(
|
||||
gym.error.Error,
|
||||
match=re.escape(
|
||||
"Malformed environment ID: “Breakout-v0”. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
||||
),
|
||||
):
|
||||
gym.spec("“Breakout-v0”")
|
||||
|
||||
|
||||
def test_spec_versioned_lookups():
|
||||
gym.register("test/TestEnv-v5", "no-entry-point")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.VersionNotFound,
|
||||
match=re.escape(
|
||||
"Environment version `v9` for environment `test/TestEnv` doesn't exist. It provides versioned environments: [ `v5` ]."
|
||||
),
|
||||
):
|
||||
gym.spec("test/TestEnv-v9")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.DeprecatedEnv,
|
||||
match=re.escape(
|
||||
"Environment version v4 for `test/TestEnv` is deprecated. Please use `test/TestEnv-v5` instead."
|
||||
),
|
||||
):
|
||||
gym.spec("test/TestEnv-v4")
|
||||
|
||||
assert gym.spec("test/TestEnv-v5") is not None
|
||||
del gym.registry["test/TestEnv-v5"]
|
||||
|
||||
|
||||
def test_spec_default_lookups():
|
||||
gym.register("test/TestEnv", "no-entry-point")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.DeprecatedEnv,
|
||||
match=re.escape(
|
||||
"Environment version `v0` for environment `test/TestEnv` doesn't exist. It provides the default version `test/TestEnv`."
|
||||
),
|
||||
):
|
||||
gym.spec("test/TestEnv-v0")
|
||||
|
||||
assert gym.spec("test/TestEnv") is not None
|
||||
del gym.registry["test/TestEnv"]
|
@@ -1,14 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class RegisterDuringMakeEnv(gym.Env):
|
||||
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
|
||||
|
||||
def __init__(self):
|
||||
self.action_space = gym.spaces.Discrete(1)
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
|
||||
|
||||
class ArgumentEnv(gym.Env):
|
||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
@@ -23,9 +17,12 @@ class ArgumentEnv(gym.Env):
|
||||
class NoHuman(gym.Env):
|
||||
"""Environment that does not have human-rendering."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
|
||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
|
||||
def __init__(self, render_mode=None):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 4}
|
||||
|
||||
def __init__(self, render_mode: list[str] = None):
|
||||
assert render_mode in self.metadata["render_modes"]
|
||||
self.render_mode = render_mode
|
||||
|
||||
@@ -33,6 +30,9 @@ class NoHuman(gym.Env):
|
||||
class NoHumanOldAPI(gym.Env):
|
||||
"""Environment that does not have human-rendering."""
|
||||
|
||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
|
||||
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
|
||||
|
||||
def __init__(self):
|
||||
@@ -42,6 +42,9 @@ class NoHumanOldAPI(gym.Env):
|
||||
class NoHumanNoRGB(gym.Env):
|
||||
"""Environment that has neither human- nor rgb-rendering"""
|
||||
|
||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
|
||||
metadata = {"render_modes": ["ascii"], "render_fps": 4}
|
||||
|
||||
def __init__(self, render_mode=None):
|
||||
@@ -52,6 +55,9 @@ class NoHumanNoRGB(gym.Env):
|
||||
class NoRenderModesMetadata(gym.Env):
|
||||
"""An environment that has rendering but has not updated the metadata."""
|
||||
|
||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
|
||||
|
||||
# metadata: dict[str, Any] = {"render_modes": []}
|
||||
|
||||
def __init__(self, render_mode):
|
13
tests/envs/registration/utils_unregistered_env.py
Normal file
13
tests/envs/registration/utils_unregistered_env.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""This utility file contains an environment that is registered upon loading the file."""
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class RegisterDuringMakeEnv(gym.Env):
|
||||
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
|
||||
|
||||
def __init__(self):
|
||||
self.action_space = gym.spaces.Discrete(1)
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
|
||||
|
||||
gym.register(id="RegisterDuringMake-v0", entry_point=RegisterDuringMakeEnv)
|
@@ -39,7 +39,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
|
||||
all_testing_env_specs,
|
||||
ids=[spec.id for spec in all_testing_env_specs],
|
||||
)
|
||||
def test_envs_pass_env_checker(spec):
|
||||
def test_all_env_api(spec):
|
||||
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env = spec.make(disable_env_checker=True).unwrapped
|
||||
@@ -52,6 +52,22 @@ def test_envs_pass_env_checker(spec):
|
||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
||||
)
|
||||
def test_all_env_passive_env_checker(spec):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env = gym.make(spec.id)
|
||||
env.reset()
|
||||
env.step(env.action_space.sample())
|
||||
|
||||
env.close()
|
||||
|
||||
for warning in caught_warnings:
|
||||
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
|
||||
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||
|
||||
|
||||
# Note that this precludes running this test in multiple threads.
|
||||
# However, we probably already can't do multithreading due to some environments.
|
||||
SEED = 0
|
||||
|
@@ -1,27 +0,0 @@
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
|
||||
# To ignore the trailing whitespaces, will need flake to ignore this file.
|
||||
# flake8: noqa
|
||||
|
||||
reduced_registry = {
|
||||
env_id: env_spec
|
||||
for env_id, env_spec in gym.registry.items()
|
||||
if env_spec.entry_point != "shimmy.atari_env:AtariEnv"
|
||||
}
|
||||
|
||||
|
||||
def test_pprint_custom_registry():
|
||||
"""Testing a registry different from default."""
|
||||
a = {
|
||||
"CartPole-v0": gym.envs.registry["CartPole-v0"],
|
||||
"CartPole-v1": gym.envs.registry["CartPole-v1"],
|
||||
}
|
||||
out = gym.pprint_registry(a, disable_print=True)
|
||||
|
||||
correct_out = """===== classic_control =====
|
||||
CartPole-v0 CartPole-v1
|
||||
|
||||
"""
|
||||
assert out == correct_out
|
@@ -1,93 +0,0 @@
|
||||
"""Tests that gym.spec works as expected."""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
def test_spec():
|
||||
spec = gym.spec("CartPole-v1")
|
||||
assert spec.id == "CartPole-v1"
|
||||
assert spec is gym.envs.registry["CartPole-v1"]
|
||||
|
||||
|
||||
def test_spec_kwargs():
|
||||
map_name_value = "8x8"
|
||||
env = gym.make("FrozenLake-v1", map_name=map_name_value)
|
||||
assert env.spec is not None
|
||||
assert env.spec.kwargs["map_name"] == map_name_value
|
||||
|
||||
|
||||
def test_spec_missing_lookup():
|
||||
gym.register(id="Test1-v0", entry_point="no-entry-point")
|
||||
gym.register(id="Test1-v15", entry_point="no-entry-point")
|
||||
gym.register(id="Test1-v9", entry_point="no-entry-point")
|
||||
gym.register(id="Other1-v100", entry_point="no-entry-point")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.DeprecatedEnv,
|
||||
match=re.escape(
|
||||
"Environment version v1 for `Test1` is deprecated. Please use `Test1-v15` instead."
|
||||
),
|
||||
):
|
||||
gym.spec("Test1-v1")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.UnregisteredEnv,
|
||||
match=re.escape(
|
||||
"Environment version `v1000` for environment `Test1` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
|
||||
),
|
||||
):
|
||||
gym.spec("Test1-v1000")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.UnregisteredEnv,
|
||||
match=re.escape("Environment Unknown1 doesn't exist. "),
|
||||
):
|
||||
gym.spec("Unknown1-v1")
|
||||
|
||||
|
||||
def test_spec_malformed_lookup():
|
||||
with pytest.raises(
|
||||
gym.error.Error,
|
||||
match=f'^{re.escape("Malformed environment ID: “Breakout-v0”.(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))")}$',
|
||||
):
|
||||
gym.spec("“Breakout-v0”")
|
||||
|
||||
|
||||
def test_spec_versioned_lookups():
|
||||
gym.register("test/Test2-v5", "no-entry-point")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.VersionNotFound,
|
||||
match=re.escape(
|
||||
"Environment version `v9` for environment `test/Test2` doesn't exist. It provides versioned environments: [ `v5` ]."
|
||||
),
|
||||
):
|
||||
gym.spec("test/Test2-v9")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.DeprecatedEnv,
|
||||
match=re.escape(
|
||||
"Environment version v4 for `test/Test2` is deprecated. Please use `test/Test2-v5` instead."
|
||||
),
|
||||
):
|
||||
gym.spec("test/Test2-v4")
|
||||
|
||||
assert gym.spec("test/Test2-v5") is not None
|
||||
|
||||
|
||||
def test_spec_default_lookups():
|
||||
gym.register("test/Test3", "no-entry-point")
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.DeprecatedEnv,
|
||||
match=re.escape(
|
||||
"Environment version `v0` for environment `test/Test3` doesn't exist. It provides the default version test/Test3`."
|
||||
),
|
||||
):
|
||||
gym.spec("test/Test3-v0")
|
||||
|
||||
assert gym.spec("test/Test3") is not None
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user