Reorder functions and refactor registration.py (#289)

Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
Mark Towers
2023-02-05 00:05:59 +00:00
committed by GitHub
parent 61f80d62a0
commit 4dd526d370
14 changed files with 757 additions and 625 deletions

View File

@@ -2,36 +2,39 @@
title: Registry title: Registry
--- ---
# Registry # Register and Make
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers. ```{eval-rst}
Environments can also be created through python imports. Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`.
```
## Make
```{eval-rst} ```{eval-rst}
.. autofunction:: gymnasium.make .. autofunction:: gymnasium.make
```
## Register
```{eval-rst}
.. autofunction:: gymnasium.register .. autofunction:: gymnasium.register
```
## All registered environments
To find all the registered Gymnasium environments, use the `gymnasium.pprint_registry()`.
This will not include environments registered only in OpenAI Gym however can be loaded by `gymnasium.make`.
## Spec
```{eval-rst}
.. autofunction:: gymnasium.spec .. autofunction:: gymnasium.spec
```
## Pretty print registry
```{eval-rst}
.. autofunction:: gymnasium.pprint_registry .. autofunction:: gymnasium.pprint_registry
``` ```
## Core variables
```{eval-rst}
.. autoclass:: gymnasium.envs.registration.EnvSpec
.. attribute:: gymnasium.envs.registration.registry
The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments.
.. attribute:: gymnasium.envs.registration.current_namespace
The current namespace when creating or registering environments. This is by default ``None`` by with :meth:`namespace` this can be modified to automatically set the environment id namespace.
```
## Additional functions
```{eval-rst}
.. autofunction:: gymnasium.envs.registration.get_env_id
.. autofunction:: gymnasium.envs.registration.parse_env_id
.. autofunction:: gymnasium.envs.registration.find_highest_version
.. autofunction:: gymnasium.envs.registration.namespace
.. autofunction:: gymnasium.envs.registration.load_env
.. autofunction:: gymnasium.envs.registration.load_plugin_envs
```

View File

@@ -2,7 +2,7 @@
from typing import Any from typing import Any
from gymnasium.envs.registration import ( from gymnasium.envs.registration import (
load_env_plugins, load_plugin_envs,
make, make,
pprint_registry, pprint_registry,
register, register,
@@ -363,4 +363,4 @@ register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error)
# Hook to load plugins from entry points # Hook to load plugins from entry points
load_env_plugins() load_plugin_envs()

View File

@@ -9,13 +9,11 @@ import importlib.util
import re import re
import sys import sys
import traceback import traceback
import warnings
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Sequence, SupportsFloat, overload from typing import Any, Iterable
import numpy as np
from gymnasium import Env, error, logger
from gymnasium.wrappers import ( from gymnasium.wrappers import (
AutoResetWrapper, AutoResetWrapper,
HumanRendering, HumanRendering,
@@ -32,12 +30,10 @@ if sys.version_info < (3, 10):
else: else:
import importlib.metadata as metadata import importlib.metadata as metadata
if sys.version_info >= (3, 8): if sys.version_info < (3, 8):
from typing import Literal from typing_extensions import Protocol
else: else:
from typing_extensions import Literal from typing import Protocol
from gymnasium import Env, error, logger
ENV_ID_RE = re.compile( ENV_ID_RE = re.compile(
@@ -45,50 +41,99 @@ ENV_ID_RE = re.compile(
) )
def load(name: str) -> Callable: __all__ = [
"""Loads an environment with name and returns an environment creation function. "EnvSpec",
"registry",
"current_namespace",
"register",
"make",
"spec",
"pprint_registry",
]
Args:
name: The environment name
Returns: class EnvCreator(Protocol):
Calls the environment constructor """Function type expected for an environment."""
def __call__(self, **kwargs: Any) -> Env:
...
@dataclass
class EnvSpec:
"""A specification for creating environments with :meth:`gymnasium.make`.
* **id**: The string used to create the environment with :meth:`gymnasium.make`
* **entry_point**: A string for the environment location, ``(import path):(environment name)`` or a function that creates the environment.
* **reward_threshold**: The reward threshold for completing the environment.
* **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
* **max_episode_steps**: The max number of steps that the environment can take before truncation
* **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions
* **autoreset**: If to automatically reset the environment on episode end
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
""" """
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name) id: str
fn = getattr(mod, attr_name) entry_point: EnvCreator | str
return fn
# Environment attributes
reward_threshold: float | None = field(default=None)
nondeterministic: bool = field(default=False)
# Wrappers
max_episode_steps: int | None = field(default=None)
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
apply_api_compatibility: bool = field(default=False)
# post-init attributes
namespace: str | None = field(init=False)
name: str = field(init=False)
version: int | None = field(init=False)
# Environment arguments
kwargs: dict = field(default_factory=dict)
def __post_init__(self):
"""Calls after the spec is created to extract the namespace, name and version from the id."""
# Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(self.id)
def make(self, **kwargs: Any) -> Env:
"""Calls ``make`` using the environment spec and any keyword arguments."""
# For compatibility purposes
return make(self, **kwargs)
def parse_env_id(id: str) -> tuple[str | None, str, int | None]: # Global registry of environments. Meant to be accessed through `register` and `make`
"""Parse environment ID string format. registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
This format is true today, but it's *not* an official spec.
[namespace/](env-name)-v(version) env-name is group 1, version is group 2
2016-10-31: We're experimentally expanding the environment ID format def parse_env_id(env_id: str) -> tuple[str | None, str, int | None]:
to include an optional namespace. """Parse environment ID string format - ``[namespace/](env-name)[-v(version)]`` where the namespace and version are optional.
Args: Args:
id: The environment id to parse env_id: The environment id to parse
Returns: Returns:
A tuple of environment namespace, environment name and version number A tuple of environment namespace, environment name and version number
Raises: Raises:
Error: If the environment id does not a valid environment regex Error: If the environment id is not valid environment regex
""" """
match = ENV_ID_RE.fullmatch(id) match = ENV_ID_RE.fullmatch(env_id)
if not match: if not match:
raise error.Error( raise error.Error(
f"Malformed environment ID: {id}." f"Malformed environment ID: {env_id}. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
) )
namespace, name, version = match.group("namespace", "name", "version") ns, name, version = match.group("namespace", "name", "version")
if version is not None: if version is not None:
version = int(version) version = int(version)
return namespace, name, version return ns, name, version
def get_env_id(ns: str | None, name: str, version: int | None) -> str: def get_env_id(ns: str | None, name: str, version: int | None) -> str:
@@ -103,97 +148,80 @@ def get_env_id(ns: str | None, name: str, version: int | None) -> str:
The environment id The environment id
""" """
full_name = name full_name = name
if version is not None:
full_name += f"-v{version}"
if ns is not None: if ns is not None:
full_name = ns + "/" + full_name full_name = f"{ns}/{name}"
if version is not None:
full_name = f"{full_name}-v{version}"
return full_name return full_name
@dataclass def find_highest_version(ns: str | None, name: str) -> int | None:
class EnvSpec: """Finds the highest registered version of the environment given the namespace and name in the registry.
"""A specification for creating environments with `gym.make`.
* id: The string used to create the environment with `gym.make` Args:
* entry_point: The location of the environment to create from ns: The environment namespace
* reward_threshold: The reward threshold for completing the environment. name: The environment name (id)
* nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
* max_episode_steps: The max number of steps that the environment can take before truncation Returns:
* order_enforce: If to enforce the order of `reset` before `step` and `render` functions The highest version of an environment with matching namespace and name, otherwise ``None`` is returned.
* autoreset: If to automatically reset the environment on episode end
* disable_env_checker: If to disable the environment checker wrapper in `gym.make`, by default False (runs the environment checker)
* kwargs: Additional keyword arguments passed to the environments through `gym.make`
""" """
version: list[int] = [
id: str env_spec.version
entry_point: Callable | str for env_spec in registry.values()
if env_spec.namespace == ns
# Environment attributes and env_spec.name == name
reward_threshold: float | None = field(default=None) and env_spec.version is not None
nondeterministic: bool = field(default=False) ]
return max(version, default=None)
# Wrappers
max_episode_steps: int | None = field(default=None)
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
apply_api_compatibility: bool = field(default=False)
# Environment arguments
kwargs: dict = field(default_factory=dict)
# post-init attributes
namespace: str | None = field(init=False)
name: str = field(init=False)
version: int | None = field(init=False)
def __post_init__(self):
"""Calls after the spec is created to extract the namespace, name and version from the id."""
# Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(self.id)
def make(self, **kwargs: Any) -> Env:
"""Calls ``make`` using the environment spec and any keyword arguments."""
# For compatibility purposes
return make(self, **kwargs)
def _check_namespace_exists(ns: str | None): def _check_namespace_exists(ns: str | None):
"""Check if a namespace exists. If it doesn't, print a helpful error message.""" """Check if a namespace exists. If it doesn't, print a helpful error message."""
# If the namespace is none, then the namespace does exist
if ns is None: if ns is None:
return return
namespaces = {
spec_.namespace for spec_ in registry.values() if spec_.namespace is not None # Check if the namespace exists in one of the registry's specs
namespaces: set[str] = {
env_spec.namespace
for env_spec in registry.values()
if env_spec.namespace is not None
} }
if ns in namespaces: if ns in namespaces:
return return
# Otherwise, the namespace doesn't exist and raise a helpful message
suggestion = ( suggestion = (
difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
) )
suggestion_msg = ( if suggestion:
f"Did you mean: `{suggestion[0]}`?" suggestion_msg = f"Did you mean: `{suggestion[0]}`?"
if suggestion else:
else f"Have you installed the proper package for {ns}?" suggestion_msg = f"Have you installed the proper package for {ns}?"
)
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}") raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
def _check_name_exists(ns: str | None, name: str): def _check_name_exists(ns: str | None, name: str):
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message.""" """Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
# First check if the namespace exists
_check_namespace_exists(ns) _check_namespace_exists(ns)
names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns}
# Then check if the name exists
names: set[str] = {
env_spec.name for env_spec in registry.values() if env_spec.namespace == ns
}
if name in names: if name in names:
return return
# Otherwise, raise a helpful error to the user
suggestion = difflib.get_close_matches(name, names, n=1) suggestion = difflib.get_close_matches(name, names, n=1)
namespace_msg = f" in namespace {ns}" if ns else "" namespace_msg = f" in namespace {ns}" if ns else ""
suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else "" suggestion_msg = f" Did you mean: `{suggestion[0]}`?" if suggestion else ""
raise error.NameNotFound( raise error.NameNotFound(
f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}" f"Environment `{name}` doesn't exist{namespace_msg}.{suggestion_msg}"
) )
@@ -222,26 +250,28 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist." message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
env_specs = [ env_specs = [
spec_ env_spec
for spec_ in registry.values() for env_spec in registry.values()
if spec_.namespace == ns and spec_.name == name if env_spec.namespace == ns and env_spec.name == name
] ]
env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1)) env_specs = sorted(env_specs, key=lambda env_spec: int(env_spec.version or -1))
default_spec = [spec_ for spec_ in env_specs if spec_.version is None] default_spec = [env_spec for env_spec in env_specs if env_spec.version is None]
if default_spec: if default_spec:
message += f" It provides the default version {default_spec[0].id}`." message += f" It provides the default version `{default_spec[0].id}`."
if len(env_specs) == 1: if len(env_specs) == 1:
raise error.DeprecatedEnv(message) raise error.DeprecatedEnv(message)
# Process possible versioned environments # Process possible versioned environments
versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None] versioned_specs = [
env_spec for env_spec in env_specs if env_spec.version is not None
]
latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore latest_spec = max(versioned_specs, key=lambda env_spec: env_spec.version, default=None) # type: ignore
if latest_spec is not None and version > latest_spec.version: if latest_spec is not None and version > latest_spec.version:
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs) version_list_msg = ", ".join(f"`v{env_spec.version}`" for env_spec in env_specs)
message += f" It provides versioned environments: [ {version_list_msg} ]." message += f" It provides versioned environments: [ {version_list_msg} ]."
raise error.VersionNotFound(message) raise error.VersionNotFound(message)
@@ -253,18 +283,80 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
) )
def find_highest_version(ns: str | None, name: str) -> int | None: def _check_spec_register(testing_spec: EnvSpec):
"""Finds the highest registered version of the environment in the registry.""" """Checks whether the spec is valid to be registered. Helper function for `register`."""
version: list[int] = [ latest_versioned_spec = max(
spec_.version (
for spec_ in registry.values() env_spec
if spec_.namespace == ns and spec_.name == name and spec_.version is not None for env_spec in registry.values()
] if env_spec.namespace == testing_spec.namespace
return max(version, default=None) and env_spec.name == testing_spec.name
and env_spec.version is not None
),
key=lambda spec_: int(spec_.version), # type: ignore
default=None,
)
unversioned_spec = next(
(
env_spec
for env_spec in registry.values()
if env_spec.namespace == testing_spec.namespace
and env_spec.name == testing_spec.name
and env_spec.version is None
),
None,
)
if unversioned_spec is not None and testing_spec.version is not None:
raise error.RegistrationError(
"Can't register the versioned environment "
f"`{testing_spec.id}` when the unversioned environment "
f"`{unversioned_spec.id}` of the same name already exists."
)
elif latest_versioned_spec is not None and testing_spec.version is None:
raise error.RegistrationError(
f"Can't register the unversioned environment `{testing_spec.id}` when the versioned environment "
f"`{latest_versioned_spec.id}` of the same name already exists. Note: the default behavior is "
"that `gym.make` with the unversioned environment will return the latest versioned environment"
)
def load_env_plugins(entry_point: str = "gymnasium.envs") -> None: def _check_metadata(testing_metadata: dict[str, Any]):
"""Load modules (plugins) using the gymnasium entry points == to `entry_points`. """Check the metadata of an environment."""
if not isinstance(testing_metadata, dict):
raise error.InvalidMetadata(
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
)
render_modes = testing_metadata.get("render_modes")
if render_modes is None:
logger.warn(
f"The environment creator metadata doesn't include `render_modes`, contains: {list(testing_metadata.keys())}"
)
elif not isinstance(render_modes, Iterable):
logger.warn(
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
)
def load_env(name: str) -> EnvCreator:
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
Args:
name: The environment name
Returns:
The environment constructor for the given environment name.
"""
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
return fn
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
Args: Args:
entry_point: The string for the entry point. entry_point: The string for the entry point.
@@ -282,7 +374,7 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
else: else:
module, attr = plugin.value, None module, attr = plugin.value, None
except Exception as e: except Exception as e:
warnings.warn( logger.warn(
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}" f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
) )
module, attr = None, None module, attr = None, None
@@ -314,136 +406,6 @@ def load_env_plugins(entry_point: str = "gymnasium.envs") -> None:
logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}") logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")
# fmt: off
@overload
def make(id: str, **kwargs) -> Env: ...
@overload
def make(id: EnvSpec, **kwargs) -> Env: ...
# Classic control
# ----------------------------------------
@overload
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Box2d
# ----------------------------------------
@overload
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
# Toy Text
# ----------------------------------------
@overload
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Mujoco
# ----------------------------------------
@overload
def make(id: Literal[
"Reacher-v2", "Reacher-v4",
"Pusher-v2", "Pusher-v4",
"InvertedPendulum-v2", "InvertedPendulum-v4",
"InvertedDoublePendulum-v2", "InvertedDoublePendulum-v4",
"HalfCheetah-v2", "HalfCheetah-v3", "HalfCheetah-v4",
"Hopper-v2", "Hopper-v3", "Hopper-v4",
"Swimmer-v2", "Swimmer-v3", "Swimmer-v4",
"Walker2d-v2", "Walker2d-v3", "Walker2d-v4",
"Ant-v2", "Ant-v3", "Ant-v4",
"HumanoidStandup-v2", "HumanoidStandup-v4",
"Humanoid-v2", "Humanoid-v3", "Humanoid-v4",
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
# fmt: on
# Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {}
current_namespace: str | None = None
def _check_spec_register(spec: EnvSpec):
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
global registry
latest_versioned_spec = max(
(
spec_
for spec_ in registry.values()
if spec_.namespace == spec.namespace
and spec_.name == spec.name
and spec_.version is not None
),
key=lambda spec_: int(spec_.version), # type: ignore
default=None,
)
unversioned_spec = next(
(
spec_
for spec_ in registry.values()
if spec_.namespace == spec.namespace
and spec_.name == spec.name
and spec_.version is None
),
None,
)
if unversioned_spec is not None and spec.version is not None:
raise error.RegistrationError(
"Can't register the versioned environment "
f"`{spec.id}` when the unversioned environment "
f"`{unversioned_spec.id}` of the same name already exists."
)
elif latest_versioned_spec is not None and spec.version is None:
raise error.RegistrationError(
"Can't register the unversioned environment "
f"`{spec.id}` when the versioned environment "
f"`{latest_versioned_spec.id}` of the same name "
f"already exists. Note: the default behavior is "
f"that `gym.make` with the unversioned environment "
f"will return the latest versioned environment"
)
def _check_metadata(metadata_: dict):
if not isinstance(metadata_, dict):
raise error.InvalidMetadata(
f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
)
render_modes = metadata_.get("render_modes")
if render_modes is None:
logger.warn(
f"The environment creator metadata doesn't include `render_modes`, contains: {list(metadata_.keys())}"
)
elif not isinstance(render_modes, Iterable):
logger.warn(
f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
)
# Public API
@contextlib.contextmanager @contextlib.contextmanager
def namespace(ns: str): def namespace(ns: str):
"""Context manager for modifying the current namespace.""" """Context manager for modifying the current namespace."""
@@ -456,7 +418,7 @@ def namespace(ns: str):
def register( def register(
id: str, id: str,
entry_point: Callable | str, entry_point: EnvCreator | str,
reward_threshold: float | None = None, reward_threshold: float | None = None,
nondeterministic: bool = False, nondeterministic: bool = False,
max_episode_steps: int | None = None, max_episode_steps: int | None = None,
@@ -464,26 +426,28 @@ def register(
autoreset: bool = False, autoreset: bool = False,
disable_env_checker: bool = False, disable_env_checker: bool = False,
apply_api_compatibility: bool = False, apply_api_compatibility: bool = False,
**kwargs, **kwargs: Any,
): ):
"""Register an environment with gymnasium. """Registers an environment in gymnasium with an ``id`` to use with :meth:`gymnasium.make` with the ``entry_point`` being a string or callable for creating the environment.
The `id` parameter corresponds to the name of the environment, with the syntax as follows: The ``id`` parameter corresponds to the name of the environment, with the syntax as follows:
`(namespace)/(env_name)-v(version)` where `namespace` is optional. ``[namespace/](env_name)[-v(version)]`` where ``namespace`` and ``-v(version)`` is optional.
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor. It takes arbitrary keyword arguments, which are passed to the :class:`EnvSpec` ``kwargs`` parameter.
Args: Args:
id: The environment id id: The environment id
entry_point: The entry point for creating the environment entry_point: The entry point for creating the environment
reward_threshold: The reward threshold considered to have learnt an environment reward_threshold: The reward threshold considered for an agent to have learnt the environment
nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions) nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions, the same state cannot be reached)
max_episode_steps: The maximum number of episodes steps before truncation. Used by the Time Limit wrapper. max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``.
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order.
autoreset: If to add the autoreset wrapper such that reset does not need to be called. If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment.
disable_env_checker: If to disable the environment checker for the environment. Recommended to False. autoreset: If to add the :class:`gymnasium.wrappers.AutoResetWrapper` such that on ``(terminated or truncated) is True``, :meth:`gymnasium.Env.reset` is called.
apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper. disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
**kwargs: arbitrary keyword arguments which are passed to the environment constructor apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
Use if the environment is implemented in the gym v0.21 environment API.
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
""" """
global registry, current_namespace global registry, current_namespace
ns, name, version = parse_env_id(id) ns, name, version = parse_env_id(id)
@@ -502,10 +466,10 @@ def register(
else: else:
ns_id = ns ns_id = ns
full_id = get_env_id(ns_id, name, version) full_env_id = get_env_id(ns_id, name, version)
new_spec = EnvSpec( new_spec = EnvSpec(
id=full_id, id=full_env_id,
entry_point=entry_point, entry_point=entry_point,
reward_threshold=reward_threshold, reward_threshold=reward_threshold,
nondeterministic=nondeterministic, nondeterministic=nondeterministic,
@@ -517,6 +481,7 @@ def register(
**kwargs, **kwargs,
) )
_check_spec_register(new_spec) _check_spec_register(new_spec)
if new_spec.id in registry: if new_spec.id in registry:
logger.warn(f"Overriding environment {new_spec.id} already in registry.") logger.warn(f"Overriding environment {new_spec.id} already in registry.")
registry[new_spec.id] = new_spec registry[new_spec.id] = new_spec
@@ -528,36 +493,36 @@ def make(
autoreset: bool = False, autoreset: bool = False,
apply_api_compatibility: bool | None = None, apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None, disable_env_checker: bool | None = None,
**kwargs, **kwargs: Any,
) -> Env: ) -> Env:
"""Create an environment according to the given ID. """Creates an environment previously registered with :meth:`gymnasium.register` or a :class:`EnvSpec`.
To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids. To find all available environments use ``gymnasium.envs.registry.keys()`` for all valid ids.
Args: Args:
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``.
max_episode_steps: Maximum length of an episode (TimeLimit wrapper). This is equivalent to importing the module first to register the environment followed by making the environment.
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``.
apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that The value is used by :class:`gymnasium.wrappers.TimeLimit`.
autoreset: Whether to automatically reset the environment after each episode (:class:`gymnasium.wrappers.AutoResetWrapper`).
apply_api_compatibility: Whether to wrap the environment with the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper that
converts the environment step from a done bool to return termination and truncation bools. converts the environment step from a done bool to return termination and truncation bools.
By default, the argument is None to which the environment specification `apply_api_compatibility` is used By default, the argument is None in which the :class:`EnvSpec` ``apply_api_compatibility`` is used, otherwise this variable is used in favor.
which defaults to False. Otherwise, the value of `apply_api_compatibility` is used. disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the
If `True`, the wrapper is applied otherwise, the wrapper is not applied. :class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used.
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
(which is by default False, running the environment checker),
otherwise will run according to this parameter (`True` = not run, `False` = run)
kwargs: Additional arguments to pass to the environment constructor. kwargs: Additional arguments to pass to the environment constructor.
Returns: Returns:
An instance of the environment. An instance of the environment with wrappers applied.
Raises: Raises:
Error: If the ``id`` doesn't exist then an error is raised Error: If the ``id`` doesn't exist in the :attr:`registry`
""" """
if isinstance(id, EnvSpec): if isinstance(id, EnvSpec):
spec_ = id env_spec = id
else: else:
module, id = (None, id) if ":" not in id else id.split(":") # The environment name can include an unloaded module in "module:env_name" style
module, env_name = (None, id) if ":" not in id else id.split(":")
if module is not None: if module is not None:
try: try:
importlib.import_module(module) importlib.import_module(module)
@@ -566,9 +531,13 @@ def make(
f"{e}. Environment registration via importing a module failed. " f"{e}. Environment registration via importing a module failed. "
f"Check whether '{module}' contains env registration and can be imported." f"Check whether '{module}' contains env registration and can be imported."
) from e ) from e
spec_ = registry.get(id)
ns, name, version = parse_env_id(id) # load the env spec from the registry
env_spec = registry.get(env_name)
# update env spec is not version provided, raise warning if out of date
ns, name, version = parse_env_id(env_name)
latest_version = find_highest_version(ns, name) latest_version = find_highest_version(ns, name)
if ( if (
version is not None version is not None
@@ -576,38 +545,44 @@ def make(
and latest_version > version and latest_version > version
): ):
logger.warn( logger.warn(
f"The environment {id} is out of date. You should consider " f"The environment {env_name} is out of date. You should consider "
f"upgrading to version `v{latest_version}`." f"upgrading to version `v{latest_version}`."
) )
if version is None and latest_version is not None: if version is None and latest_version is not None:
version = latest_version version = latest_version
new_env_id = get_env_id(ns, name, version) new_env_id = get_env_id(ns, name, version)
spec_ = registry.get(new_env_id) env_spec = registry.get(new_env_id)
logger.warn( logger.warn(
f"Using the latest versioned environment `{new_env_id}` " f"Using the latest versioned environment `{new_env_id}` "
f"instead of the unversioned environment `{id}`." f"instead of the unversioned environment `{env_name}`."
) )
if spec_ is None: if env_spec is None:
_check_version_exists(ns, name, version) _check_version_exists(ns, name, version)
raise error.Error(f"No registered env with id: {id}") raise error.Error(f"No registered env with id: {env_name}")
_kwargs = spec_.kwargs.copy() assert isinstance(
_kwargs.update(kwargs) env_spec, EnvSpec
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}"
# Extract the spec kwargs and append the make kwargs
spec_kwargs = env_spec.kwargs.copy()
spec_kwargs.update(kwargs)
if spec_.entry_point is None: # Load the environment creator
raise error.Error(f"{spec_.id} registered but entry_point is not specified") if env_spec.entry_point is None:
elif callable(spec_.entry_point): raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
env_creator = spec_.entry_point elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else: else:
# Assume it's a string # Assume it's a string
env_creator = load(spec_.entry_point) env_creator = load_env(env_spec.entry_point)
render_modes = None # Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"): if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata) _check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes") render_modes = env_creator.metadata.get("render_modes")
mode = _kwargs.get("render_mode") mode = spec_kwargs.get("render_mode")
apply_human_rendering = False apply_human_rendering = False
apply_render_collection = False apply_render_collection = False
@@ -619,10 +594,10 @@ def make(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. " "You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment." "The HumanRendering wrapper is being applied to your environment."
) )
_kwargs["render_mode"] = displayable_modes.pop() spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes: elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
_kwargs["render_mode"] = mode[: -len("_list")] spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True apply_render_collection = True
else: else:
logger.warn( logger.warn(
@@ -630,16 +605,16 @@ def make(
f"that is not in the possible render_modes ({render_modes})." f"that is not in the possible render_modes ({render_modes})."
) )
if apply_api_compatibility is True or ( if apply_api_compatibility or (
apply_api_compatibility is None and spec_.apply_api_compatibility is True apply_api_compatibility is None and env_spec.apply_api_compatibility
): ):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator # If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = _kwargs.pop("render_mode", None) render_mode = spec_kwargs.pop("render_mode", None)
else: else:
render_mode = None render_mode = None
try: try:
env = env_creator(**_kwargs) env = env_creator(**spec_kwargs)
except TypeError as e: except TypeError as e:
if ( if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0 str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
@@ -654,31 +629,31 @@ def make(
raise e raise e
# Copies the environment creation specification and kwargs to add to the environment specification details # Copies the environment creation specification and kwargs to add to the environment specification details
spec_ = copy.deepcopy(spec_) env_spec = copy.deepcopy(env_spec)
spec_.kwargs = _kwargs env_spec.kwargs = spec_kwargs
env.unwrapped.spec = spec_ env.unwrapped.spec = env_spec
# Add step API wrapper # Add step API wrapper
if apply_api_compatibility is True or ( if apply_api_compatibility is True or (
apply_api_compatibility is None and spec_.apply_api_compatibility is True apply_api_compatibility is None and env_spec.apply_api_compatibility is True
): ):
env = EnvCompatibility(env, render_mode) env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper # Run the environment checker as the lowest level wrapper
if disable_env_checker is False or ( if disable_env_checker is False or (
disable_env_checker is None and spec_.disable_env_checker is False disable_env_checker is None and env_spec.disable_env_checker is False
): ):
env = PassiveEnvChecker(env) env = PassiveEnvChecker(env)
# Add the order enforcing wrapper # Add the order enforcing wrapper
if spec_.order_enforce: if env_spec.order_enforce:
env = OrderEnforcing(env) env = OrderEnforcing(env)
# Add the time limit wrapper # Add the time limit wrapper
if max_episode_steps is not None: if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps) env = TimeLimit(env, max_episode_steps)
elif spec_.max_episode_steps is not None: elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps) env = TimeLimit(env, env_spec.max_episode_steps)
# Add the autoreset wrapper # Add the autoreset wrapper
if autoreset: if autoreset:
@@ -694,74 +669,101 @@ def make(
def spec(env_id: str) -> EnvSpec: def spec(env_id: str) -> EnvSpec:
"""Retrieve the spec for the given environment from the global registry.""" """Retrieve the :class:`EnvSpec` for the environment id from the :attr:`registry`.
spec_ = registry.get(env_id)
if spec_ is None: Args:
env_id: The environment id with the expected format of ``[(namespace)/]id[-v(version)]``
Returns:
The environment spec if it exists
Raises:
Error: If the environment id doesn't exist
"""
env_spec = registry.get(env_id)
if env_spec is None:
ns, name, version = parse_env_id(env_id) ns, name, version = parse_env_id(env_id)
_check_version_exists(ns, name, version) _check_version_exists(ns, name, version)
raise error.Error(f"No registered env with id: {env_id}") raise error.Error(f"No registered env with id: {env_id}")
else: else:
assert isinstance(spec_, EnvSpec) assert isinstance(
return spec_ env_spec, EnvSpec
), f"Expected the registry for {env_id} to be an `EnvSpec`, actual type is {type(env_spec)}"
return env_spec
def pprint_registry( def pprint_registry(
_registry: dict = registry, print_registry: dict[str, EnvSpec] = registry,
*,
num_cols: int = 3, num_cols: int = 3,
exclude_namespaces: list[str] | None = None, exclude_namespaces: list[str] | None = None,
disable_print: bool = False, disable_print: bool = False,
) -> str | None: ) -> str | None:
"""Pretty print the environments in the registry. """Pretty prints all environments in the :attr:`registry`.
Note:
All arguments are keyword only
Args: Args:
_registry: Environment registry to be printed. print_registry: Environment registry to be printed. By default, :attr:`registry`
num_cols: Number of columns to arrange environments in, for display. num_cols: Number of columns to arrange environments in, for display.
exclude_namespaces: Exclude any namespaces from being printed. exclude_namespaces: A list of namespaces to be excluded from printing. Helpful if only ALE environments are wanted.
disable_print: Whether to return a string of all the namespaces and environment IDs disable_print: Whether to return a string of all the namespaces and environment IDs
instead of printing it to console. or to print the string to console.
""" """
# Defaultdict to store environment names according to namespace. # Defaultdict to store environment ids according to namespace.
namespace_envs = defaultdict(lambda: []) namespace_envs: dict[str, list[str]] = defaultdict(lambda: [])
max_justify = float("-inf") max_justify = float("-inf")
for env in _registry.values():
namespace, _, _ = parse_env_id(env.id)
if namespace is None:
# Since namespace is currently none, use regex to obtain namespace from entrypoints.
env_entry_point = re.sub(r":\w+", "", env.entry_point)
e_ep_split = env_entry_point.split(".")
if len(e_ep_split) >= 3:
# If namespace is of the format - gymnasium.envs.mujoco.ant_v4:AntEnv
# or gymnasium.envs.mujoco:HumanoidEnv
idx = 2
namespace = e_ep_split[idx]
elif len(e_ep_split) > 1:
# If namespace is of the format - shimmy.atari_env
idx = 1
namespace = e_ep_split[idx]
else:
# If namespace cannot be found, default to env id.
namespace = env.id
namespace_envs[namespace].append(env.id)
max_justify = max(max_justify, len(env.id))
# Iterate through each namespace and print environment alphabetically. # Find the namespace associated with each environment spec
return_str = "" for env_spec in print_registry.values():
for namespace, envs in namespace_envs.items(): ns = env_spec.namespace
if ns is None and isinstance(env_spec.entry_point, str):
# Use regex to obtain namespace from entrypoints.
env_entry_point = re.sub(r":\w+", "", env_spec.entry_point)
split_entry_point = env_entry_point.split(".")
if len(split_entry_point) >= 3:
# If namespace is of the format:
# - gymnasium.envs.mujoco.ant_v4:AntEnv
# - gymnasium.envs.mujoco:HumanoidEnv
ns = split_entry_point[2]
elif len(split_entry_point) > 1:
# If namespace is of the format - shimmy.atari_env
ns = split_entry_point[1]
else:
# If namespace cannot be found, default to env name
ns = env_spec.name
namespace_envs[ns].append(env_spec.id)
max_justify = max(max_justify, len(env_spec.name))
# Iterate through each namespace and print environment alphabetically
output: list[str] = []
for ns, env_ids in namespace_envs.items():
# Ignore namespaces to exclude. # Ignore namespaces to exclude.
if exclude_namespaces is not None and namespace in exclude_namespaces: if exclude_namespaces is not None and ns in exclude_namespaces:
continue continue
return_str += f"{'=' * 5} {namespace} {'=' * 5}\n" # Print namespace.
# Print the namespace
namespace_output = f"{'=' * 5} {ns} {'=' * 5}\n"
# Reference: https://stackoverflow.com/a/33464001 # Reference: https://stackoverflow.com/a/33464001
for count, item in enumerate(sorted(envs), 1): for count, env_id in enumerate(sorted(env_ids), 1):
return_str += ( # Print column with justification.
item.ljust(max_justify) + " " namespace_output += env_id.ljust(max_justify) + " "
) # Print column with justification.
# Once all rows printed, switch to new column. # Once all rows printed, switch to new column.
if count % num_cols == 0 or count == len(envs): if count % num_cols == 0:
return_str = return_str.rstrip(" ") + "\n" namespace_output = namespace_output.rstrip(" ")
return_str += "\n"
if count != len(env_ids):
namespace_output += "\n"
output.append(namespace_output.rstrip(" "))
if disable_print: if disable_print:
return return_str return "\n".join(output)
else: else:
print(return_str, end="") print("\n".join(output))

View File

View File

@@ -1,14 +1,14 @@
"""Tests that gym.make works as expected.""" """Tests that gym.make works as expected."""
from __future__ import annotations
import re import re
import warnings import warnings
from copy import deepcopy
import numpy as np import numpy as np
import pytest import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.classic_control import cartpole from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.wrappers import ( from gymnasium.wrappers import (
AutoResetWrapper, AutoResetWrapper,
HumanRendering, HumanRendering,
@@ -16,9 +16,8 @@ from gymnasium.wrappers import (
TimeLimit, TimeLimit,
) )
from gymnasium.wrappers.env_checker import PassiveEnvChecker from gymnasium.wrappers.env_checker import PassiveEnvChecker
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.registration.utils_envs import ArgumentEnv
from tests.envs.utils import all_testing_env_specs from tests.envs.utils import all_testing_env_specs
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
from tests.testing_env import GenericTestEnv, old_step_func from tests.testing_env import GenericTestEnv, old_step_func
from tests.wrappers.utils import has_wrapper from tests.wrappers.utils import has_wrapper
@@ -30,16 +29,11 @@ except ImportError:
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def register_make_testing_envs(): def register_testing_envs():
"""Registers testing envs for `gym.make`""" """Registers testing envs for `gym.make`"""
gym.register(
"RegisterDuringMakeEnv-v0",
entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv",
)
gym.register( gym.register(
id="test.ArgumentEnv-v0", id="test.ArgumentEnv-v0",
entry_point="tests.envs.utils_envs:ArgumentEnv", entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={ kwargs={
"arg1": "arg1", "arg1": "arg1",
"arg2": "arg2", "arg2": "arg2",
@@ -48,26 +42,25 @@ def register_make_testing_envs():
gym.register( gym.register(
id="test/NoHuman-v0", id="test/NoHuman-v0",
entry_point="tests.envs.utils_envs:NoHuman", entry_point="tests.envs.registration.utils_envs:NoHuman",
) )
gym.register( gym.register(
id="test/NoHumanOldAPI-v0", id="test/NoHumanOldAPI-v0",
entry_point="tests.envs.utils_envs:NoHumanOldAPI", entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
) )
gym.register( gym.register(
id="test/NoHumanNoRGB-v0", id="test/NoHumanNoRGB-v0",
entry_point="tests.envs.utils_envs:NoHumanNoRGB", entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
) )
gym.register( gym.register(
id="test/NoRenderModesMetadata-v0", id="test/NoRenderModesMetadata-v0",
entry_point="tests.envs.utils_envs:NoRenderModesMetadata", entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
) )
yield yield
del gym.envs.registration.registry["RegisterDuringMakeEnv-v0"]
del gym.envs.registration.registry["test.ArgumentEnv-v0"] del gym.envs.registration.registry["test.ArgumentEnv-v0"]
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"] del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
del gym.envs.registration.registry["test/NoHuman-v0"] del gym.envs.registration.registry["test/NoHuman-v0"]
@@ -76,14 +69,16 @@ def register_make_testing_envs():
def test_make(): def test_make():
env = gym.make("CartPole-v1", disable_env_checker=True) """Test basic `gym.make`."""
env = gym.make("CartPole-v1")
assert env.spec is not None assert env.spec is not None
assert env.spec.id == "CartPole-v1" assert env.spec.id == "CartPole-v1"
assert isinstance(env.unwrapped, cartpole.CartPoleEnv) assert isinstance(env.unwrapped, CartPoleEnv)
env.close() env.close()
def test_make_deprecated(): def test_make_deprecated():
"""Test make with a deprecated environment (i.e., doesn't exist)."""
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):
with pytest.raises( with pytest.raises(
gym.error.Error, gym.error.Error,
@@ -91,21 +86,20 @@ def test_make_deprecated():
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead." "Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
), ),
): ):
gym.make("Humanoid-v0", disable_env_checker=True) gym.make("Humanoid-v0")
def test_make_max_episode_steps(register_make_testing_envs): def test_make_max_episode_steps(register_testing_envs):
# Default, uses the spec's # Default, uses the spec's
env = gym.make("CartPole-v1", disable_env_checker=True) env = gym.make("CartPole-v1")
assert has_wrapper(env, TimeLimit) assert has_wrapper(env, TimeLimit)
assert env.spec is not None assert env.spec is not None
assert ( assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps
env.spec.max_episode_steps == gym.envs.registry["CartPole-v1"].max_episode_steps
)
env.close() env.close()
# Custom max episode steps # Custom max episode steps
env = gym.make("CartPole-v1", max_episode_steps=100, disable_env_checker=True) assert gym.spec("CartPole-v1").max_episode_steps != 100
env = gym.make("CartPole-v1", max_episode_steps=100)
assert has_wrapper(env, TimeLimit) assert has_wrapper(env, TimeLimit)
assert env.spec is not None assert env.spec is not None
assert env.spec.max_episode_steps == 100 assert env.spec.max_episode_steps == 100
@@ -113,20 +107,20 @@ def test_make_max_episode_steps(register_make_testing_envs):
# Env spec has no max episode steps # Env spec has no max episode steps
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
env = gym.make( env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None)
"test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None, disable_env_checker=True assert env.spec is not None
) assert env.spec.max_episode_steps is None
assert has_wrapper(env, TimeLimit) is False assert has_wrapper(env, TimeLimit) is False
env.close() env.close()
def test_gym_make_autoreset(): def test_make_autoreset():
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`.""" """Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
env = gym.make("CartPole-v1", disable_env_checker=True) env = gym.make("CartPole-v1")
assert has_wrapper(env, AutoResetWrapper) is False assert has_wrapper(env, AutoResetWrapper) is False
env.close() env.close()
env = gym.make("CartPole-v1", autoreset=False, disable_env_checker=True) env = gym.make("CartPole-v1", autoreset=False)
assert has_wrapper(env, AutoResetWrapper) is False assert has_wrapper(env, AutoResetWrapper) is False
env.close() env.close()
@@ -135,43 +129,49 @@ def test_gym_make_autoreset():
env.close() env.close()
def test_make_disable_env_checker(): @pytest.mark.parametrize(
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.""" "registration_disabled, make_disabled, if_disabled",
spec = deepcopy(gym.spec("CartPole-v1")) [
[False, False, False],
[False, True, True],
[True, False, False],
[True, True, True],
[False, None, False],
[True, None, True],
],
)
def test_make_disable_env_checker(
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
):
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
# Test with spec disable env checker The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
spec.disable_env_checker = False """
env = gym.make(spec) gym.register(
assert has_wrapper(env, PassiveEnvChecker) "testing-env-v0",
lambda: GenericTestEnv(),
disable_env_checker=registration_disabled,
)
# Test when the registered EnvSpec.disable_env_checker = False
env = gym.make("testing-env-v0", disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
env.close() env.close()
# Test with overwritten spec using make disable env checker del gym.registry["testing-env-v0"]
assert spec.disable_env_checker is False
env = gym.make(spec, disable_env_checker=True)
assert has_wrapper(env, PassiveEnvChecker) is False
env.close()
# Test with spec enabled disable env checker
spec.disable_env_checker = True
env = gym.make(spec)
assert has_wrapper(env, PassiveEnvChecker) is False
env.close()
# Test with overwritten spec using make disable env checker
assert spec.disable_env_checker is True
env = gym.make(spec, disable_env_checker=False)
assert has_wrapper(env, PassiveEnvChecker)
env.close()
def test_apply_api_compatibility(): def test_make_apply_api_compatibility():
"""Test the API compatibility wrapper."""
gym.register( gym.register(
"testing-old-env", "testing-old-env",
lambda: GenericTestEnv(step_func=old_step_func), lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True, apply_api_compatibility=True,
max_episode_steps=3, max_episode_steps=3,
) )
# Apply the environment compatibility and check it works as intended
env = gym.make("testing-old-env") env = gym.make("testing-old-env")
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
env.reset() env.reset()
assert len(env.step(env.action_space.sample())) == 5 assert len(env.step(env.action_space.sample())) == 5
@@ -179,11 +179,19 @@ def test_apply_api_compatibility():
_, _, termination, truncation, _ = env.step(env.action_space.sample()) _, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True assert termination is False and truncation is True
# Turn off the spec api compatibility
gym.spec("testing-old-env").apply_api_compatibility = False gym.spec("testing-old-env").apply_api_compatibility = False
env = gym.make("testing-old-env") env = gym.make("testing-old-env")
# Cannot run reset and step as will not work assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
env.reset()
with pytest.raises(
ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)")
):
env.step(env.action_space.sample())
# Apply the environment compatibility and check it works as intended
env = gym.make("testing-old-env", apply_api_compatibility=True) env = gym.make("testing-old-env", apply_api_compatibility=True)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
env.reset() env.reset()
assert len(env.step(env.action_space.sample())) == 5 assert len(env.step(env.action_space.sample())) == 5
@@ -191,57 +199,63 @@ def test_apply_api_compatibility():
_, _, termination, truncation, _ = env.step(env.action_space.sample()) _, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True assert termination is False and truncation is True
gym.envs.registry.pop("testing-old-env") del gym.registry["testing-old-env"]
@pytest.mark.parametrize(
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
)
def test_passive_checker_wrapper_warnings(spec):
with warnings.catch_warnings(record=True) as caught_warnings:
env = gym.make(spec) # disable_env_checker=False
env.reset()
env.step(env.action_space.sample())
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
env.close()
for warning in caught_warnings:
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
def test_make_order_enforcing(): def test_make_order_enforcing():
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper.""" """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
assert all(spec.order_enforce is True for spec in all_testing_env_specs) assert all(spec.order_enforce is True for spec in all_testing_env_specs)
env = gym.make("CartPole-v1", disable_env_checker=True) env = gym.make("CartPole-v1")
assert has_wrapper(env, OrderEnforcing) assert has_wrapper(env, OrderEnforcing)
# We can assume that there all other specs will also have the order enforcing # We can assume that there all other specs will also have the order enforcing
env.close() env.close()
gym.register( gym.register(
id="test.OrderlessArgumentEnv-v0", id="test.OrderlessArgumentEnv-v0",
entry_point="tests.envs.utils_envs:ArgumentEnv", entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
order_enforce=False, order_enforce=False,
kwargs={"arg1": None, "arg2": None, "arg3": None}, kwargs={"arg1": None, "arg2": None, "arg3": None},
) )
env = gym.make("test.OrderlessArgumentEnv-v0", disable_env_checker=True) env = gym.make("test.OrderlessArgumentEnv-v0")
assert has_wrapper(env, OrderEnforcing) is False assert has_wrapper(env, OrderEnforcing) is False
env.close() env.close()
# There is no `make(..., order_enforcing=...)` so we don't test that
def test_make_render_mode(register_make_testing_envs):
env = gym.make("CartPole-v1", disable_env_checker=True) def test_make_render_mode():
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
env = gym.make("CartPole-v1", render_mode=None)
assert env.render_mode is None assert env.render_mode is None
env.close() env.close()
assert "rgb_array" in env.metadata["render_modes"]
env = gym.make("CartPole-v1", render_mode="rgb_array")
assert env.render_mode == "rgb_array"
env.close()
assert "no-render-mode" not in env.metadata["render_modes"]
# cartpole is special that it doesn't check the render_mode passed at initialisation
with pytest.warns(
UserWarning,
match=re.escape(
"\x1b[33mWARN: The environment is being initialised with render_mode='no-render-mode' that is not in the possible render_modes (['human', 'rgb_array']).\x1b[0m"
),
):
env = gym.make("CartPole-v1", render_mode="no-render-mode")
assert env.render_mode == "no-render-mode"
env.close()
def test_make_render_collection():
# Make sure that render_mode is applied correctly # Make sure that render_mode is applied correctly
env = gym.make( env = gym.make("CartPole-v1", render_mode="rgb_array_list")
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True assert has_wrapper(env, gym.wrappers.RenderCollection)
)
assert env.render_mode == "rgb_array_list" assert env.render_mode == "rgb_array_list"
assert env.unwrapped.render_mode == "rgb_array"
env.reset() env.reset()
renders = env.render() renders = env.render()
assert isinstance( assert isinstance(
@@ -250,24 +264,10 @@ def test_make_render_mode(register_make_testing_envs):
assert isinstance(renders[0], np.ndarray) assert isinstance(renders[0], np.ndarray)
env.close() env.close()
env = gym.make("CartPole-v1", render_mode=None, disable_env_checker=True)
assert env.render_mode is None
valid_render_modes = env.metadata["render_modes"]
env.close()
assert len(valid_render_modes) > 0
with warnings.catch_warnings(record=True) as caught_warnings:
env = gym.make(
"CartPole-v1", render_mode=valid_render_modes[0], disable_env_checker=True
)
assert env.render_mode == valid_render_modes[0]
env.close()
for warning in caught_warnings:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
def test_make_human_rendering(register_testing_envs):
# Make sure that native rendering is used when possible # Make sure that native rendering is used when possible
env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True) env = gym.make("CartPole-v1", render_mode="human")
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
assert env.render_mode == "human" assert env.render_mode == "human"
env.close() env.close()
@@ -278,10 +278,8 @@ def test_make_render_mode(register_make_testing_envs):
"You are trying to use 'human' rendering for an environment that doesn't natively support it. The HumanRendering wrapper is being applied to your environment." "You are trying to use 'human' rendering for an environment that doesn't natively support it. The HumanRendering wrapper is being applied to your environment."
), ),
): ):
# Make sure that `HumanRendering` is applied here # Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
env = gym.make( env = gym.make("test/NoHuman-v0", render_mode="human")
"test/NoHuman-v0", render_mode="human", disable_env_checker=True
) # This environment doesn't use native rendering
assert has_wrapper(env, HumanRendering) assert has_wrapper(env, HumanRendering)
assert env.render_mode == "human" assert env.render_mode == "human"
env.close() env.close()
@@ -292,7 +290,6 @@ def test_make_render_mode(register_make_testing_envs):
gym.make( gym.make(
"test/NoHumanOldAPI-v0", "test/NoHumanOldAPI-v0",
render_mode="rgb_array_list", render_mode="rgb_array_list",
disable_env_checker=True,
) )
# Make sure that an additional error is thrown a user tries to use the wrapper on an environment with old API # Make sure that an additional error is thrown a user tries to use the wrapper on an environment with old API
@@ -303,9 +300,7 @@ def test_make_render_mode(register_make_testing_envs):
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively." "You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively."
), ),
): ):
gym.make( gym.make("test/NoHumanOldAPI-v0", render_mode="human")
"test/NoHumanOldAPI-v0", render_mode="human", disable_env_checker=True
)
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like # This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from # your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
@@ -326,15 +321,20 @@ def test_make_render_mode(register_make_testing_envs):
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array") gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array")
def test_make_kwargs(register_make_testing_envs): def test_make_kwargs(register_testing_envs):
env = gym.make( env = gym.make(
"test.ArgumentEnv-v0", "test.ArgumentEnv-v0",
arg2="override_arg2", arg2="override_arg2",
arg3="override_arg3", arg3="override_arg3",
disable_env_checker=True,
) )
assert env.spec is not None assert env.spec is not None
assert env.spec.id == "test.ArgumentEnv-v0" assert env.spec.id == "test.ArgumentEnv-v0"
assert env.spec.kwargs == {
"arg1": "arg1",
"arg2": "override_arg2",
"arg3": "override_arg3",
}
assert isinstance(env.unwrapped, ArgumentEnv) assert isinstance(env.unwrapped, ArgumentEnv)
assert env.arg1 == "arg1" assert env.arg1 == "arg1"
assert env.arg2 == "override_arg2" assert env.arg2 == "override_arg2"
@@ -342,11 +342,16 @@ def test_make_kwargs(register_make_testing_envs):
env.close() env.close()
def test_import_module_during_make(register_make_testing_envs): def test_import_module_during_make():
# Test custom environment which is registered at make # Test custom environment which is registered at make
assert "RegisterDuringMake-v0" not in gym.registry
env = gym.make( env = gym.make(
"tests.envs.utils:RegisterDuringMakeEnv-v0", "tests.envs.registration.utils_unregistered_env:RegisterDuringMake-v0"
disable_env_checker=True,
) )
assert "RegisterDuringMake-v0" in gym.registry
from tests.envs.registration.utils_unregistered_env import RegisterDuringMakeEnv
assert isinstance(env.unwrapped, RegisterDuringMakeEnv) assert isinstance(env.unwrapped, RegisterDuringMakeEnv)
env.close() env.close()
del gym.registry["RegisterDuringMake-v0"]

View File

@@ -0,0 +1,112 @@
from __future__ import annotations
import gymnasium as gym
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.registration import EnvSpec
# To ignore the trailing whitespaces, will need flake to ignore this file.
# flake8: noqa
EXAMPLE_ENTRY_POINT = "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
def test_pprint_default_registry():
out = gym.pprint_registry(disable_print=True)
assert isinstance(out, str) and len(out) > 0
def test_pprint_example_registry():
"""Testing a registry different from default."""
example_registry: dict[str, EnvSpec] = {
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
}
out = gym.pprint_registry(example_registry, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1 CartPole-v2
CartPole-v3"""
assert out == correct_out
def test_pprint_namespace():
example_registry: dict[str, EnvSpec] = {
"CartPole-v0": EnvSpec(
"CartPole-v0", "gymnasium.envs.classic_control.cartpole:CartPoleEnv"
),
"CartPole-v1": EnvSpec(
"CartPole-v1", "gymnasium.envs.classic_control:CartPoleEnv"
),
"CartPole-v2": EnvSpec("CartPole-v2", "gymnasium.cartpole:CartPoleEnv"),
"CartPole-v3": EnvSpec("CartPole-v3", lambda: CartPoleEnv()),
"ExampleNamespace/CartPole-v2": EnvSpec(
"ExampleNamespace/CartPole-v2", "gymnasium.envs.classic_control:CartPoleEnv"
),
}
out = gym.pprint_registry(example_registry, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1
===== cartpole =====
CartPole-v2
===== None =====
CartPole-v3
===== ExampleNamespace =====
ExampleNamespace/CartPole-v2"""
assert out == correct_out
def test_pprint_n_columns():
example_registry = {
"CartPole-v0": EnvSpec("CartPole-v0", EXAMPLE_ENTRY_POINT),
"CartPole-v1": EnvSpec("CartPole-v1", EXAMPLE_ENTRY_POINT),
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
}
out = gym.pprint_registry(example_registry, num_cols=2, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1
CartPole-v2 CartPole-v3"""
assert out == correct_out
out = gym.pprint_registry(example_registry, num_cols=5, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1 CartPole-v2 CartPole-v3"""
assert out == correct_out
def test_pprint_exclude_namespace():
example_registry: dict[str, EnvSpec] = {
"Test/CartPole-v0": EnvSpec("Test/CartPole-v0", EXAMPLE_ENTRY_POINT),
"Test/CartPole-v1": EnvSpec("Test/CartPole-v1", EXAMPLE_ENTRY_POINT),
"CartPole-v2": EnvSpec("CartPole-v2", EXAMPLE_ENTRY_POINT),
"CartPole-v3": EnvSpec("CartPole-v3", EXAMPLE_ENTRY_POINT),
}
out = gym.pprint_registry(
example_registry, exclude_namespaces=["Test"], disable_print=True
)
correct_out = """===== classic_control =====
CartPole-v2 CartPole-v3"""
assert out == correct_out
out = gym.pprint_registry(
example_registry, exclude_namespaces=["classic_control"], disable_print=True
)
correct_out = """===== Test =====
Test/CartPole-v0 Test/CartPole-v1"""
assert out == correct_out
example_registry["Example/CartPole-v4"] = EnvSpec(
"Example/CartPole-v4", EXAMPLE_ENTRY_POINT
)
out = gym.pprint_registry(
example_registry, exclude_namespaces=["Test", "Example"], disable_print=True
)
correct_out = """===== classic_control =====
CartPole-v2 CartPole-v3"""
assert out == correct_out

View File

@@ -18,7 +18,7 @@ def register_registration_testing_envs():
env_id = f"{namespace}/{versioned_name}-v{version}" env_id = f"{namespace}/{versioned_name}-v{version}"
gym.register( gym.register(
id=env_id, id=env_id,
entry_point="tests.envs.utils_envs:ArgumentEnv", entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={ kwargs={
"arg1": "arg1", "arg1": "arg1",
"arg2": "arg2", "arg2": "arg2",
@@ -111,7 +111,7 @@ def test_env_suggestions(
with pytest.raises( with pytest.raises(
gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?" gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
): ):
gym.make(env_id_input, disable_env_checker=True) gym.make(env_id_input)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -136,13 +136,13 @@ def test_env_version_suggestions(
gym.error.DeprecatedEnv, gym.error.DeprecatedEnv,
match="It provides the default version", # env name, match="It provides the default version", # env name,
): ):
gym.make(env_id_input, disable_env_checker=True) gym.make(env_id_input)
else: else:
with pytest.raises( with pytest.raises(
gym.error.UnregisteredEnv, gym.error.UnregisteredEnv,
match=f"It provides versioned environments: \\[ {suggested_versions} \\]", match=f"It provides versioned environments: \\[ {suggested_versions} \\]",
): ):
gym.make(env_id_input, disable_env_checker=True) gym.make(env_id_input)
def test_register_versioned_unversioned(): def test_register_versioned_unversioned():
@@ -185,9 +185,7 @@ def test_make_latest_versioned_env(register_registration_testing_envs):
"Using the latest versioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv-v5` instead of the unversioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv`." "Using the latest versioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv-v5` instead of the unversioned environment `MyAwesomeNamespace/MyAwesomeVersionedEnv`."
), ),
): ):
env = gym.make( env = gym.make("MyAwesomeNamespace/MyAwesomeVersionedEnv")
"MyAwesomeNamespace/MyAwesomeVersionedEnv", disable_env_checker=True
)
assert env.spec is not None assert env.spec is not None
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5" assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"

