mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Add plugin system for third-party environments (#2383)
This commit is contained in:
committed by
GitHub
parent
590db96e36
commit
e212043a93
@@ -1,4 +1,14 @@
|
||||
from gym.envs.registration import registry, register, make, spec
|
||||
from gym.envs.registration import (
|
||||
registry,
|
||||
register,
|
||||
make,
|
||||
spec,
|
||||
load_plugins as _load_plugins,
|
||||
)
|
||||
|
||||
# Hook to load plugins from entry points
|
||||
_load_plugins()
|
||||
|
||||
|
||||
# Classic
|
||||
# ----------------------------------------
|
||||
|
@@ -1,7 +1,15 @@
|
||||
import re
|
||||
import sys
|
||||
import copy
|
||||
import importlib
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata as metadata
|
||||
else:
|
||||
import importlib.metadata as metadata
|
||||
|
||||
from gym import error, logger
|
||||
|
||||
# This format is true today, but it's *not* an official spec.
|
||||
@@ -11,6 +19,9 @@ from gym import error, logger
|
||||
# to include an optional username.
|
||||
env_id_re = re.compile(r"^(?:[\w:-]+\/)?([\w:.-]+)-v(\d+)$")
|
||||
|
||||
# Whitelist of plugins which can hook into the `gym.envs.internal` entry point.
|
||||
plugin_internal_whitelist = {"ale_py.gym"}
|
||||
|
||||
|
||||
def load(name):
|
||||
mod_name, attr_name = name.split(":")
|
||||
@@ -95,6 +106,7 @@ class EnvRegistry(object):
|
||||
|
||||
def __init__(self):
|
||||
self.env_specs = {}
|
||||
self._ns = None
|
||||
|
||||
def make(self, path, **kwargs):
|
||||
if len(kwargs) > 0:
|
||||
@@ -183,10 +195,25 @@ class EnvRegistry(object):
|
||||
raise error.UnregisteredEnv("No registered env with id: {}".format(id))
|
||||
|
||||
def register(self, id, **kwargs):
|
||||
if self._ns is not None:
|
||||
if "/" in id:
|
||||
namespace, id = id.split("/")
|
||||
logger.warn(
|
||||
f"Custom namespace '{namespace}' is being overrode by namespace '{self._ns}'. "
|
||||
"If you are developing a plugin you shouldn't specify a namespace in `register` calls. "
|
||||
"The namespace is specified through the entry point key."
|
||||
)
|
||||
id = f"{self._ns}/{id}"
|
||||
if id in self.env_specs:
|
||||
logger.warn("Overriding environment {}".format(id))
|
||||
self.env_specs[id] = EnvSpec(id, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def namespace(self, ns):
|
||||
self._ns = ns
|
||||
yield
|
||||
self._ns = None
|
||||
|
||||
|
||||
# Have a global registry
|
||||
registry = EnvRegistry()
|
||||
@@ -202,3 +229,35 @@ def make(id, **kwargs):
|
||||
|
||||
def spec(id):
|
||||
return registry.spec(id)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def namespace(ns):
|
||||
with registry.namespace(ns):
|
||||
yield
|
||||
|
||||
|
||||
def load_plugins(
|
||||
third_party_entry_point="gym.envs", internal_entry_point="gym.envs.internal"
|
||||
):
|
||||
# Load third-party environments
|
||||
for external in metadata.entry_points().get(third_party_entry_point, []):
|
||||
if external.attr is not None:
|
||||
raise error.Error(
|
||||
"Gym environment plugins must specify a root module to load, not a function"
|
||||
)
|
||||
# Force namespace on all `register` calls for third-party envs
|
||||
with namespace(external.name):
|
||||
external.load()
|
||||
|
||||
# Load plugins which hook into `gym.envs.internal`
|
||||
# These plugins must be in the whitelist defined at the top of this file
|
||||
# We don't force a namespace on register calls in this module
|
||||
for internal in metadata.entry_points().get(internal_entry_point, []):
|
||||
if internal.module not in plugin_internal_whitelist:
|
||||
continue
|
||||
if external.attr is not None:
|
||||
raise error.Error(
|
||||
"Gym environment plugins must specify a root module to load, not a function"
|
||||
)
|
||||
internal.load()
|
||||
|
Reference in New Issue
Block a user