mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2026-01-29 22:57:15 +00:00
Typing/basics (#2529)
* Typing in gym/envs/registration.py * Add registration to type checked list * Adds type hints to space.py * Typing in gym.core.Env * Typing in seeding.py * fixup Typing after rebase * revert accidental change * Install dependencies in pyright runner * fix: can only install dependencies after checkout * fix: install types in a venv * fix path * skip env activation, install directly from venv interpreter * absolute path to venv * use central python installation * skip one more typecheck * cleanup gh actions .yml * Add py.typed to signal using sources for typechecking * black! Co-authored-by: sj_petterson <sj_petterson@gmail.com> Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
@@ -3,9 +3,10 @@ import sys
|
||||
import copy
|
||||
import importlib
|
||||
import contextlib
|
||||
from typing import Callable, Type, Optional, Union, Dict, Set, Tuple, Generator
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata as metadata
|
||||
import importlib_metadata as metadata # type: ignore
|
||||
else:
|
||||
import importlib.metadata as metadata
|
||||
|
||||
@@ -13,9 +14,7 @@ from collections import defaultdict
|
||||
from collections.abc import MutableMapping
|
||||
from operator import getitem
|
||||
|
||||
from typing import Optional, Union, Dict, Set, Tuple, Generator
|
||||
|
||||
from gym import error, logger
|
||||
from gym import error, logger, Env
|
||||
|
||||
# This format is true today, but it's *not* an official spec.
|
||||
# [username/](env-name)-v(version) env-name is group 1, version is group 2
|
||||
@@ -325,7 +324,7 @@ internal_env_namespace_relocation_map: Dict[str, Tuple[str, str]] = {
|
||||
}
|
||||
|
||||
|
||||
def load(name):
|
||||
def load(name: str) -> Type:
|
||||
mod_name, attr_name = name.split(":")
|
||||
mod = importlib.import_module(mod_name)
|
||||
fn = getattr(mod, attr_name)
|
||||
@@ -337,25 +336,25 @@ class EnvSpec:
|
||||
to register the parameters for official evaluations.
|
||||
|
||||
Args:
|
||||
id (str): The official environment ID
|
||||
entry_point (Optional[str]): The Python entrypoint of the environment class (e.g. module.name:Class)
|
||||
reward_threshold (Optional[int]): The reward threshold before the task is considered solved
|
||||
nondeterministic (bool): Whether this environment is non-deterministic even after seeding
|
||||
max_episode_steps (Optional[int]): The maximum number of steps that an episode can consist of
|
||||
order_enforce (Optional[int]): Whether to wrap the environment in an orderEnforcing wrapper
|
||||
kwargs (dict): The kwargs to pass to the environment class
|
||||
id: The official environment ID
|
||||
entry_point: The Python entrypoint of the environment class (e.g. module.name:Class)
|
||||
reward_threshold: The reward threshold before the task is considered solved
|
||||
nondeterministic: Whether this environment is non-deterministic even after seeding
|
||||
max_episode_steps: The maximum number of steps that an episode can consist of
|
||||
order_enforce: Whether to wrap the environment in an orderEnforcing wrapper
|
||||
kwargs: The kwargs to pass to the environment class
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id,
|
||||
entry_point=None,
|
||||
reward_threshold=None,
|
||||
nondeterministic=False,
|
||||
max_episode_steps=None,
|
||||
order_enforce=True,
|
||||
kwargs=None,
|
||||
id: str,
|
||||
entry_point: Union[Callable, str, None] = None,
|
||||
reward_threshold: Optional[int] = None,
|
||||
nondeterministic: bool = False,
|
||||
max_episode_steps: Optional[int] = None,
|
||||
order_enforce: Optional[int] = True,
|
||||
kwargs: dict = None,
|
||||
):
|
||||
self.id = id
|
||||
self.entry_point = entry_point
|
||||
@@ -373,7 +372,7 @@ class EnvSpec:
|
||||
)
|
||||
self._env_name = match.group("name")
|
||||
|
||||
def make(self, **kwargs):
|
||||
def make(self, **kwargs) -> Env:
|
||||
"""Instantiates an instance of the environment with appropriate kwargs"""
|
||||
if self.entry_point is None:
|
||||
raise error.Error(
|
||||
@@ -468,7 +467,7 @@ class EnvSpecTree(MutableMapping):
|
||||
match = env_id_re.fullmatch(key)
|
||||
if match is None:
|
||||
raise KeyError(f"Malformed environment spec key {key}.")
|
||||
return match.group("namespace", "name", "version")
|
||||
return match.group("namespace", "name", "version") # type: ignore #
|
||||
|
||||
def _exists(self, namespace: Optional[str], name: str, version: str) -> bool:
|
||||
# Helper which can look if an ID exists in the tree.
|
||||
@@ -581,9 +580,9 @@ class EnvRegistry:
|
||||
|
||||
def __init__(self):
|
||||
self.env_specs = EnvSpecTree()
|
||||
self._ns = None
|
||||
self._ns: Optional[str] = None
|
||||
|
||||
def make(self, path, **kwargs):
|
||||
def make(self, path: str, **kwargs) -> Env:
|
||||
if len(kwargs) > 0:
|
||||
logger.info("Making new env: %s (%s)", path, kwargs)
|
||||
else:
|
||||
@@ -616,7 +615,7 @@ class EnvRegistry:
|
||||
def all(self):
|
||||
return self.env_specs.values()
|
||||
|
||||
def spec(self, path):
|
||||
def spec(self, path: str) -> EnvSpec:
|
||||
if ":" in path:
|
||||
mod_name, _, id = path.partition(":")
|
||||
try:
|
||||
@@ -690,7 +689,7 @@ class EnvRegistry:
|
||||
version_not_found_error_msg += ", ".join(
|
||||
map(
|
||||
lambda version: env_id_from_parts(
|
||||
getitem(version, 1), name, getitem(version, 0)
|
||||
getitem(version, 1), name, getitem(version, 0) # type: ignore
|
||||
),
|
||||
versions,
|
||||
)
|
||||
@@ -700,14 +699,14 @@ class EnvRegistry:
|
||||
# If we've requested a version less than the
|
||||
# most recent version it's considered deprecated.
|
||||
# Otherwise it isn't registered.
|
||||
if int(version) < getitem(max(versions), 0):
|
||||
if int(version) < getitem(max(versions), 0): # type: ignore
|
||||
raise error.DeprecatedEnv(version_not_found_error_msg)
|
||||
else:
|
||||
raise error.UnregisteredEnv(version_not_found_error_msg)
|
||||
|
||||
return self.env_specs[id]
|
||||
|
||||
def register(self, id, **kwargs):
|
||||
def register(self, id: str, **kwargs) -> None:
|
||||
# Match ID and and get environment parts
|
||||
match = env_id_re.fullmatch(id)
|
||||
if match is None:
|
||||
@@ -715,7 +714,6 @@ class EnvRegistry:
|
||||
f"Attempted to register malformed environment ID: {id.encode('utf-8')}. "
|
||||
f"(Currently all IDs must be of the form {env_id_re.pattern}.)"
|
||||
)
|
||||
|
||||
if self._ns is not None:
|
||||
namespace, name, version = match.group("namespace", "name", "version")
|
||||
if namespace is not None:
|
||||
@@ -766,7 +764,7 @@ class EnvRegistry:
|
||||
return versions
|
||||
|
||||
@contextlib.contextmanager
|
||||
def namespace(self, ns):
|
||||
def namespace(self, ns: str):
|
||||
self._ns = ns
|
||||
yield
|
||||
self._ns = None
|
||||
@@ -776,36 +774,38 @@ class EnvRegistry:
|
||||
registry = EnvRegistry()
|
||||
|
||||
|
||||
def register(id, **kwargs):
|
||||
def register(id: str, **kwargs) -> None:
|
||||
return registry.register(id, **kwargs)
|
||||
|
||||
|
||||
def make(id, **kwargs):
|
||||
def make(id: str, **kwargs) -> Env:
|
||||
return registry.make(id, **kwargs)
|
||||
|
||||
|
||||
def spec(id):
|
||||
def spec(id: str) -> EnvSpec:
|
||||
return registry.spec(id)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def namespace(ns):
|
||||
def namespace(ns: str):
|
||||
with registry.namespace(ns):
|
||||
yield
|
||||
|
||||
|
||||
def load_env_plugins(entry_point="gym.envs"):
|
||||
def load_env_plugins(entry_point: str = "gym.envs") -> None:
|
||||
# Load third-party environments
|
||||
for plugin in metadata.entry_points().get(entry_point, []):
|
||||
# Python 3.8 doesn't support plugin.module, plugin.attr
|
||||
# So we'll have to try and parse this ourselves
|
||||
try:
|
||||
module, attr = plugin.module, plugin.attr
|
||||
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
|
||||
except AttributeError:
|
||||
if ":" in plugin.value:
|
||||
module, attr = plugin.value.split(":", maxsplit=1)
|
||||
else:
|
||||
module, attr = plugin.value, None
|
||||
except:
|
||||
module, attr = None, None
|
||||
finally:
|
||||
if attr is None:
|
||||
raise error.Error(
|
||||
|
||||
Reference in New Issue
Block a user