View File

@@ -0,0 +1,95 @@
"""Tests that `gym.spec` works as expected."""
import re
import pytest
import gymnasium as gym
def test_spec():
spec = gym.spec("CartPole-v1")
assert spec.id == "CartPole-v1"
assert spec is gym.envs.registry["CartPole-v1"]
def test_spec_missing_lookup():
gym.register(id="TestEnv-v0", entry_point="no-entry-point")
gym.register(id="TestEnv-v15", entry_point="no-entry-point")
gym.register(id="TestEnv-v9", entry_point="no-entry-point")
gym.register(id="OtherEnv-v100", entry_point="no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v1 for `TestEnv` is deprecated. Please use `TestEnv-v15` instead."
),
):
gym.spec("TestEnv-v1")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape(
"Environment version `v1000` for environment `TestEnv` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
),
):
gym.spec("TestEnv-v1000")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape("Environment `UnknownEnv` doesn't exist."),
):
gym.spec("UnknownEnv-v1")
del gym.registry["TestEnv-v0"]
del gym.registry["TestEnv-v15"]
del gym.registry["TestEnv-v9"]
del gym.registry["OtherEnv-v100"]
def test_spec_malformed_lookup():
with pytest.raises(
gym.error.Error,
match=re.escape(
"Malformed environment ID: “Breakout-v0”. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
),
):
gym.spec("“Breakout-v0”")
def test_spec_versioned_lookups():
gym.register("test/TestEnv-v5", "no-entry-point")
with pytest.raises(
gym.error.VersionNotFound,
match=re.escape(
"Environment version `v9` for environment `test/TestEnv` doesn't exist. It provides versioned environments: [ `v5` ]."
),
):
gym.spec("test/TestEnv-v9")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v4 for `test/TestEnv` is deprecated. Please use `test/TestEnv-v5` instead."
),
):
gym.spec("test/TestEnv-v4")
assert gym.spec("test/TestEnv-v5") is not None
del gym.registry["test/TestEnv-v5"]
def test_spec_default_lookups():
gym.register("test/TestEnv", "no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version `v0` for environment `test/TestEnv` doesn't exist. It provides the default version `test/TestEnv`."
),
):
gym.spec("test/TestEnv-v0")
assert gym.spec("test/TestEnv") is not None
del gym.registry["test/TestEnv"]

