mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 09:55:39 +00:00
Rewriting of the registration mechanism (#2748)
* First version of the new registration * Almost done * Hopefully final commit * Minor fixes * Missing error * Type fixes * Type fixes * Add some type hinting stuff * Fix an error? * Fix literal import * Add a comment * Add some docstrings Remove old tests * Add some docstrings, rename helper functions * Rename a function * Registration check fix * Consistently use `register` instead of `envs.register` in tests * Fix the malformed registration error message to not use a write-only format * Change an error back to a warning when double-registering an environment
This commit is contained in:
committed by
GitHub
parent
0a5f543d6a
commit
00a60e6cc8
@@ -106,8 +106,8 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
integer seed right after initialization and then never again.
|
integer seed right after initialization and then never again.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed (int or None): The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom). However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset.
|
seed (int or None): The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
|
||||||
If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action.
|
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action.
|
||||||
return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step`
|
return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step`
|
||||||
options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment)
|
options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment)
|
||||||
|
|
||||||
|
@@ -7,34 +7,29 @@ import importlib
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
SupportsFloat,
|
SupportsFloat,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gym.envs.__relocated__ import internal_env_relocation_map
|
||||||
|
from gym.wrappers import AutoResetWrapper, OrderEnforcing, TimeLimit
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
import importlib_metadata as metadata # type: ignore
|
import importlib_metadata as metadata # type: ignore
|
||||||
else:
|
else:
|
||||||
import importlib.metadata as metadata
|
import importlib.metadata as metadata
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from collections.abc import MutableMapping
|
|
||||||
from dataclasses import InitVar, dataclass, field
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from gym import Env, error, logger
|
|
||||||
from gym.envs.__relocated__ import internal_env_relocation_map
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
if sys.version_info >= (3, 8):
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
else:
|
else:
|
||||||
@@ -44,6 +39,8 @@ else:
|
|||||||
return Any
|
return Any
|
||||||
|
|
||||||
|
|
||||||
|
from gym import Env, error, logger
|
||||||
|
|
||||||
ENV_ID_RE: re.Pattern = re.compile(
|
ENV_ID_RE: re.Pattern = re.compile(
|
||||||
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
|
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
|
||||||
)
|
)
|
||||||
@@ -60,16 +57,16 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
|||||||
"""Parse environment ID string format.
|
"""Parse environment ID string format.
|
||||||
|
|
||||||
This format is true today, but it's *not* an official spec.
|
This format is true today, but it's *not* an official spec.
|
||||||
[username/](env-name)-v(version) env-name is group 1, version is group 2
|
[namespace/](env-name)-v(version) env-name is group 1, version is group 2
|
||||||
|
|
||||||
2016-10-31: We're experimentally expanding the environment ID format
|
2016-10-31: We're experimentally expanding the environment ID format
|
||||||
to include an optional username.
|
to include an optional namespace.
|
||||||
"""
|
"""
|
||||||
match = ENV_ID_RE.fullmatch(id)
|
match = ENV_ID_RE.fullmatch(id)
|
||||||
if not match:
|
if not match:
|
||||||
raise error.Error(
|
raise error.Error(
|
||||||
f"Malformed environment ID: {id}."
|
f"Malformed environment ID: {id}."
|
||||||
f"(Currently all IDs must be of the form {ENV_ID_RE}.)"
|
f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
|
||||||
)
|
)
|
||||||
namespace, name, version = match.group("namespace", "name", "version")
|
namespace, name, version = match.group("namespace", "name", "version")
|
||||||
if version is not None:
|
if version is not None:
|
||||||
@@ -78,24 +75,21 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
|||||||
return namespace, name, version
|
return namespace, name, version
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_id(ns: Optional[str], name: str, version: Optional[int]):
|
||||||
|
"""Get the full env ID given a name and (optional) version and namespace.
|
||||||
|
Inverse of parse_env_id."""
|
||||||
|
|
||||||
|
full_name = name
|
||||||
|
if version is not None:
|
||||||
|
full_name += f"-v{version}"
|
||||||
|
if ns is not None:
|
||||||
|
full_name = ns + "/" + full_name
|
||||||
|
return full_name
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnvSpec:
|
class EnvSpec:
|
||||||
"""A specification for a particular instance of the environment. Used
|
id: str
|
||||||
to register the parameters for official evaluations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
id_requested: The official environment ID
|
|
||||||
entry_point: The Python entrypoint of the environment class (e.g. module.name:Class)
|
|
||||||
reward_threshold: The reward threshold before the task is considered solved
|
|
||||||
nondeterministic: Whether this environment is non-deterministic even after seeding
|
|
||||||
max_episode_steps: The maximum number of steps that an episode can consist of
|
|
||||||
order_enforce: Whether to wrap the environment in an orderEnforcing wrapper
|
|
||||||
autoreset: Whether the environment should automatically reset when it reaches the done state
|
|
||||||
kwargs: The kwargs to pass to the environment class
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
id_requested: InitVar[str]
|
|
||||||
entry_point: Optional[Union[Callable, str]] = field(default=None)
|
entry_point: Optional[Union[Callable, str]] = field(default=None)
|
||||||
reward_threshold: Optional[float] = field(default=None)
|
reward_threshold: Optional[float] = field(default=None)
|
||||||
nondeterministic: bool = field(default=False)
|
nondeterministic: bool = field(default=False)
|
||||||
@@ -103,533 +97,173 @@ class EnvSpec:
|
|||||||
order_enforce: bool = field(default=True)
|
order_enforce: bool = field(default=True)
|
||||||
autoreset: bool = field(default=False)
|
autoreset: bool = field(default=False)
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
namespace: Optional[str] = field(init=False)
|
namespace: Optional[str] = field(init=False)
|
||||||
name: str = field(init=False)
|
name: str = field(init=False)
|
||||||
version: Optional[int] = field(init=False)
|
version: Optional[int] = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self, id_requested):
|
def __post_init__(self):
|
||||||
# Initialize namespace, name, version
|
# Initialize namespace, name, version
|
||||||
self.namespace, self.name, self.version = parse_env_id(id_requested)
|
self.namespace, self.name, self.version = parse_env_id(self.id)
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> str:
|
|
||||||
"""
|
|
||||||
`id_requested` is an InitVar meaning it's only used at initialization to parse
|
|
||||||
the namespace, name, and version. This means we can define the dynamic
|
|
||||||
property `id` to construct the `id` from the parsed fields. This has the
|
|
||||||
benefit that we update the fields and obtain a dynamic id.
|
|
||||||
"""
|
|
||||||
namespace = "" if self.namespace is None else f"{self.namespace}/"
|
|
||||||
name = self.name
|
|
||||||
version = "" if self.version is None else f"-v{self.version}"
|
|
||||||
return f"{namespace}{name}{version}"
|
|
||||||
|
|
||||||
def make(self, **kwargs) -> Env:
|
def make(self, **kwargs) -> Env:
|
||||||
"""Instantiates an instance of the environment with appropriate kwargs"""
|
# For compatibility purposes
|
||||||
if self.entry_point is None:
|
return make(self, **kwargs)
|
||||||
raise error.Error(
|
|
||||||
f"Attempting to make deprecated env {self.id}. "
|
|
||||||
"(HINT: is there a newer registered version of this env?)"
|
|
||||||
)
|
|
||||||
_kwargs = self.kwargs.copy()
|
|
||||||
_kwargs.update(kwargs)
|
|
||||||
|
|
||||||
if "autoreset" in _kwargs:
|
|
||||||
self.autoreset = _kwargs["autoreset"]
|
|
||||||
del _kwargs["autoreset"]
|
|
||||||
|
|
||||||
if callable(self.entry_point):
|
|
||||||
env = self.entry_point(**_kwargs)
|
|
||||||
else:
|
|
||||||
cls = load(self.entry_point)
|
|
||||||
env = cls(**_kwargs)
|
|
||||||
|
|
||||||
# Make the environment aware of which spec it came from.
|
|
||||||
spec = copy.deepcopy(self)
|
|
||||||
spec.kwargs = _kwargs
|
|
||||||
env.unwrapped.spec = spec
|
|
||||||
|
|
||||||
if self.order_enforce:
|
|
||||||
from gym.wrappers.order_enforcing import OrderEnforcing
|
|
||||||
|
|
||||||
env = OrderEnforcing(env)
|
|
||||||
|
|
||||||
assert env.spec is not None, "expected spec to be set to the unwrapped env."
|
|
||||||
if env.spec.max_episode_steps is not None:
|
|
||||||
from gym.wrappers.time_limit import TimeLimit
|
|
||||||
|
|
||||||
env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
|
|
||||||
|
|
||||||
if self.autoreset:
|
|
||||||
from gym.wrappers.autoreset import AutoResetWrapper
|
|
||||||
|
|
||||||
env = AutoResetWrapper(env)
|
|
||||||
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
class EnvSpecTree(MutableMapping):
|
def _check_namespace_exists(ns: Optional[str]):
|
||||||
"""
|
"""Check if a namespace exists. If it doesn't, print a helpful error message."""
|
||||||
The EnvSpecTree provides a dict-like mapping object
|
if ns is None:
|
||||||
from environment IDs to specifications.
|
return
|
||||||
|
namespaces = {
|
||||||
The EnvSpecTree is backed by a tree-like structure.
|
spec_.namespace for spec_ in registry.values() if spec_.namespace is not None
|
||||||
The environment ID format is [{namespace}/]{name}-v{version}.
|
|
||||||
|
|
||||||
The tree has multiple root nodes corresponding to a namespace.
|
|
||||||
The children of a namespace node corresponds to the environment name.
|
|
||||||
Furthermore, each name has a mapping from versions to specifications.
|
|
||||||
It looks like the following,
|
|
||||||
|
|
||||||
{
|
|
||||||
None: {
|
|
||||||
MountainCar: {
|
|
||||||
0: EnvSpec(...),
|
|
||||||
1: EnvSpec(...)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ALE: {
|
|
||||||
Tetris: {
|
|
||||||
5: EnvSpec(...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if ns in namespaces:
|
||||||
|
return
|
||||||
|
|
||||||
The tree-structure isn't user-facing and the EnvSpecTree will act
|
suggestion = (
|
||||||
like a dictionary. For example, to lookup an environment ID:
|
difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
|
||||||
|
)
|
||||||
|
suggestion_msg = (
|
||||||
|
f"Did you mean: `{suggestion[0]}`?"
|
||||||
|
if suggestion
|
||||||
|
else f"Have you installed the proper package for {ns}?"
|
||||||
|
)
|
||||||
|
|
||||||
```
|
raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
|
||||||
specs = EnvSpecTree()
|
|
||||||
|
|
||||||
specs["My/Env-v0"] = EnvSpec(...)
|
|
||||||
assert specs["My/Env-v0"] == EnvSpec(...)
|
|
||||||
|
|
||||||
assert specs.tree["My"]["Env"]["0"] == specs["My/Env-v0"]
|
def _check_name_exists(ns: Optional[str], name: str):
|
||||||
```
|
"""Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
|
||||||
"""
|
_check_namespace_exists(ns)
|
||||||
|
names = {spec_.name for spec_ in registry.values()}
|
||||||
|
|
||||||
def __init__(self):
|
if name in names:
|
||||||
# Initialize the tree as a nested sequence of defaultdicts
|
return
|
||||||
self.tree = defaultdict(lambda: defaultdict(dict))
|
|
||||||
self._length = 0
|
|
||||||
|
|
||||||
def versions(self, namespace: Optional[str], name: str) -> Sequence[EnvSpec]:
|
if namespace is None and name in internal_env_relocation_map:
|
||||||
"""
|
relocated_namespace, relocated_package = internal_env_relocation_map[name]
|
||||||
Returns the versions associated with a namespace and name.
|
message = f"The environment `{name}` has been moved out of Gym to the package `{relocated_package}`."
|
||||||
|
|
||||||
Note: This function takes into account environment relocations.
|
# Check if the package is installed
|
||||||
For example, `versions(None, "Breakout")` will return,
|
# If not instruct the user to install the package and then how to instantiate the env
|
||||||
```
|
if importlib.util.find_spec(relocated_package) is None:
|
||||||
[
|
message += (
|
||||||
EnvSpec(namespace=None, name="Breakout", version=0),
|
f" Please install the package via `pip install {relocated_package}`."
|
||||||
EnvSpec(namespace=None, name="Breakout", version=4),
|
)
|
||||||
EnvSpec(namespace="ALE", name="Breakout", version=5)
|
|
||||||
]
|
|
||||||
```
|
|
||||||
Notice the last environment which is outside of the requested namespace.
|
|
||||||
This only applies to environments which are in the `internal_env_relocation_map`.
|
|
||||||
See `gym/envs/__relocated__.py` for more info.
|
|
||||||
"""
|
|
||||||
self._assert_name_exists(namespace, name)
|
|
||||||
|
|
||||||
versions = list(self.tree[namespace][name].values())
|
# Otherwise the user should be able to instantiate the environment directly
|
||||||
|
if namespace != relocated_namespace:
|
||||||
|
message += f" You can instantiate the new namespaced environment as `{relocated_namespace}/{name}`."
|
||||||
|
|
||||||
if namespace is None and name in internal_env_relocation_map:
|
|
||||||
relocated_namespace, _ = internal_env_relocation_map[name]
|
|
||||||
try:
|
|
||||||
self._assert_name_exists(relocated_namespace, name)
|
|
||||||
versions += list(self.tree[relocated_namespace][name].values())
|
|
||||||
except error.UnregisteredEnv:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return versions
|
|
||||||
|
|
||||||
def names(self, namespace: Optional[str]) -> Sequence[str]:
|
|
||||||
"""
|
|
||||||
Returns all the environment names associated with a namespace.
|
|
||||||
"""
|
|
||||||
self._assert_namespace_exists(namespace)
|
|
||||||
return list(self.tree[namespace].keys())
|
|
||||||
|
|
||||||
def namespaces(self) -> Sequence[str]:
|
|
||||||
"""
|
|
||||||
Returns all the namespaces contained in the tree.
|
|
||||||
"""
|
|
||||||
return list(filter(None, self.tree.keys()))
|
|
||||||
|
|
||||||
def __iter__(self) -> Generator[str, None, None]:
|
|
||||||
# Iterate through the structure and generate the IDs contained in the tree.
|
|
||||||
for namespace, names in self.tree.items():
|
|
||||||
for name, versions in names.items():
|
|
||||||
for version, spec in versions.items():
|
|
||||||
assert spec.namespace == namespace
|
|
||||||
assert spec.name == name
|
|
||||||
assert spec.version == version
|
|
||||||
yield spec.id
|
|
||||||
|
|
||||||
def _assert_namespace_exists(self, namespace: Optional[str]) -> None:
|
|
||||||
if namespace in self.tree:
|
|
||||||
return
|
|
||||||
|
|
||||||
message = f"Namespace `{namespace}` does not exist."
|
|
||||||
if namespace:
|
|
||||||
suggestions = difflib.get_close_matches(namespace, self.namespaces(), n=1)
|
|
||||||
if suggestions:
|
|
||||||
message += f" Did you mean: `{suggestions[0]}`?"
|
|
||||||
else:
|
|
||||||
message += f" Have you installed the proper package for `{namespace}`?"
|
|
||||||
raise error.NamespaceNotFound(message)
|
|
||||||
|
|
||||||
def _assert_name_exists(self, namespace: Optional[str], name: str) -> None:
|
|
||||||
self._assert_namespace_exists(namespace)
|
|
||||||
if name in self.tree[namespace]:
|
|
||||||
return
|
|
||||||
|
|
||||||
if namespace is None and name in internal_env_relocation_map:
|
|
||||||
relocated_namespace, relocated_package = internal_env_relocation_map[name]
|
|
||||||
message = f"The environment `{name}` has been moved out of Gym to the package `{relocated_package}`."
|
|
||||||
|
|
||||||
# Check if the package is installed
|
|
||||||
# If not instruct the user to install the package and then how to instantiate the env
|
|
||||||
if importlib.util.find_spec(relocated_package) is None:
|
|
||||||
message += f" Please install the package via `pip install {relocated_package}`."
|
|
||||||
|
|
||||||
# Otherwise the user should be able to instantiate the environment directly
|
|
||||||
if namespace != relocated_namespace:
|
|
||||||
message += f" You can instantiate the new namespaced environment as `{relocated_namespace}/{name}`."
|
|
||||||
# If the environment hasn't been relocated we'll construct a generic error message
|
|
||||||
else:
|
|
||||||
message = f"Environment `{name}` doesn't exist"
|
|
||||||
if namespace is not None:
|
|
||||||
message += f" in namespace `{namespace}`"
|
|
||||||
message += "."
|
|
||||||
suggestions = difflib.get_close_matches(name, self.names(namespace), n=1)
|
|
||||||
if suggestions:
|
|
||||||
message += f" Did you mean: `{suggestions[0]}`?"
|
|
||||||
# Throw the error
|
|
||||||
raise error.NameNotFound(message)
|
raise error.NameNotFound(message)
|
||||||
|
|
||||||
def _assert_version_exists(
|
suggestion = difflib.get_close_matches(name, names, n=1)
|
||||||
self, namespace: Optional[str], name: str, version: Optional[int]
|
namespace_msg = f" in namespace {ns}" if ns else ""
|
||||||
):
|
suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else ""
|
||||||
self._assert_name_exists(namespace, name)
|
|
||||||
if version in self.tree[namespace][name]:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Construct the appropriate exception.
|
raise error.NameNotFound(
|
||||||
# If the version is less than the latest version
|
f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}"
|
||||||
# then we throw an error.DeprecatedEnv exception.
|
)
|
||||||
# Otherwise we throw error.VersionNotFound.
|
|
||||||
versions = self.tree[namespace][name]
|
|
||||||
assert len(versions) > 0
|
|
||||||
|
|
||||||
versioned_specs = list(
|
|
||||||
filter(lambda spec: isinstance(spec.version, int), versions.values())
|
|
||||||
)
|
|
||||||
default_spec = versions[None] if None in versions else None
|
|
||||||
assert len(versioned_specs) > 0 or default_spec is not None
|
|
||||||
|
|
||||||
latest_spec = max(
|
def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
|
||||||
versioned_specs, key=lambda spec: spec.version, default=default_spec
|
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
|
||||||
)
|
This is a complete test whether an environment identifier is valid, and will provide the best available hints."""
|
||||||
|
if get_env_id(ns, name, version) in registry:
|
||||||
|
return
|
||||||
|
|
||||||
if version is not None:
|
_check_name_exists(ns, name)
|
||||||
message = f"Environment version `v{version}` for `"
|
if version is None:
|
||||||
else:
|
return
|
||||||
message = "The default version for `"
|
|
||||||
|
|
||||||
if namespace is not None:
|
message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
|
||||||
message += f"{namespace}/"
|
|
||||||
message += f"{name}` "
|
|
||||||
|
|
||||||
# If this version doesn't exist but there exists a newer non-default
|
env_specs = [
|
||||||
# version we should warn the user this version is deprecated.
|
spec_
|
||||||
if (
|
for spec_ in registry.values()
|
||||||
latest_spec
|
if spec_.namespace == ns and spec_.name == name
|
||||||
and latest_spec.version is not None
|
]
|
||||||
and version is not None
|
env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1))
|
||||||
and version < latest_spec.version
|
|
||||||
):
|
default_spec = [spec_ for spec_ in env_specs if spec_.version is None]
|
||||||
message += "is deprecated. "
|
|
||||||
message += f"Please use the latest version `v{latest_spec.version}`."
|
if default_spec:
|
||||||
|
message += f" It provides the default version {default_spec[0].id}`."
|
||||||
|
if len(env_specs) == 1:
|
||||||
raise error.DeprecatedEnv(message)
|
raise error.DeprecatedEnv(message)
|
||||||
# If this version doesn't exist and there only exists a default version
|
|
||||||
elif latest_spec and latest_spec.version is None:
|
|
||||||
message += "is deprecated. "
|
|
||||||
message += f"`{latest_spec.name}` only provides the default version. "
|
|
||||||
message += (
|
|
||||||
f'You can initialize the environment as `gym.make("{latest_spec.id}")`.'
|
|
||||||
)
|
|
||||||
raise error.DeprecatedEnv(message)
|
|
||||||
# Otherwise we've asked for a version that doesn't exist.
|
|
||||||
else:
|
|
||||||
message += f"could not be found. `{name}` provides "
|
|
||||||
|
|
||||||
if default_spec:
|
# Process possible versioned environments
|
||||||
message += "a default version"
|
|
||||||
if versioned_specs:
|
|
||||||
message += " and "
|
|
||||||
if versioned_specs:
|
|
||||||
message += "the versioned environments: [ "
|
|
||||||
versioned_specs_sorted = sorted(
|
|
||||||
versioned_specs, key=lambda spec: spec.version
|
|
||||||
)
|
|
||||||
message += ", ".join(
|
|
||||||
map(lambda spec: f"`v{spec.version}`", versioned_specs_sorted)
|
|
||||||
)
|
|
||||||
message += " ]"
|
|
||||||
message += "."
|
|
||||||
raise error.VersionNotFound(message)
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> EnvSpec:
|
versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None]
|
||||||
# Get an item from the tree.
|
|
||||||
# We first parse the components so we can look up the
|
|
||||||
# appropriate environment ID.
|
|
||||||
namespace, name, version = parse_env_id(key)
|
|
||||||
self._assert_version_exists(namespace, name, version)
|
|
||||||
|
|
||||||
return self.tree[namespace][name][version]
|
latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore
|
||||||
|
if latest_spec is not None and version > latest_spec.version:
|
||||||
|
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs)
|
||||||
|
message += f" It provides versioned environments: [ {version_list_msg} ]."
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: EnvSpec) -> None:
|
raise error.VersionNotFound(message)
|
||||||
# Insert an item into the tree.
|
|
||||||
# First we parse the components to get the path
|
|
||||||
# for insertion.
|
|
||||||
namespace, name, version = parse_env_id(key)
|
|
||||||
self.tree[namespace][name][version] = value
|
|
||||||
# Increase the size
|
|
||||||
self._length += 1
|
|
||||||
|
|
||||||
def __delitem__(self, key: str) -> None:
|
if latest_spec is not None and version < latest_spec.version:
|
||||||
# Delete an item from the tree.
|
raise error.DeprecatedEnv(
|
||||||
# First parse the components so we can follow the
|
f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. "
|
||||||
# path to delete.
|
f"Please use `{latest_spec.id}` instead."
|
||||||
namespace, name, version = parse_env_id(key)
|
|
||||||
self._assert_version_exists(namespace, name, version)
|
|
||||||
|
|
||||||
# Remove the envspec with this version.
|
|
||||||
self.tree[namespace][name].pop(version)
|
|
||||||
# Remove the name if it's empty.
|
|
||||||
if len(self.tree[namespace][name]) == 0:
|
|
||||||
self.tree[namespace].pop(name)
|
|
||||||
# Remove the namespace if it's empty.
|
|
||||||
if len(self.tree[namespace]) == 0:
|
|
||||||
self.tree.pop(namespace)
|
|
||||||
# Decrease the size
|
|
||||||
self._length -= 1
|
|
||||||
|
|
||||||
def __contains__(self, key: str) -> bool:
|
|
||||||
# Check if the tree contains a path for this key.
|
|
||||||
namespace, name, version = parse_env_id(key)
|
|
||||||
if (
|
|
||||||
namespace in self.tree
|
|
||||||
and name in self.tree[namespace]
|
|
||||||
and version in self.tree[namespace][name]
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
# Construct a tree-like representation structure
|
|
||||||
# so we can easily look at the contents of the tree.
|
|
||||||
tree_repr = ""
|
|
||||||
for namespace, names in self.tree.items():
|
|
||||||
# For each namespace we'll iterate over the names
|
|
||||||
root = namespace is None
|
|
||||||
# Insert a separator if we're between depths
|
|
||||||
if len(tree_repr) > 0:
|
|
||||||
tree_repr += "│\n"
|
|
||||||
# if this isn't the root we'll display the namespace
|
|
||||||
if not root:
|
|
||||||
tree_repr += f"├──{str(namespace)}\n"
|
|
||||||
|
|
||||||
# Construct the namespace string so we can print this for
|
|
||||||
# our children.
|
|
||||||
namespace = f"{namespace}/" if namespace is not None else ""
|
|
||||||
for name_idx, (name, versions) in enumerate(names.items()):
|
|
||||||
# If this isn't the root we'll have to increase our
|
|
||||||
# depth, i.e., insert some space
|
|
||||||
if not root:
|
|
||||||
tree_repr += "│ "
|
|
||||||
# If this is the last item make sure we use the
|
|
||||||
# termination character. Otherwise use the nested
|
|
||||||
# character.
|
|
||||||
if name_idx == len(names) - 1:
|
|
||||||
tree_repr += "└──"
|
|
||||||
else:
|
|
||||||
tree_repr += "├──"
|
|
||||||
# Print the namespace and the name
|
|
||||||
# and get ready to print the versions.
|
|
||||||
tree_repr += f"{namespace}{name}: [ "
|
|
||||||
# Print each version comma separated
|
|
||||||
for version_idx, version in enumerate(versions.keys()):
|
|
||||||
if version is not None:
|
|
||||||
tree_repr += f"v{version}"
|
|
||||||
else:
|
|
||||||
tree_repr += ""
|
|
||||||
if version_idx < len(versions) - 1:
|
|
||||||
tree_repr += ", "
|
|
||||||
tree_repr += " ]\n"
|
|
||||||
|
|
||||||
return tree_repr
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
# Return the length of the container
|
|
||||||
return self._length
|
|
||||||
|
|
||||||
|
|
||||||
class EnvRegistry:
|
|
||||||
"""Register an env by ID. IDs remain stable over time and are
|
|
||||||
guaranteed to resolve to the same environment dynamics (or be
|
|
||||||
desupported). The goal is that results on a particular environment
|
|
||||||
should always be comparable, and not depend on the version of the
|
|
||||||
code that was running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.env_specs = EnvSpecTree()
|
|
||||||
self._ns: Optional[str] = None
|
|
||||||
|
|
||||||
def make(self, path: str, **kwargs) -> Env:
|
|
||||||
if len(kwargs) > 0:
|
|
||||||
logger.info("Making new env: %s (%s)", path, kwargs)
|
|
||||||
else:
|
|
||||||
logger.info("Making new env: %s", path)
|
|
||||||
|
|
||||||
# We need to manually parse the ID so we can check
|
|
||||||
# the version without error-ing out in self.spec
|
|
||||||
namespace, name, version = parse_env_id(path)
|
|
||||||
|
|
||||||
# Get all versions of this spec.
|
|
||||||
versions = self.env_specs.versions(namespace, name)
|
|
||||||
|
|
||||||
# We check what the latest version of the environment is and display
|
|
||||||
# a warning if the user is attempting to initialize an older version
|
|
||||||
# or an unversioned one.
|
|
||||||
latest_versioned_spec = max(
|
|
||||||
filter(lambda spec: spec.version, versions),
|
|
||||||
key=lambda spec: cast(int, spec.version),
|
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
latest_versioned_spec
|
|
||||||
and version is not None
|
|
||||||
and version < cast(int, latest_versioned_spec.version)
|
|
||||||
):
|
|
||||||
logger.warn(
|
|
||||||
f"The environment {path} is out of date. You should consider "
|
|
||||||
f"upgrading to version `v{latest_versioned_spec.version}` "
|
|
||||||
f"with the environment ID `{latest_versioned_spec.id}`."
|
|
||||||
)
|
|
||||||
elif latest_versioned_spec and version is None:
|
|
||||||
logger.warn(
|
|
||||||
f"Using the latest versioned environment `{latest_versioned_spec.id}` "
|
|
||||||
f"instead of the unversioned environment `{path}`"
|
|
||||||
)
|
|
||||||
path = latest_versioned_spec.id
|
|
||||||
|
|
||||||
# Lookup our path
|
|
||||||
spec = self.spec(path)
|
|
||||||
# Construct the environment
|
|
||||||
return spec.make(**kwargs)
|
|
||||||
|
|
||||||
def all(self):
|
def find_highest_version(ns: Optional[str], name: str) -> Optional[int]:
|
||||||
return self.env_specs.values()
|
version: list[int] = [
|
||||||
|
spec_.version
|
||||||
|
for spec_ in registry.values()
|
||||||
|
if spec_.namespace == ns and spec_.name == name and spec_.version is not None
|
||||||
|
]
|
||||||
|
return max(version, default=None)
|
||||||
|
|
||||||
def spec(self, path: str) -> EnvSpec:
|
|
||||||
if ":" in path:
|
|
||||||
mod_name, _, id = path.partition(":")
|
|
||||||
try:
|
|
||||||
importlib.import_module(mod_name)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
raise error.Error(
|
|
||||||
f"A module ({mod_name}) was specified for the environment but was not found, "
|
|
||||||
"make sure the package is installed with `pip install` before calling `gym.make()`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
id = path
|
|
||||||
|
|
||||||
# We can go ahead and return the env_spec.
|
|
||||||
# The EnvSpecTree will take care of any exceptions.
|
|
||||||
return self.env_specs[id]
|
|
||||||
|
|
||||||
def register(self, id: str, **kwargs) -> None:
|
|
||||||
spec = EnvSpec(id, **kwargs)
|
|
||||||
|
|
||||||
if self._ns is not None:
|
|
||||||
if spec.namespace is not None:
|
|
||||||
logger.warn(
|
|
||||||
f"Custom namespace `{spec.namespace}` is being overridden "
|
|
||||||
f"by namespace `{self._ns}`. If you are developing a "
|
|
||||||
"plugin you shouldn't specify a namespace in `register` "
|
|
||||||
"calls. The namespace is specified through the "
|
|
||||||
"entry point package metadata."
|
|
||||||
)
|
|
||||||
# Replace namespace
|
|
||||||
spec.namespace = self._ns
|
|
||||||
|
|
||||||
|
def load_env_plugins(entry_point: str = "gym.envs") -> None:
|
||||||
|
# Load third-party environments
|
||||||
|
for plugin in metadata.entry_points(group=entry_point):
|
||||||
|
# Python 3.8 doesn't support plugin.module, plugin.attr
|
||||||
|
# So we'll have to try and parse this ourselves
|
||||||
try:
|
try:
|
||||||
# Get all versions of this spec.
|
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
|
||||||
versions = self.env_specs.versions(spec.namespace, spec.name)
|
except AttributeError:
|
||||||
|
if ":" in plugin.value:
|
||||||
# We raise an error if the user is attempting to initialize an
|
module, attr = plugin.value.split(":", maxsplit=1)
|
||||||
# unversioned environment when a versioned one already exists.
|
else:
|
||||||
latest_versioned_spec = max(
|
module, attr = plugin.value, None
|
||||||
filter(lambda spec: isinstance(spec.version, int), versions),
|
except:
|
||||||
key=lambda spec: cast(int, spec.version),
|
module, attr = None, None
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
unversioned_spec = next(
|
|
||||||
filter(lambda spec: spec.version is None, versions), None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trying to register an unversioned spec when versioned spec exists
|
|
||||||
if unversioned_spec and spec.version is not None:
|
|
||||||
message = (
|
|
||||||
"Can't register the versioned environment "
|
|
||||||
f"`{spec.id}` when the unversioned environment "
|
|
||||||
f"`{unversioned_spec.id}` of the same name already exists."
|
|
||||||
)
|
|
||||||
raise error.RegistrationError(message)
|
|
||||||
elif latest_versioned_spec and spec.version is None:
|
|
||||||
message = (
|
|
||||||
f"Can't register the unversioned environment `{spec.id}` "
|
|
||||||
f"when version `{latest_versioned_spec.version}` "
|
|
||||||
"of the same name already exists. Note: the default "
|
|
||||||
"behavior is that the `gym.make` with the unversioned "
|
|
||||||
"environment will return the latest versioned environment."
|
|
||||||
)
|
|
||||||
raise error.RegistrationError(message)
|
|
||||||
# We might not find this namespace or name in which case
|
|
||||||
# we should continue to register the environment.
|
|
||||||
except (error.NamespaceNotFound, error.NameNotFound):
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
if spec.id in self.env_specs:
|
if attr is None:
|
||||||
logger.warn(f"Overriding environment {id}")
|
raise error.Error(
|
||||||
self.env_specs[spec.id] = spec
|
f"Gym environment plugin `{module}` must specify a function to execute, not a root module"
|
||||||
|
)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
context = namespace(plugin.name)
|
||||||
def namespace(self, ns: str):
|
if plugin.name.startswith("__") and plugin.name.endswith("__"):
|
||||||
self._ns = ns
|
# `__internal__` is an artifact of the plugin system when
|
||||||
yield
|
# the root namespace had an allow-list. The allow-list is now
|
||||||
self._ns = None
|
# removed and plugins can register environments in the root
|
||||||
|
# namespace with the `__root__` magic key.
|
||||||
|
if plugin.name == "__root__" or plugin.name == "__internal__":
|
||||||
|
context = contextlib.nullcontext()
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
f"The environment namespace magic key `{plugin.name}` is unsupported. "
|
||||||
|
"To register an environment at the root namespace you should specify "
|
||||||
|
"the `__root__` namespace."
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
with context:
|
||||||
return repr(self.env_specs)
|
fn = plugin.load()
|
||||||
|
try:
|
||||||
|
fn()
|
||||||
# Have a global registry
|
except Exception as e:
|
||||||
registry = EnvRegistry()
|
logger.warn(str(e))
|
||||||
|
|
||||||
|
|
||||||
def register(id: str, **kwargs) -> None:
|
|
||||||
return registry.register(id, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# Continuous
|
|
||||||
# ----------------------------------------
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||||
@overload
|
@overload
|
||||||
@@ -680,63 +314,191 @@ def make(id: Literal[
|
|||||||
"Ant-v2"
|
"Ant-v2"
|
||||||
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
|
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
|
||||||
|
|
||||||
# ----------------------------------------
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def make(id: str, **kwargs) -> Env: ...
|
def make(id: str, **kwargs) -> Env: ...
|
||||||
|
@overload
|
||||||
|
def make(id: EnvSpec, **kwargs) -> Env: ...
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def make(id: str, **kwargs) -> Env:
|
|
||||||
return registry.make(id, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def spec(id: str) -> EnvSpec:
|
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||||
return registry.spec(id)
|
registry: dict[str, EnvSpec] = dict()
|
||||||
|
current_namespace: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _check_spec_register(spec: EnvSpec):
|
||||||
|
"""Checks whether the spec is valid to be registered. Helper function for `register`."""
|
||||||
|
global registry, current_namespace
|
||||||
|
if current_namespace is not None:
|
||||||
|
if spec.namespace is not None:
|
||||||
|
logger.warn(
|
||||||
|
f"Custom namespace `{spec.namespace}` is being overridden "
|
||||||
|
f"by namespace `{current_namespace}`. If you are developing a "
|
||||||
|
"plugin you shouldn't specify a namespace in `register` "
|
||||||
|
"calls. The namespace is specified through the "
|
||||||
|
"entry point package metadata."
|
||||||
|
)
|
||||||
|
|
||||||
|
latest_versioned_spec = max(
|
||||||
|
(
|
||||||
|
spec_
|
||||||
|
for spec_ in registry.values()
|
||||||
|
if spec_.namespace == spec.namespace
|
||||||
|
and spec_.name == spec.name
|
||||||
|
and spec_.version is not None
|
||||||
|
),
|
||||||
|
key=lambda spec_: int(spec_.version), # type: ignore
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
unversioned_spec = next(
|
||||||
|
(
|
||||||
|
spec_
|
||||||
|
for spec_ in registry.values()
|
||||||
|
if spec_.namespace == spec.namespace
|
||||||
|
and spec_.name == spec.name
|
||||||
|
and spec_.version is None
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if unversioned_spec and spec.version is not None:
|
||||||
|
raise error.RegistrationError(
|
||||||
|
"Can't register the versioned environment "
|
||||||
|
f"`{spec.id}` when the unversioned environment "
|
||||||
|
f"`{unversioned_spec.id}` of the same name already exists."
|
||||||
|
)
|
||||||
|
elif latest_versioned_spec and spec.version is None:
|
||||||
|
raise error.RegistrationError(
|
||||||
|
"Can't register the unversioned environment "
|
||||||
|
f"`{spec.id}` when the versioned environment "
|
||||||
|
f"`{latest_versioned_spec.id}` of the same name "
|
||||||
|
f"already exists. Note: the default behavior is "
|
||||||
|
f"that `gym.make` with the unversioned environment "
|
||||||
|
f"will return the latest versioned environment"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Public API
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def namespace(ns: str):
|
def namespace(ns: str):
|
||||||
with registry.namespace(ns):
|
global current_namespace
|
||||||
yield
|
old_namespace = current_namespace
|
||||||
|
current_namespace = ns
|
||||||
|
yield
|
||||||
|
current_namespace = old_namespace
|
||||||
|
|
||||||
|
|
||||||
def load_env_plugins(entry_point: str = "gym.envs") -> None:
|
def register(id: str, **kwargs):
|
||||||
# Load third-party environments
|
"""
|
||||||
for plugin in metadata.entry_points(group=entry_point):
|
Register an environment with gym. The `id` parameter corresponds to the name of the environment,
|
||||||
# Python 3.8 doesn't support plugin.module, plugin.attr
|
with the syntax as follows:
|
||||||
# So we'll have to try and parse this ourselves
|
`(namespace)/(env_name)-(version)`
|
||||||
try:
|
where `namespace` is optional.
|
||||||
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
|
|
||||||
except AttributeError:
|
|
||||||
if ":" in plugin.value:
|
|
||||||
module, attr = plugin.value.split(":", maxsplit=1)
|
|
||||||
else:
|
|
||||||
module, attr = plugin.value, None
|
|
||||||
except:
|
|
||||||
module, attr = None, None
|
|
||||||
finally:
|
|
||||||
if attr is None:
|
|
||||||
raise error.Error(
|
|
||||||
f"Gym environment plugin `{module}` must specify a function to execute, not a root module"
|
|
||||||
)
|
|
||||||
|
|
||||||
context = namespace(plugin.name)
|
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
|
||||||
if plugin.name.startswith("__") and plugin.name.endswith("__"):
|
"""
|
||||||
# `__internal__` is an artifact of the plugin system when
|
global registry, current_namespace
|
||||||
# the root namespace had an allow-list. The allow-list is now
|
full_id = (current_namespace or "") + id
|
||||||
# removed and plugins can register environments in the root
|
spec = EnvSpec(id=full_id, **kwargs)
|
||||||
# namespace with the `__root__` magic key.
|
_check_spec_register(spec)
|
||||||
if plugin.name == "__root__" or plugin.name == "__internal__":
|
if spec.id in registry:
|
||||||
context = contextlib.nullcontext()
|
logger.warn(f"Overriding environment {id}")
|
||||||
else:
|
registry[spec.id] = spec
|
||||||
logger.warn(
|
|
||||||
f"The environment namespace magic key `{plugin.name}` is unsupported. "
|
|
||||||
"To register an environment at the root namespace you should specify "
|
|
||||||
"the `__root__` namespace."
|
|
||||||
)
|
|
||||||
|
|
||||||
with context:
|
|
||||||
fn = plugin.load()
|
def make(
|
||||||
try:
|
id: str | EnvSpec,
|
||||||
fn()
|
max_episode_steps: Optional[int] = None,
|
||||||
except Exception as e:
|
autoreset: bool = False,
|
||||||
logger.warn(str(e))
|
**kwargs,
|
||||||
|
) -> Env:
|
||||||
|
"""
|
||||||
|
Create an environment according to the given ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: Name of the environment.
|
||||||
|
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||||
|
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||||
|
kwargs: Additional arguments to pass to the environment constructor.
|
||||||
|
Returns:
|
||||||
|
An instance of the environment.
|
||||||
|
"""
|
||||||
|
if isinstance(id, EnvSpec):
|
||||||
|
spec_ = id
|
||||||
|
else:
|
||||||
|
spec_ = registry.get(id)
|
||||||
|
|
||||||
|
ns, name, version = parse_env_id(id)
|
||||||
|
latest_version = find_highest_version(ns, name)
|
||||||
|
if (
|
||||||
|
version is not None
|
||||||
|
and latest_version is not None
|
||||||
|
and latest_version > version
|
||||||
|
):
|
||||||
|
logger.warn(
|
||||||
|
f"The environment {id} is out of date. You should consider "
|
||||||
|
f"upgrading to version `v{latest_version}`."
|
||||||
|
)
|
||||||
|
if version is None and latest_version is not None:
|
||||||
|
version = latest_version
|
||||||
|
new_env_id = get_env_id(ns, name, version)
|
||||||
|
spec_ = registry.get(new_env_id)
|
||||||
|
logger.warn(
|
||||||
|
f"Using the latest versioned environment `{new_env_id}` "
|
||||||
|
f"instead of the unversioned environment `{id}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if spec_ is None:
|
||||||
|
_check_version_exists(ns, name, version)
|
||||||
|
raise error.Error(f"No registered env with id: {id}")
|
||||||
|
|
||||||
|
_kwargs = spec_.kwargs.copy()
|
||||||
|
_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
# TODO: add a minimal env checker on initialization
|
||||||
|
if spec_.entry_point is None:
|
||||||
|
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
|
||||||
|
elif callable(spec_.entry_point):
|
||||||
|
cls = spec_.entry_point
|
||||||
|
else:
|
||||||
|
# Assume it's a string
|
||||||
|
cls = load(spec_.entry_point)
|
||||||
|
|
||||||
|
env = cls(**_kwargs)
|
||||||
|
|
||||||
|
spec_ = copy.deepcopy(spec_)
|
||||||
|
spec_.kwargs = _kwargs
|
||||||
|
|
||||||
|
env.unwrapped.spec = spec_
|
||||||
|
|
||||||
|
if spec_.order_enforce:
|
||||||
|
env = OrderEnforcing(env)
|
||||||
|
|
||||||
|
if max_episode_steps is not None:
|
||||||
|
env = TimeLimit(env, max_episode_steps)
|
||||||
|
elif spec_.max_episode_steps is not None:
|
||||||
|
env = TimeLimit(env, spec_.max_episode_steps)
|
||||||
|
|
||||||
|
if autoreset:
|
||||||
|
env = AutoResetWrapper(env)
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def spec(env_id: str) -> EnvSpec:
|
||||||
|
"""
|
||||||
|
Retrieve the spec for the given environment from the global registry.
|
||||||
|
"""
|
||||||
|
spec_ = registry.get(env_id)
|
||||||
|
if spec_ is None:
|
||||||
|
ns, name, version = parse_env_id(env_id)
|
||||||
|
_check_version_exists(ns, name, version)
|
||||||
|
raise error.Error(f"No registered env with id: {env_id}")
|
||||||
|
else:
|
||||||
|
assert isinstance(spec_, EnvSpec)
|
||||||
|
return spec_
|
||||||
|
@@ -50,6 +50,6 @@ def should_skip_env_spec_for_tests(spec):
|
|||||||
|
|
||||||
spec_list = [
|
spec_list = [
|
||||||
spec
|
spec
|
||||||
for spec in sorted(envs.registry.all(), key=lambda x: x.id)
|
for spec in sorted(envs.registry.values(), key=lambda x: x.id)
|
||||||
if spec.entry_point is not None and not should_skip_env_spec_for_tests(spec)
|
if spec.entry_point is not None and not should_skip_env_spec_for_tests(spec)
|
||||||
]
|
]
|
||||||
|
@@ -2,9 +2,9 @@ import pytest
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import envs, error
|
from gym import envs, error
|
||||||
from gym.envs import registration
|
from gym.envs import register, registration, registry, spec
|
||||||
from gym.envs.classic_control import cartpole
|
from gym.envs.classic_control import cartpole
|
||||||
from gym.envs.registration import EnvSpec, EnvSpecTree
|
from gym.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
class ArgumentEnv(gym.Env):
|
class ArgumentEnv(gym.Env):
|
||||||
@@ -55,8 +55,8 @@ def register_some_envs():
|
|||||||
|
|
||||||
for version in versions:
|
for version in versions:
|
||||||
env_id = f"{namespace}/{versioned_name}-v{version}"
|
env_id = f"{namespace}/{versioned_name}-v{version}"
|
||||||
del gym.envs.registry.env_specs[env_id]
|
del gym.envs.registry[env_id]
|
||||||
del gym.envs.registry.env_specs[f"{namespace}/{unversioned_name}"]
|
del gym.envs.registry[f"{namespace}/{unversioned_name}"]
|
||||||
|
|
||||||
|
|
||||||
def test_make():
|
def test_make():
|
||||||
@@ -83,10 +83,15 @@ def test_make():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_register(env_id, namespace, name, version):
|
def test_register(env_id, namespace, name, version):
|
||||||
envs.register(env_id)
|
register(env_id)
|
||||||
assert gym.envs.spec(env_id).id == env_id
|
assert gym.envs.spec(env_id).id == env_id
|
||||||
assert version in gym.envs.registry.env_specs.tree[namespace][name].keys()
|
full_name = f"{name}"
|
||||||
del gym.envs.registry.env_specs[env_id]
|
if namespace:
|
||||||
|
full_name = f"{namespace}/{full_name}"
|
||||||
|
if version is not None:
|
||||||
|
full_name = f"{full_name}-v{version}"
|
||||||
|
assert full_name in gym.envs.registry.keys()
|
||||||
|
del gym.envs.registry[env_id]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -99,7 +104,7 @@ def test_register(env_id, namespace, name, version):
|
|||||||
)
|
)
|
||||||
def test_register_error(env_id):
|
def test_register_error(env_id):
|
||||||
with pytest.raises(error.Error, match="Malformed environment ID"):
|
with pytest.raises(error.Error, match="Malformed environment ID"):
|
||||||
envs.register(env_id)
|
register(env_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -188,27 +193,23 @@ def test_spec_with_kwargs():
|
|||||||
|
|
||||||
|
|
||||||
def test_missing_lookup():
|
def test_missing_lookup():
|
||||||
registry = registration.EnvRegistry()
|
register(id="Test1-v0", entry_point=None)
|
||||||
registry.register(id="Test-v0", entry_point=None)
|
register(id="Test1-v15", entry_point=None)
|
||||||
registry.register(id="Test-v15", entry_point=None)
|
register(id="Test1-v9", entry_point=None)
|
||||||
registry.register(id="Test-v9", entry_point=None)
|
register(id="Other1-v100", entry_point=None)
|
||||||
registry.register(id="Other-v100", entry_point=None)
|
|
||||||
try:
|
with pytest.raises(error.DeprecatedEnv):
|
||||||
registry.spec("Test-v1") # must match an env name but not the version above
|
spec("Test1-v1")
|
||||||
except error.DeprecatedEnv:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registry.spec("Test-v1000")
|
spec("Test1-v1000")
|
||||||
except error.UnregisteredEnv:
|
except error.UnregisteredEnv:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registry.spec("Unknown-v1")
|
spec("Unknown1-v1")
|
||||||
except error.UnregisteredEnv:
|
except error.UnregisteredEnv:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@@ -216,9 +217,8 @@ def test_missing_lookup():
|
|||||||
|
|
||||||
|
|
||||||
def test_malformed_lookup():
|
def test_malformed_lookup():
|
||||||
registry = registration.EnvRegistry()
|
|
||||||
try:
|
try:
|
||||||
registry.spec("“Breakout-v0”")
|
spec("“Breakout-v0”")
|
||||||
except error.Error as e:
|
except error.Error as e:
|
||||||
assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}"
|
assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}"
|
||||||
else:
|
else:
|
||||||
@@ -226,99 +226,47 @@ def test_malformed_lookup():
|
|||||||
|
|
||||||
|
|
||||||
def test_versioned_lookups():
|
def test_versioned_lookups():
|
||||||
registry = registration.EnvRegistry()
|
register("test/Test2-v5")
|
||||||
registry.register("test/Test-v5")
|
|
||||||
|
|
||||||
with pytest.raises(error.VersionNotFound):
|
with pytest.raises(error.VersionNotFound):
|
||||||
registry.spec("test/Test-v9")
|
spec("test/Test2-v9")
|
||||||
|
|
||||||
with pytest.raises(error.DeprecatedEnv):
|
with pytest.raises(error.DeprecatedEnv):
|
||||||
registry.spec("test/Test-v4")
|
spec("test/Test2-v4")
|
||||||
|
|
||||||
assert registry.spec("test/Test-v5")
|
assert spec("test/Test2-v5")
|
||||||
|
|
||||||
|
|
||||||
def test_default_lookups():
|
def test_default_lookups():
|
||||||
registry = registration.EnvRegistry()
|
register("test/Test3")
|
||||||
registry.register("test/Test")
|
|
||||||
|
|
||||||
with pytest.raises(error.DeprecatedEnv):
|
with pytest.raises(error.DeprecatedEnv):
|
||||||
registry.spec("test/Test-v0")
|
spec("test/Test3-v0")
|
||||||
|
|
||||||
# Lookup default
|
# Lookup default
|
||||||
registry.spec("test/Test")
|
spec("test/Test3")
|
||||||
|
|
||||||
|
|
||||||
def test_env_spec_tree():
|
|
||||||
spec_tree = EnvSpecTree()
|
|
||||||
|
|
||||||
# Add with namespace
|
|
||||||
spec = EnvSpec("test/Test-v0")
|
|
||||||
spec_tree["test/Test-v0"] = spec
|
|
||||||
assert spec_tree.tree.keys() == {"test"}
|
|
||||||
assert spec_tree.tree["test"].keys() == {"Test"}
|
|
||||||
assert spec_tree.tree["test"]["Test"].keys() == {0}
|
|
||||||
assert spec_tree.tree["test"]["Test"][0] == spec
|
|
||||||
assert spec_tree["test/Test-v0"] == spec
|
|
||||||
|
|
||||||
# Add without namespace
|
|
||||||
spec = EnvSpec("Test-v0")
|
|
||||||
spec_tree["Test-v0"] = spec
|
|
||||||
assert spec_tree.tree.keys() == {"test", None}
|
|
||||||
assert spec_tree.tree[None].keys() == {"Test"}
|
|
||||||
assert spec_tree.tree[None]["Test"].keys() == {0}
|
|
||||||
assert spec_tree.tree[None]["Test"][0] == spec
|
|
||||||
|
|
||||||
# Delete last version deletes entire subtree
|
|
||||||
del spec_tree["test/Test-v0"]
|
|
||||||
assert spec_tree.tree.keys() == {None}
|
|
||||||
|
|
||||||
# Append second version for same name
|
|
||||||
spec_tree["Test-v1"] = EnvSpec("Test-v1")
|
|
||||||
assert spec_tree.tree.keys() == {None}
|
|
||||||
assert spec_tree.tree[None].keys() == {"Test"}
|
|
||||||
assert spec_tree.tree[None]["Test"].keys() == {0, 1}
|
|
||||||
|
|
||||||
# Deleting one version leaves other
|
|
||||||
del spec_tree["Test-v0"]
|
|
||||||
assert spec_tree.tree.keys() == {None}
|
|
||||||
assert spec_tree.tree[None].keys() == {"Test"}
|
|
||||||
assert spec_tree.tree[None]["Test"].keys() == {1}
|
|
||||||
|
|
||||||
# Add without version
|
|
||||||
myenv = "MyAwesomeEnv"
|
|
||||||
spec = EnvSpec(myenv)
|
|
||||||
spec_tree[myenv] = spec
|
|
||||||
assert spec_tree.tree.keys() == {None}
|
|
||||||
assert myenv in spec_tree.tree[None].keys()
|
|
||||||
assert spec_tree.tree[None][myenv].keys() == {None}
|
|
||||||
assert spec_tree.tree[None][myenv][None] == spec
|
|
||||||
assert spec_tree.__repr__() == "├──Test: [ v1 ]\n" + f"└──{myenv}: [ ]\n"
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_versioned_unversioned():
|
def test_register_versioned_unversioned():
|
||||||
# Register versioned then unversioned
|
# Register versioned then unversioned
|
||||||
versioned_env = "Test/MyEnv-v0"
|
versioned_env = "Test/MyEnv-v0"
|
||||||
envs.register(versioned_env)
|
register(versioned_env)
|
||||||
assert gym.envs.spec(versioned_env).id == versioned_env
|
assert gym.envs.spec(versioned_env).id == versioned_env
|
||||||
unversioned_env = "Test/MyEnv"
|
unversioned_env = "Test/MyEnv"
|
||||||
with pytest.raises(error.RegistrationError):
|
with pytest.raises(error.RegistrationError):
|
||||||
envs.register(unversioned_env)
|
register(unversioned_env)
|
||||||
|
|
||||||
# Clean everything
|
# Clean everything
|
||||||
del gym.envs.registry.env_specs[versioned_env]
|
del gym.envs.registry[versioned_env]
|
||||||
|
|
||||||
# Register unversioned then versioned
|
# Register unversioned then versioned
|
||||||
with pytest.warns(UserWarning):
|
register(unversioned_env)
|
||||||
envs.register(unversioned_env)
|
|
||||||
assert gym.envs.spec(unversioned_env).id == unversioned_env
|
assert gym.envs.spec(unversioned_env).id == unversioned_env
|
||||||
with pytest.raises(error.RegistrationError):
|
with pytest.raises(error.RegistrationError):
|
||||||
envs.register(versioned_env)
|
register(versioned_env)
|
||||||
|
|
||||||
# Clean everything
|
# Clean everything
|
||||||
envs_list = [versioned_env, unversioned_env]
|
del gym.envs.registry[unversioned_env]
|
||||||
for env in envs_list:
|
|
||||||
del gym.envs.registry.env_specs[env]
|
|
||||||
|
|
||||||
|
|
||||||
def test_return_latest_versioned_env(register_some_envs):
|
def test_return_latest_versioned_env(register_some_envs):
|
||||||
|
Reference in New Issue
Block a user