Add types to gym.make("CartPole-v1") and such (#2594)

* Add typing to gym.make via literals

* Fix for type object not subscriptable

* Use Env[ObsType, ActType] instead of specific types to account for wrapping

* Move make and all overloads to `registration.py`

* Hack around py37 typing checks
This commit is contained in:
Ilya Kamen
2022-02-16 18:29:24 +01:00
committed by GitHub
parent 1400cfbcac
commit 27108c98d9
2 changed files with 87 additions and 3 deletions

View File

@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
args:
- --ignore=E203,E402,E712,E722,E731,E741,F401,F403,F405,F524,F841,W503
- --ignore=E203,E402,E712,E722,E731,E741,F401,F403,F405,F524,F841,W503,E302,E704
- --max-complexity=30
- --max-line-length=456
- --show-source

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import re
import sys
import copy
@@ -5,7 +7,19 @@ import difflib
import importlib
import importlib.util
import contextlib
from typing import Callable, Type, Optional, Union, Tuple, Generator, Sequence, cast
from typing import (
Callable,
Type,
Optional,
Union,
Tuple,
Generator,
Sequence,
cast,
SupportsFloat,
overload,
Any,
)
if sys.version_info < (3, 10):
import importlib_metadata as metadata # type: ignore
@@ -16,10 +30,21 @@ from dataclasses import dataclass, field, InitVar
from collections import defaultdict
from collections.abc import MutableMapping
import numpy as np
from gym import error, logger, Env
from gym.envs.__relocated__ import internal_env_relocation_map
if sys.version_info >= (3, 8):
from typing import Literal
else:
class Literal(str):
def __class_getitem__(cls, item):
return Any
ENV_ID_RE: re.Pattern = re.compile(
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
)
@@ -588,7 +613,66 @@ def register(id: str, **kwargs) -> None:
return registry.register(id, **kwargs)
def make(id: str, **kwargs) -> Env:
# fmt: off
# Continuous
# ----------------------------------------
@overload
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Box2d
# ----------------------------------------
@overload
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
@overload
def make(id: Literal["CarRacing-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
# Toy Text
# ----------------------------------------
@overload
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
@overload
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
# Mujoco
# ----------------------------------------
@overload
def make(id: Literal[
"Reacher-v2",
"Pusher-v2",
"Thrower-v2",
"Striker-v2",
"InvertedPendulum-v2",
"InvertedDoublePendulum-v2",
"HalfCheetah-v2", "HalfCheetah-v3",
"Hopper-v2", "Hopper-v3",
"Swimmer-v2", "Swimmer-v3",
"Walker2d-v2", "Walker2d-v3",
"Ant-v2"
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
# ----------------------------------------
@overload
def make(id: str, **kwargs) -> "Env": ...
# fmt: on
def make(id: str, **kwargs) -> "Env":
return registry.make(id, **kwargs)