From 00a60e6cc85c9ee84ecd65d4d3ee5040b2d04230 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 21 Apr 2022 20:41:15 +0200 Subject: [PATCH] Rewriting of the registration mechanism (#2748) * First version of the new registration * Almost done * Hopefully final commit * Minor fixes * Missing error * Type fixes * Type fixes * Add some type hinting stuff * Fix an error? * Fix literal import * Add a comment * Add some docstrings Remove old tests * Add some docstrings, rename helper functions * Rename a function * Registration check fix * Consistently use `register` instead of `envs.register` in tests * Fix the malformed registration error message to not use a write-only format * Change an error back to a warning when double-registering an environment --- gym/core.py | 4 +- gym/envs/registration.py | 878 ++++++++++++-------------------- tests/envs/spec_list.py | 2 +- tests/envs/test_registration.py | 124 ++--- 4 files changed, 359 insertions(+), 649 deletions(-) diff --git a/gym/core.py b/gym/core.py index eabaea695..03eccf9fe 100644 --- a/gym/core.py +++ b/gym/core.py @@ -106,8 +106,8 @@ class Env(Generic[ObsType, ActType]): integer seed right after initialization and then never again. Args: - seed (int or None): The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom). However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. - If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action. + seed (int or None): The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom). + However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action. return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step` options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment) diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 09471c62d..48b7d6f96 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -7,34 +7,29 @@ import importlib import importlib.util import re import sys +from dataclasses import dataclass, field from typing import ( Any, Callable, - Generator, Optional, Sequence, SupportsFloat, Tuple, Type, Union, - cast, overload, ) +import numpy as np + +from gym.envs.__relocated__ import internal_env_relocation_map +from gym.wrappers import AutoResetWrapper, OrderEnforcing, TimeLimit + if sys.version_info < (3, 10): import importlib_metadata as metadata # type: ignore else: import importlib.metadata as metadata -from collections import defaultdict -from collections.abc import MutableMapping -from dataclasses import InitVar, dataclass, field - -import numpy as np - -from gym import Env, error, logger -from gym.envs.__relocated__ import internal_env_relocation_map - if sys.version_info >= (3, 8): from typing import Literal else: @@ -44,6 +39,8 @@ else: return Any +from gym import Env, error, logger + ENV_ID_RE: re.Pattern = re.compile( r"^(?:(?P[\w:-]+)\/)?(?:(?P[\w:.-]+?))(?:-v(?P\d+))?$" ) @@ -60,16 +57,16 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]: """Parse environment ID string format. This format is true today, but it's *not* an official spec. - [username/](env-name)-v(version) env-name is group 1, version is group 2 + [namespace/](env-name)-v(version) env-name is group 1, version is group 2 2016-10-31: We're experimentally expanding the environment ID format - to include an optional username. + to include an optional namespace. """ match = ENV_ID_RE.fullmatch(id) if not match: raise error.Error( f"Malformed environment ID: {id}." - f"(Currently all IDs must be of the form {ENV_ID_RE}.)" + f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))" ) namespace, name, version = match.group("namespace", "name", "version") if version is not None: @@ -78,24 +75,21 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]: return namespace, name, version +def get_env_id(ns: Optional[str], name: str, version: Optional[int]): + """Get the full env ID given a name and (optional) version and namespace. + Inverse of parse_env_id.""" + + full_name = name + if version is not None: + full_name += f"-v{version}" + if ns is not None: + full_name = ns + "/" + full_name + return full_name + + @dataclass class EnvSpec: - """A specification for a particular instance of the environment. Used - to register the parameters for official evaluations. - - Args: - id_requested: The official environment ID - entry_point: The Python entrypoint of the environment class (e.g. module.name:Class) - reward_threshold: The reward threshold before the task is considered solved - nondeterministic: Whether this environment is non-deterministic even after seeding - max_episode_steps: The maximum number of steps that an episode can consist of - order_enforce: Whether to wrap the environment in an orderEnforcing wrapper - autoreset: Whether the environment should automatically reset when it reaches the done state - kwargs: The kwargs to pass to the environment class - - """ - - id_requested: InitVar[str] + id: str entry_point: Optional[Union[Callable, str]] = field(default=None) reward_threshold: Optional[float] = field(default=None) nondeterministic: bool = field(default=False) @@ -103,533 +97,173 @@ class EnvSpec: order_enforce: bool = field(default=True) autoreset: bool = field(default=False) kwargs: dict = field(default_factory=dict) + namespace: Optional[str] = field(init=False) name: str = field(init=False) version: Optional[int] = field(init=False) - def __post_init__(self, id_requested): + def __post_init__(self): # Initialize namespace, name, version - self.namespace, self.name, self.version = parse_env_id(id_requested) - - @property - def id(self) -> str: - """ - `id_requested` is an InitVar meaning it's only used at initialization to parse - the namespace, name, and version. This means we can define the dynamic - property `id` to construct the `id` from the parsed fields. This has the - benefit that we update the fields and obtain a dynamic id. - """ - namespace = "" if self.namespace is None else f"{self.namespace}/" - name = self.name - version = "" if self.version is None else f"-v{self.version}" - return f"{namespace}{name}{version}" + self.namespace, self.name, self.version = parse_env_id(self.id) def make(self, **kwargs) -> Env: - """Instantiates an instance of the environment with appropriate kwargs""" - if self.entry_point is None: - raise error.Error( - f"Attempting to make deprecated env {self.id}. " - "(HINT: is there a newer registered version of this env?)" - ) - _kwargs = self.kwargs.copy() - _kwargs.update(kwargs) - - if "autoreset" in _kwargs: - self.autoreset = _kwargs["autoreset"] - del _kwargs["autoreset"] - - if callable(self.entry_point): - env = self.entry_point(**_kwargs) - else: - cls = load(self.entry_point) - env = cls(**_kwargs) - - # Make the environment aware of which spec it came from. - spec = copy.deepcopy(self) - spec.kwargs = _kwargs - env.unwrapped.spec = spec - - if self.order_enforce: - from gym.wrappers.order_enforcing import OrderEnforcing - - env = OrderEnforcing(env) - - assert env.spec is not None, "expected spec to be set to the unwrapped env." - if env.spec.max_episode_steps is not None: - from gym.wrappers.time_limit import TimeLimit - - env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps) - - if self.autoreset: - from gym.wrappers.autoreset import AutoResetWrapper - - env = AutoResetWrapper(env) - - return env + # For compatibility purposes + return make(self, **kwargs) -class EnvSpecTree(MutableMapping): - """ - The EnvSpecTree provides a dict-like mapping object - from environment IDs to specifications. - - The EnvSpecTree is backed by a tree-like structure. - The environment ID format is [{namespace}/]{name}-v{version}. - - The tree has multiple root nodes corresponding to a namespace. - The children of a namespace node corresponds to the environment name. - Furthermore, each name has a mapping from versions to specifications. - It looks like the following, - - { - None: { - MountainCar: { - 0: EnvSpec(...), - 1: EnvSpec(...) - } - }, - ALE: { - Tetris: { - 5: EnvSpec(...) - } - } +def _check_namespace_exists(ns: Optional[str]): + """Check if a namespace exists. If it doesn't, print a helpful error message.""" + if ns is None: + return + namespaces = { + spec_.namespace for spec_ in registry.values() if spec_.namespace is not None } + if ns in namespaces: + return - The tree-structure isn't user-facing and the EnvSpecTree will act - like a dictionary. For example, to lookup an environment ID: + suggestion = ( + difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None + ) + suggestion_msg = ( + f"Did you mean: `{suggestion[0]}`?" + if suggestion + else f"Have you installed the proper package for {ns}?" + ) - ``` - specs = EnvSpecTree() + raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}") - specs["My/Env-v0"] = EnvSpec(...) - assert specs["My/Env-v0"] == EnvSpec(...) - assert specs.tree["My"]["Env"]["0"] == specs["My/Env-v0"] - ``` - """ +def _check_name_exists(ns: Optional[str], name: str): + """Check if an env exists in a namespace. If it doesn't, print a helpful error message.""" + _check_namespace_exists(ns) + names = {spec_.name for spec_ in registry.values()} - def __init__(self): - # Initialize the tree as a nested sequence of defaultdicts - self.tree = defaultdict(lambda: defaultdict(dict)) - self._length = 0 + if name in names: + return - def versions(self, namespace: Optional[str], name: str) -> Sequence[EnvSpec]: - """ - Returns the versions associated with a namespace and name. + if namespace is None and name in internal_env_relocation_map: + relocated_namespace, relocated_package = internal_env_relocation_map[name] + message = f"The environment `{name}` has been moved out of Gym to the package `{relocated_package}`." - Note: This function takes into account environment relocations. - For example, `versions(None, "Breakout")` will return, - ``` - [ - EnvSpec(namespace=None, name="Breakout", version=0), - EnvSpec(namespace=None, name="Breakout", version=4), - EnvSpec(namespace="ALE", name="Breakout", version=5) - ] - ``` - Notice the last environment which is outside of the requested namespace. - This only applies to environments which are in the `internal_env_relocation_map`. - See `gym/envs/__relocated__.py` for more info. - """ - self._assert_name_exists(namespace, name) + # Check if the package is installed + # If not instruct the user to install the package and then how to instantiate the env + if importlib.util.find_spec(relocated_package) is None: + message += ( + f" Please install the package via `pip install {relocated_package}`." + ) - versions = list(self.tree[namespace][name].values()) + # Otherwise the user should be able to instantiate the environment directly + if namespace != relocated_namespace: + message += f" You can instantiate the new namespaced environment as `{relocated_namespace}/{name}`." - if namespace is None and name in internal_env_relocation_map: - relocated_namespace, _ = internal_env_relocation_map[name] - try: - self._assert_name_exists(relocated_namespace, name) - versions += list(self.tree[relocated_namespace][name].values()) - except error.UnregisteredEnv: - pass - - return versions - - def names(self, namespace: Optional[str]) -> Sequence[str]: - """ - Returns all the environment names associated with a namespace. - """ - self._assert_namespace_exists(namespace) - return list(self.tree[namespace].keys()) - - def namespaces(self) -> Sequence[str]: - """ - Returns all the namespaces contained in the tree. - """ - return list(filter(None, self.tree.keys())) - - def __iter__(self) -> Generator[str, None, None]: - # Iterate through the structure and generate the IDs contained in the tree. - for namespace, names in self.tree.items(): - for name, versions in names.items(): - for version, spec in versions.items(): - assert spec.namespace == namespace - assert spec.name == name - assert spec.version == version - yield spec.id - - def _assert_namespace_exists(self, namespace: Optional[str]) -> None: - if namespace in self.tree: - return - - message = f"Namespace `{namespace}` does not exist." - if namespace: - suggestions = difflib.get_close_matches(namespace, self.namespaces(), n=1) - if suggestions: - message += f" Did you mean: `{suggestions[0]}`?" - else: - message += f" Have you installed the proper package for `{namespace}`?" - raise error.NamespaceNotFound(message) - - def _assert_name_exists(self, namespace: Optional[str], name: str) -> None: - self._assert_namespace_exists(namespace) - if name in self.tree[namespace]: - return - - if namespace is None and name in internal_env_relocation_map: - relocated_namespace, relocated_package = internal_env_relocation_map[name] - message = f"The environment `{name}` has been moved out of Gym to the package `{relocated_package}`." - - # Check if the package is installed - # If not instruct the user to install the package and then how to instantiate the env - if importlib.util.find_spec(relocated_package) is None: - message += f" Please install the package via `pip install {relocated_package}`." - - # Otherwise the user should be able to instantiate the environment directly - if namespace != relocated_namespace: - message += f" You can instantiate the new namespaced environment as `{relocated_namespace}/{name}`." - # If the environment hasn't been relocated we'll construct a generic error message - else: - message = f"Environment `{name}` doesn't exist" - if namespace is not None: - message += f" in namespace `{namespace}`" - message += "." - suggestions = difflib.get_close_matches(name, self.names(namespace), n=1) - if suggestions: - message += f" Did you mean: `{suggestions[0]}`?" - # Throw the error raise error.NameNotFound(message) - def _assert_version_exists( - self, namespace: Optional[str], name: str, version: Optional[int] - ): - self._assert_name_exists(namespace, name) - if version in self.tree[namespace][name]: - return + suggestion = difflib.get_close_matches(name, names, n=1) + namespace_msg = f" in namespace {ns}" if ns else "" + suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else "" - # Construct the appropriate exception. - # If the version is less than the latest version - # then we throw an error.DeprecatedEnv exception. - # Otherwise we throw error.VersionNotFound. - versions = self.tree[namespace][name] - assert len(versions) > 0 + raise error.NameNotFound( + f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}" + ) - versioned_specs = list( - filter(lambda spec: isinstance(spec.version, int), versions.values()) - ) - default_spec = versions[None] if None in versions else None - assert len(versioned_specs) > 0 or default_spec is not None - latest_spec = max( - versioned_specs, key=lambda spec: spec.version, default=default_spec - ) +def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]): + """Check if an env version exists in a namespace. If it doesn't, print a helpful error message. + This is a complete test whether an environment identifier is valid, and will provide the best available hints.""" + if get_env_id(ns, name, version) in registry: + return - if version is not None: - message = f"Environment version `v{version}` for `" - else: - message = "The default version for `" + _check_name_exists(ns, name) + if version is None: + return - if namespace is not None: - message += f"{namespace}/" - message += f"{name}` " + message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist." - # If this version doesn't exist but there exists a newer non-default - # version we should warn the user this version is deprecated. - if ( - latest_spec - and latest_spec.version is not None - and version is not None - and version < latest_spec.version - ): - message += "is deprecated. " - message += f"Please use the latest version `v{latest_spec.version}`." + env_specs = [ + spec_ + for spec_ in registry.values() + if spec_.namespace == ns and spec_.name == name + ] + env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1)) + + default_spec = [spec_ for spec_ in env_specs if spec_.version is None] + + if default_spec: + message += f" It provides the default version {default_spec[0].id}`." + if len(env_specs) == 1: raise error.DeprecatedEnv(message) - # If this version doesn't exist and there only exists a default version - elif latest_spec and latest_spec.version is None: - message += "is deprecated. " - message += f"`{latest_spec.name}` only provides the default version. " - message += ( - f'You can initialize the environment as `gym.make("{latest_spec.id}")`.' - ) - raise error.DeprecatedEnv(message) - # Otherwise we've asked for a version that doesn't exist. - else: - message += f"could not be found. `{name}` provides " - if default_spec: - message += "a default version" - if versioned_specs: - message += " and " - if versioned_specs: - message += "the versioned environments: [ " - versioned_specs_sorted = sorted( - versioned_specs, key=lambda spec: spec.version - ) - message += ", ".join( - map(lambda spec: f"`v{spec.version}`", versioned_specs_sorted) - ) - message += " ]" - message += "." - raise error.VersionNotFound(message) + # Process possible versioned environments - def __getitem__(self, key: str) -> EnvSpec: - # Get an item from the tree. - # We first parse the components so we can look up the - # appropriate environment ID. - namespace, name, version = parse_env_id(key) - self._assert_version_exists(namespace, name, version) + versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None] - return self.tree[namespace][name][version] + latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore + if latest_spec is not None and version > latest_spec.version: + version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs) + message += f" It provides versioned environments: [ {version_list_msg} ]." - def __setitem__(self, key: str, value: EnvSpec) -> None: - # Insert an item into the tree. - # First we parse the components to get the path - # for insertion. - namespace, name, version = parse_env_id(key) - self.tree[namespace][name][version] = value - # Increase the size - self._length += 1 + raise error.VersionNotFound(message) - def __delitem__(self, key: str) -> None: - # Delete an item from the tree. - # First parse the components so we can follow the - # path to delete. - namespace, name, version = parse_env_id(key) - self._assert_version_exists(namespace, name, version) - - # Remove the envspec with this version. - self.tree[namespace][name].pop(version) - # Remove the name if it's empty. - if len(self.tree[namespace][name]) == 0: - self.tree[namespace].pop(name) - # Remove the namespace if it's empty. - if len(self.tree[namespace]) == 0: - self.tree.pop(namespace) - # Decrease the size - self._length -= 1 - - def __contains__(self, key: str) -> bool: - # Check if the tree contains a path for this key. - namespace, name, version = parse_env_id(key) - if ( - namespace in self.tree - and name in self.tree[namespace] - and version in self.tree[namespace][name] - ): - return True - return False - - def __repr__(self) -> str: - # Construct a tree-like representation structure - # so we can easily look at the contents of the tree. - tree_repr = "" - for namespace, names in self.tree.items(): - # For each namespace we'll iterate over the names - root = namespace is None - # Insert a separator if we're between depths - if len(tree_repr) > 0: - tree_repr += "│\n" - # if this isn't the root we'll display the namespace - if not root: - tree_repr += f"├──{str(namespace)}\n" - - # Construct the namespace string so we can print this for - # our children. - namespace = f"{namespace}/" if namespace is not None else "" - for name_idx, (name, versions) in enumerate(names.items()): - # If this isn't the root we'll have to increase our - # depth, i.e., insert some space - if not root: - tree_repr += "│ " - # If this is the last item make sure we use the - # termination character. Otherwise use the nested - # character. - if name_idx == len(names) - 1: - tree_repr += "└──" - else: - tree_repr += "├──" - # Print the namespace and the name - # and get ready to print the versions. - tree_repr += f"{namespace}{name}: [ " - # Print each version comma separated - for version_idx, version in enumerate(versions.keys()): - if version is not None: - tree_repr += f"v{version}" - else: - tree_repr += "" - if version_idx < len(versions) - 1: - tree_repr += ", " - tree_repr += " ]\n" - - return tree_repr - - def __len__(self): - # Return the length of the container - return self._length - - -class EnvRegistry: - """Register an env by ID. IDs remain stable over time and are - guaranteed to resolve to the same environment dynamics (or be - desupported). The goal is that results on a particular environment - should always be comparable, and not depend on the version of the - code that was running. - """ - - def __init__(self): - self.env_specs = EnvSpecTree() - self._ns: Optional[str] = None - - def make(self, path: str, **kwargs) -> Env: - if len(kwargs) > 0: - logger.info("Making new env: %s (%s)", path, kwargs) - else: - logger.info("Making new env: %s", path) - - # We need to manually parse the ID so we can check - # the version without error-ing out in self.spec - namespace, name, version = parse_env_id(path) - - # Get all versions of this spec. - versions = self.env_specs.versions(namespace, name) - - # We check what the latest version of the environment is and display - # a warning if the user is attempting to initialize an older version - # or an unversioned one. - latest_versioned_spec = max( - filter(lambda spec: spec.version, versions), - key=lambda spec: cast(int, spec.version), - default=None, + if latest_spec is not None and version < latest_spec.version: + raise error.DeprecatedEnv( + f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. " + f"Please use `{latest_spec.id}` instead." ) - if ( - latest_versioned_spec - and version is not None - and version < cast(int, latest_versioned_spec.version) - ): - logger.warn( - f"The environment {path} is out of date. You should consider " - f"upgrading to version `v{latest_versioned_spec.version}` " - f"with the environment ID `{latest_versioned_spec.id}`." - ) - elif latest_versioned_spec and version is None: - logger.warn( - f"Using the latest versioned environment `{latest_versioned_spec.id}` " - f"instead of the unversioned environment `{path}`" - ) - path = latest_versioned_spec.id - # Lookup our path - spec = self.spec(path) - # Construct the environment - return spec.make(**kwargs) - def all(self): - return self.env_specs.values() +def find_highest_version(ns: Optional[str], name: str) -> Optional[int]: + version: list[int] = [ + spec_.version + for spec_ in registry.values() + if spec_.namespace == ns and spec_.name == name and spec_.version is not None + ] + return max(version, default=None) - def spec(self, path: str) -> EnvSpec: - if ":" in path: - mod_name, _, id = path.partition(":") - try: - importlib.import_module(mod_name) - except ModuleNotFoundError: - raise error.Error( - f"A module ({mod_name}) was specified for the environment but was not found, " - "make sure the package is installed with `pip install` before calling `gym.make()`" - ) - else: - id = path - - # We can go ahead and return the env_spec. - # The EnvSpecTree will take care of any exceptions. - return self.env_specs[id] - - def register(self, id: str, **kwargs) -> None: - spec = EnvSpec(id, **kwargs) - - if self._ns is not None: - if spec.namespace is not None: - logger.warn( - f"Custom namespace `{spec.namespace}` is being overridden " - f"by namespace `{self._ns}`. If you are developing a " - "plugin you shouldn't specify a namespace in `register` " - "calls. The namespace is specified through the " - "entry point package metadata." - ) - # Replace namespace - spec.namespace = self._ns +def load_env_plugins(entry_point: str = "gym.envs") -> None: + # Load third-party environments + for plugin in metadata.entry_points(group=entry_point): + # Python 3.8 doesn't support plugin.module, plugin.attr + # So we'll have to try and parse this ourselves try: - # Get all versions of this spec. - versions = self.env_specs.versions(spec.namespace, spec.name) - - # We raise an error if the user is attempting to initialize an - # unversioned environment when a versioned one already exists. - latest_versioned_spec = max( - filter(lambda spec: isinstance(spec.version, int), versions), - key=lambda spec: cast(int, spec.version), - default=None, - ) - unversioned_spec = next( - filter(lambda spec: spec.version is None, versions), None - ) - - # Trying to register an unversioned spec when versioned spec exists - if unversioned_spec and spec.version is not None: - message = ( - "Can't register the versioned environment " - f"`{spec.id}` when the unversioned environment " - f"`{unversioned_spec.id}` of the same name already exists." - ) - raise error.RegistrationError(message) - elif latest_versioned_spec and spec.version is None: - message = ( - f"Can't register the unversioned environment `{spec.id}` " - f"when version `{latest_versioned_spec.version}` " - "of the same name already exists. Note: the default " - "behavior is that the `gym.make` with the unversioned " - "environment will return the latest versioned environment." - ) - raise error.RegistrationError(message) - # We might not find this namespace or name in which case - # we should continue to register the environment. - except (error.NamespaceNotFound, error.NameNotFound): - pass + module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint" + except AttributeError: + if ":" in plugin.value: + module, attr = plugin.value.split(":", maxsplit=1) + else: + module, attr = plugin.value, None + except: + module, attr = None, None finally: - if spec.id in self.env_specs: - logger.warn(f"Overriding environment {id}") - self.env_specs[spec.id] = spec + if attr is None: + raise error.Error( + f"Gym environment plugin `{module}` must specify a function to execute, not a root module" + ) - @contextlib.contextmanager - def namespace(self, ns: str): - self._ns = ns - yield - self._ns = None + context = namespace(plugin.name) + if plugin.name.startswith("__") and plugin.name.endswith("__"): + # `__internal__` is an artifact of the plugin system when + # the root namespace had an allow-list. The allow-list is now + # removed and plugins can register environments in the root + # namespace with the `__root__` magic key. + if plugin.name == "__root__" or plugin.name == "__internal__": + context = contextlib.nullcontext() + else: + logger.warn( + f"The environment namespace magic key `{plugin.name}` is unsupported. " + "To register an environment at the root namespace you should specify " + "the `__root__` namespace." + ) - def __repr__(self): - return repr(self.env_specs) - - -# Have a global registry -registry = EnvRegistry() - - -def register(id: str, **kwargs) -> None: - return registry.register(id, **kwargs) + with context: + fn = plugin.load() + try: + fn() + except Exception as e: + logger.warn(str(e)) # fmt: off -# Continuous -# ---------------------------------------- - @overload def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... @overload @@ -680,63 +314,191 @@ def make(id: Literal[ "Ant-v2" ], **kwargs) -> Env[np.ndarray, np.ndarray]: ... -# ---------------------------------------- @overload def make(id: str, **kwargs) -> Env: ... +@overload +def make(id: EnvSpec, **kwargs) -> Env: ... + # fmt: on -def make(id: str, **kwargs) -> Env: - return registry.make(id, **kwargs) -def spec(id: str) -> EnvSpec: - return registry.spec(id) +# Global registry of environments. Meant to be accessed through `register` and `make` +registry: dict[str, EnvSpec] = dict() +current_namespace: Optional[str] = None + + +def _check_spec_register(spec: EnvSpec): + """Checks whether the spec is valid to be registered. Helper function for `register`.""" + global registry, current_namespace + if current_namespace is not None: + if spec.namespace is not None: + logger.warn( + f"Custom namespace `{spec.namespace}` is being overridden " + f"by namespace `{current_namespace}`. If you are developing a " + "plugin you shouldn't specify a namespace in `register` " + "calls. The namespace is specified through the " + "entry point package metadata." + ) + + latest_versioned_spec = max( + ( + spec_ + for spec_ in registry.values() + if spec_.namespace == spec.namespace + and spec_.name == spec.name + and spec_.version is not None + ), + key=lambda spec_: int(spec_.version), # type: ignore + default=None, + ) + + unversioned_spec = next( + ( + spec_ + for spec_ in registry.values() + if spec_.namespace == spec.namespace + and spec_.name == spec.name + and spec_.version is None + ), + None, + ) + + if unversioned_spec and spec.version is not None: + raise error.RegistrationError( + "Can't register the versioned environment " + f"`{spec.id}` when the unversioned environment " + f"`{unversioned_spec.id}` of the same name already exists." + ) + elif latest_versioned_spec and spec.version is None: + raise error.RegistrationError( + "Can't register the unversioned environment " + f"`{spec.id}` when the versioned environment " + f"`{latest_versioned_spec.id}` of the same name " + f"already exists. Note: the default behavior is " + f"that `gym.make` with the unversioned environment " + f"will return the latest versioned environment" + ) + + +# Public API @contextlib.contextmanager def namespace(ns: str): - with registry.namespace(ns): - yield + global current_namespace + old_namespace = current_namespace + current_namespace = ns + yield + current_namespace = old_namespace -def load_env_plugins(entry_point: str = "gym.envs") -> None: - # Load third-party environments - for plugin in metadata.entry_points(group=entry_point): - # Python 3.8 doesn't support plugin.module, plugin.attr - # So we'll have to try and parse this ourselves - try: - module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint" - except AttributeError: - if ":" in plugin.value: - module, attr = plugin.value.split(":", maxsplit=1) - else: - module, attr = plugin.value, None - except: - module, attr = None, None - finally: - if attr is None: - raise error.Error( - f"Gym environment plugin `{module}` must specify a function to execute, not a root module" - ) +def register(id: str, **kwargs): + """ + Register an environment with gym. The `id` parameter corresponds to the name of the environment, + with the syntax as follows: + `(namespace)/(env_name)-(version)` + where `namespace` is optional. - context = namespace(plugin.name) - if plugin.name.startswith("__") and plugin.name.endswith("__"): - # `__internal__` is an artifact of the plugin system when - # the root namespace had an allow-list. The allow-list is now - # removed and plugins can register environments in the root - # namespace with the `__root__` magic key. - if plugin.name == "__root__" or plugin.name == "__internal__": - context = contextlib.nullcontext() - else: - logger.warn( - f"The environment namespace magic key `{plugin.name}` is unsupported. " - "To register an environment at the root namespace you should specify " - "the `__root__` namespace." - ) + It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor. + """ + global registry, current_namespace + full_id = (current_namespace or "") + id + spec = EnvSpec(id=full_id, **kwargs) + _check_spec_register(spec) + if spec.id in registry: + logger.warn(f"Overriding environment {id}") + registry[spec.id] = spec - with context: - fn = plugin.load() - try: - fn() - except Exception as e: - logger.warn(str(e)) + +def make( + id: str | EnvSpec, + max_episode_steps: Optional[int] = None, + autoreset: bool = False, + **kwargs, +) -> Env: + """ + Create an environment according to the given ID. + + Args: + id: Name of the environment. + max_episode_steps: Maximum length of an episode (TimeLimit wrapper). + autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). + kwargs: Additional arguments to pass to the environment constructor. + Returns: + An instance of the environment. + """ + if isinstance(id, EnvSpec): + spec_ = id + else: + spec_ = registry.get(id) + + ns, name, version = parse_env_id(id) + latest_version = find_highest_version(ns, name) + if ( + version is not None + and latest_version is not None + and latest_version > version + ): + logger.warn( + f"The environment {id} is out of date. You should consider " + f"upgrading to version `v{latest_version}`." + ) + if version is None and latest_version is not None: + version = latest_version + new_env_id = get_env_id(ns, name, version) + spec_ = registry.get(new_env_id) + logger.warn( + f"Using the latest versioned environment `{new_env_id}` " + f"instead of the unversioned environment `{id}`." + ) + + if spec_ is None: + _check_version_exists(ns, name, version) + raise error.Error(f"No registered env with id: {id}") + + _kwargs = spec_.kwargs.copy() + _kwargs.update(kwargs) + + # TODO: add a minimal env checker on initialization + if spec_.entry_point is None: + raise error.Error(f"{spec_.id} registered but entry_point is not specified") + elif callable(spec_.entry_point): + cls = spec_.entry_point + else: + # Assume it's a string + cls = load(spec_.entry_point) + + env = cls(**_kwargs) + + spec_ = copy.deepcopy(spec_) + spec_.kwargs = _kwargs + + env.unwrapped.spec = spec_ + + if spec_.order_enforce: + env = OrderEnforcing(env) + + if max_episode_steps is not None: + env = TimeLimit(env, max_episode_steps) + elif spec_.max_episode_steps is not None: + env = TimeLimit(env, spec_.max_episode_steps) + + if autoreset: + env = AutoResetWrapper(env) + + return env + + +def spec(env_id: str) -> EnvSpec: + """ + Retrieve the spec for the given environment from the global registry. + """ + spec_ = registry.get(env_id) + if spec_ is None: + ns, name, version = parse_env_id(env_id) + _check_version_exists(ns, name, version) + raise error.Error(f"No registered env with id: {env_id}") + else: + assert isinstance(spec_, EnvSpec) + return spec_ diff --git a/tests/envs/spec_list.py b/tests/envs/spec_list.py index 11c816f6b..289d38f2e 100644 --- a/tests/envs/spec_list.py +++ b/tests/envs/spec_list.py @@ -50,6 +50,6 @@ def should_skip_env_spec_for_tests(spec): spec_list = [ spec - for spec in sorted(envs.registry.all(), key=lambda x: x.id) + for spec in sorted(envs.registry.values(), key=lambda x: x.id) if spec.entry_point is not None and not should_skip_env_spec_for_tests(spec) ] diff --git a/tests/envs/test_registration.py b/tests/envs/test_registration.py index 125b7cef5..be7e856c7 100644 --- a/tests/envs/test_registration.py +++ b/tests/envs/test_registration.py @@ -2,9 +2,9 @@ import pytest import gym from gym import envs, error -from gym.envs import registration +from gym.envs import register, registration, registry, spec from gym.envs.classic_control import cartpole -from gym.envs.registration import EnvSpec, EnvSpecTree +from gym.envs.registration import EnvSpec class ArgumentEnv(gym.Env): @@ -55,8 +55,8 @@ def register_some_envs(): for version in versions: env_id = f"{namespace}/{versioned_name}-v{version}" - del gym.envs.registry.env_specs[env_id] - del gym.envs.registry.env_specs[f"{namespace}/{unversioned_name}"] + del gym.envs.registry[env_id] + del gym.envs.registry[f"{namespace}/{unversioned_name}"] def test_make(): @@ -83,10 +83,15 @@ def test_make(): ], ) def test_register(env_id, namespace, name, version): - envs.register(env_id) + register(env_id) assert gym.envs.spec(env_id).id == env_id - assert version in gym.envs.registry.env_specs.tree[namespace][name].keys() - del gym.envs.registry.env_specs[env_id] + full_name = f"{name}" + if namespace: + full_name = f"{namespace}/{full_name}" + if version is not None: + full_name = f"{full_name}-v{version}" + assert full_name in gym.envs.registry.keys() + del gym.envs.registry[env_id] @pytest.mark.parametrize( @@ -99,7 +104,7 @@ def test_register(env_id, namespace, name, version): ) def test_register_error(env_id): with pytest.raises(error.Error, match="Malformed environment ID"): - envs.register(env_id) + register(env_id) @pytest.mark.parametrize( @@ -188,27 +193,23 @@ def test_spec_with_kwargs(): def test_missing_lookup(): - registry = registration.EnvRegistry() - registry.register(id="Test-v0", entry_point=None) - registry.register(id="Test-v15", entry_point=None) - registry.register(id="Test-v9", entry_point=None) - registry.register(id="Other-v100", entry_point=None) - try: - registry.spec("Test-v1") # must match an env name but not the version above - except error.DeprecatedEnv: - pass - else: - assert False + register(id="Test1-v0", entry_point=None) + register(id="Test1-v15", entry_point=None) + register(id="Test1-v9", entry_point=None) + register(id="Other1-v100", entry_point=None) + + with pytest.raises(error.DeprecatedEnv): + spec("Test1-v1") try: - registry.spec("Test-v1000") + spec("Test1-v1000") except error.UnregisteredEnv: pass else: assert False try: - registry.spec("Unknown-v1") + spec("Unknown1-v1") except error.UnregisteredEnv: pass else: @@ -216,9 +217,8 @@ def test_missing_lookup(): def test_malformed_lookup(): - registry = registration.EnvRegistry() try: - registry.spec("“Breakout-v0”") + spec("“Breakout-v0”") except error.Error as e: assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}" else: @@ -226,99 +226,47 @@ def test_malformed_lookup(): def test_versioned_lookups(): - registry = registration.EnvRegistry() - registry.register("test/Test-v5") + register("test/Test2-v5") with pytest.raises(error.VersionNotFound): - registry.spec("test/Test-v9") + spec("test/Test2-v9") with pytest.raises(error.DeprecatedEnv): - registry.spec("test/Test-v4") + spec("test/Test2-v4") - assert registry.spec("test/Test-v5") + assert spec("test/Test2-v5") def test_default_lookups(): - registry = registration.EnvRegistry() - registry.register("test/Test") + register("test/Test3") with pytest.raises(error.DeprecatedEnv): - registry.spec("test/Test-v0") + spec("test/Test3-v0") # Lookup default - registry.spec("test/Test") - - -def test_env_spec_tree(): - spec_tree = EnvSpecTree() - - # Add with namespace - spec = EnvSpec("test/Test-v0") - spec_tree["test/Test-v0"] = spec - assert spec_tree.tree.keys() == {"test"} - assert spec_tree.tree["test"].keys() == {"Test"} - assert spec_tree.tree["test"]["Test"].keys() == {0} - assert spec_tree.tree["test"]["Test"][0] == spec - assert spec_tree["test/Test-v0"] == spec - - # Add without namespace - spec = EnvSpec("Test-v0") - spec_tree["Test-v0"] = spec - assert spec_tree.tree.keys() == {"test", None} - assert spec_tree.tree[None].keys() == {"Test"} - assert spec_tree.tree[None]["Test"].keys() == {0} - assert spec_tree.tree[None]["Test"][0] == spec - - # Delete last version deletes entire subtree - del spec_tree["test/Test-v0"] - assert spec_tree.tree.keys() == {None} - - # Append second version for same name - spec_tree["Test-v1"] = EnvSpec("Test-v1") - assert spec_tree.tree.keys() == {None} - assert spec_tree.tree[None].keys() == {"Test"} - assert spec_tree.tree[None]["Test"].keys() == {0, 1} - - # Deleting one version leaves other - del spec_tree["Test-v0"] - assert spec_tree.tree.keys() == {None} - assert spec_tree.tree[None].keys() == {"Test"} - assert spec_tree.tree[None]["Test"].keys() == {1} - - # Add without version - myenv = "MyAwesomeEnv" - spec = EnvSpec(myenv) - spec_tree[myenv] = spec - assert spec_tree.tree.keys() == {None} - assert myenv in spec_tree.tree[None].keys() - assert spec_tree.tree[None][myenv].keys() == {None} - assert spec_tree.tree[None][myenv][None] == spec - assert spec_tree.__repr__() == "├──Test: [ v1 ]\n" + f"└──{myenv}: [ ]\n" + spec("test/Test3") def test_register_versioned_unversioned(): # Register versioned then unversioned versioned_env = "Test/MyEnv-v0" - envs.register(versioned_env) + register(versioned_env) assert gym.envs.spec(versioned_env).id == versioned_env unversioned_env = "Test/MyEnv" with pytest.raises(error.RegistrationError): - envs.register(unversioned_env) + register(unversioned_env) # Clean everything - del gym.envs.registry.env_specs[versioned_env] + del gym.envs.registry[versioned_env] # Register unversioned then versioned - with pytest.warns(UserWarning): - envs.register(unversioned_env) + register(unversioned_env) assert gym.envs.spec(unversioned_env).id == unversioned_env with pytest.raises(error.RegistrationError): - envs.register(versioned_env) + register(versioned_env) # Clean everything - envs_list = [versioned_env, unversioned_env] - for env in envs_list: - del gym.envs.registry.env_specs[env] + del gym.envs.registry[unversioned_env] def test_return_latest_versioned_env(register_some_envs):