[WIP] Typing fixes (#2939)

* Update pyproject.toml

* Maybe fix typing for frozen lake and taxi

* Clean up pyproject a bit

* Reenable box2d pyright

* Disable box2d pyright

* Fix typing for toy text envs

* Narrow down pyright excludes for almost everything except mujoco

* Fix a wrapper test typing

* Remove env checker from excludes

* Remove redundant type hints which mess up pyright. And a typo fix.

* Remove test_spaces exclude

* Remove some redundant warnings

* Change pyright ignore to more specific

* Fix a weird private variable which gets explicitly exported for whatever reason. It's still exported for backwards compatibility, but internal code now uses a public variable.

* re-export the private variable

* Fix lunar_lander typing

* Remove lunar lander from excludes

* Small fix
This commit is contained in:
Ariel Kwiatkowski
2022-06-30 18:04:14 +02:00
committed by GitHub
parent 2017f3ed9e
commit 9c1d288a2d
11 changed files with 110 additions and 59 deletions

View File

@@ -220,7 +220,7 @@ class LunarLander(gym.Env, EzPickle):
self.isopen = True
self.world = Box2D.b2World(gravity=(0, gravity))
self.moon = None
self.lander = None
self.lander: Optional[Box2D.b2Body] = None
self.particles = []
self.prev_reward = None
@@ -335,7 +335,7 @@ class LunarLander(gym.Env, EzPickle):
self.moon.color2 = (0.0, 0.0, 0.0)
initial_y = VIEWPORT_H / SCALE
self.lander = self.world.CreateDynamicBody(
self.lander: Box2D.b2Body = self.world.CreateDynamicBody(
position=(VIEWPORT_W / SCALE / 2, initial_y),
angle=0.0,
fixtures=fixtureDef(
@@ -428,6 +428,7 @@ class LunarLander(gym.Env, EzPickle):
def step(self, action):
# Update wind
assert self.lander is not None, "You forgot to call reset()"
if self.enable_wind and not (
self.legs[0].ground_contact or self.legs[1].ground_contact
):
@@ -603,6 +604,10 @@ class LunarLander(gym.Env, EzPickle):
if self.clock is None:
self.clock = pygame.time.Clock()
assert (
self.screen is not None
), "Something went wrong with pygame, there is no screen to render"
self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
pygame.transform.scale(self.surf, (SCALE, SCALE))

View File

@@ -30,6 +30,27 @@ MAPS = {
}
# DFS to check that it's a valid path.
def is_valid(board: List[List[str]], max_size: int) -> bool:
frontier, discovered = [], set()
frontier.append((0, 0))
while frontier:
r, c = frontier.pop()
if not (r, c) in discovered:
discovered.add((r, c))
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
for x, y in directions:
r_new = r + x
c_new = c + y
if r_new < 0 or r_new >= max_size or c_new < 0 or c_new >= max_size:
continue
if board[r_new][c_new] == "G":
return True
if board[r_new][c_new] != "H":
frontier.append((r_new, c_new))
return False
def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
"""Generates a random valid map (one that has a path from start to goal)
@@ -41,34 +62,15 @@ def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
A random valid map
"""
valid = False
# DFS to check that it's a valid path.
def is_valid(res):
frontier, discovered = [], set()
frontier.append((0, 0))
while frontier:
r, c = frontier.pop()
if not (r, c) in discovered:
discovered.add((r, c))
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
for x, y in directions:
r_new = r + x
c_new = c + y
if r_new < 0 or r_new >= size or c_new < 0 or c_new >= size:
continue
if res[r_new][c_new] == "G":
return True
if res[r_new][c_new] != "H":
frontier.append((r_new, c_new))
return False
board = [] # initialize to make pyright happy
while not valid:
p = min(1, p)
res = np.random.choice(["F", "H"], (size, size), p=[p, 1 - p])
res[0][0] = "S"
res[-1][-1] = "G"
valid = is_valid(res)
return ["".join(x) for x in res]
board = np.random.choice(["F", "H"], (size, size), p=[p, 1 - p])
board[0][0] = "S"
board[-1][-1] = "G"
valid = is_valid(board, size)
return ["".join(x) for x in board]
class FrozenLakeEnv(Env):
@@ -296,6 +298,11 @@ class FrozenLakeEnv(Env):
self.window_surface = pygame.display.set_mode(self.window_size)
elif mode in {"rgb_array", "single_rgb_array"}:
self.window_surface = pygame.Surface(self.window_size)
assert (
self.window_surface is not None
), "Something went wrong with pygame. This should never happen."
if self.clock is None:
self.clock = pygame.time.Clock()
if self.hole_img is None:
@@ -349,6 +356,7 @@ class FrozenLakeEnv(Env):
start_img = pygame.transform.scale(self.start_img, (small_cell_w, small_cell_h))
desc = self.desc.tolist()
assert isinstance(desc, list), f"desc should be a list or an array, got {desc}"
for y in range(self.nrow):
for x in range(self.ncol):
rect = (x * cell_width, y * cell_height, cell_width, cell_height)

View File

@@ -307,6 +307,10 @@ class TaxiEnv(Env):
self.window = pygame.display.set_mode(WINDOW_SIZE)
elif mode in {"rgb_array", "single_rgb_array"}:
self.window = pygame.Surface(WINDOW_SIZE)
assert (
self.window is not None
), "Something went wrong with pygame. This should never happen."
if self.clock is None:
self.clock = pygame.time.Clock()
if self.taxi_imgs is None:

View File

@@ -6,7 +6,7 @@ from gym.vector.utils.shared_memory import (
read_from_shared_memory,
write_to_shared_memory,
)
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space, iterate
from gym.vector.utils.spaces import BaseGymSpaces, _BaseGymSpaces, batch_space, iterate
__all__ = [
"CloudpickleWrapper",
@@ -16,6 +16,7 @@ __all__ = [
"create_shared_memory",
"read_from_shared_memory",
"write_to_shared_memory",
"BaseGymSpaces",
"_BaseGymSpaces",
"batch_space",
"iterate",

View File

@@ -9,8 +9,9 @@ import numpy as np
from gym.error import CustomSpaceError
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
_BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
_BaseGymSpaces = BaseGymSpaces
__all__ = ["BaseGymSpaces", "_BaseGymSpaces", "batch_space", "iterate"]
@singledispatch
@@ -179,7 +180,7 @@ def _iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
and (not isinstance(subspace, _BaseGymSpaces + (Tuple, Dict)))
and (not isinstance(subspace, BaseGymSpaces + (Tuple, Dict)))
for subspace in space.spaces
):
return iter(items)

View File

@@ -1,8 +1,6 @@
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
from typing import Tuple, Union
import gym
from gym.core import ActType, ObsType
from gym.core import ActType
from gym.utils.passive_env_checker import (
check_action_space,
check_observation_space,
@@ -32,7 +30,7 @@ class PassiveEnvChecker(gym.Wrapper):
self.checked_step = False
self.checked_render = False
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
def step(self, action: ActType):
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self.checked_step is False:
self.checked_step = True
@@ -40,7 +38,7 @@ class PassiveEnvChecker(gym.Wrapper):
else:
return self.env.step(action)
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
def reset(self, **kwargs):
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self.checked_reset is False:
self.checked_reset = True

View File

@@ -9,15 +9,40 @@ exclude = [
"**/node_modules",
"**/__pycache__",
"gym/envs/box2d/**",
# "gym/envs/classic_control/**",
"gym/envs/box2d/bipedal_walker.py",
"gym/envs/box2d/car_racing.py",
"gym/spaces/graph.py",
"gym/envs/mujoco/**",
"gym/envs/toy_text/**",
"gym/spaces/**",
"gym/utils/**",
"gym/vector/**",
"gym/wrappers/**",
"tests/**"
"gym/utils/play.py",
"gym/vector/async_vector_env.py",
"gym/vector/utils/__init__.py",
"gym/wrappers/atari_preprocessing.py",
"gym/wrappers/gray_scale_observation.py",
"gym/wrappers/human_rendering.py",
"gym/wrappers/normalize.py",
"gym/wrappers/pixel_observation.py",
"gym/wrappers/record_video.py",
"gym/wrappers/monitoring/video_recorder.py",
"gym/wrappers/resize_observation.py",
"tests/envs/test_env_implementation.py",
"tests/utils/test_play.py",
"tests/vector/test_async_vector_env.py",
"tests/vector/test_shared_memory.py",
"tests/vector/test_spaces.py",
"tests/vector/test_sync_vector_env.py",
"tests/vector/test_vector_env.py",
"tests/wrappers/test_gray_scale_observation.py",
"tests/wrappers/test_order_enforcing.py",
"tests/wrappers/test_record_episode_statistics.py",
"tests/wrappers/test_resize_observation.py",
"tests/wrappers/test_time_aware_observation.py",
"tests/wrappers/test_video_recorder.py",
]
strict = [
@@ -41,4 +66,5 @@ reportPrivateUsage = "warning"
reportUntypedFunctionDecorator = "none"
reportMissingTypeStubs = false
reportUnboundVariable = "warning"
reportGeneralTypeIssues ="none"
reportGeneralTypeIssues = "none"
reportInvalidTypeVarUse = "none"

View File

@@ -797,9 +797,11 @@ def test_space_legacy_state_pickling():
space.__setstate__(legacy_state)
assert space.shape == legacy_state["shape"]
assert space._shape == legacy_state["shape"]
assert space._shape == legacy_state["shape"] # pyright: reportPrivateUsage=false
assert space.np_random == legacy_state["np_random"]
assert space._np_random == legacy_state["np_random"]
assert (
space._np_random == legacy_state["np_random"]
) # pyright: reportPrivateUsage=false
assert space.n == 3
assert space.dtype == legacy_state["dtype"]

View File

@@ -5,7 +5,7 @@ import pytest
from gym.spaces import Dict, Tuple
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
from gym.vector.utils.spaces import _BaseGymSpaces
from gym.vector.utils.spaces import BaseGymSpaces
from tests.vector.utils import spaces
@@ -57,7 +57,7 @@ def test_concatenate(space):
)
def test_create_empty_array(space, n):
def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == (n,) + space.shape
@@ -87,7 +87,7 @@ def test_create_empty_array(space, n):
)
def test_create_empty_array_zeros(space, n):
def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == (n,) + space.shape
@@ -117,7 +117,7 @@ def test_create_empty_array_zeros(space, n):
)
def test_create_empty_array_none_shape_ones(space):
def assert_nested_type(arr, space):
if isinstance(space, _BaseGymSpaces):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == space.shape

View File

@@ -13,7 +13,7 @@ from gym.vector.utils.shared_memory import (
read_from_shared_memory,
write_to_shared_memory,
)
from gym.vector.utils.spaces import _BaseGymSpaces
from gym.vector.utils.spaces import BaseGymSpaces
from tests.vector.utils import custom_spaces, spaces
expected_types = [
@@ -147,7 +147,7 @@ def test_read_from_shared_memory(space):
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
)
elif isinstance(space, _BaseGymSpaces):
elif isinstance(space, BaseGymSpaces):
assert isinstance(lhs, np.ndarray)
assert lhs.shape == ((n,) + space.shape)
assert lhs.dtype == space.dtype

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple
import numpy as np
import pytest
@@ -9,7 +9,9 @@ from gym.wrappers.filter_observation import FilterObservation
class FakeEnvironment(gym.Env):
def __init__(self, render_mode=None, observation_keys=("state")):
def __init__(
self, render_mode=None, observation_keys: Tuple[str, ...] = ("state",)
):
self.observation_space = spaces.Dict(
{
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
@@ -23,10 +25,16 @@ class FakeEnvironment(gym.Env):
image_shape = (32, 32, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation
return observation if not return_info else (observation, {})
def step(self, action):
del action
@@ -80,7 +88,5 @@ class TestFilterObservation:
):
env = FakeEnvironment(observation_keys=("key1", "key2"))
ValueError
with pytest.raises(error_type, match=error_match):
FilterObservation(env, filter_keys=filter_keys)