Files
Gymnasium/gym/envs/registration.py
Andrea PIERRÉ e913bc81b8 Improve pre-commit workflow (#2602)
* feat: add `isort` to `pre-commit`

* ci: skip `__init__.py` file for `isort`

* ci: make `isort` mandatory in lint pipeline

* docs: add a section on Git hooks

* ci: check isort diff

* fix: isort from master branch

* docs: add pre-commit badge

* ci: update black + bandit versions

* feat: add PR template

* refactor: PR template

* ci: remove bandit

* docs: add Black badge

* ci: try to remove all `|| true` statements

* ci: remove lint_python job

- Remove `lint_python` CI job
- Move `pyupgrade` job to `pre-commit` workflow

* fix: avoid messing with typing

* docs: add a note on running `pre-cpmmit` manually

* ci: apply `pre-commit` to the whole codebase
2022-03-31 15:50:38 -04:00

729 lines
27 KiB
Python

from __future__ import annotations
import contextlib
import copy
import difflib
import importlib
import importlib.util
import re
import sys
from typing import (
Any,
Callable,
Generator,
Optional,
Sequence,
SupportsFloat,
Tuple,
Type,
Union,
cast,
overload,
)
if sys.version_info < (3, 10):
import importlib_metadata as metadata # type: ignore
else:
import importlib.metadata as metadata
from collections import defaultdict
from collections.abc import MutableMapping
from dataclasses import InitVar, dataclass, field
import numpy as np
from gym import Env, error, logger
from gym.envs.__relocated__ import internal_env_relocation_map
if sys.version_info >= (3, 8):
from typing import Literal
else:
class Literal(str):
def __class_getitem__(cls, item):
return Any
ENV_ID_RE: re.Pattern = re.compile(
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
)
def load(name: str) -> Type:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
return fn
def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
"""Parse environment ID string format.
This format is true today, but it's *not* an official spec.
[username/](env-name)-v(version) env-name is group 1, version is group 2
2016-10-31: We're experimentally expanding the environment ID format
to include an optional username.
"""
match = ENV_ID_RE.fullmatch(id)
if not match:
raise error.Error(
f"Malformed environment ID: {id}."
f"(Currently all IDs must be of the form {ENV_ID_RE}.)"
)
namespace, name, version = match.group("namespace", "name", "version")
if version is not None:
version = int(version)
return namespace, name, version
@dataclass
class EnvSpec:
"""A specification for a particular instance of the environment. Used
to register the parameters for official evaluations.
Args:
id_requested: The official environment ID
entry_point: The Python entrypoint of the environment class (e.g. module.name:Class)
reward_threshold: The reward threshold before the task is considered solved
nondeterministic: Whether this environment is non-deterministic even after seeding
max_episode_steps: The maximum number of steps that an episode can consist of
order_enforce: Whether to wrap the environment in an orderEnforcing wrapper
kwargs: The kwargs to pass to the environment class
"""
id_requested: InitVar[str]
entry_point: Optional[Union[Callable, str]] = field(default=None)
reward_threshold: Optional[int] = field(default=None)
nondeterministic: bool = field(default=False)
max_episode_steps: Optional[int] = field(default=None)
order_enforce: bool = field(default=True)
kwargs: dict = field(default_factory=dict)
namespace: Optional[str] = field(init=False)
name: str = field(init=False)
version: Optional[int] = field(init=False)
def __post_init__(self, id_requested):
# Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(id_requested)
@property
def id(self) -> str:
"""
`id_requested` is an InitVar meaning it's only used at initialization to parse
the namespace, name, and version. This means we can define the dynamic
property `id` to construct the `id` from the parsed fields. This has the
benefit that we update the fields and obtain a dynamic id.
"""
namespace = "" if self.namespace is None else f"{self.namespace}/"
name = self.name
version = "" if self.version is None else f"-v{self.version}"
return f"{namespace}{name}{version}"
def make(self, **kwargs) -> Env:
"""Instantiates an instance of the environment with appropriate kwargs"""
if self.entry_point is None:
raise error.Error(
f"Attempting to make deprecated env {self.id}. "
"(HINT: is there a newer registered version of this env?)"
)
_kwargs = self.kwargs.copy()
_kwargs.update(kwargs)
if callable(self.entry_point):
env = self.entry_point(**_kwargs)
else:
cls = load(self.entry_point)
env = cls(**_kwargs)
# Make the environment aware of which spec it came from.
spec = copy.deepcopy(self)
spec.kwargs = _kwargs
env.unwrapped.spec = spec
if self.order_enforce:
from gym.wrappers.order_enforcing import OrderEnforcing
env = OrderEnforcing(env)
assert env.spec is not None, "expected spec to be set to the unwrapped env."
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit
env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
return env
class EnvSpecTree(MutableMapping):
"""
The EnvSpecTree provides a dict-like mapping object
from environment IDs to specifications.
The EnvSpecTree is backed by a tree-like structure.
The environment ID format is [{namespace}/]{name}-v{version}.
The tree has multiple root nodes corresponding to a namespace.
The children of a namespace node corresponds to the environment name.
Furthermore, each name has a mapping from versions to specifications.
It looks like the following,
{
None: {
MountainCar: {
0: EnvSpec(...),
1: EnvSpec(...)
}
},
ALE: {
Tetris: {
5: EnvSpec(...)
}
}
}
The tree-structure isn't user-facing and the EnvSpecTree will act
like a dictionary. For example, to lookup an environment ID:
```
specs = EnvSpecTree()
specs["My/Env-v0"] = EnvSpec(...)
assert specs["My/Env-v0"] == EnvSpec(...)
assert specs.tree["My"]["Env"]["0"] == specs["My/Env-v0"]
```
"""
def __init__(self):
# Initialize the tree as a nested sequence of defaultdicts
self.tree = defaultdict(lambda: defaultdict(dict))
self._length = 0
def versions(self, namespace: Optional[str], name: str) -> Sequence[EnvSpec]:
"""
Returns the versions associated with a namespace and name.
Note: This function takes into account environment relocations.
For example, `versions(None, "Breakout")` will return,
```
[
EnvSpec(namespace=None, name="Breakout", version=0),
EnvSpec(namespace=None, name="Breakout", version=4),
EnvSpec(namespace="ALE", name="Breakout", version=5)
]
```
Notice the last environment which is outside of the requested namespace.
This only applies to environments which are in the `internal_env_relocation_map`.
See `gym/envs/__relocated__.py` for more info.
"""
self._assert_name_exists(namespace, name)
versions = list(self.tree[namespace][name].values())
if namespace is None and name in internal_env_relocation_map:
relocated_namespace, _ = internal_env_relocation_map[name]
try:
self._assert_name_exists(relocated_namespace, name)
versions += list(self.tree[relocated_namespace][name].values())
except error.UnregisteredEnv:
pass
return versions
def names(self, namespace: Optional[str]) -> Sequence[str]:
"""
Returns all the environment names associated with a namespace.
"""
self._assert_namespace_exists(namespace)
return list(self.tree[namespace].keys())
def namespaces(self) -> Sequence[str]:
"""
Returns all the namespaces contained in the tree.
"""
return list(filter(None, self.tree.keys()))
def __iter__(self) -> Generator[str, None, None]:
# Iterate through the structure and generate the IDs contained in the tree.
for namespace, names in self.tree.items():
for name, versions in names.items():
for version, spec in versions.items():
assert spec.namespace == namespace
assert spec.name == name
assert spec.version == version
yield spec.id
def _assert_namespace_exists(self, namespace: Optional[str]) -> None:
if namespace in self.tree:
return
message = f"Namespace `{namespace}` does not exist."
if namespace:
suggestions = difflib.get_close_matches(namespace, self.namespaces(), n=1)
if suggestions:
message += f" Did you mean: `{suggestions[0]}`?"
else:
message += f" Have you installed the proper package for `{namespace}`?"
raise error.NamespaceNotFound(message)
def _assert_name_exists(self, namespace: Optional[str], name: str) -> None:
self._assert_namespace_exists(namespace)
if name in self.tree[namespace]:
return
if namespace is None and name in internal_env_relocation_map:
relocated_namespace, relocated_package = internal_env_relocation_map[name]
message = f"The environment `{name}` has been moved out of Gym to the package `{relocated_package}`."
# Check if the package is installed
# If not instruct the user to install the package and then how to instantiate the env
if importlib.util.find_spec(relocated_package) is None:
message += f" Please install the package via `pip install {relocated_package}`."
# Otherwise the user should be able to instantiate the environment directly
if namespace != relocated_namespace:
message += f" You can instantiate the new namespaced environment as `{relocated_namespace}/{name}`."
# If the environment hasn't been relocated we'll construct a generic error message
else:
message = f"Environment `{name}` doesn't exist"
if namespace is not None:
message += f" in namespace `{namespace}`"
message += "."
suggestions = difflib.get_close_matches(name, self.names(namespace), n=1)
if suggestions:
message += f" Did you mean: `{suggestions[0]}`?"
# Throw the error
raise error.NameNotFound(message)
def _assert_version_exists(
self, namespace: Optional[str], name: str, version: Optional[int]
):
self._assert_name_exists(namespace, name)
if version in self.tree[namespace][name]:
return
# Construct the appropriate exception.
# If the version is less than the latest version
# then we throw an error.DeprecatedEnv exception.
# Otherwise we throw error.VersionNotFound.
versions = self.tree[namespace][name]
assert len(versions) > 0
versioned_specs = list(
filter(lambda spec: isinstance(spec.version, int), versions.values())
)
default_spec = versions[None] if None in versions else None
assert len(versioned_specs) > 0 or default_spec is not None
latest_spec = max(
versioned_specs, key=lambda spec: spec.version, default=default_spec
)
if version is not None:
message = f"Environment version `v{version}` for `"
else:
message = "The default version for `"
if namespace is not None:
message += f"{namespace}/"
message += f"{name}` "
# If this version doesn't exist but there exists a newer non-default
# version we should warn the user this version is deprecated.
if (
latest_spec
and latest_spec.version is not None
and version is not None
and version < latest_spec.version
):
message += "is deprecated. "
message += f"Please use the latest version `v{latest_spec.version}`."
raise error.DeprecatedEnv(message)
# If this version doesn't exist and there only exists a default version
elif latest_spec and latest_spec.version is None:
message += "is deprecated. "
message += f"`{latest_spec.name}` only provides the default version. "
message += (
f'You can initialize the environment as `gym.make("{latest_spec.id}")`.'
)
raise error.DeprecatedEnv(message)
# Otherwise we've asked for a version that doesn't exist.
else:
message += f"could not be found. `{name}` provides "
if default_spec:
message += "a default version"
if versioned_specs:
message += " and "
if versioned_specs:
message += "the versioned environments: [ "
versioned_specs_sorted = sorted(
versioned_specs, key=lambda spec: spec.version
)
message += ", ".join(
map(lambda spec: f"`v{spec.version}`", versioned_specs_sorted)
)
message += " ]"
message += "."
raise error.VersionNotFound(message)
def __getitem__(self, key: str) -> EnvSpec:
# Get an item from the tree.
# We first parse the components so we can look up the
# appropriate environment ID.
namespace, name, version = parse_env_id(key)
self._assert_version_exists(namespace, name, version)
return self.tree[namespace][name][version]
def __setitem__(self, key: str, value: EnvSpec) -> None:
# Insert an item into the tree.
# First we parse the components to get the path
# for insertion.
namespace, name, version = parse_env_id(key)
self.tree[namespace][name][version] = value
# Increase the size
self._length += 1
def __delitem__(self, key: str) -> None:
# Delete an item from the tree.
# First parse the components so we can follow the
# path to delete.
namespace, name, version = parse_env_id(key)
self._assert_version_exists(namespace, name, version)
# Remove the envspec with this version.
self.tree[namespace][name].pop(version)
# Remove the name if it's empty.
if len(self.tree[namespace][name]) == 0:
self.tree[namespace].pop(name)
# Remove the namespace if it's empty.
if len(self.tree[namespace]) == 0:
self.tree.pop(namespace)
# Decrease the size
self._length -= 1
def __contains__(self, key: str) -> bool:
# Check if the tree contains a path for this key.
namespace, name, version = parse_env_id(key)
if (
namespace in self.tree
and name in self.tree[namespace]
and version in self.tree[namespace][name]
):
return True
return False
def __repr__(self) -> str:
# Construct a tree-like representation structure
# so we can easily look at the contents of the tree.
tree_repr = ""
for namespace, names in self.tree.items():
# For each namespace we'll iterate over the names
root = namespace is None
# Insert a separator if we're between depths
if len(tree_repr) > 0:
tree_repr += "\n"
# if this isn't the root we'll display the namespace
if not root:
tree_repr += f"├──{str(namespace)}\n"
# Construct the namespace string so we can print this for
# our children.
namespace = f"{namespace}/" if namespace is not None else ""
for name_idx, (name, versions) in enumerate(names.items()):
# If this isn't the root we'll have to increase our
# depth, i.e., insert some space
if not root:
tree_repr += ""
# If this is the last item make sure we use the
# termination character. Otherwise use the nested
# character.
if name_idx == len(names) - 1:
tree_repr += "└──"
else:
tree_repr += "├──"
# Print the namespace and the name
# and get ready to print the versions.
tree_repr += f"{namespace}{name}: [ "
# Print each version comma separated
for version_idx, version in enumerate(versions.keys()):
if version is not None:
tree_repr += f"v{version}"
else:
tree_repr += ""
if version_idx < len(versions) - 1:
tree_repr += ", "
tree_repr += " ]\n"
return tree_repr
def __len__(self):
# Return the length of the container
return self._length
class EnvRegistry:
"""Register an env by ID. IDs remain stable over time and are
guaranteed to resolve to the same environment dynamics (or be
desupported). The goal is that results on a particular environment
should always be comparable, and not depend on the version of the
code that was running.
"""
def __init__(self):
self.env_specs = EnvSpecTree()
self._ns: Optional[str] = None
def make(self, path: str, **kwargs) -> Env:
if len(kwargs) > 0:
logger.info("Making new env: %s (%s)", path, kwargs)
else:
logger.info("Making new env: %s", path)
# We need to manually parse the ID so we can check
# the version without error-ing out in self.spec
namespace, name, version = parse_env_id(path)
# Get all versions of this spec.
versions = self.env_specs.versions(namespace, name)
# We check what the latest version of the environment is and display
# a warning if the user is attempting to initialize an older version
# or an unversioned one.
latest_versioned_spec = max(
filter(lambda spec: spec.version, versions),
key=lambda spec: cast(int, spec.version),
default=None,
)
if (
latest_versioned_spec
and version is not None
and version < cast(int, latest_versioned_spec.version)
):
logger.warn(
f"The environment {path} is out of date. You should consider "
f"upgrading to version `v{latest_versioned_spec.version}` "
f"with the environment ID `{latest_versioned_spec.id}`."
)
elif latest_versioned_spec and version is None:
logger.warn(
f"Using the latest versioned environment `{latest_versioned_spec.id}` "
f"instead of the unversioned environment `{path}`"
)
path = latest_versioned_spec.id
# Lookup our path
spec = self.spec(path)
# Construct the environment
return spec.make(**kwargs)
def all(self):
return self.env_specs.values()
def spec(self, path: str) -> EnvSpec:
if ":" in path:
mod_name, _, id = path.partition(":")
try:
importlib.import_module(mod_name)
except ModuleNotFoundError:
raise error.Error(
f"A module ({mod_name}) was specified for the environment but was not found, "
"make sure the package is installed with `pip install` before calling `gym.make()`"
)
else:
id = path
# We can go ahead and return the env_spec.
# The EnvSpecTree will take care of any exceptions.
return self.env_specs[id]
def register(self, id: str, **kwargs) -> None:
spec = EnvSpec(id, **kwargs)
if self._ns is not None:
if spec.namespace is not None:
logger.warn(
f"Custom namespace `{spec.namespace}` is being overridden "
f"by namespace `{self._ns}`. If you are developing a "
"plugin you shouldn't specify a namespace in `register` "
"calls. The namespace is specified through the "
"entry point package metadata."
)
# Replace namespace
spec.namespace = self._ns
try:
# Get all versions of this spec.
versions = self.env_specs.versions(spec.namespace, spec.name)
# We raise an error if the user is attempting to initialize an
# unversioned environment when a versioned one already exists.
latest_versioned_spec = max(
filter(lambda spec: isinstance(spec.version, int), versions),
key=lambda spec: cast(int, spec.version),
default=None,
)
unversioned_spec = next(
filter(lambda spec: spec.version is None, versions), None
)
# Trying to register an unversioned spec when versioned spec exists
if unversioned_spec and spec.version is not None:
message = (
"Can't register the versioned environment "
f"`{spec.id}` when the unversioned environment "
f"`{unversioned_spec.id}` of the same name already exists."
)
raise error.RegistrationError(message)
elif latest_versioned_spec and spec.version is None:
message = (
f"Can't register the unversioned environment `{spec.id}` "
f"when version `{latest_versioned_spec.version}` "
"of the same name already exists. Note: the default "
"behavior is that the `gym.make` with the unversioned "
"environment will return the latest versioned environment."
)
raise error.RegistrationError(message)
# We might not find this namespace or name in which case
# we should continue to register the environment.
except (error.NamespaceNotFound, error.NameNotFound):
pass
finally:
if spec.id in self.env_specs:
logger.warn(f"Overriding environment {id}")
self.env_specs[spec.id] = spec
@contextlib.contextmanager
def namespace(self, ns: str):
self._ns = ns
yield
self._ns = None
def __repr__(self):
return repr(self.env_specs)
# Have a global registry
registry = EnvRegistry()
def register(id: str, **kwargs) -> None:
return registry.register(id, **kwargs)
# fmt: off
# Continuous
# ----------------------------------------
@overload
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Box2d
# ----------------------------------------
@overload
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["CarRacing-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
# Toy Text
# ----------------------------------------
@overload
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Mujoco
# ----------------------------------------
@overload
def make(id: Literal[
"Reacher-v2",
"Pusher-v2",
"Thrower-v2",
"Striker-v2",
"InvertedPendulum-v2",
"InvertedDoublePendulum-v2",
"HalfCheetah-v2", "HalfCheetah-v3",
"Hopper-v2", "Hopper-v3",
"Swimmer-v2", "Swimmer-v3",
"Walker2d-v2", "Walker2d-v3",
"Ant-v2"
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
# ----------------------------------------
@overload
def make(id: str, **kwargs) -> Env: ...
# fmt: on
def make(id: str, **kwargs) -> Env:
return registry.make(id, **kwargs)
def spec(id: str) -> EnvSpec:
return registry.spec(id)
@contextlib.contextmanager
def namespace(ns: str):
with registry.namespace(ns):
yield
def load_env_plugins(entry_point: str = "gym.envs") -> None:
# Load third-party environments
for plugin in metadata.entry_points(group=entry_point):
# Python 3.8 doesn't support plugin.module, plugin.attr
# So we'll have to try and parse this ourselves
try:
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
except AttributeError:
if ":" in plugin.value:
module, attr = plugin.value.split(":", maxsplit=1)
else:
module, attr = plugin.value, None
except:
module, attr = None, None
finally:
if attr is None:
raise error.Error(
f"Gym environment plugin `{module}` must specify a function to execute, not a root module"
)
context = namespace(plugin.name)
if plugin.name.startswith("__") and plugin.name.endswith("__"):
# `__internal__` is an artifact of the plugin system when
# the root namespace had an allow-list. The allow-list is now
# removed and plugins can register environments in the root
# namespace with the `__root__` magic key.
if plugin.name == "__root__" or plugin.name == "__internal__":
context = contextlib.nullcontext()
else:
logger.warn(
f"The environment namespace magic key `{plugin.name}` is unsupported. "
"To register an environment at the root namespace you should specify "
"the `__root__` namespace."
)
with context:
fn = plugin.load()
try:
fn()
except Exception as e:
logger.warn(str(e))