mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 14:26:33 +00:00
Add types to gym.make("CartPole-v1") and such (#2594)
* Add typing to gym.make via literals * Fix for type object not subscriptable * Use Env[ObsType, ActType] instead of specific types to account for wrapping * Move make and all overloads to `registration.py` * Hack around py37 typing checks
This commit is contained in:
@@ -23,7 +23,7 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
args:
|
||||
- --ignore=E203,E402,E712,E722,E731,E741,F401,F403,F405,F524,F841,W503
|
||||
- --ignore=E203,E402,E712,E722,E731,E741,F401,F403,F405,F524,F841,W503,E302,E704
|
||||
- --max-complexity=30
|
||||
- --max-line-length=456
|
||||
- --show-source
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
import copy
|
||||
@@ -5,7 +7,19 @@ import difflib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import contextlib
|
||||
from typing import Callable, Type, Optional, Union, Tuple, Generator, Sequence, cast
|
||||
from typing import (
|
||||
Callable,
|
||||
Type,
|
||||
Optional,
|
||||
Union,
|
||||
Tuple,
|
||||
Generator,
|
||||
Sequence,
|
||||
cast,
|
||||
SupportsFloat,
|
||||
overload,
|
||||
Any,
|
||||
)
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
import importlib_metadata as metadata # type: ignore
|
||||
@@ -16,10 +30,21 @@ from dataclasses import dataclass, field, InitVar
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym import error, logger, Env
|
||||
from gym.envs.__relocated__ import internal_env_relocation_map
|
||||
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
|
||||
class Literal(str):
|
||||
def __class_getitem__(cls, item):
|
||||
return Any
|
||||
|
||||
|
||||
ENV_ID_RE: re.Pattern = re.compile(
|
||||
r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
|
||||
)
|
||||
@@ -588,7 +613,66 @@ def register(id: str, **kwargs) -> None:
|
||||
return registry.register(id, **kwargs)
|
||||
|
||||
|
||||
def make(id: str, **kwargs) -> Env:
|
||||
# fmt: off
|
||||
# Continuous
|
||||
# ----------------------------------------
|
||||
|
||||
@overload
|
||||
def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
@overload
|
||||
def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
@overload
|
||||
def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
|
||||
# Box2d
|
||||
# ----------------------------------------
|
||||
|
||||
@overload
|
||||
def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
@overload
|
||||
def make(id: Literal["CarRacing-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | Sequence[SupportsFloat]]: ...
|
||||
|
||||
# Toy Text
|
||||
# ----------------------------------------
|
||||
|
||||
@overload
|
||||
def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
@overload
|
||||
def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, np.ndarray | int]: ...
|
||||
|
||||
# Mujoco
|
||||
# ----------------------------------------
|
||||
@overload
|
||||
def make(id: Literal[
|
||||
"Reacher-v2",
|
||||
"Pusher-v2",
|
||||
"Thrower-v2",
|
||||
"Striker-v2",
|
||||
"InvertedPendulum-v2",
|
||||
"InvertedDoublePendulum-v2",
|
||||
"HalfCheetah-v2", "HalfCheetah-v3",
|
||||
"Hopper-v2", "Hopper-v3",
|
||||
"Swimmer-v2", "Swimmer-v3",
|
||||
"Walker2d-v2", "Walker2d-v3",
|
||||
"Ant-v2"
|
||||
], **kwargs) -> Env[np.ndarray, np.ndarray]: ...
|
||||
|
||||
# ----------------------------------------
|
||||
|
||||
@overload
|
||||
def make(id: str, **kwargs) -> "Env": ...
|
||||
# fmt: on
|
||||
def make(id: str, **kwargs) -> "Env":
|
||||
return registry.make(id, **kwargs)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user