Rewriting of the registration mechanism (#2748)

* First version of the new registration

* Almost done

* Hopefully final commit

* Minor fixes

* Missing error

* Type fixes

* Type fixes

* Add some type hinting stuff

* Fix an error?

* Fix literal import

* Add a comment

* Add some docstrings

Remove old tests

* Add some docstrings, rename helper functions

* Rename a function

* Registration check fix

* Consistently use `register` instead of `envs.register` in tests

* Fix the malformed registration error message to not use a write-only format

* Change an error back to a warning when double-registering an environment
This commit is contained in:
Ariel Kwiatkowski
2022-04-21 20:41:15 +02:00
committed by GitHub
parent 0a5f543d6a
commit 00a60e6cc8
4 changed files with 359 additions and 649 deletions

View File

@@ -106,8 +106,8 @@ class Env(Generic[ObsType, ActType]):
integer seed right after initialization and then never again. integer seed right after initialization and then never again.
Args: Args:
seed (int or None): The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom). However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. seed (int or None): The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action. However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action.
return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step` return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step`
options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment) options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment)

View File

@@ -7,34 +7,29 @@ import importlib
import importlib.util import importlib.util
import re import re
import sys import sys
from dataclasses import dataclass, field
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Generator,
Optional, Optional,
Sequence, Sequence,
SupportsFloat, SupportsFloat,
Tuple, Tuple,
Type, Type,
Union, Union,
cast,
overload, overload,
) )
import numpy as np
from gym.envs.__relocated__ import internal_env_relocation_map
from gym.wrappers import AutoResetWrapper, OrderEnforcing, TimeLimit
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
import importlib_metadata as metadata # type: ignore import importlib_metadata as metadata # type: ignore
else: else:
import importlib.metadata as metadata import importlib.metadata as metadata
from collections import defaultdict
from collections.abc import MutableMapping
from dataclasses import InitVar, dataclass, field
import numpy as np
from gym import Env, error, logger
from gym.envs.__relocated__ import internal_env_relocation_map
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import Literal from typing import Literal
else: else:
@@ -44,6 +39,8 @@ else:
return Any return Any
from gym import Env, error, logger
ENV_ID_RE: re.Pattern = re.compile( ENV_ID_RE: re.Pattern = re.compile(
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$" r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
) )
@@ -60,16 +57,16 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
"""Parse environment ID string format. """Parse environment ID string format.
This format is true today, but it's *not* an official spec. This format is true today, but it's *not* an official spec.
[username/](env-name)-v(version) env-name is group 1, version is group 2 [namespace/](env-name)-v(version) env-name is group 1, version is group 2
2016-10-31: We're experimentally expanding the environment ID format 2016-10-31: We're experimentally expanding the environment ID format
to include an optional username. to include an optional namespace.
""" """
match = ENV_ID_RE.fullmatch(id) match = ENV_ID_RE.fullmatch(id)
if not match: if not match:
raise error.Error( raise error.Error(
f"Malformed environment ID: {id}." f"Malformed environment ID: {id}."
f"(Currently all IDs must be of the form {ENV_ID_RE}.)" f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
) )
namespace, name, version = match.group("namespace", "name", "version") namespace, name, version = match.group("namespace", "name", "version")
if version is not None: if version is not None:
@@ -78,24 +75,21 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
return namespace, name, version return namespace, name, version
def get_env_id(ns: Optional[str], name: str, version: Optional[int]):
"""Get the full env ID given a name and (optional) version and namespace.
Inverse of parse_env_id."""
full_name = name
if version is not None:
full_name += f"-v{version}"
if ns is not None:
full_name = ns + "/" + full_name
return full_name
@dataclass @dataclass
class EnvSpec: class EnvSpec:
"""A specification for a particular instance of the environment. Used id: str
to register the parameters for official evaluations.
Args:
id_requested: The official environment ID
entry_point: The Python entrypoint of the environment class (e.g. module.name:Class)
reward_threshold: The reward threshold before the task is considered solved
nondeterministic: Whether this environment is non-deterministic even after seeding
max_episode_steps: The maximum number of steps that an episode can consist of
order_enforce: Whether to wrap the environment in an orderEnforcing wrapper
autoreset: Whether the environment should automatically reset when it reaches the done state
kwargs: The kwargs to pass to the environment class
"""
id_requested: InitVar[str]
entry_point: Optional[Union[Callable, str]] = field(default=None) entry_point: Optional[Union[Callable, str]] = field(default=None)
reward_threshold: Optional[float] = field(default=None) reward_threshold: Optional[float] = field(default=None)
nondeterministic: bool = field(default=False) nondeterministic: bool = field(default=False)
@@ -103,533 +97,173 @@ class EnvSpec:
order_enforce: bool = field(default=True) order_enforce: bool = field(default=True)
autoreset: bool = field(default=False) autoreset: bool = field(default=False)
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)
namespace: Optional[str] = field(init=False) namespace: Optional[str] = field(init=False)
name: str = field(init=False) name: str = field(init=False)
version: Optional[int] = field(init=False) version: Optional[int] = field(init=False)
def __post_init__(self, id_requested): def __post_init__(self):
# Initialize namespace, name, version # Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(id_requested) self.namespace, self.name, self.version = parse_env_id(self.id)
@property
def id(self) -> str:
"""
`id_requested` is an InitVar meaning it's only used at initialization to parse
the namespace, name, and version. This means we can define the dynamic
property `id` to construct the `id` from the parsed fields. This has the
benefit that we update the fields and obtain a dynamic id.
"""
namespace = "" if self.namespace is None else f"{self.namespace}/"
name = self.name
version = "" if self.version is None else f"-v{self.version}"
return f"{namespace}{name}{version}"
def make(self, **kwargs) -> Env: def make(self, **kwargs) -> Env:
"""Instantiates an instance of the environment with appropriate kwargs""" # For compatibility purposes
if self.entry_point is None: return make(self, **kwargs)
raise error.Error(
f"Attempting to make deprecated env {self.id}. "
"(HINT: is there a newer registered version of this env?)"
)
_kwargs = self.kwargs.copy()
_kwargs.update(kwargs)
if "autoreset" in _kwargs:
self.autoreset = _kwargs["autoreset"]
del _kwargs["autoreset"]
if callable(self.entry_point):
env = self.entry_point(**_kwargs)
else:
cls = load(self.entry_point)
env = cls(**_kwargs)
# Make the environment aware of which spec it came from.
spec = copy.deepcopy(self)
spec.kwargs = _kwargs
env.unwrapped.spec = spec
if self.order_enforce:
from gym.wrappers.order_enforcing import OrderEnforcing
env = OrderEnforcing(env)
assert env.spec is not None, "expected spec to be set to the unwrapped env."
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit
env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
if self.autoreset:
from gym.wrappers.autoreset import AutoResetWrapper
env = AutoResetWrapper(env)
return env
class EnvSpecTree(MutableMapping): def _check_namespace_exists(ns: Optional[str]):
""" """Check if a namespace exists. If it doesn't, print a helpful error message."""
The EnvSpecTree provides a dict-like mapping object if ns is None:
from environment IDs to specifications. return
namespaces = {
The EnvSpecTree is backed by a tree-like structure. spec_.namespace for spec_ in registry.values() if spec_.namespace is not None
The environment ID format is [{namespace}/]{name}-v{version}.
The tree has multiple root nodes corresponding to a namespace.
The children of a namespace node corresponds to the environment name.
Furthermore, each name has a mapping from versions to specifications.
It looks like the following,
{
None: {
MountainCar: {
0: EnvSpec(...),
1: EnvSpec(...)
}
},
ALE: {
Tetris: {
5: EnvSpec(...)
}
}
} }
if ns in namespaces:
return
The tree-structure isn't user-facing and the EnvSpecTree will act suggestion = (
like a dictionary. For example, to lookup an environment ID: difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
)
suggestion_msg = (
f"Did you mean: `{suggestion[0]}`?"
if suggestion
else f"Have you installed the proper package for {ns}?"
)
``` raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
specs = EnvSpecTree()
specs["My/Env-v0"] = EnvSpec(...)
assert specs["My/Env-v0"] == EnvSpec(...)
assert specs.tree["My"]["Env"]["0"] == specs["My/Env-v0"] def _check_name_exists(ns: Optional[str], name: str):
``` """Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
""" _check_namespace_exists(ns)
names = {spec_.name for spec_ in registry.values()}
def __init__(self): if name in names:
# Initialize the tree as a nested sequence of defaultdicts return
self.tree = defaultdict(lambda: defaultdict(dict))
self._length = 0
def versions(self, namespace: Optional[str], name: str) -> Sequence[EnvSpec]: if namespace is None and name in internal_env_relocation_map:
""" relocated_namespace, relocated_package = internal_env_relocation_map[name]
Returns the versions associated with a namespace and name. message = f"The environment `{name}` has been moved out of Gym to the package `{relocated_package}`."
Note: This function takes into account environment relocations. # Check if the package is installed
For example, `versions(None, "Breakout")` will return, # If not instruct the user to install the package and then how to instantiate the env
``` if importlib.util.find_spec(relocated_package) is None:
[ message += (
EnvSpec(namespace=None, name="Breakout", version=0), f" Please install the package via `pip install {relocated_package}`."
EnvSpec(namespace=None, name="Breakout", version=4), )
EnvSpec(namespace="ALE", name="Breakout", version=5)
]
```
Notice the last environment which is outside of the requested namespace.
This only applies to environments which are in the `internal_env_relocation_map`.
See `gym/envs/__relocated__.py` for more info.
"""
self._assert_name_exists(namespace, name)
versions = list(self.tree[namespace][name].values()) # Otherwise the user should be able to instantiate the environment directly
if namespace != relocated_namespace:
message += f" You can instantiate the new namespaced environment as `{relocated_namespace}/{name}`."
if namespace is None and name in internal_env_relocation_map:
relocated_namespace, _ = internal_env_relocation_map[name]
try:
self._assert_name_exists(relocated_namespace, name)
versions += list(self.tree[relocated_namespace][name].values())
except error.UnregisteredEnv:
pass
return versions
def names(self, namespace: Optional[str]) -> Sequence[str]:
"""
Returns all the environment names associated with a namespace.
"""
self._assert_namespace_exists(namespace)
return list(self.tree[namespace].keys())
def namespaces(self) -> Sequence[str]:
"""
Returns all the namespaces contained in the tree.
"""
return list(filter(None, self.tree.keys()))
def __iter__(self) -> Generator[str, None, None]:
# Iterate through the structure and generate the IDs contained in the tree.
for namespace, names in self.tree.items():
for name, versions in names.items():
for version, spec in versions.items():
assert spec.namespace == namespace
assert spec.name == name
assert spec.version == version
yield spec.id
def _assert_namespace_exists(self, namespace: Optional[str]) -> None:
if namespace in self.tree:
return
message = f"Namespace `{namespace}` does not exist."
if namespace:
suggestions = difflib.get_close_matches(namespace, self.namespaces(), n=1)
if suggestions:
message += f" Did you mean: `{suggestions[0]}`?"
else:
message += f" Have you installed the proper package for `{namespace}`?"
raise error.NamespaceNotFound(message)
def _assert_name_exists(self, namespace: Optional[str], name: str) -> None:
self._assert_namespace_exists(namespace)
if name in self.tree[namespace]:
return
if namespace is None and name in internal_env_relocation_map:
relocated_namespace, relocated_package = internal_env_relocation_map[name]
message = f"The environment `{name}` has been moved out of Gym to the package `{relocated_package}`."
# Check if the package is installed
# If not instruct the user to install the package and then how to instantiate the env
if importlib.util.find_spec(relocated_package) is None:
message += f" Please install the package via `pip install {relocated_package}`."
# Otherwise the user should be able to instantiate the environment directly
if namespace != relocated_namespace:
message += f" You can instantiate the new namespaced environment as `{relocated_namespace}/{name}`."
# If the environment hasn't been relocated we'll construct a generic error message
else:
message = f"Environment `{name}` doesn't exist"
if namespace is not None:
message += f" in namespace `{namespace}`"
message += "."
suggestions = difflib.get_close_matches(name, self.names(namespace), n=1)
if suggestions:
message += f" Did you mean: `{suggestions[0]}`?"
# Throw the error
raise error.NameNotFound(message) raise error.NameNotFound(message)
def _assert_version_exists( suggestion = difflib.get_close_matches(name, names, n=1)
self, namespace: Optional[str], name: str, version: Optional[int] namespace_msg = f" in namespace {ns}" if ns else ""
): suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else ""
self._assert_name_exists(namespace, name)
if version in self.tree[namespace][name]:
return
# Construct the appropriate exception. raise error.NameNotFound(
# If the version is less than the latest version f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}"
# then we throw an error.DeprecatedEnv exception. )
# Otherwise we throw error.VersionNotFound.
versions = self.tree[namespace][name]
assert len(versions) > 0
versioned_specs = list(
filter(lambda spec: isinstance(spec.version, int), versions.values())
)
default_spec = versions[None] if None in versions else None
assert len(versioned_specs) > 0 or default_spec is not None
latest_spec = max( def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
versioned_specs, key=lambda spec: spec.version, default=default_spec """Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
) This is a complete test whether an environment identifier is valid, and will provide the best available hints."""
if get_env_id(ns, name, version) in registry:
return
if version is not None: _check_name_exists(ns, name)
message = f"Environment version `v{version}` for `" if version is None:
else: return
message = "The default version for `"
if namespace is not None: message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
message += f"{namespace}/"
message += f"{name}` "
# If this version doesn't exist but there exists a newer non-default env_specs = [
# version we should warn the user this version is deprecated. spec_
if ( for spec_ in registry.values()
latest_spec if spec_.namespace == ns and spec_.name == name
and latest_spec.version is not None ]
and version is not None env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1))
and version < latest_spec.version
): default_spec = [spec_ for spec_ in env_specs if spec_.version is None]
message += "is deprecated. "
message += f"Please use the latest version `v{latest_spec.version}`." if default_spec:
message += f" It provides the default version {default_spec[0].id}`."
if len(env_specs) == 1:
raise error.DeprecatedEnv(message) raise error.DeprecatedEnv(message)
# If this version doesn't exist and there only exists a default version
elif latest_spec and latest_spec.version is None:
message += "is deprecated. "
message += f"`{latest_spec.name}` only provides the default version. "
message += (
f'You can initialize the environment as `gym.make("{latest_spec.id}")`.'
)
raise error.DeprecatedEnv(message)
# Otherwise we've asked for a version that doesn't exist.
else:
message += f"could not be found. `{name}` provides "
if default_spec: # Process possible versioned environments
message += "a default version"
if versioned_specs:
message += " and "
if versioned_specs:
message += "the versioned environments: [ "
versioned_specs_sorted = sorted(
versioned_specs, key=lambda spec: spec.version
)
message += ", ".join(
map(lambda spec: f"`v{spec.version}`", versioned_specs_sorted)
)
message += " ]"
message += "."
raise error.VersionNotFound(message)
def __getitem__(self, key: str) -> EnvSpec: versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None]
# Get an item from the tree.
# We first parse the components so we can look up the
# appropriate environment ID.
namespace, name, version = parse_env_id(key)
self._assert_version_exists(namespace, name, version)
return self.tree[namespace][name][version] latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore
if latest_spec is not None and version > latest_spec.version:
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs)
message += f" It provides versioned environments: [ {version_list_msg} ]."
def __setitem__(self, key: str, value: EnvSpec) -> None: raise error.VersionNotFound(message)
# Insert an item into the tree.
# First we parse the components to get the path
# for insertion.
namespace, name, version = parse_env_id(key)
self.tree[namespace][name][version] = value
# Increase the size
self._length += 1
def __delitem__(self, key: str) -> None: if latest_spec is not None and version < latest_spec.version:
# Delete an item from the tree. raise error.DeprecatedEnv(
# First parse the components so we can follow the f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. "
# path to delete. f"Please use `{latest_spec.id}` instead."
namespace, name, version = parse_env_id(key)
self._assert_version_exists(namespace, name, version)
# Remove the envspec with this version.
self.tree[namespace][name].pop(version)
# Remove the name if it's empty.
if len(self.tree[namespace][name]) == 0:
self.tree[namespace].pop(name)
# Remove the namespace if it's empty.
if len(self.tree[namespace]) == 0:
self.tree.pop(namespace)
# Decrease the size
self._length -= 1
def __contains__(self, key: str) -> bool:
# Check if the tree contains a path for this key.
namespace, name, version = parse_env_id(key)
if (
namespace in self.tree
and name in self.tree[namespace]
and version in self.tree[namespace][name]
):
return True
return False
def __repr__(self) -> str:
# Construct a tree-like representation structure
# so we can easily look at the contents of the tree.
tree_repr = ""
for namespace, names in self.tree.items():
# For each namespace we'll iterate over the names
root = namespace is None
# Insert a separator if we're between depths
if len(tree_repr) > 0:
tree_repr += "\n"
# if this isn't the root we'll display the namespace
if not root:
tree_repr += f"├──{str(namespace)}\n"
# Construct the namespace string so we can print this for
# our children.
namespace = f"{namespace}/" if namespace is not None else ""
for name_idx, (name, versions) in enumerate(names.items()):
# If this isn't the root we'll have to increase our
# depth, i.e., insert some space
if not root:
tree_repr += ""
# If this is the last item make sure we use the
# termination character. Otherwise use the nested
# character.
if name_idx == len(names) - 1:
tree_repr += "└──"
else:
tree_repr += "├──"
# Print the namespace and the name
# and get ready to print the versions.
tree_repr += f"{namespace}{name}: [ "
# Print each version comma separated
for version_idx, version in enumerate(versions.keys()):
if version is not None:
tree_repr += f"v{version}"
else:
tree_repr += ""
if version_idx < len(versions) - 1:
tree_repr += ", "
tree_repr += " ]\n"
return tree_repr
def __len__(self):
# Return the length of the container
return self._length
class EnvRegistry:
"""Register an env by ID. IDs remain stable over time and are
guaranteed to resolve to the same environment dynamics (or be
desupported). The goal is that results on a particular environment
should always be comparable, and not depend on the version of the
code that was running.
"""
def __init__(self):
self.env_specs = EnvSpecTree()
self._ns: Optional[str] = None
def make(self, path: str, **kwargs) -> Env:
if len(kwargs) > 0:
logger.info("Making new env: %s (%s)", path, kwargs)
else:
logger.info("Making new env: %s", path)
# We need to manually parse the ID so we can check
# the version without error-ing out in self.spec
namespace, name, version = parse_env_id(path)
# Get all versions of this spec.
versions = self.env_specs.versions(namespace, name)
# We check what the latest version of the environment is and display
# a warning if the user is attempting to initialize an older version
# or an unversioned one.
latest_versioned_spec = max(
filter(lambda spec: spec.version, versions),
key=lambda spec: cast(int, spec.version),
default=None,
) )
if (
latest_versioned_spec
and version is not None
and version < cast(int, latest_versioned_spec.version)
):
logger.warn(
f"The environment {path} is out of date. You should consider "
f"upgrading to version `v{latest_versioned_spec.version}` "
f"with the environment ID `{latest_versioned_spec.id}`."
)
elif latest_versioned_spec and version is None:
logger.warn(
f"Using the latest versioned environment `{latest_versioned_spec.id}` "
f"instead of the unversioned environment `{path}`"
)
path = latest_versioned_spec.id
# Lookup our path
spec = self.spec(path)
# Construct the environment
return spec.make(**kwargs)
def all(self): def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
return self.env_specs.values() version: list[int] = [
spec_.version
for spec_ in registry.values()
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
]
return max(version, default=None)
def spec(self, path: str) -> EnvSpec:
if ":" in path:
mod_name, _, id = path.partition(":")
try:
importlib.import_module(mod_name)
except ModuleNotFoundError:
raise error.Error(
f"A module ({mod_name}) was specified for the environment but was not found, "
"make sure the package is installed with `pip install` before calling `gym.make()`"
)
else:
id = path
# We can go ahead and return the env_spec.
# The EnvSpecTree will take care of any exceptions.
return self.env_specs[id]
def register(self, id: str, **kwargs) -> None:
spec = EnvSpec(id, **kwargs)
if self._ns is not None:
if spec.namespace is not None:
logger.warn(
f"Custom namespace `{spec.namespace}` is being overridden "
f"by namespace `{self._ns}`. If you are developing a "
"plugin you shouldn't specify a namespace in `register` "
"calls. The namespace is specified through the "
"entry point package metadata."
)
# Replace namespace
spec.namespace = self._ns
def load_env_plugins(entry_point: str = "gym.envs") -> None:
# Load third-party environments
for plugin in metadata.entry_points(group=entry_point):
# Python 3.8 doesn't support plugin.module, plugin.attr
# So we'll have to try and parse this ourselves
try: try:
# Get all versions of this spec. module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
versions = self.env_specs.versions(spec.namespace, spec.name) except AttributeError:
if ":" in plugin.value:
# We raise an error if the user is attempting to initialize an module, attr = plugin.value.split(":", maxsplit=1)
# unversioned environment when a versioned one already exists. else:
latest_versioned_spec = max( module, attr = plugin.value, None
filter(lambda spec: isinstance(spec.version, int), versions), except:
key=lambda spec: cast(int, spec.version), module, attr = None, None
default=None,
)
unversioned_spec = next(
filter(lambda spec: spec.version is None, versions), None
)
# Trying to register an unversioned spec when versioned spec exists
if unversioned_spec and spec.version is not None:
message = (
"Can't register the versioned environment "
f"`{spec.id}` when the unversioned environment "
f"`{unversioned_spec.id}` of the same name already exists."
)
raise error.RegistrationError(message)
elif latest_versioned_spec and spec.version is None:
message = (
f"Can't register the unversioned environment `{spec.id}` "
f"when version `{latest_versioned_spec.version}` "
"of the same name already exists. Note: the default "
"behavior is that the `gym.make` with the unversioned "
"environment will return the latest versioned environment."
)
raise error.RegistrationError(message)
# We might not find this namespace or name in which case
# we should continue to register the environment.
except (error.NamespaceNotFound, error.NameNotFound):
pass
finally: finally:
if spec.id in self.env_specs: if attr is None:
logger.warn(f"Overriding environment {id}") raise error.Error(
self.env_specs[spec.id] = spec f"Gym environment plugin `{module}` must specify a function to execute, not a root module"
)
@contextlib.contextmanager context = namespace(plugin.name)
def namespace(self, ns: str): if plugin.name.startswith("__") and plugin.name.endswith("__"):
self._ns = ns # `__internal__` is an artifact of the plugin system when
yield # the root namespace had an allow-list. The allow-list is now
self._ns = None # removed and plugins can register environments in the root
# namespace with the `__root__` magic key.
if plugin.name == "__root__" or plugin.name == "__internal__":
context = contextlib.nullcontext()
else:
logger.warn(
f"The environment namespace magic key `{plugin.name}` is unsupported. "
"To register an environment at the root namespace you should specify "
"the `__root__` namespace."
)
def __repr__(self): with context:
return repr(self.env_specs) fn = plugin.load()
try:
fn()
# Have a global registry except Exception as e:
registry = EnvRegistry() logger.warn(str(e))
def register(id: str, **kwargs) -> None:
return registry.register(id, **kwargs)
# fmt: off # fmt: off
# Continuous
# ----------------------------------------
@overload @overload
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ... def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload @overload
@@ -680,63 +314,191 @@ def make(id: Literal[
"Ant-v2" "Ant-v2"
], **kwargs) -> Env[np.ndarray, np.ndarray]: ... ], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
# ----------------------------------------
@overload @overload
def make(id: str, **kwargs) -> Env: ... def make(id: str, **kwargs) -> Env: ...
@overload
def make(id: EnvSpec, **kwargs) -> Env: ...
# fmt: on # fmt: on
def make(id: str, **kwargs) -> Env:
return registry.make(id, **kwargs)
def spec(id: str) -> EnvSpec: # Global registry of environments. Meant to be accessed through `register` and `make`
return registry.spec(id) registry: dict[str, EnvSpec] = dict()
current_namespace: Optional[str] = None
def _check_spec_register(spec: EnvSpec):
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
global registry, current_namespace
if current_namespace is not None:
if spec.namespace is not None:
logger.warn(
f"Custom namespace `{spec.namespace}` is being overridden "
f"by namespace `{current_namespace}`. If you are developing a "
"plugin you shouldn't specify a namespace in `register` "
"calls. The namespace is specified through the "
"entry point package metadata."
)
latest_versioned_spec = max(
(
spec_
for spec_ in registry.values()
if spec_.namespace == spec.namespace
and spec_.name == spec.name
and spec_.version is not None
),
key=lambda spec_: int(spec_.version), # type: ignore
default=None,
)
unversioned_spec = next(
(
spec_
for spec_ in registry.values()
if spec_.namespace == spec.namespace
and spec_.name == spec.name
and spec_.version is None
),
None,
)
if unversioned_spec and spec.version is not None:
raise error.RegistrationError(
"Can't register the versioned environment "
f"`{spec.id}` when the unversioned environment "
f"`{unversioned_spec.id}` of the same name already exists."
)
elif latest_versioned_spec and spec.version is None:
raise error.RegistrationError(
"Can't register the unversioned environment "
f"`{spec.id}` when the versioned environment "
f"`{latest_versioned_spec.id}` of the same name "
f"already exists. Note: the default behavior is "
f"that `gym.make` with the unversioned environment "
f"will return the latest versioned environment"
)
# Public API
@contextlib.contextmanager @contextlib.contextmanager
def namespace(ns: str): def namespace(ns: str):
with registry.namespace(ns): global current_namespace
yield old_namespace = current_namespace
current_namespace = ns
yield
current_namespace = old_namespace
def load_env_plugins(entry_point: str = "gym.envs") -> None: def register(id: str, **kwargs):
# Load third-party environments """
for plugin in metadata.entry_points(group=entry_point): Register an environment with gym. The `id` parameter corresponds to the name of the environment,
# Python 3.8 doesn't support plugin.module, plugin.attr with the syntax as follows:
# So we'll have to try and parse this ourselves `(namespace)/(env_name)-(version)`
try: where `namespace` is optional.
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
except AttributeError:
if ":" in plugin.value:
module, attr = plugin.value.split(":", maxsplit=1)
else:
module, attr = plugin.value, None
except:
module, attr = None, None
finally:
if attr is None:
raise error.Error(
f"Gym environment plugin `{module}` must specify a function to execute, not a root module"
)
context = namespace(plugin.name) It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
if plugin.name.startswith("__") and plugin.name.endswith("__"): """
# `__internal__` is an artifact of the plugin system when global registry, current_namespace
# the root namespace had an allow-list. The allow-list is now full_id = (current_namespace or "") + id
# removed and plugins can register environments in the root spec = EnvSpec(id=full_id, **kwargs)
# namespace with the `__root__` magic key. _check_spec_register(spec)
if plugin.name == "__root__" or plugin.name == "__internal__": if spec.id in registry:
context = contextlib.nullcontext() logger.warn(f"Overriding environment {id}")
else: registry[spec.id] = spec
logger.warn(
f"The environment namespace magic key `{plugin.name}` is unsupported. "
"To register an environment at the root namespace you should specify "
"the `__root__` namespace."
)
with context:
fn = plugin.load() def make(
try: id: str | EnvSpec,
fn() max_episode_steps: Optional[int] = None,
except Exception as e: autoreset: bool = False,
logger.warn(str(e)) **kwargs,
) -> Env:
"""
Create an environment according to the given ID.
Args:
id: Name of the environment.
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
kwargs: Additional arguments to pass to the environment constructor.
Returns:
An instance of the environment.
"""
if isinstance(id, EnvSpec):
spec_ = id
else:
spec_ = registry.get(id)
ns, name, version = parse_env_id(id)
latest_version = find_highest_version(ns, name)
if (
version is not None
and latest_version is not None
and latest_version > version
):
logger.warn(
f"The environment {id} is out of date. You should consider "
f"upgrading to version `v{latest_version}`."
)
if version is None and latest_version is not None:
version = latest_version
new_env_id = get_env_id(ns, name, version)
spec_ = registry.get(new_env_id)
logger.warn(
f"Using the latest versioned environment `{new_env_id}` "
f"instead of the unversioned environment `{id}`."
)
if spec_ is None:
_check_version_exists(ns, name, version)
raise error.Error(f"No registered env with id: {id}")
_kwargs = spec_.kwargs.copy()
_kwargs.update(kwargs)
# TODO: add a minimal env checker on initialization
if spec_.entry_point is None:
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
elif callable(spec_.entry_point):
cls = spec_.entry_point
else:
# Assume it's a string
cls = load(spec_.entry_point)
env = cls(**_kwargs)
spec_ = copy.deepcopy(spec_)
spec_.kwargs = _kwargs
env.unwrapped.spec = spec_
if spec_.order_enforce:
env = OrderEnforcing(env)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif spec_.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps)
if autoreset:
env = AutoResetWrapper(env)
return env
def spec(env_id: str) -> EnvSpec:
"""
Retrieve the spec for the given environment from the global registry.
"""
spec_ = registry.get(env_id)
if spec_ is None:
ns, name, version = parse_env_id(env_id)
_check_version_exists(ns, name, version)
raise error.Error(f"No registered env with id: {env_id}")
else:
assert isinstance(spec_, EnvSpec)
return spec_

View File

@@ -50,6 +50,6 @@ def should_skip_env_spec_for_tests(spec):
spec_list = [ spec_list = [
spec spec
for spec in sorted(envs.registry.all(), key=lambda x: x.id) for spec in sorted(envs.registry.values(), key=lambda x: x.id)
if spec.entry_point is not None and not should_skip_env_spec_for_tests(spec) if spec.entry_point is not None and not should_skip_env_spec_for_tests(spec)
] ]

View File

@@ -2,9 +2,9 @@ import pytest
import gym import gym
from gym import envs, error from gym import envs, error
from gym.envs import registration from gym.envs import register, registration, registry, spec
from gym.envs.classic_control import cartpole from gym.envs.classic_control import cartpole
from gym.envs.registration import EnvSpec, EnvSpecTree from gym.envs.registration import EnvSpec
class ArgumentEnv(gym.Env): class ArgumentEnv(gym.Env):
@@ -55,8 +55,8 @@ def register_some_envs():
for version in versions: for version in versions:
env_id = f"{namespace}/{versioned_name}-v{version}" env_id = f"{namespace}/{versioned_name}-v{version}"
del gym.envs.registry.env_specs[env_id] del gym.envs.registry[env_id]
del gym.envs.registry.env_specs[f"{namespace}/{unversioned_name}"] del gym.envs.registry[f"{namespace}/{unversioned_name}"]
def test_make(): def test_make():
@@ -83,10 +83,15 @@ def test_make():
], ],
) )
def test_register(env_id, namespace, name, version): def test_register(env_id, namespace, name, version):
envs.register(env_id) register(env_id)
assert gym.envs.spec(env_id).id == env_id assert gym.envs.spec(env_id).id == env_id
assert version in gym.envs.registry.env_specs.tree[namespace][name].keys() full_name = f"{name}"
del gym.envs.registry.env_specs[env_id] if namespace:
full_name = f"{namespace}/{full_name}"
if version is not None:
full_name = f"{full_name}-v{version}"
assert full_name in gym.envs.registry.keys()
del gym.envs.registry[env_id]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -99,7 +104,7 @@ def test_register(env_id, namespace, name, version):
) )
def test_register_error(env_id): def test_register_error(env_id):
with pytest.raises(error.Error, match="Malformed environment ID"): with pytest.raises(error.Error, match="Malformed environment ID"):
envs.register(env_id) register(env_id)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -188,27 +193,23 @@ def test_spec_with_kwargs():
def test_missing_lookup(): def test_missing_lookup():
registry = registration.EnvRegistry() register(id="Test1-v0", entry_point=None)
registry.register(id="Test-v0", entry_point=None) register(id="Test1-v15", entry_point=None)
registry.register(id="Test-v15", entry_point=None) register(id="Test1-v9", entry_point=None)
registry.register(id="Test-v9", entry_point=None) register(id="Other1-v100", entry_point=None)
registry.register(id="Other-v100", entry_point=None)
try: with pytest.raises(error.DeprecatedEnv):
registry.spec("Test-v1") # must match an env name but not the version above spec("Test1-v1")
except error.DeprecatedEnv:
pass
else:
assert False
try: try:
registry.spec("Test-v1000") spec("Test1-v1000")
except error.UnregisteredEnv: except error.UnregisteredEnv:
pass pass
else: else:
assert False assert False
try: try:
registry.spec("Unknown-v1") spec("Unknown1-v1")
except error.UnregisteredEnv: except error.UnregisteredEnv:
pass pass
else: else:
@@ -216,9 +217,8 @@ def test_missing_lookup():
def test_malformed_lookup(): def test_malformed_lookup():
registry = registration.EnvRegistry()
try: try:
registry.spec("“Breakout-v0”") spec("“Breakout-v0”")
except error.Error as e: except error.Error as e:
assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}" assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}"
else: else:
@@ -226,99 +226,47 @@ def test_malformed_lookup():
def test_versioned_lookups(): def test_versioned_lookups():
registry = registration.EnvRegistry() register("test/Test2-v5")
registry.register("test/Test-v5")
with pytest.raises(error.VersionNotFound): with pytest.raises(error.VersionNotFound):
registry.spec("test/Test-v9") spec("test/Test2-v9")
with pytest.raises(error.DeprecatedEnv): with pytest.raises(error.DeprecatedEnv):
registry.spec("test/Test-v4") spec("test/Test2-v4")
assert registry.spec("test/Test-v5") assert spec("test/Test2-v5")
def test_default_lookups(): def test_default_lookups():
registry = registration.EnvRegistry() register("test/Test3")
registry.register("test/Test")
with pytest.raises(error.DeprecatedEnv): with pytest.raises(error.DeprecatedEnv):
registry.spec("test/Test-v0") spec("test/Test3-v0")
# Lookup default # Lookup default
registry.spec("test/Test") spec("test/Test3")
def test_env_spec_tree():
spec_tree = EnvSpecTree()
# Add with namespace
spec = EnvSpec("test/Test-v0")
spec_tree["test/Test-v0"] = spec
assert spec_tree.tree.keys() == {"test"}
assert spec_tree.tree["test"].keys() == {"Test"}
assert spec_tree.tree["test"]["Test"].keys() == {0}
assert spec_tree.tree["test"]["Test"][0] == spec
assert spec_tree["test/Test-v0"] == spec
# Add without namespace
spec = EnvSpec("Test-v0")
spec_tree["Test-v0"] = spec
assert spec_tree.tree.keys() == {"test", None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {0}
assert spec_tree.tree[None]["Test"][0] == spec
# Delete last version deletes entire subtree
del spec_tree["test/Test-v0"]
assert spec_tree.tree.keys() == {None}
# Append second version for same name
spec_tree["Test-v1"] = EnvSpec("Test-v1")
assert spec_tree.tree.keys() == {None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {0, 1}
# Deleting one version leaves other
del spec_tree["Test-v0"]
assert spec_tree.tree.keys() == {None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {1}
# Add without version
myenv = "MyAwesomeEnv"
spec = EnvSpec(myenv)
spec_tree[myenv] = spec
assert spec_tree.tree.keys() == {None}
assert myenv in spec_tree.tree[None].keys()
assert spec_tree.tree[None][myenv].keys() == {None}
assert spec_tree.tree[None][myenv][None] == spec
assert spec_tree.__repr__() == "├──Test: [ v1 ]\n" + f"└──{myenv}: [ ]\n"
def test_register_versioned_unversioned(): def test_register_versioned_unversioned():
# Register versioned then unversioned # Register versioned then unversioned
versioned_env = "Test/MyEnv-v0" versioned_env = "Test/MyEnv-v0"
envs.register(versioned_env) register(versioned_env)
assert gym.envs.spec(versioned_env).id == versioned_env assert gym.envs.spec(versioned_env).id == versioned_env
unversioned_env = "Test/MyEnv" unversioned_env = "Test/MyEnv"
with pytest.raises(error.RegistrationError): with pytest.raises(error.RegistrationError):
envs.register(unversioned_env) register(unversioned_env)
# Clean everything # Clean everything
del gym.envs.registry.env_specs[versioned_env] del gym.envs.registry[versioned_env]
# Register unversioned then versioned # Register unversioned then versioned
with pytest.warns(UserWarning): register(unversioned_env)
envs.register(unversioned_env)
assert gym.envs.spec(unversioned_env).id == unversioned_env assert gym.envs.spec(unversioned_env).id == unversioned_env
with pytest.raises(error.RegistrationError): with pytest.raises(error.RegistrationError):
envs.register(versioned_env) register(versioned_env)
# Clean everything # Clean everything
envs_list = [versioned_env, unversioned_env] del gym.envs.registry[unversioned_env]
for env in envs_list:
del gym.envs.registry.env_specs[env]
def test_return_latest_versioned_env(register_some_envs): def test_return_latest_versioned_env(register_some_envs):