View File

@@ -1,14 +1,8 @@
from __future__ import annotations
import gymnasium as gym import gymnasium as gym
class RegisterDuringMakeEnv(gym.Env):
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
def __init__(self):
self.action_space = gym.spaces.Discrete(1)
self.observation_space = gym.spaces.Discrete(1)
class ArgumentEnv(gym.Env): class ArgumentEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
@@ -23,9 +17,12 @@ class ArgumentEnv(gym.Env):
class NoHuman(gym.Env): class NoHuman(gym.Env):
"""Environment that does not have human-rendering.""" """Environment that does not have human-rendering."""
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
def __init__(self, render_mode=None): metadata = {"render_modes": ["rgb_array"], "render_fps": 4}
def __init__(self, render_mode: list[str] = None):
assert render_mode in self.metadata["render_modes"] assert render_mode in self.metadata["render_modes"]
self.render_mode = render_mode self.render_mode = render_mode
@@ -33,6 +30,9 @@ class NoHuman(gym.Env):
class NoHumanOldAPI(gym.Env): class NoHumanOldAPI(gym.Env):
"""Environment that does not have human-rendering.""" """Environment that does not have human-rendering."""
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
def __init__(self): def __init__(self):
@@ -42,6 +42,9 @@ class NoHumanOldAPI(gym.Env):
class NoHumanNoRGB(gym.Env): class NoHumanNoRGB(gym.Env):
"""Environment that has neither human- nor rgb-rendering""" """Environment that has neither human- nor rgb-rendering"""
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
metadata = {"render_modes": ["ascii"], "render_fps": 4} metadata = {"render_modes": ["ascii"], "render_fps": 4}
def __init__(self, render_mode=None): def __init__(self, render_mode=None):
@@ -52,6 +55,9 @@ class NoHumanNoRGB(gym.Env):
class NoRenderModesMetadata(gym.Env): class NoRenderModesMetadata(gym.Env):
"""An environment that has rendering but has not updated the metadata.""" """An environment that has rendering but has not updated the metadata."""
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
# metadata: dict[str, Any] = {"render_modes": []} # metadata: dict[str, Any] = {"render_modes": []}
def __init__(self, render_mode): def __init__(self, render_mode):

