Add plugin system for third-party environments (#2383)

This commit is contained in:
Jesse Farebrother
2021-09-14 20:14:05 -06:00
committed by GitHub
parent 590db96e36
commit e212043a93
3 changed files with 71 additions and 1 deletions

View File

@@ -1,4 +1,14 @@
from gym.envs.registration import registry, register, make, spec
from gym.envs.registration import (
registry,
register,
make,
spec,
load_plugins as _load_plugins,
)
# Hook to load plugins from entry points
_load_plugins()
# Classic
# ----------------------------------------

View File

@@ -1,7 +1,15 @@
import re
import sys
import copy
import importlib
from contextlib import contextmanager
if sys.version_info < (3, 8):
import importlib_metadata as metadata
else:
import importlib.metadata as metadata
from gym import error, logger
# This format is true today, but it's *not* an official spec.
@@ -11,6 +19,9 @@ from gym import error, logger
# to include an optional username.
env_id_re = re.compile(r"^(?:[\w:-]+\/)?([\w:.-]+)-v(\d+)$")
# Whitelist of plugins which can hook into the `gym.envs.internal` entry point.
plugin_internal_whitelist = {"ale_py.gym"}
def load(name):
mod_name, attr_name = name.split(":")
@@ -95,6 +106,7 @@ class EnvRegistry(object):
def __init__(self):
self.env_specs = {}
self._ns = None
def make(self, path, **kwargs):
if len(kwargs) > 0:
@@ -183,10 +195,25 @@ class EnvRegistry(object):
raise error.UnregisteredEnv("No registered env with id: {}".format(id))
def register(self, id, **kwargs):
if self._ns is not None:
if "/" in id:
namespace, id = id.split("/")
logger.warn(
f"Custom namespace '{namespace}' is being overrode by namespace '{self._ns}'. "
"If you are developing a plugin you shouldn't specify a namespace in `register` calls. "
"The namespace is specified through the entry point key."
)
id = f"{self._ns}/{id}"
if id in self.env_specs:
logger.warn("Overriding environment {}".format(id))
self.env_specs[id] = EnvSpec(id, **kwargs)
@contextmanager
def namespace(self, ns):
self._ns = ns
yield
self._ns = None
# Have a global registry
registry = EnvRegistry()
@@ -202,3 +229,35 @@ def make(id, **kwargs):
def spec(id):
return registry.spec(id)
@contextmanager
def namespace(ns):
with registry.namespace(ns):
yield
def load_plugins(
third_party_entry_point="gym.envs", internal_entry_point="gym.envs.internal"
):
# Load third-party environments
for external in metadata.entry_points().get(third_party_entry_point, []):
if external.attr is not None:
raise error.Error(
"Gym environment plugins must specify a root module to load, not a function"
)
# Force namespace on all `register` calls for third-party envs
with namespace(external.name):
external.load()
# Load plugins which hook into `gym.envs.internal`
# These plugins must be in the whitelist defined at the top of this file
# We don't force a namespace on register calls in this module
for internal in metadata.entry_points().get(internal_entry_point, []):
if internal.module not in plugin_internal_whitelist:
continue
if external.attr is not None:
raise error.Error(
"Gym environment plugins must specify a root module to load, not a function"
)
internal.load()

View File

@@ -44,6 +44,7 @@ setup(
install_requires=[
"numpy>=1.18.0",
"cloudpickle>=1.2.0",
"importlib_metadata>=4.8.1; python_version < '3.8'",
],
extras_require=extras,
package_data={