Typing/basics (#2529)

* Typing in gym/envs/registration.py

* Add registration to type checked list

* Adds type hints to space.py

* Typing in gym.core.Env

* Typing in seeding.py

* fixup Typing after rebase

* revert accidental change

* Install dependencies in pyright runner

* fix: can only install dependencies after checkout

* fix: install types in a venv

* fix path

* skip env activation, install directly from venv interpreter

* absolute path to venv

* use central python installation

* skip one more typecheck

* cleanup gh actions .yml

* Add py.typed to signal using sources for typechecking

* black!

Co-authored-by: sj_petterson <sj_petterson@gmail.com>
Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
Ilya Kamen
2021-12-22 19:12:57 +01:00
committed by GitHub
parent 4966c5fccf
commit 65eeb73366
7 changed files with 104 additions and 72 deletions

View File

@@ -3,9 +3,10 @@ import sys
import copy
import importlib
import contextlib
from typing import Callable, Type, Optional, Union, Dict, Set, Tuple, Generator
if sys.version_info < (3, 8):
import importlib_metadata as metadata
import importlib_metadata as metadata # type: ignore
else:
import importlib.metadata as metadata
@@ -13,9 +14,7 @@ from collections import defaultdict
from collections.abc import MutableMapping
from operator import getitem
from typing import Optional, Union, Dict, Set, Tuple, Generator
from gym import error, logger
from gym import error, logger, Env
# This format is true today, but it's *not* an official spec.
# [username/](env-name)-v(version) env-name is group 1, version is group 2
@@ -325,7 +324,7 @@ internal_env_namespace_relocation_map: Dict[str, Tuple[str, str]] = {
}
def load(name):
def load(name: str) -> Type:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
@@ -337,25 +336,25 @@ class EnvSpec:
to register the parameters for official evaluations.
Args:
id (str): The official environment ID
entry_point (Optional[str]): The Python entrypoint of the environment class (e.g. module.name:Class)
reward_threshold (Optional[int]): The reward threshold before the task is considered solved
nondeterministic (bool): Whether this environment is non-deterministic even after seeding
max_episode_steps (Optional[int]): The maximum number of steps that an episode can consist of
order_enforce (Optional[int]): Whether to wrap the environment in an orderEnforcing wrapper
kwargs (dict): The kwargs to pass to the environment class
id: The official environment ID
entry_point: The Python entrypoint of the environment class (e.g. module.name:Class)
reward_threshold: The reward threshold before the task is considered solved
nondeterministic: Whether this environment is non-deterministic even after seeding
max_episode_steps: The maximum number of steps that an episode can consist of
order_enforce: Whether to wrap the environment in an orderEnforcing wrapper
kwargs: The kwargs to pass to the environment class
"""
def __init__(
self,
id,
entry_point=None,
reward_threshold=None,
nondeterministic=False,
max_episode_steps=None,
order_enforce=True,
kwargs=None,
id: str,
entry_point: Union[Callable, str, None] = None,
reward_threshold: Optional[int] = None,
nondeterministic: bool = False,
max_episode_steps: Optional[int] = None,
order_enforce: Optional[int] = True,
kwargs: dict = None,
):
self.id = id
self.entry_point = entry_point
@@ -373,7 +372,7 @@ class EnvSpec:
)
self._env_name = match.group("name")
def make(self, **kwargs):
def make(self, **kwargs) -> Env:
"""Instantiates an instance of the environment with appropriate kwargs"""
if self.entry_point is None:
raise error.Error(
@@ -468,7 +467,7 @@ class EnvSpecTree(MutableMapping):
match = env_id_re.fullmatch(key)
if match is None:
raise KeyError(f"Malformed environment spec key {key}.")
return match.group("namespace", "name", "version")
return match.group("namespace", "name", "version") # type: ignore #
def _exists(self, namespace: Optional[str], name: str, version: str) -> bool:
# Helper which can look if an ID exists in the tree.
@@ -581,9 +580,9 @@ class EnvRegistry:
def __init__(self):
self.env_specs = EnvSpecTree()
self._ns = None
self._ns: Optional[str] = None
def make(self, path, **kwargs):
def make(self, path: str, **kwargs) -> Env:
if len(kwargs) > 0:
logger.info("Making new env: %s (%s)", path, kwargs)
else:
@@ -616,7 +615,7 @@ class EnvRegistry:
def all(self):
return self.env_specs.values()
def spec(self, path):
def spec(self, path: str) -> EnvSpec:
if ":" in path:
mod_name, _, id = path.partition(":")
try:
@@ -690,7 +689,7 @@ class EnvRegistry:
version_not_found_error_msg += ", ".join(
map(
lambda version: env_id_from_parts(
getitem(version, 1), name, getitem(version, 0)
getitem(version, 1), name, getitem(version, 0) # type: ignore
),
versions,
)
@@ -700,14 +699,14 @@ class EnvRegistry:
# If we've requested a version less than the
# most recent version it's considered deprecated.
# Otherwise it isn't registered.
if int(version) < getitem(max(versions), 0):
if int(version) < getitem(max(versions), 0): # type: ignore
raise error.DeprecatedEnv(version_not_found_error_msg)
else:
raise error.UnregisteredEnv(version_not_found_error_msg)
return self.env_specs[id]
def register(self, id, **kwargs):
def register(self, id: str, **kwargs) -> None:
# Match ID and and get environment parts
match = env_id_re.fullmatch(id)
if match is None:
@@ -715,7 +714,6 @@ class EnvRegistry:
f"Attempted to register malformed environment ID: {id.encode('utf-8')}. "
f"(Currently all IDs must be of the form {env_id_re.pattern}.)"
)
if self._ns is not None:
namespace, name, version = match.group("namespace", "name", "version")
if namespace is not None:
@@ -766,7 +764,7 @@ class EnvRegistry:
return versions
@contextlib.contextmanager
def namespace(self, ns):
def namespace(self, ns: str):
self._ns = ns
yield
self._ns = None
@@ -776,36 +774,38 @@ class EnvRegistry:
registry = EnvRegistry()
def register(id, **kwargs):
def register(id: str, **kwargs) -> None:
return registry.register(id, **kwargs)
def make(id, **kwargs):
def make(id: str, **kwargs) -> Env:
return registry.make(id, **kwargs)
def spec(id):
def spec(id: str) -> EnvSpec:
return registry.spec(id)
@contextlib.contextmanager
def namespace(ns):
def namespace(ns: str):
with registry.namespace(ns):
yield
def load_env_plugins(entry_point="gym.envs"):
def load_env_plugins(entry_point: str = "gym.envs") -> None:
# Load third-party environments
for plugin in metadata.entry_points().get(entry_point, []):
# Python 3.8 doesn't support plugin.module, plugin.attr
# So we'll have to try and parse this ourselves
try:
module, attr = plugin.module, plugin.attr
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
except AttributeError:
if ":" in plugin.value:
module, attr = plugin.value.split(":", maxsplit=1)
else:
module, attr = plugin.value, None
except:
module, attr = None, None
finally:
if attr is None:
raise error.Error(