View File

@@ -0,0 +1,13 @@
"""This utility file contains an environment that is registered upon loading the file."""
import gymnasium as gym
class RegisterDuringMakeEnv(gym.Env):
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""
def __init__(self):
self.action_space = gym.spaces.Discrete(1)
self.observation_space = gym.spaces.Discrete(1)
gym.register(id="RegisterDuringMake-v0", entry_point=RegisterDuringMakeEnv)

View File

@@ -39,7 +39,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
all_testing_env_specs, all_testing_env_specs,
ids=[spec.id for spec in all_testing_env_specs], ids=[spec.id for spec in all_testing_env_specs],
) )
def test_envs_pass_env_checker(spec): def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected.""" """Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make(disable_env_checker=True).unwrapped env = spec.make(disable_env_checker=True).unwrapped
@@ -52,6 +52,22 @@ def test_envs_pass_env_checker(spec):
raise gym.error.Error(f"Unexpected warning: {warning.message}") raise gym.error.Error(f"Unexpected warning: {warning.message}")
@pytest.mark.parametrize(
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
)
def test_all_env_passive_env_checker(spec):
with warnings.catch_warnings(record=True) as caught_warnings:
env = gym.make(spec.id)
env.reset()
env.step(env.action_space.sample())
env.close()
for warning in caught_warnings:
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
# Note that this precludes running this test in multiple threads. # Note that this precludes running this test in multiple threads.
# However, we probably already can't do multithreading due to some environments. # However, we probably already can't do multithreading due to some environments.
SEED = 0 SEED = 0

