mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 01:50:19 +00:00
[WIP] Typing fixes (#2939)
* Update pyproject.toml * Maybe fix typing for frozen lake and taxi * Clean up pyproject a bit * Reenable box2d pyright * Disable box2d pyright * Fix typing for toy text envs * Narrow down pyright excludes for almost everything except mujoco * Fix a wrapper test typing * Remove env checker from excludes * Remove redundant type hints which mess up pyright. And a typo fix. * Remove test_spaces exclude * Remove some redundant warnings * Change pyright ignore to more specific * Fix a weird private variable which gets explicitly exported for whatever reason. It's still exported for backwards compatibility, but internal code now uses a public variable. * re-export the private variable * Fix lunar_lander typing * Remove lunar lander from excludes * Small fix
This commit is contained in:
committed by
GitHub
parent
2017f3ed9e
commit
9c1d288a2d
@@ -220,7 +220,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.isopen = True
|
||||
self.world = Box2D.b2World(gravity=(0, gravity))
|
||||
self.moon = None
|
||||
self.lander = None
|
||||
self.lander: Optional[Box2D.b2Body] = None
|
||||
self.particles = []
|
||||
|
||||
self.prev_reward = None
|
||||
@@ -335,7 +335,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.moon.color2 = (0.0, 0.0, 0.0)
|
||||
|
||||
initial_y = VIEWPORT_H / SCALE
|
||||
self.lander = self.world.CreateDynamicBody(
|
||||
self.lander: Box2D.b2Body = self.world.CreateDynamicBody(
|
||||
position=(VIEWPORT_W / SCALE / 2, initial_y),
|
||||
angle=0.0,
|
||||
fixtures=fixtureDef(
|
||||
@@ -428,6 +428,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
|
||||
def step(self, action):
|
||||
# Update wind
|
||||
assert self.lander is not None, "You forgot to call reset()"
|
||||
if self.enable_wind and not (
|
||||
self.legs[0].ground_contact or self.legs[1].ground_contact
|
||||
):
|
||||
@@ -603,6 +604,10 @@ class LunarLander(gym.Env, EzPickle):
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
assert (
|
||||
self.screen is not None
|
||||
), "Something went wrong with pygame, there is no screen to render"
|
||||
|
||||
self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
|
||||
|
||||
pygame.transform.scale(self.surf, (SCALE, SCALE))
|
||||
|
@@ -30,6 +30,27 @@ MAPS = {
|
||||
}
|
||||
|
||||
|
||||
# DFS to check that it's a valid path.
|
||||
def is_valid(board: List[List[str]], max_size: int) -> bool:
|
||||
frontier, discovered = [], set()
|
||||
frontier.append((0, 0))
|
||||
while frontier:
|
||||
r, c = frontier.pop()
|
||||
if not (r, c) in discovered:
|
||||
discovered.add((r, c))
|
||||
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
|
||||
for x, y in directions:
|
||||
r_new = r + x
|
||||
c_new = c + y
|
||||
if r_new < 0 or r_new >= max_size or c_new < 0 or c_new >= max_size:
|
||||
continue
|
||||
if board[r_new][c_new] == "G":
|
||||
return True
|
||||
if board[r_new][c_new] != "H":
|
||||
frontier.append((r_new, c_new))
|
||||
return False
|
||||
|
||||
|
||||
def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
|
||||
"""Generates a random valid map (one that has a path from start to goal)
|
||||
|
||||
@@ -41,34 +62,15 @@ def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
|
||||
A random valid map
|
||||
"""
|
||||
valid = False
|
||||
|
||||
# DFS to check that it's a valid path.
|
||||
def is_valid(res):
|
||||
frontier, discovered = [], set()
|
||||
frontier.append((0, 0))
|
||||
while frontier:
|
||||
r, c = frontier.pop()
|
||||
if not (r, c) in discovered:
|
||||
discovered.add((r, c))
|
||||
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
|
||||
for x, y in directions:
|
||||
r_new = r + x
|
||||
c_new = c + y
|
||||
if r_new < 0 or r_new >= size or c_new < 0 or c_new >= size:
|
||||
continue
|
||||
if res[r_new][c_new] == "G":
|
||||
return True
|
||||
if res[r_new][c_new] != "H":
|
||||
frontier.append((r_new, c_new))
|
||||
return False
|
||||
board = [] # initialize to make pyright happy
|
||||
|
||||
while not valid:
|
||||
p = min(1, p)
|
||||
res = np.random.choice(["F", "H"], (size, size), p=[p, 1 - p])
|
||||
res[0][0] = "S"
|
||||
res[-1][-1] = "G"
|
||||
valid = is_valid(res)
|
||||
return ["".join(x) for x in res]
|
||||
board = np.random.choice(["F", "H"], (size, size), p=[p, 1 - p])
|
||||
board[0][0] = "S"
|
||||
board[-1][-1] = "G"
|
||||
valid = is_valid(board, size)
|
||||
return ["".join(x) for x in board]
|
||||
|
||||
|
||||
class FrozenLakeEnv(Env):
|
||||
@@ -296,6 +298,11 @@ class FrozenLakeEnv(Env):
|
||||
self.window_surface = pygame.display.set_mode(self.window_size)
|
||||
elif mode in {"rgb_array", "single_rgb_array"}:
|
||||
self.window_surface = pygame.Surface(self.window_size)
|
||||
|
||||
assert (
|
||||
self.window_surface is not None
|
||||
), "Something went wrong with pygame. This should never happen."
|
||||
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
if self.hole_img is None:
|
||||
@@ -349,6 +356,7 @@ class FrozenLakeEnv(Env):
|
||||
start_img = pygame.transform.scale(self.start_img, (small_cell_w, small_cell_h))
|
||||
|
||||
desc = self.desc.tolist()
|
||||
assert isinstance(desc, list), f"desc should be a list or an array, got {desc}"
|
||||
for y in range(self.nrow):
|
||||
for x in range(self.ncol):
|
||||
rect = (x * cell_width, y * cell_height, cell_width, cell_height)
|
||||
|
@@ -307,6 +307,10 @@ class TaxiEnv(Env):
|
||||
self.window = pygame.display.set_mode(WINDOW_SIZE)
|
||||
elif mode in {"rgb_array", "single_rgb_array"}:
|
||||
self.window = pygame.Surface(WINDOW_SIZE)
|
||||
|
||||
assert (
|
||||
self.window is not None
|
||||
), "Something went wrong with pygame. This should never happen."
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
if self.taxi_imgs is None:
|
||||
|
@@ -6,7 +6,7 @@ from gym.vector.utils.shared_memory import (
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space, iterate
|
||||
from gym.vector.utils.spaces import BaseGymSpaces, _BaseGymSpaces, batch_space, iterate
|
||||
|
||||
__all__ = [
|
||||
"CloudpickleWrapper",
|
||||
@@ -16,6 +16,7 @@ __all__ = [
|
||||
"create_shared_memory",
|
||||
"read_from_shared_memory",
|
||||
"write_to_shared_memory",
|
||||
"BaseGymSpaces",
|
||||
"_BaseGymSpaces",
|
||||
"batch_space",
|
||||
"iterate",
|
||||
|
@@ -9,8 +9,9 @@ import numpy as np
|
||||
from gym.error import CustomSpaceError
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
||||
|
||||
_BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
|
||||
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
|
||||
BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
|
||||
_BaseGymSpaces = BaseGymSpaces
|
||||
__all__ = ["BaseGymSpaces", "_BaseGymSpaces", "batch_space", "iterate"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -179,7 +180,7 @@ def _iterate_tuple(space, items):
|
||||
# If this is a tuple of custom subspaces only, then simply iterate over items
|
||||
if all(
|
||||
isinstance(subspace, Space)
|
||||
and (not isinstance(subspace, _BaseGymSpaces + (Tuple, Dict)))
|
||||
and (not isinstance(subspace, BaseGymSpaces + (Tuple, Dict)))
|
||||
for subspace in space.spaces
|
||||
):
|
||||
return iter(items)
|
||||
|
@@ -1,8 +1,6 @@
|
||||
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
|
||||
from typing import Tuple, Union
|
||||
|
||||
import gym
|
||||
from gym.core import ActType, ObsType
|
||||
from gym.core import ActType
|
||||
from gym.utils.passive_env_checker import (
|
||||
check_action_space,
|
||||
check_observation_space,
|
||||
@@ -32,7 +30,7 @@ class PassiveEnvChecker(gym.Wrapper):
|
||||
self.checked_step = False
|
||||
self.checked_render = False
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
||||
if self.checked_step is False:
|
||||
self.checked_step = True
|
||||
@@ -40,7 +38,7 @@ class PassiveEnvChecker(gym.Wrapper):
|
||||
else:
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
||||
if self.checked_reset is False:
|
||||
self.checked_reset = True
|
||||
|
@@ -9,15 +9,40 @@ exclude = [
|
||||
"**/node_modules",
|
||||
"**/__pycache__",
|
||||
|
||||
"gym/envs/box2d/**",
|
||||
# "gym/envs/classic_control/**",
|
||||
"gym/envs/box2d/bipedal_walker.py",
|
||||
"gym/envs/box2d/car_racing.py",
|
||||
|
||||
"gym/spaces/graph.py",
|
||||
|
||||
"gym/envs/mujoco/**",
|
||||
"gym/envs/toy_text/**",
|
||||
"gym/spaces/**",
|
||||
"gym/utils/**",
|
||||
"gym/vector/**",
|
||||
"gym/wrappers/**",
|
||||
"tests/**"
|
||||
"gym/utils/play.py",
|
||||
|
||||
"gym/vector/async_vector_env.py",
|
||||
"gym/vector/utils/__init__.py",
|
||||
|
||||
"gym/wrappers/atari_preprocessing.py",
|
||||
"gym/wrappers/gray_scale_observation.py",
|
||||
"gym/wrappers/human_rendering.py",
|
||||
"gym/wrappers/normalize.py",
|
||||
"gym/wrappers/pixel_observation.py",
|
||||
"gym/wrappers/record_video.py",
|
||||
"gym/wrappers/monitoring/video_recorder.py",
|
||||
"gym/wrappers/resize_observation.py",
|
||||
|
||||
"tests/envs/test_env_implementation.py",
|
||||
"tests/utils/test_play.py",
|
||||
"tests/vector/test_async_vector_env.py",
|
||||
"tests/vector/test_shared_memory.py",
|
||||
"tests/vector/test_spaces.py",
|
||||
"tests/vector/test_sync_vector_env.py",
|
||||
"tests/vector/test_vector_env.py",
|
||||
"tests/wrappers/test_gray_scale_observation.py",
|
||||
"tests/wrappers/test_order_enforcing.py",
|
||||
"tests/wrappers/test_record_episode_statistics.py",
|
||||
"tests/wrappers/test_resize_observation.py",
|
||||
"tests/wrappers/test_time_aware_observation.py",
|
||||
"tests/wrappers/test_video_recorder.py",
|
||||
|
||||
]
|
||||
|
||||
strict = [
|
||||
@@ -41,4 +66,5 @@ reportPrivateUsage = "warning"
|
||||
reportUntypedFunctionDecorator = "none"
|
||||
reportMissingTypeStubs = false
|
||||
reportUnboundVariable = "warning"
|
||||
reportGeneralTypeIssues ="none"
|
||||
reportGeneralTypeIssues = "none"
|
||||
reportInvalidTypeVarUse = "none"
|
@@ -797,9 +797,11 @@ def test_space_legacy_state_pickling():
|
||||
space.__setstate__(legacy_state)
|
||||
|
||||
assert space.shape == legacy_state["shape"]
|
||||
assert space._shape == legacy_state["shape"]
|
||||
assert space._shape == legacy_state["shape"] # pyright: reportPrivateUsage=false
|
||||
assert space.np_random == legacy_state["np_random"]
|
||||
assert space._np_random == legacy_state["np_random"]
|
||||
assert (
|
||||
space._np_random == legacy_state["np_random"]
|
||||
) # pyright: reportPrivateUsage=false
|
||||
assert space.n == 3
|
||||
assert space.dtype == legacy_state["dtype"]
|
||||
|
||||
|
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from gym.spaces import Dict, Tuple
|
||||
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces
|
||||
from gym.vector.utils.spaces import BaseGymSpaces
|
||||
from tests.vector.utils import spaces
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ def test_concatenate(space):
|
||||
)
|
||||
def test_create_empty_array(space, n):
|
||||
def assert_nested_type(arr, space, n):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
if isinstance(space, BaseGymSpaces):
|
||||
assert isinstance(arr, np.ndarray)
|
||||
assert arr.dtype == space.dtype
|
||||
assert arr.shape == (n,) + space.shape
|
||||
@@ -87,7 +87,7 @@ def test_create_empty_array(space, n):
|
||||
)
|
||||
def test_create_empty_array_zeros(space, n):
|
||||
def assert_nested_type(arr, space, n):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
if isinstance(space, BaseGymSpaces):
|
||||
assert isinstance(arr, np.ndarray)
|
||||
assert arr.dtype == space.dtype
|
||||
assert arr.shape == (n,) + space.shape
|
||||
@@ -117,7 +117,7 @@ def test_create_empty_array_zeros(space, n):
|
||||
)
|
||||
def test_create_empty_array_none_shape_ones(space):
|
||||
def assert_nested_type(arr, space):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
if isinstance(space, BaseGymSpaces):
|
||||
assert isinstance(arr, np.ndarray)
|
||||
assert arr.dtype == space.dtype
|
||||
assert arr.shape == space.shape
|
||||
|
@@ -13,7 +13,7 @@ from gym.vector.utils.shared_memory import (
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces
|
||||
from gym.vector.utils.spaces import BaseGymSpaces
|
||||
from tests.vector.utils import custom_spaces, spaces
|
||||
|
||||
expected_types = [
|
||||
@@ -147,7 +147,7 @@ def test_read_from_shared_memory(space):
|
||||
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
|
||||
)
|
||||
|
||||
elif isinstance(space, _BaseGymSpaces):
|
||||
elif isinstance(space, BaseGymSpaces):
|
||||
assert isinstance(lhs, np.ndarray)
|
||||
assert lhs.shape == ((n,) + space.shape)
|
||||
assert lhs.dtype == space.dtype
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -9,7 +9,9 @@ from gym.wrappers.filter_observation import FilterObservation
|
||||
|
||||
|
||||
class FakeEnvironment(gym.Env):
|
||||
def __init__(self, render_mode=None, observation_keys=("state")):
|
||||
def __init__(
|
||||
self, render_mode=None, observation_keys: Tuple[str, ...] = ("state",)
|
||||
):
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
|
||||
@@ -23,10 +25,16 @@ class FakeEnvironment(gym.Env):
|
||||
image_shape = (32, 32, 3)
|
||||
return np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
return observation
|
||||
return observation if not return_info else (observation, {})
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
@@ -80,7 +88,5 @@ class TestFilterObservation:
|
||||
):
|
||||
env = FakeEnvironment(observation_keys=("key1", "key2"))
|
||||
|
||||
ValueError
|
||||
|
||||
with pytest.raises(error_type, match=error_match):
|
||||
FilterObservation(env, filter_keys=filter_keys)
|
||||
|
Reference in New Issue
Block a user