Reorder functions and refactor registration.py (#289)

Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
Mark Towers
2023-02-05 00:05:59 +00:00
committed by GitHub
parent 61f80d62a0
commit 4dd526d370
14 changed files with 757 additions and 625 deletions

View File

@@ -2,36 +2,39 @@
title: Registry
---
# Registry
# Register and Make
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers.
Environments can also be created through python imports.
## Make
```{eval-rst}
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`.
```
```{eval-rst}
.. autofunction:: gymnasium.make
```
## Register
```{eval-rst}
.. autofunction:: gymnasium.register
```
## All registered environments
To find all the registered Gymnasium environments, use the `gymnasium.pprint_registry()`.
This will not include environments registered only in OpenAI Gym however can be loaded by `gymnasium.make`.
## Spec
```{eval-rst}
.. autofunction:: gymnasium.spec
```
## Pretty print registry
```{eval-rst}
.. autofunction:: gymnasium.pprint_registry
```
## Core variables
```{eval-rst}
.. autoclass:: gymnasium.envs.registration.EnvSpec
.. attribute:: gymnasium.envs.registration.registry
The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments.
.. attribute:: gymnasium.envs.registration.current_namespace
The current namespace when creating or registering environments. This is by default ``None`` by with :meth:`namespace` this can be modified to automatically set the environment id namespace.
```
## Additional functions
```{eval-rst}
.. autofunction:: gymnasium.envs.registration.get_env_id
.. autofunction:: gymnasium.envs.registration.parse_env_id
.. autofunction:: gymnasium.envs.registration.find_highest_version
.. autofunction:: gymnasium.envs.registration.namespace
.. autofunction:: gymnasium.envs.registration.load_env
.. autofunction:: gymnasium.envs.registration.load_plugin_envs
```

View File

@@ -2,7 +2,7 @@
from typing import Any
from gymnasium.envs.registration import (
load_env_plugins,
load_plugin_envs,
make,
pprint_registry,
register,
@@ -363,4 +363,4 @@ register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error)
# Hook to load plugins from entry points
load_env_plugins()
load_plugin_envs()

View File

@@ -9,13 +9,11 @@ import importlib.util
import re
import sys
import traceback
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Sequence, SupportsFloat, overload
import numpy as np
from typing import Any, Iterable
from gymnasium import Env, error, logger
from gymnasium.wrappers import (
AutoResetWrapper,
HumanRendering,
@@ -32,12 +30,10 @@ if sys.version_info < (3, 10):
else:
import importlib.metadata as metadata
if sys.version_info >= (3, 8):
from typing import Literal
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing_extensions import Literal
from gymnasium import Env, error, logger
from typing import Protocol
ENV_ID_RE = re.compile(
@@ -45,50 +41,99 @@ ENV_ID_RE = re.compile(
)
def load(name: str) -> Callable:
"""Loads an environment with name and returns an environment creation function.
__all__ = [
"EnvSpec",
"registry",
"current_namespace",
"register",
"make",
"spec",
"pprint_registry",
]
Args:
name: The environment name
Returns:
Calls the environment constructor
class EnvCreator(Protocol):
"""Function type expected for an environment."""
def __call__(self, **kwargs: Any) -> Env:
...
@dataclass
class EnvSpec:
"""A specification for creating environments with :meth:`gymnasium.make`.
* **id**: The string used to create the environment with :meth:`gymnasium.make`
* **entry_point**: A string for the environment location, ``(import path):(environment name)`` or a function that creates the environment.
* **reward_threshold**: The reward threshold for completing the environment.
* **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
* **max_episode_steps**: The max number of steps that the environment can take before truncation
* **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions
* **autoreset**: If to automatically reset the environment on episode end
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
"""
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
return fn
id: str
entry_point: EnvCreator | str
# Environment attributes
reward_threshold: float | None = field(default=None)
nondeterministic: bool = field(default=False)
# Wrappers
max_episode_steps: int | None = field(default=None)
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
apply_api_compatibility: bool = field(default=False)
# post-init attributes
namespace: str | None = field(init=False)
name: str = field(init=False)
version: int | None = field(init=False)
# Environment arguments
kwargs: dict = field(default_factory=dict)
def __post_init__(self):
"""Calls after the spec is created to extract the namespace, name and version from the id."""
# Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(self.id)
def make(self, **kwargs: Any) -> Env:
"""Calls ``make`` using the environment spec and any keyword arguments."""
# For compatibility purposes
return make(self, **kwargs)
def parse_env_id(id: str) -> tuple[str | None, str, int | None]:
"""Parse environment ID string format.
# Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
This format is true today, but it's *not* an official spec.
[namespace/](env-name)-v(version) env-name is group 1, version is group 2
2016-10-31: We're experimentally expanding the environment ID format
to include an optional namespace.
def parse_env_id(env_id: str) -> tuple[str | None, str, int | None]:
"""Parse environment ID string format - ``[namespace/](env-name)[-v(version)]`` where the namespace and version are optional.
Args:
id: The environment id to parse
env_id: The environment id to parse
Returns:
A tuple of environment namespace, environment name and version number
Raises:
Error: If the environment id does not a valid environment regex
Error: If the environment id is not valid environment regex
"""
match = ENV_ID_RE.fullmatch(id)
match = ENV_ID_RE.fullmatch(env_id)
if not match:
raise error.Error(
f"Malformed environment ID: {id}."
f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
f"Malformed environment ID: {env_id}. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
)
namespace, name, version = match.group("namespace", "name", "version")
ns, name, version = match.group("namespace", "name", "version")
if version is not None:
version = int(version)
return namespace, name, version
return ns, name, version
def get_env_id(ns: str | None, name: str, version: int | None) -> str:
@@ -103,97 +148,80 @@ def get_env_id(ns: str | None, name: str, version: int | None) -> str:
The environment id
"""
full_name = name
if version is not None:
full_name += f"-v{version}"
if ns is not None:
full_name = ns + "/" + full_name
full_name = f"{ns}/{name}"
if version is not None:
full_name = f"{full_name}-v{version}"
return full_name
@dataclass
class EnvSpec:
"""A specification for creating environments with `gym.make`.
def find_highest_version(ns: str | None, name: str) -> int | None:
"""Finds the highest registered version of the environment given the namespace and name in the registry.
* id: The string used to create the environment with `gym.make`
* entry_point: The location of the environment to create from
* reward_threshold: The reward threshold for completing the environment.
* nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
* max_episode_steps: The max number of steps that the environment can take before truncation
* order_enforce: If to enforce the order of `reset` before `step` and `render` functions
* autoreset: If to automatically reset the environment on episode end
* disable_env_checker: If to disable the environment checker wrapper in `gym.make`, by default False (runs the environment checker)
* kwargs: Additional keyword arguments passed to the environments through `gym.make`
Args:
ns: The environment namespace
name: The environment name (id)
Returns:
The highest version of an environment with matching namespace and name, otherwise ``None`` is returned.
"""
id: str
entry_point: Callable | str
# Environment attributes
reward_threshold: float | None = field(default=None)
nondeterministic: bool = field(default=False)
# Wrappers
max_episode_steps: int | None = field(default=None)
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
apply_api_compatibility: bool = field(default=False)
# Environment arguments
kwargs: dict = field(default_factory=dict)
# post-init attributes
namespace: str | None = field(init=False)
name: str = field(init=False)
version: int | None = field(init=False)
def __post_init__(self):
"""Calls after the spec is created to extract the namespace, name and version from the id."""
# Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(self.id)
def make(self, **kwargs: Any) -> Env:
"""Calls ``make`` using the environment spec and any keyword arguments."""
# For compatibility purposes
return make(self, **kwargs)
version: list[int] = [
env_spec.version
for env_spec in registry.values()
if env_spec.namespace == ns
and env_spec.name == name
and env_spec.version is not None
]
return max(version, default=None)
def _check_namespace_exists(ns: str | None):
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
# If the namespace is none, then the namespace does exist
if ns is None:
return
namespaces = {
spec_.namespace for spec_ in registry.values() if spec_.namespace is not None
# Check if the namespace exists in one of the registry's specs
namespaces: set[str] = {
env_spec.namespace
for env_spec in registry.values()
if env_spec.namespace is not None
}
if ns in namespaces:
return
# Otherwise, the namespace doesn't exist and raise a helpful message
suggestion = (
difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
)
suggestion_msg = (
f"Did you mean: `{suggestion[0]}`?"
if suggestion
else f"Have you installed the proper package for {ns}?"
)
if suggestion:
suggestion_msg = f"Did you mean: `{suggestion[0]}`?"
else:
suggestion_msg = f"Have you installed the proper package for {ns}?"
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
def _check_name_exists(ns: str | None, name: str):
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
# First check if the namespace exists
_check_namespace_exists(ns)
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
# Then check if the name exists
names: set[str] = {
env_spec.name for env_spec in registry.values() if env_spec.namespace == ns
}
if name in names:
return
# Otherwise, raise a helpful error to the user
suggestion = difflib.get_close_matches(name, names, n=1)
namespace_msg = f" in namespace {ns}" if ns else ""
suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else ""
suggestion_msg = f" Did you mean: `{suggestion[0]}`?" if suggestion else ""
raise error.NameNotFound(
f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}"
f"Environment `{name}` doesn't exist{namespace_msg}.{suggestion_msg}"
)
@@ -222,26 +250,28 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
env_specs = [
spec_
for spec_ in registry.values()
if spec_.namespace == ns and spec_.name == name
env_spec
for env_spec in registry.values()
if env_spec.namespace == ns and env_spec.name == name
]
env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1))
env_specs = sorted(env_specs, key=lambda env_spec: int(env_spec.version or -1))
default_spec = [spec_ for spec_ in env_specs if spec_.version is None]
default_spec = [env_spec for env_spec in env_specs if env_spec.version is None]
if default_spec:
message += f" It provides the default version {default_spec[0].id}`."
message += f" It provides the default version `{default_spec[0].id}`."
if len(env_specs) == 1:
raise error.DeprecatedEnv(message)
# Process possible versioned environments
versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None]
versioned_specs = [
env_spec for env_spec in env_specs if env_spec.version is not None
]
latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore
latest_spec = max(versioned_specs, key=lambda env_spec: env_spec.version, default=None) # type: ignore
if latest_spec is not None and version > latest_spec.version:
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs)
version_list_msg = ", ".join(f"`v{env_spec.version}`" for env_spec in env_specs)
message += f" It provides versioned environments: [ {version_list_msg} ]."
raise error.VersionNotFound(message)
@@ -253,18 +283,80 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
)
def find_highest_version(ns: str | None, name: str) -> int | None:
"""Finds the highest registered version of the environment in the registry."""
version: list[int] = [
spec_.version
for spec_ in registry.values()
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
]
return max(version, default=None)
def _check_spec_register(testing_spec: EnvSpec):
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
latest_versioned_spec = max(
(
env_spec
for env_spec in registry.values()
if env_spec.namespace == testing_spec.namespace
and env_spec.name == testing_spec.name
and env_spec.version is not None
),
key=lambda spec_: int(spec_.version), # type: ignore
default=None,
)
unversioned_spec = next(
(
env_spec
for env_spec in registry.values()
if env_spec.namespace == testing_spec.namespace
and env_spec.name == testing_spec.name
and env_spec.version is None
),
None,
)
if unversioned_spec is not None and testing_spec.version is not None:
raise error.RegistrationError(
"Can't register the versioned environment "
f"`{testing_spec.id}` when the unversioned environment "
f"`{unversioned_spec.id}` of the same name already exists."
)
elif latest_versioned_spec is not None and testing_spec.version is None:
raise error.RegistrationError(
f"Can't register the unversioned environment `{testing_spec.id}` when the versioned environment "
f"`{latest_versioned_spec.id}` of the same name already exists. Note: the default behavior is "
"that `gym.make` with the unversioned environment will return the latest versioned environment"
)
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
"""Load modules (plugins) using the gymnasium entry points == to `entry_points`.
def _check_metadata(testing_metadata: dict[str, Any]):
"""Check the metadata of an environment."""
if not isinstance(testing_metadata, dict):
raise error.InvalidMetadata(
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
)
render_modes = testing_metadata.get("render_modes")
if render_modes is None:
logger.warn(
f"The environment creator metadata doesn't include `render_modes`, contains: {list(testing_metadata.keys())}"
)
elif not isinstance(render_modes, Iterable):
logger.warn(
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
)
def load_env(name: str) -> EnvCreator:
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
Args:
name: The environment name
Returns:
The environment constructor for the given environment name.
"""
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
return fn
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
Args:
entry_point: The string for the entry point.
@@ -282,7 +374,7 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
else:
module, attr = plugin.value, None
except Exception as e:
warnings.warn(
logger.warn(
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
)
module, attr = None, None
@@ -314,136 +406,6 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")
# fmt: off
@overload
def make(id: str, **kwargs) -> Env: ...
@overload
def make(id: EnvSpec, **kwargs) -> Env: ...
# Classic control
# ----------------------------------------
@overload
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Box2d
# ----------------------------------------
@overload
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
# Toy Text
# ----------------------------------------
@overload
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Mujoco
# ----------------------------------------
@overload
def make(id: Literal[
"Reacher-v2", "Reacher-v4",
"Pusher-v2", "Pusher-v4",
"InvertedPendulum-v2", "InvertedPendulum-v4",
"InvertedDoublePendulum-v2", "InvertedDoublePendulum-v4",
"HalfCheetah-v2", "HalfCheetah-v3", "HalfCheetah-v4",
"Hopper-v2", "Hopper-v3", "Hopper-v4",
"Swimmer-v2", "Swimmer-v3", "Swimmer-v4",
"Walker2d-v2", "Walker2d-v3", "Walker2d-v4",
"Ant-v2", "Ant-v3", "Ant-v4",
"HumanoidStandup-v2", "HumanoidStandup-v4",
"Humanoid-v2", "Humanoid-v3", "Humanoid-v4",
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
# fmt: on
# Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
def _check_spec_register(spec: EnvSpec):
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
global registry
latest_versioned_spec = max(
(
spec_
for spec_ in registry.values()
if spec_.namespace == spec.namespace
and spec_.name == spec.name
and spec_.version is not None
),
key=lambda spec_: int(spec_.version), # type: ignore
default=None,
)
unversioned_spec = next(
(
spec_
for spec_ in registry.values()
if spec_.namespace == spec.namespace
and spec_.name == spec.name
and spec_.version is None
),
None,
)
if unversioned_spec is not None and spec.version is not None:
raise error.RegistrationError(
"Can't register the versioned environment "
f"`{spec.id}` when the unversioned environment "
f"`{unversioned_spec.id}` of the same name already exists."
)
elif latest_versioned_spec is not None and spec.version is None:
raise error.RegistrationError(
"Can't register the unversioned environment "
f"`{spec.id}` when the versioned environment "
f"`{latest_versioned_spec.id}` of the same name "
f"already exists. Note: the default behavior is "
f"that `gym.make` with the unversioned environment "
f"will return the latest versioned environment"
)
def _check_metadata(metadata_: dict):
if not isinstance(metadata_, dict):
raise error.InvalidMetadata(
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
)
render_modes = metadata_.get("render_modes")
if render_modes is None:
logger.warn(
f"The environment creator metadata doesn't include `render_modes`, contains: {list(metadata_.keys())}"
)
elif not isinstance(render_modes, Iterable):
logger.warn(
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
)
# Public API
@contextlib.contextmanager
def namespace(ns: str):
"""Context manager for modifying the current namespace."""
@@ -456,7 +418,7 @@ def namespace(ns: str):
def register(
id: str,
entry_point: Callable | str,
entry_point: EnvCreator | str,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
@@ -464,26 +426,28 @@ def register(
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
**kwargs: Any,
):
"""Register an environment with gymnasium.
"""Registers an environment in gymnasium with an ``id`` to use with :meth:`gymnasium.make` with the ``entry_point`` being a string or callable for creating the environment.
The `id` parameter corresponds to the name of the environment, with the syntax as follows:
`(namespace)/(env_name)-v(version)` where `namespace` is optional.
The ``id`` parameter corresponds to the name of the environment, with the syntax as follows:
``[namespace/](env_name)[-v(version)]`` where ``namespace`` and ``-v(version)`` is optional.
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
It takes arbitrary keyword arguments, which are passed to the :class:`EnvSpec` ``kwargs`` parameter.
Args:
id: The environment id
entry_point: The entry point for creating the environment
reward_threshold: The reward threshold considered to have learnt an environment
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions)
max_episode_steps: The maximum number of episodes steps before truncation. Used by the Time Limit wrapper.
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order
autoreset: If to add the autoreset wrapper such that reset does not need to be called.
disable_env_checker: If to disable the environment checker for the environment. Recommended to False.
apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper.
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
reward_threshold: The reward threshold considered for an agent to have learnt the environment
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions, the same state cannot be reached)
max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``.
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order.
If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment.
autoreset: If to add the :class:`gymnasium.wrappers.AutoResetWrapper` such that on ``(terminated or truncated) is True``, :meth:`gymnasium.Env.reset` is called.
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
Use if the environment is implemented in the gym v0.21 environment API.
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
"""
global registry, current_namespace
ns, name, version = parse_env_id(id)
@@ -502,10 +466,10 @@ def register(
else:
ns_id = ns
full_id = get_env_id(ns_id, name, version)
full_env_id = get_env_id(ns_id, name, version)
new_spec = EnvSpec(
id=full_id,
id=full_env_id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
@@ -517,6 +481,7 @@ def register(
**kwargs,
)
_check_spec_register(new_spec)
if new_spec.id in registry:
logger.warn(f"Overriding environment {new_spec.id} already in registry.")
registry[new_spec.id] = new_spec
@@ -528,36 +493,36 @@ def make(
autoreset: bool = False,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None,
**kwargs,
**kwargs: Any,
) -> Env:
"""Create an environment according to the given ID.
"""Creates an environment previously registered with :meth:`gymnasium.register` or a :class:`EnvSpec`.
To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids.
To find all available environments use ``gymnasium.envs.registry.keys()`` for all valid ids.
Args:
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``.
This is equivalent to importing the module first to register the environment followed by making the environment.
max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``.
The value is used by :class:`gymnasium.wrappers.TimeLimit`.
autoreset: Whether to automatically reset the environment after each episode (:class:`gymnasium.wrappers.AutoResetWrapper`).
apply_api_compatibility: Whether to wrap the environment with the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper that
converts the environment step from a done bool to return termination and truncation bools.
By default, the argument is None to which the environment specification `apply_api_compatibility` is used
which defaults to False. Otherwise, the value of `apply_api_compatibility` is used.
If `True`, the wrapper is applied otherwise, the wrapper is not applied.
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
(which is by default False, running the environment checker),
otherwise will run according to this parameter (`True` = not run, `False` = run)
By default, the argument is None in which the :class:`EnvSpec` ``apply_api_compatibility`` is used, otherwise this variable is used in favor.
disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the
:class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used.
kwargs: Additional arguments to pass to the environment constructor.
Returns:
An instance of the environment.
An instance of the environment with wrappers applied.
Raises:
Error: If the ``id`` doesn't exist then an error is raised
Error: If the ``id`` doesn't exist in the :attr:`registry`
"""
if isinstance(id, EnvSpec):
spec_ = id
env_spec = id
else:
module, id = (None, id) if ":" not in id else id.split(":")
# The environment name can include an unloaded module in "module:env_name" style
module, env_name = (None, id) if ":" not in id else id.split(":")
if module is not None:
try:
importlib.import_module(module)
@@ -566,9 +531,13 @@ def make(
f"{e}. Environment registration via importing a module failed. "
f"Check whether '{module}' contains env registration and can be imported."
) from e
spec_ = registry.get(id)
ns, name, version = parse_env_id(id)
# load the env spec from the registry
env_spec = registry.get(env_name)
# update env spec is not version provided, raise warning if out of date
ns, name, version = parse_env_id(env_name)
latest_version = find_highest_version(ns, name)
if (
version is not None
@@ -576,38 +545,44 @@ def make(
and latest_version > version
):
logger.warn(
f"The environment {id} is out of date. You should consider "
f"The environment {env_name} is out of date. You should consider "
f"upgrading to version `v{latest_version}`."
)
if version is None and latest_version is not None:
version = latest_version
new_env_id = get_env_id(ns, name, version)
spec_ = registry.get(new_env_id)
env_spec = registry.get(new_env_id)
logger.warn(
f"Using the latest versioned environment `{new_env_id}` "
f"instead of the unversioned environment `{id}`."
f"instead of the unversioned environment `{env_name}`."
)
if spec_ is None:
if env_spec is None:
_check_version_exists(ns, name, version)
raise error.Error(f"No registered env with id: {id}")
raise error.Error(f"No registered env with id: {env_name}")
_kwargs = spec_.kwargs.copy()
_kwargs.update(kwargs)
assert isinstance(
env_spec, EnvSpec
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}"
# Extract the spec kwargs and append the make kwargs
spec_kwargs = env_spec.kwargs.copy()
spec_kwargs.update(kwargs)
if spec_.entry_point is None:
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
elif callable(spec_.entry_point):
env_creator = spec_.entry_point
# Load the environment creator
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load(spec_.entry_point)
env_creator = load_env(env_spec.entry_point)
render_modes = None
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
mode = _kwargs.get("render_mode")
mode = spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
@@ -619,10 +594,10 @@ def make(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
_kwargs["render_mode"] = displayable_modes.pop()
spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
_kwargs["render_mode"] = mode[: -len("_list")]
spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
@@ -630,16 +605,16 @@ def make(
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility is True or (
apply_api_compatibility is None and spec_.apply_api_compatibility is True
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = _kwargs.pop("render_mode", None)
render_mode = spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**_kwargs)
env = env_creator(**spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
@@ -654,31 +629,31 @@ def make(
raise e
# Copies the environment creation specification and kwargs to add to the environment specification details
spec_ = copy.deepcopy(spec_)
spec_.kwargs = _kwargs
env.unwrapped.spec = spec_
env_spec = copy.deepcopy(env_spec)
env_spec.kwargs = spec_kwargs
env.unwrapped.spec = env_spec
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and spec_.apply_api_compatibility is True
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and spec_.disable_env_checker is False
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if spec_.order_enforce:
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif spec_.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the autoreset wrapper
if autoreset:
@@ -694,74 +669,101 @@ def make(
def spec(env_id: str) -> EnvSpec:
"""Retrieve the spec for the given environment from the global registry."""
spec_ = registry.get(env_id)
if spec_ is None:
"""Retrieve the :class:`EnvSpec` for the environment id from the :attr:`registry`.
Args:
env_id: The environment id with the expected format of ``[(namespace)/]id[-v(version)]``
Returns:
The environment spec if it exists
Raises:
Error: If the environment id doesn't exist
"""
env_spec = registry.get(env_id)
if env_spec is None:
ns, name, version = parse_env_id(env_id)
_check_version_exists(ns, name, version)
raise error.Error(f"No registered env with id: {env_id}")
else:
assert isinstance(spec_, EnvSpec)
return spec_
assert isinstance(
env_spec, EnvSpec
), f"Expected the registry for {env_id} to be an `EnvSpec`, actual type is {type(env_spec)}"
return env_spec
def pprint_registry(
_registry: dict = registry,
print_registry: dict[str, EnvSpec] = registry,
*,
num_cols: int = 3,
exclude_namespaces: list[str] | None = None,
disable_print: bool = False,
) -> str | None:
"""Pretty print the environments in the registry.
"""Pretty prints all environments in the :attr:`registry`.
Note:
All arguments are keyword only
Args:
_registry: Environment registry to be printed.
print_registry: Environment registry to be printed. By default, :attr:`registry`
num_cols: Number of columns to arrange environments in, for display.
exclude_namespaces: Exclude any namespaces from being printed.
exclude_namespaces: A list of namespaces to be excluded from printing. Helpful if only ALE environments are wanted.
disable_print: Whether to return a string of all the namespaces and environment IDs
instead of printing it to console.
or to print the string to console.
"""
# Defaultdict to store environment names according to namespace.
namespace_envs = defaultdict(lambda: [])
# Defaultdict to store environment ids according to namespace.
namespace_envs: dict[str, list[str]] = defaultdict(lambda: [])
max_justify = float("-inf")
for env in _registry.values():
namespace, _, _ = parse_env_id(env.id)
if namespace is None:
# Since namespace is currently none, use regex to obtain namespace from entrypoints.
env_entry_point = re.sub(r":\w+", "", env.entry_point)
e_ep_split = env_entry_point.split(".")
if len(e_ep_split) >= 3:
# If namespace is of the format - gymnasium.envs.mujoco.ant_v4:AntEnv
# or gymnasium.envs.mujoco:HumanoidEnv
idx = 2
namespace = e_ep_split[idx]
elif len(e_ep_split) > 1:
# If namespace is of the format - shimmy.atari_env
idx = 1
namespace = e_ep_split[idx]
else:
# If namespace cannot be found, default to env id.
namespace = env.id
namespace_envs[namespace].append(env.id)
max_justify = max(max_justify, len(env.id))
# Iterate through each namespace and print environment alphabetically.
return_str = ""
for namespace, envs in namespace_envs.items():
# Find the namespace associated with each environment spec
for env_spec in print_registry.values():
ns = env_spec.namespace
if ns is None and isinstance(env_spec.entry_point, str):
# Use regex to obtain namespace from entrypoints.
env_entry_point = re.sub(r":\w+", "", env_spec.entry_point)
split_entry_point = env_entry_point.split(".")
if len(split_entry_point) >= 3:
# If namespace is of the format:
# - gymnasium.envs.mujoco.ant_v4:AntEnv
# - gymnasium.envs.mujoco:HumanoidEnv
ns = split_entry_point[2]
elif len(split_entry_point) > 1:
# If namespace is of the format - shimmy.atari_env
ns = split_entry_point[1]
else:
# If namespace cannot be found, default to env name
ns = env_spec.name
namespace_envs[ns].append(env_spec.id)
max_justify = max(max_justify, len(env_spec.name))
# Iterate through each namespace and print environment alphabetically
output: list[str] = []
for ns, env_ids in namespace_envs.items():
# Ignore namespaces to exclude.
if exclude_namespaces is not None and namespace in exclude_namespaces:
if exclude_namespaces is not None and ns in exclude_namespaces:
continue
return_str += f"{'=' * 5} {namespace} {'=' * 5}\n" # Print namespace.
# Print the namespace
namespace_output = f"{'=' * 5} {ns} {'=' * 5}\n"
# Reference: https://stackoverflow.com/a/33464001
for count, item in enumerate(sorted(envs), 1):
return_str += (
item.ljust(max_justify) + " "
) # Print column with justification.
for count, env_id in enumerate(sorted(env_ids), 1):
# Print column with justification.
namespace_output += env_id.ljust(max_justify) + " "
# Once all rows printed, switch to new column.
if count % num_cols == 0 or count == len(envs):
return_str = return_str.rstrip(" ") + "\n"
return_str += "\n"
if count % num_cols == 0:
namespace_output = namespace_output.rstrip(" ")
if count != len(env_ids):
namespace_output += "\n"
output.append(namespace_output.rstrip(" "))
if disable_print:
return return_str
return "\n".join(output)
else:
print(return_str, end="")
print("\n".join(output))

View File

View File

@@ -1,14 +1,14 @@
"""Tests that gym.make works as expected."""
from __future__ import annotations
import re
import warnings
from copy import deepcopy
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.envs.classic_control import cartpole
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.wrappers import (
AutoResetWrapper,
HumanRendering,
@@ -16,9 +16,8 @@ from gymnasium.wrappers import (
TimeLimit,
)
from gymnasium.wrappers.env_checker import PassiveEnvChecker
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.registration.utils_envs import ArgumentEnv
from tests.envs.utils import all_testing_env_specs
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
from tests.testing_env import GenericTestEnv, old_step_func
from tests.wrappers.utils import has_wrapper
@@ -30,16 +29,11 @@ except ImportError:
@pytest.fixture(scope="function")
def register_make_testing_envs():
def register_testing_envs():
"""Registers testing envs for `gym.make`"""
gym.register(
"RegisterDuringMakeEnv-v0",
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
)
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.utils_envs:ArgumentEnv",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
@@ -48,26 +42,25 @@ def register_make_testing_envs():
gym.register(
id="test/NoHuman-v0",
entry_point="tests.envs.utils_envs:NoHuman",
entry_point="tests.envs.registration.utils_envs:NoHuman",
)
gym.register(
id="test/NoHumanOldAPI-v0",
entry_point="tests.envs.utils_envs:NoHumanOldAPI",
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
)
gym.register(
id="test/NoHumanNoRGB-v0",
entry_point="tests.envs.utils_envs:NoHumanNoRGB",
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
)
gym.register(
id="test/NoRenderModesMetadata-v0",
entry_point="tests.envs.utils_envs:NoRenderModesMetadata",
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
)
yield
del gym.envs.registration.registry["RegisterDuringMakeEnv-v0"]
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
del gym.envs.registration.registry["test/NoHuman-v0"]
@@ -76,14 +69,16 @@ def register_make_testing_envs():
def test_make():
env = gym.make("CartPole-v1", disable_env_checker=True)
"""Test basic `gym.make`."""
env = gym.make("CartPole-v1")
assert env.spec is not None
assert env.spec.id == "CartPole-v1"
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
assert isinstance(env.unwrapped, CartPoleEnv)
env.close()
def test_make_deprecated():
"""Test make with a deprecated environment (i.e., doesn't exist)."""
with warnings.catch_warnings(record=True):
with pytest.raises(
gym.error.Error,
@@ -91,21 +86,20 @@ def test_make_deprecated():
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
),
):
gym.make("Humanoid-v0", disable_env_checker=True)
gym.make("Humanoid-v0")
def test_make_max_episode_steps(register_make_testing_envs):
def test_make_max_episode_steps(register_testing_envs):
# Default, uses the spec's
env = gym.make("CartPole-v1", disable_env_checker=True)
env = gym.make("CartPole-v1")
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert (
env.spec.max_episode_steps == gym.envs.registry["CartPole-v1"].max_episode_steps
)
assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps
env.close()
# Custom max episode steps
env = gym.make("CartPole-v1", max_episode_steps=100, disable_env_checker=True)
assert gym.spec("CartPole-v1").max_episode_steps != 100
env = gym.make("CartPole-v1", max_episode_steps=100)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == 100
@@ -113,20 +107,20 @@ def test_make_max_episode_steps(register_make_testing_envs):
# Env spec has no max episode steps
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
env = gym.make(
"test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None, disable_env_checker=True
)
env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None)
assert env.spec is not None
assert env.spec.max_episode_steps is None
assert has_wrapper(env, TimeLimit) is False
env.close()
def test_gym_make_autoreset():
def test_make_autoreset():
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
env = gym.make("CartPole-v1", disable_env_checker=True)
env = gym.make("CartPole-v1")
assert has_wrapper(env, AutoResetWrapper) is False
env.close()
env = gym.make("CartPole-v1", autoreset=False, disable_env_checker=True)
env = gym.make("CartPole-v1", autoreset=False)
assert has_wrapper(env, AutoResetWrapper) is False
env.close()
@@ -135,43 +129,49 @@ def test_gym_make_autoreset():
env.close()
def test_make_disable_env_checker():
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`."""
spec = deepcopy(gym.spec("CartPole-v1"))
@pytest.mark.parametrize(
"registration_disabled, make_disabled, if_disabled",
[
[False, False, False],
[False, True, True],
[True, False, False],
[True, True, True],
[False, None, False],
[True, None, True],
],
)
def test_make_disable_env_checker(
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
):
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
# Test with spec disable env checker
spec.disable_env_checker = False
env = gym.make(spec)
assert has_wrapper(env, PassiveEnvChecker)
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
"""
gym.register(
"testing-env-v0",
lambda: GenericTestEnv(),
disable_env_checker=registration_disabled,
)
# Test when the registered EnvSpec.disable_env_checker = False
env = gym.make("testing-env-v0", disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
env.close()
# Test with overwritten spec using make disable env checker
assert spec.disable_env_checker is False
env = gym.make(spec, disable_env_checker=True)
assert has_wrapper(env, PassiveEnvChecker) is False
env.close()
# Test with spec enabled disable env checker
spec.disable_env_checker = True
env = gym.make(spec)
assert has_wrapper(env, PassiveEnvChecker) is False
env.close()
# Test with overwritten spec using make disable env checker
assert spec.disable_env_checker is True
env = gym.make(spec, disable_env_checker=False)
assert has_wrapper(env, PassiveEnvChecker)
env.close()
del gym.registry["testing-env-v0"]
def test_apply_api_compatibility():
def test_make_apply_api_compatibility():
"""Test the API compatibility wrapper."""
gym.register(
"testing-old-env",
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True,
max_episode_steps=3,
)
# Apply the environment compatibility and check it works as intended
env = gym.make("testing-old-env")
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
env.reset()
assert len(env.step(env.action_space.sample())) == 5
@@ -179,11 +179,19 @@ def test_apply_api_compatibility():
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
# Turn off the spec api compatibility
gym.spec("testing-old-env").apply_api_compatibility = False
env = gym.make("testing-old-env")
# Cannot run reset and step as will not work
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
env.reset()
with pytest.raises(
ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)")
):
env.step(env.action_space.sample())
# Apply the environment compatibility and check it works as intended
env = gym.make("testing-old-env", apply_api_compatibility=True)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
env.reset()
assert len(env.step(env.action_space.sample())) == 5
@@ -191,57 +199,63 @@ def test_apply_api_compatibility():
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
gym.envs.registry.pop("testing-old-env")
@pytest.mark.parametrize(
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
)
def test_passive_checker_wrapper_warnings(spec):
with warnings.catch_warnings(record=True) as caught_warnings:
env = gym.make(spec) # disable_env_checker=False
env.reset()
env.step(env.action_space.sample())
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
env.close()
for warning in caught_warnings:
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
del gym.registry["testing-old-env"]
def test_make_order_enforcing():
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
assert all(spec.order_enforce is True for spec in all_testing_env_specs)
env = gym.make("CartPole-v1", disable_env_checker=True)
env = gym.make("CartPole-v1")
assert has_wrapper(env, OrderEnforcing)
# We can assume that there all other specs will also have the order enforcing
env.close()
gym.register(
id="test.OrderlessArgumentEnv-v0",
entry_point="tests.envs.utils_envs:ArgumentEnv",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
order_enforce=False,
kwargs={"arg1": None, "arg2": None, "arg3": None},
)
env = gym.make("test.OrderlessArgumentEnv-v0", disable_env_checker=True)
env = gym.make("test.OrderlessArgumentEnv-v0")
assert has_wrapper(env, OrderEnforcing) is False
env.close()
# There is no `make(..., order_enforcing=...)` so we don't test that
def test_make_render_mode(register_make_testing_envs):
env = gym.make("CartPole-v1", disable_env_checker=True)
def test_make_render_mode():
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
env = gym.make("CartPole-v1", render_mode=None)
assert env.render_mode is None
env.close()
assert "rgb_array" in env.metadata["render_modes"]
env = gym.make("CartPole-v1", render_mode="rgb_array")
assert env.render_mode == "rgb_array"
env.close()
assert "no-render-mode" not in env.metadata["render_modes"]
# cartpole is special that it doesn't check the render_mode passed at initialisation
with pytest.warns(
UserWarning,
match=re.escape(
"\x1b[33mWARN: The environment is being initialised with render_mode='no-render-mode' that is not in the possible render_modes (['human', 'rgb_array']).\x1b[0m"
),
):
env = gym.make("CartPole-v1", render_mode="no-render-mode")
assert env.render_mode == "no-render-mode"
env.close()
def test_make_render_collection():
# Make sure that render_mode is applied correctly
env = gym.make(
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
)
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
assert has_wrapper(env, gym.wrappers.RenderCollection)
assert env.render_mode == "rgb_array_list"
assert env.unwrapped.render_mode == "rgb_array"
env.reset()
renders = env.render()
assert isinstance(
@@ -250,24 +264,10 @@ def test_make_render_mode(register_make_testing_envs):
assert isinstance(renders[0], np.ndarray)
env.close()
env = gym.make("CartPole-v1", render_mode=None, disable_env_checker=True)
assert env.render_mode is None
valid_render_modes = env.metadata["render_modes"]
env.close()
assert len(valid_render_modes) > 0
with warnings.catch_warnings(record=True) as caught_warnings:
env = gym.make(
"CartPole-v1", render_mode=valid_render_modes[0], disable_env_checker=True
)
assert env.render_mode == valid_render_modes[0]
env.close()
for warning in caught_warnings:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
def test_make_human_rendering(register_testing_envs):
# Make sure that native rendering is used when possible
env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True)
env = gym.make("CartPole-v1", render_mode="human")
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
assert env.render_mode == "human"
env.close()
@@ -278,10 +278,8 @@ def test_make_render_mode(register_make_testing_envs):
"You are trying to use 'human' rendering for an environment that doesn't natively support it. The HumanRendering wrapper is being applied to your environment."
),
):
# Make sure that `HumanRendering` is applied here
env = gym.make(
"test/NoHuman-v0", render_mode="human", disable_env_checker=True
) # This environment doesn't use native rendering
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
env = gym.make("test/NoHuman-v0", render_mode="human")
assert has_wrapper(env, HumanRendering)
assert env.render_mode == "human"
env.close()
@@ -292,7 +290,6 @@ def test_make_render_mode(register_make_testing_envs):
gym.make(
"test/NoHumanOldAPI-v0",
render_mode="rgb_array_list",
disable_env_checker=True,
)
# Make sure that an additional error is thrown a user tries to use the wrapper on an environment with old API
@@ -303,9 +300,7 @@ def test_make_render_mode(register_make_testing_envs):
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively."
),
):
gym.make(
"test/NoHumanOldAPI-v0", render_mode="human", disable_env_checker=True
)
gym.make("test/NoHumanOldAPI-v0", render_mode="human")
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
@@ -326,15 +321,20 @@ def test_make_render_mode(register_make_testing_envs):
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array")
def test_make_kwargs(register_make_testing_envs):
def test_make_kwargs(register_testing_envs):
env = gym.make(
"test.ArgumentEnv-v0",
arg2="override_arg2",
arg3="override_arg3",
disable_env_checker=True,
)
assert env.spec is not None
assert env.spec.id == "test.ArgumentEnv-v0"
assert env.spec.kwargs == {
"arg1": "arg1",
"arg2": "override_arg2",
"arg3": "override_arg3",
}
assert isinstance(env.unwrapped, ArgumentEnv)
assert env.arg1 == "arg1"
assert env.arg2 == "override_arg2"
@@ -342,11 +342,16 @@ def test_make_kwargs(register_make_testing_envs):
env.close()
def test_import_module_during_make(register_make_testing_envs):
def test_import_module_during_make():
# Test custom environment which is registered at make
assert "RegisterDuringMake-v0" not in gym.registry
env = gym.make(
"tests.envs.utils:RegisterDuringMakeEnv-v0",
disable_env_checker=True,
"tests.envs.registration.utils_unregistered_env:RegisterDuringMake-v0"
)
assert "RegisterDuringMake-v0" in gym.registry
from tests.envs.registration.utils_unregistered_env import RegisterDuringMakeEnv
assert isinstance(env.unwrapped, RegisterDuringMakeEnv)
env.close()
del gym.registry["RegisterDuringMake-v0"]

View File

@@ -0,0 +1,112 @@
from __future__ import annotations
import gymnasium as gym
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.registration import EnvSpec
# To ignore the trailing whitespaces, will need flake to ignore this file.
# flake8: noqa
EXAMPLE_ENTRY_POINT = "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
def test_pprint_default_registry():
out = gym.pprint_registry(disable_print=True)
assert isinstance(out, str) and len(out) > 0
def test_pprint_example_registry():
"""Testing a registry different from default."""
example_registry: dict[str, EnvSpec] = {
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
}
out = gym.pprint_registry(example_registry, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1 CartPole-v2
CartPole-v3"""
assert out == correct_out
def test_pprint_namespace():
example_registry: dict[str, EnvSpec] = {
"CartPole-v0": EnvSpec(
"CartPole-v0", "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
),
"CartPole-v1": EnvSpec(
"CartPole-v1", "gymnasium.envs.classic_control:CartPoleEnv"
),
"CartPole-v2": EnvSpec("CartPole-v2", "gymnasium.cartpole:CartPoleEnv"),
"CartPole-v3": EnvSpec("CartPole-v3", lambda: CartPoleEnv()),
"ExampleNamespace/CartPole-v2": EnvSpec(
"ExampleNamespace/CartPole-v2", "gymnasium.envs.classic_control:CartPoleEnv"
),
}
out = gym.pprint_registry(example_registry, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1
===== cartpole =====
CartPole-v2
===== None =====
CartPole-v3
===== ExampleNamespace =====
ExampleNamespace/CartPole-v2"""
assert out == correct_out
def test_pprint_n_columns():
example_registry = {
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
}
out = gym.pprint_registry(example_registry, num_cols=2, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1
CartPole-v2 CartPole-v3"""
assert out == correct_out
out = gym.pprint_registry(example_registry, num_cols=5, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1 CartPole-v2 CartPole-v3"""
assert out == correct_out
def test_pprint_exclude_namespace():
example_registry: dict[str, EnvSpec] = {
"Test/CartPole-v0": EnvSpec("Test/CartPole-v0", EXAMPLE_ENTRY_POINT),
"Test/CartPole-v1": EnvSpec("Test/CartPole-v1", EXAMPLE_ENTRY_POINT),
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
}
out = gym.pprint_registry(
example_registry, exclude_namespaces=["Test"], disable_print=True
)
correct_out = """===== classic_control =====
CartPole-v2 CartPole-v3"""
assert out == correct_out
out = gym.pprint_registry(
example_registry, exclude_namespaces=["classic_control"], disable_print=True
)
correct_out = """===== Test =====
Test/CartPole-v0 Test/CartPole-v1"""
assert out == correct_out
example_registry["Example/CartPole-v4"] = EnvSpec(
"Example/CartPole-v4", EXAMPLE_ENTRY_POINT
)
out = gym.pprint_registry(
example_registry, exclude_namespaces=["Test", "Example"], disable_print=True
)
correct_out = """===== classic_control =====
CartPole-v2 CartPole-v3"""
assert out == correct_out

View File

@@ -18,7 +18,7 @@ def register_registration_testing_envs():
env_id = f"{namespace}/{versioned_name}-v{version}"
gym.register(
id=env_id,
entry_point="tests.envs.utils_envs:ArgumentEnv",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
@@ -111,7 +111,7 @@ def test_env_suggestions(
with pytest.raises(
gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
):
gym.make(env_id_input, disable_env_checker=True)
gym.make(env_id_input)
@pytest.mark.parametrize(
@@ -136,13 +136,13 @@ def test_env_version_suggestions(
gym.error.DeprecatedEnv,
match="It provides the default version", # env name,
):
gym.make(env_id_input, disable_env_checker=True)
gym.make(env_id_input)
else:
with pytest.raises(
gym.error.UnregisteredEnv,
match=f"It provides versioned environments: \\[ {suggested_versions} \\]",
):
gym.make(env_id_input, disable_env_checker=True)
gym.make(env_id_input)
def test_register_versioned_unversioned():
@@ -185,9 +185,7 @@ def test_make_latest_versioned_env(register_registration_testing_envs):
"Using the latest versioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv-v5` instead of the unversioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv`."
),
):
env = gym.make(
"MyAwesomeNamespace/MyAwesomeVersionedEnv", disable_env_checker=True
)
env = gym.make("MyAwesomeNamespace/MyAwesomeVersionedEnv")
assert env.spec is not None
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"

View File

@@ -0,0 +1,95 @@
"""Tests that `gym.spec` works as expected."""
import re
import pytest
import gymnasium as gym
def test_spec():
spec = gym.spec("CartPole-v1")
assert spec.id == "CartPole-v1"
assert spec is gym.envs.registry["CartPole-v1"]
def test_spec_missing_lookup():
gym.register(id="TestEnv-v0", entry_point="no-entry-point")
gym.register(id="TestEnv-v15", entry_point="no-entry-point")
gym.register(id="TestEnv-v9", entry_point="no-entry-point")
gym.register(id="OtherEnv-v100", entry_point="no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v1 for `TestEnv` is deprecated. Please use `TestEnv-v15` instead."
),
):
gym.spec("TestEnv-v1")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape(
"Environment version `v1000` for environment `TestEnv` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
),
):
gym.spec("TestEnv-v1000")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape("Environment `UnknownEnv` doesn't exist."),
):
gym.spec("UnknownEnv-v1")
del gym.registry["TestEnv-v0"]
del gym.registry["TestEnv-v15"]
del gym.registry["TestEnv-v9"]
del gym.registry["OtherEnv-v100"]
def test_spec_malformed_lookup():
with pytest.raises(
gym.error.Error,
match=re.escape(
"Malformed environment ID: “Breakout-v0”. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
),
):
gym.spec("“Breakout-v0”")
def test_spec_versioned_lookups():
gym.register("test/TestEnv-v5", "no-entry-point")
with pytest.raises(
gym.error.VersionNotFound,
match=re.escape(
"Environment version `v9` for environment `test/TestEnv` doesn't exist. It provides versioned environments: [ `v5` ]."
),
):
gym.spec("test/TestEnv-v9")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v4 for `test/TestEnv` is deprecated. Please use `test/TestEnv-v5` instead."
),
):
gym.spec("test/TestEnv-v4")
assert gym.spec("test/TestEnv-v5") is not None
del gym.registry["test/TestEnv-v5"]
def test_spec_default_lookups():
gym.register("test/TestEnv", "no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version `v0` for environment `test/TestEnv` doesn't exist. It provides the default version `test/TestEnv`."
),
):
gym.spec("test/TestEnv-v0")
assert gym.spec("test/TestEnv") is not None
del gym.registry["test/TestEnv"]

View File

@@ -1,14 +1,8 @@
from __future__ import annotations
import gymnasium as gym
class RegisterDuringMakeEnv(gym.Env):
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
def __init__(self):
self.action_space = gym.spaces.Discrete(1)
self.observation_space = gym.spaces.Discrete(1)
class ArgumentEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
@@ -23,9 +17,12 @@ class ArgumentEnv(gym.Env):
class NoHuman(gym.Env):
"""Environment that does not have human-rendering."""
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
def __init__(self, render_mode=None):
metadata = {"render_modes": ["rgb_array"], "render_fps": 4}
def __init__(self, render_mode: list[str] = None):
assert render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
@@ -33,6 +30,9 @@ class NoHuman(gym.Env):
class NoHumanOldAPI(gym.Env):
"""Environment that does not have human-rendering."""
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
def __init__(self):
@@ -42,6 +42,9 @@ class NoHumanOldAPI(gym.Env):
class NoHumanNoRGB(gym.Env):
"""Environment that has neither human- nor rgb-rendering"""
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
metadata = {"render_modes": ["ascii"], "render_fps": 4}
def __init__(self, render_mode=None):
@@ -52,6 +55,9 @@ class NoHumanNoRGB(gym.Env):
class NoRenderModesMetadata(gym.Env):
"""An environment that has rendering but has not updated the metadata."""
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
# metadata: dict[str, Any] = {"render_modes": []}
def __init__(self, render_mode):

View File

@@ -0,0 +1,13 @@
"""This utility file contains an environment that is registered upon loading the file."""
import gymnasium as gym
class RegisterDuringMakeEnv(gym.Env):
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
def __init__(self):
self.action_space = gym.spaces.Discrete(1)
self.observation_space = gym.spaces.Discrete(1)
gym.register(id="RegisterDuringMake-v0", entry_point=RegisterDuringMakeEnv)

View File

@@ -39,7 +39,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
all_testing_env_specs,
ids=[spec.id for spec in all_testing_env_specs],
)
def test_envs_pass_env_checker(spec):
def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make(disable_env_checker=True).unwrapped
@@ -52,6 +52,22 @@ def test_envs_pass_env_checker(spec):
raise gym.error.Error(f"Unexpected warning: {warning.message}")
@pytest.mark.parametrize(
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
)
def test_all_env_passive_env_checker(spec):
with warnings.catch_warnings(record=True) as caught_warnings:
env = gym.make(spec.id)
env.reset()
env.step(env.action_space.sample())
env.close()
for warning in caught_warnings:
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
# Note that this precludes running this test in multiple threads.
# However, we probably already can't do multithreading due to some environments.
SEED = 0

View File

@@ -1,27 +0,0 @@
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
# To ignore the trailing whitespaces, will need flake to ignore this file.
# flake8: noqa
reduced_registry = {
env_id: env_spec
for env_id, env_spec in gym.registry.items()
if env_spec.entry_point != "shimmy.atari_env:AtariEnv"
}
def test_pprint_custom_registry():
"""Testing a registry different from default."""
a = {
"CartPole-v0": gym.envs.registry["CartPole-v0"],
"CartPole-v1": gym.envs.registry["CartPole-v1"],
}
out = gym.pprint_registry(a, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1
"""
assert out == correct_out

View File

@@ -1,93 +0,0 @@
"""Tests that gym.spec works as expected."""
import re
import pytest
import gymnasium as gym
def test_spec():
spec = gym.spec("CartPole-v1")
assert spec.id == "CartPole-v1"
assert spec is gym.envs.registry["CartPole-v1"]
def test_spec_kwargs():
map_name_value = "8x8"
env = gym.make("FrozenLake-v1", map_name=map_name_value)
assert env.spec is not None
assert env.spec.kwargs["map_name"] == map_name_value
def test_spec_missing_lookup():
gym.register(id="Test1-v0", entry_point="no-entry-point")
gym.register(id="Test1-v15", entry_point="no-entry-point")
gym.register(id="Test1-v9", entry_point="no-entry-point")
gym.register(id="Other1-v100", entry_point="no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v1 for `Test1` is deprecated. Please use `Test1-v15` instead."
),
):
gym.spec("Test1-v1")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape(
"Environment version `v1000` for environment `Test1` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
),
):
gym.spec("Test1-v1000")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape("Environment Unknown1 doesn't exist. "),
):
gym.spec("Unknown1-v1")
def test_spec_malformed_lookup():
with pytest.raises(
gym.error.Error,
match=f'^{re.escape("Malformed environment ID: “Breakout-v0”.(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))")}$',
):
gym.spec("“Breakout-v0”")
def test_spec_versioned_lookups():
gym.register("test/Test2-v5", "no-entry-point")
with pytest.raises(
gym.error.VersionNotFound,
match=re.escape(
"Environment version `v9` for environment `test/Test2` doesn't exist. It provides versioned environments: [ `v5` ]."
),
):
gym.spec("test/Test2-v9")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v4 for `test/Test2` is deprecated. Please use `test/Test2-v5` instead."
),
):
gym.spec("test/Test2-v4")
assert gym.spec("test/Test2-v5") is not None
def test_spec_default_lookups():
gym.register("test/Test3", "no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version `v0` for environment `test/Test3` doesn't exist. It provides the default version test/Test3`."
),
):
gym.spec("test/Test3-v0")
assert gym.spec("test/Test3") is not None

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import gymnasium as gym