View File

@@ -1,27 +0,0 @@
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
# To ignore the trailing whitespaces, will need flake to ignore this file.
# flake8: noqa
reduced_registry = {
env_id: env_spec
for env_id, env_spec in gym.registry.items()
if env_spec.entry_point != "shimmy.atari_env:AtariEnv"
}
def test_pprint_custom_registry():
"""Testing a registry different from default."""
a = {
"CartPole-v0": gym.envs.registry["CartPole-v0"],
"CartPole-v1": gym.envs.registry["CartPole-v1"],
}
out = gym.pprint_registry(a, disable_print=True)
correct_out = """===== classic_control =====
CartPole-v0 CartPole-v1
"""
assert out == correct_out

View File

@@ -1,93 +0,0 @@
"""Tests that gym.spec works as expected."""
import re
import pytest
import gymnasium as gym
def test_spec():
spec = gym.spec("CartPole-v1")
assert spec.id == "CartPole-v1"
assert spec is gym.envs.registry["CartPole-v1"]
def test_spec_kwargs():
map_name_value = "8x8"
env = gym.make("FrozenLake-v1", map_name=map_name_value)
assert env.spec is not None
assert env.spec.kwargs["map_name"] == map_name_value
def test_spec_missing_lookup():
gym.register(id="Test1-v0", entry_point="no-entry-point")
gym.register(id="Test1-v15", entry_point="no-entry-point")
gym.register(id="Test1-v9", entry_point="no-entry-point")
gym.register(id="Other1-v100", entry_point="no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v1 for `Test1` is deprecated. Please use `Test1-v15` instead."
),
):
gym.spec("Test1-v1")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape(
"Environment version `v1000` for environment `Test1` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
),
):
gym.spec("Test1-v1000")
with pytest.raises(
gym.error.UnregisteredEnv,
match=re.escape("Environment Unknown1 doesn't exist. "),
):
gym.spec("Unknown1-v1")
def test_spec_malformed_lookup():
with pytest.raises(
gym.error.Error,
match=f'^{re.escape("Malformed environment ID: “Breakout-v0”.(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))")}$',
):
gym.spec("“Breakout-v0”")
def test_spec_versioned_lookups():
gym.register("test/Test2-v5", "no-entry-point")
with pytest.raises(
gym.error.VersionNotFound,
match=re.escape(
"Environment version `v9` for environment `test/Test2` doesn't exist. It provides versioned environments: [ `v5` ]."
),
):
gym.spec("test/Test2-v9")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version v4 for `test/Test2` is deprecated. Please use `test/Test2-v5` instead."
),
):
gym.spec("test/Test2-v4")
assert gym.spec("test/Test2-v5") is not None
def test_spec_default_lookups():
gym.register("test/Test3", "no-entry-point")
with pytest.raises(
gym.error.DeprecatedEnv,
match=re.escape(
"Environment version `v0` for environment `test/Test3` doesn't exist. It provides the default version test/Test3`."
),
):
gym.spec("test/Test3-v0")
assert gym.spec("test/Test3") is not None

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import gymnasium as gym import gymnasium as gym