Re-enable reportGeneralTypeIssues typing rule (#1391)

This commit is contained in:
James Mochizuki-Freeman
2025-06-07 10:31:31 -04:00
committed by GitHub
parent f20e3f4845
commit 428493e584
16 changed files with 118 additions and 70 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any
from typing import Any, Generic, TypeAlias
import jax
import jax.numpy as jnp
@@ -10,17 +10,20 @@ import jax.random as jrng
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv, ObsType, StateType
from gymnasium.utils import seeding
from gymnasium.vector import AutoresetMode
from gymnasium.vector.utils import batch_space
class FunctionalJaxEnv(gym.Env):
PRNGKeyType: TypeAlias = jax.Array
class FunctionalJaxEnv(gym.Env, Generic[StateType]):
"""A conversion layer for jax-based environments."""
state: StateType
rng: jrng.PRNGKey
rng: PRNGKeyType
def __init__(
self,
@@ -98,15 +101,17 @@ class FunctionalJaxEnv(gym.Env):
self.render_state = None
class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
class FunctionalJaxVectorEnv(
gym.vector.VectorEnv[ObsType, ActType, Any], Generic[ObsType, ActType, StateType]
):
"""A vector env implementation for functional Jax envs."""
state: StateType
rng: jrng.PRNGKey
rng: PRNGKeyType
def __init__(
self,
func_env: FuncEnv,
func_env: FuncEnv[StateType, ObsType, ActType, Any, Any, Any, Any],
num_envs: int,
max_episode_steps: int = 0,
metadata: dict[str, Any] | None = None,

View File

@@ -2,22 +2,23 @@
from __future__ import annotations
from typing import Any, Tuple
from typing import Any, Tuple, TypeAlias
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode
PRNGKeyType: TypeAlias = jax.Array
StateType: TypeAlias = jax.Array
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
@@ -43,14 +44,16 @@ class CartPoleParams:
class CartPoleFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
):
"""Cartpole but in jax and functional."""
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)
def initial(self, rng: PRNGKey, params: CartPoleParams = CartPoleParams):
def initial(
self, rng: PRNGKeyType, params: CartPoleParams = CartPoleParams
) -> StateType:
"""Initial state generation."""
return jax.random.uniform(
key=rng, minval=-params.x_init, maxval=params.x_init, shape=(4,)
@@ -58,7 +61,7 @@ class CartPoleFunctional(
def transition(
self,
state: jax.Array,
state: StateType,
action: int | jax.Array,
rng: None = None,
params: CartPoleParams = CartPoleParams,
@@ -90,13 +93,13 @@ class CartPoleFunctional(
return state
def observation(
self, state: jax.Array, rng: Any, params: CartPoleParams = CartPoleParams
self, state: StateType, rng: Any, params: CartPoleParams = CartPoleParams
) -> jax.Array:
"""Cartpole observation."""
return state
def terminal(
self, state: jax.Array, rng: Any, params: CartPoleParams = CartPoleParams
self, state: StateType, rng: Any, params: CartPoleParams = CartPoleParams
) -> jax.Array:
"""Checks if the state is terminal."""
x, _, theta, _ = state

View File

@@ -3,22 +3,23 @@
from __future__ import annotations
from os import path
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, TypeAlias
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode
PRNGKeyType: TypeAlias = jax.Array
StateType: TypeAlias = jax.Array
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
@@ -37,7 +38,7 @@ class PendulumParams:
class PendulumFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, PendulumParams]
FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, PendulumParams]
):
"""Pendulum but in jax and functional structure."""
@@ -46,18 +47,20 @@ class PendulumFunctional(
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
def initial(self, rng: PRNGKey, params: PendulumParams = PendulumParams):
def initial(
self, rng: PRNGKeyType, params: PendulumParams = PendulumParams
) -> StateType:
"""Initial state generation."""
high = jnp.array([params.high_x, params.high_y])
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
def transition(
self,
state: jax.Array,
state: StateType,
action: int | jax.Array,
rng: None = None,
params: PendulumParams = PendulumParams,
) -> jax.Array:
) -> StateType:
"""Pendulum transition."""
th, thdot = state # th := theta
u = action
@@ -77,7 +80,7 @@ class PendulumFunctional(
return new_state
def observation(
self, state: jax.Array, rng: Any, params: PendulumParams = PendulumParams
self, state: StateType, rng: Any, params: PendulumParams = PendulumParams
) -> jax.Array:
"""Generates an observation based on the state."""
theta, thetadot = state

View File

@@ -2,14 +2,13 @@
import math
import os
from typing import NamedTuple, Optional, Tuple, Union
from typing import NamedTuple, Optional, Tuple, TypeAlias, Union
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax import random
from jax.random import PRNGKey
from gymnasium import spaces
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
@@ -20,6 +19,7 @@ from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering
PRNGKeyType: TypeAlias = jax.Array
RenderStateType = Tuple["pygame.Surface", str, int] # type: ignore # noqa: F821
@@ -168,7 +168,7 @@ class BlackJackParams:
class BlackjackFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, BlackJackParams]
FuncEnv[EnvState, jax.Array, int, float, bool, RenderStateType, BlackJackParams]
):
"""Blackjack is a card game where the goal is to beat the dealer by obtaining cards that sum to closer to 21 (without going over 21) than the dealers cards.
@@ -247,9 +247,9 @@ class BlackjackFunctional(
self,
state: EnvState,
action: Union[int, jax.Array],
key: PRNGKey,
key: PRNGKeyType,
params: BlackJackParams = BlackJackParams,
):
) -> EnvState:
"""The blackjack environment's state transition function."""
env_state = jax.lax.cond(action, take, notake, (state, key))
@@ -273,7 +273,9 @@ class BlackjackFunctional(
return new_state
def initial(self, rng: PRNGKey, params: BlackJackParams = BlackJackParams):
def initial(
self, rng: PRNGKeyType, params: BlackJackParams = BlackJackParams
) -> EnvState:
"""Blackjack initial observataion function."""
player_hand = jnp.zeros(21)
dealer_hand = jnp.zeros(21)
@@ -293,7 +295,10 @@ class BlackjackFunctional(
return state
def observation(
self, state: EnvState, rng: PRNGKey, params: BlackJackParams = BlackJackParams
self,
state: EnvState,
rng: PRNGKeyType,
params: BlackJackParams = BlackJackParams,
) -> jax.Array:
"""Blackjack observation."""
return jnp.array(
@@ -306,7 +311,10 @@ class BlackjackFunctional(
)
def terminal(
self, state: EnvState, rng: PRNGKey, params: BlackJackParams = BlackJackParams
self,
state: EnvState,
rng: PRNGKeyType,
params: BlackJackParams = BlackJackParams,
) -> jax.Array:
"""Determines if a particular Blackjack observation is terminal."""
return (state.done) > 0
@@ -315,8 +323,8 @@ class BlackjackFunctional(
self,
state: EnvState,
action: ActType,
next_state: StateType,
rng: PRNGKey,
next_state: EnvState,
rng: PRNGKeyType,
params: BlackJackParams = BlackJackParams,
) -> jax.Array:
"""Calculates reward from a state."""

View File

@@ -3,17 +3,16 @@
from __future__ import annotations
from os import path
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
from gymnasium import spaces
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional import ActType, FuncEnv
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering
@@ -26,7 +25,7 @@ if TYPE_CHECKING:
class RenderStateType(NamedTuple):
"""A named tuple which contains the full render state of the Cliffwalking Env. This is static during the episode."""
screen: pygame.surface
screen: pygame.Surface
shape: tuple[int, int]
nS: int
cell_size: tuple[int, int]
@@ -47,11 +46,14 @@ class RenderStateType(NamedTuple):
class EnvState(NamedTuple):
"""A named tuple which contains the full state of the Cliffwalking game."""
player_position: jnp.array
player_position: jax.Array
last_action: int
fallen: bool
PRNGKeyType: TypeAlias = jax.Array
def fell_off(player_position):
"""Checks to see if the player_position means the player has fallen of the cliff."""
return (
@@ -62,7 +64,7 @@ def fell_off(player_position):
class CliffWalkingFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, None]
FuncEnv[EnvState, jax.Array, int, float, bool, RenderStateType, None]
):
"""Cliff walking involves crossing a gridworld from start to goal while avoiding falling off a cliff.
@@ -144,9 +146,9 @@ class CliffWalkingFunctional(
self,
state: EnvState,
action: int | jax.Array,
key: PRNGKey,
key: PRNGKeyType,
params: None = None,
):
) -> EnvState:
"""The Cliffwalking environment's state transition function."""
new_position = state.player_position
@@ -182,14 +184,14 @@ class CliffWalkingFunctional(
return new_state
def initial(self, rng: PRNGKey, params: None = None) -> EnvState:
def initial(self, rng: PRNGKeyType, params: None = None) -> EnvState:
"""Cliffwalking initial observation function."""
player_position = jnp.array([3, 0])
state = EnvState(player_position=player_position, last_action=-1, fallen=False)
return state
def observation(self, state: EnvState, params: None = None) -> int:
def observation(self, state: EnvState, params: None = None) -> jax.Array:
"""Cliffwalking observation."""
return jnp.array(
state.player_position[0] * 12 + state.player_position[1]
@@ -203,7 +205,7 @@ class CliffWalkingFunctional(
self,
state: EnvState,
action: ActType,
next_state: StateType,
next_state: EnvState,
params: None = None,
) -> jax.Array:
"""Calculates reward from a state."""
@@ -296,7 +298,7 @@ class CliffWalkingFunctional(
)
def render_image(
self, state: StateType, render_state: RenderStateType, params: None = None
self, state: EnvState, render_state: RenderStateType, params: None = None
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an image from a state."""
try:

View File

@@ -39,13 +39,17 @@ class MultiBinary(Space[NDArray[np.int8]]):
or some sort of sequence (tuple, list or np.ndarray) if there are multiple axes.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
"""
if isinstance(n, (Sequence, np.ndarray)):
self.n = input_n = tuple(int(i) for i in n)
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
else:
if isinstance(n, int):
self.n = n = int(n)
input_n = (n,)
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
elif isinstance(n, (Sequence, np.ndarray)):
self.n = input_n = tuple(int(i) for i in n)
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
else:
raise ValueError(
f"Expected n to be an int or a sequence of ints, actual type: {type(n)}"
)
super().__init__(input_n, np.int8, seed)

View File

@@ -271,6 +271,8 @@ def play(
key_code_to_action = {}
for key_combination, action in keys_to_action.items():
if isinstance(key_combination, int):
key_combination = (key_combination,)
key_code = tuple(
sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
)

View File

@@ -250,7 +250,7 @@ class AsyncVectorEnv(VectorEnv):
def reset(
self,
*,
seed: int | list[int] | None = None,
seed: int | list[int | None] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets all sub-environments in parallel and return a batch of concatenated observations and info.
@@ -267,7 +267,7 @@ class AsyncVectorEnv(VectorEnv):
def reset_async(
self,
seed: int | list[int] | None = None,
seed: int | list[int | None] | None = None,
options: dict | None = None,
):
"""Send calls to the :obj:`reset` methods of the sub-environments.
@@ -719,7 +719,7 @@ class AsyncVectorEnv(VectorEnv):
def _async_worker(
index: int,
env_fn: callable,
env_fn: Callable,
pipe: Connection,
parent_pipe: Connection,
shared_memory: SynchronizedArray | dict[str, Any] | tuple[Any, ...],

View File

@@ -164,7 +164,7 @@ class SyncVectorEnv(VectorEnv):
def reset(
self,
*,
seed: int | list[int] | None = None,
seed: int | list[int | None] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets each of the sub-environments and concatenate the results together.

View File

@@ -10,6 +10,7 @@
from __future__ import annotations
import typing
from collections.abc import Callable
from copy import deepcopy
from functools import singledispatch
from typing import Any, Iterable, Iterator
@@ -428,7 +429,7 @@ def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any,
@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
space: Space, n: int = 1, fn: Callable = np.zeros
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.

View File

@@ -33,9 +33,9 @@ __all__ = [
class AutoresetMode(Enum):
"""Enum representing the different autoreset modes, next step, same step and disabled."""
NEXT_STEP: str = "NextStep"
SAME_STEP: str = "SameStep"
DISABLED: str = "Disabled"
NEXT_STEP = "NextStep"
SAME_STEP = "SameStep"
DISABLED = "Disabled"
class VectorEnv(Generic[ObsType, ActType, ArrayType]):

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
import gc
import os
from copy import deepcopy
from typing import Any, Callable, List, SupportsFloat
from typing import Any, Callable, Generic, List, SupportsFloat
import numpy as np
@@ -32,7 +32,9 @@ __all__ = [
class RenderCollection(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
gym.Wrapper[ObsType, ActType, ObsType, ActType],
Generic[ObsType, ActType, RenderFrame],
gym.utils.RecordConstructorArgs,
):
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.
@@ -159,7 +161,9 @@ class RenderCollection(
class RecordVideo(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
gym.Wrapper[ObsType, ActType, ObsType, ActType],
Generic[ObsType, ActType, RenderFrame],
gym.utils.RecordConstructorArgs,
):
"""Records videos of environment episodes using the environment's render function.

View File

@@ -91,7 +91,7 @@ class NormalizeReward(
gym.Wrapper.__init__(self, env)
self.return_rms = RunningMeanStd(shape=())
self.discounted_reward: np.array = np.array([0.0])
self.discounted_reward = np.array([0.0])
self.gamma = gamma
self.epsilon = epsilon
self._update_running_mean = True

View File

@@ -82,7 +82,7 @@ class NormalizeReward(VectorWrapper, gym.utils.RecordConstructorArgs):
VectorWrapper.__init__(self, env)
self.return_rms = RunningMeanStd(shape=())
self.accumulated_reward: np.array = np.zeros((self.num_envs,), dtype=np.float32)
self.accumulated_reward = np.zeros((self.num_envs,), dtype=np.float32)
self.gamma = gamma
self.epsilon = epsilon
self._update_running_mean = True

View File

@@ -151,7 +151,6 @@ reportMissingTypeStubs = false
# For warning and error, will raise an error when
reportInvalidTypeVarUse = "none"
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
reportAttributeAccessIssue = "none" # pyright provides false positives
reportArgumentType = "none" # pyright provides false positives

View File

@@ -48,11 +48,15 @@ class PlayStatus:
self.last_observation = obs_tp1
def dummy_keys_to_action():
def dummy_keys_to_action() -> dict[tuple[int], int]:
return {(RELEVANT_KEY_1,): 0, (RELEVANT_KEY_2,): 1}
def dummy_keys_to_action_str():
def dummy_keys_to_action_int() -> dict[int, int]:
return {RELEVANT_KEY_1: 0, RELEVANT_KEY_2: 1}
def dummy_keys_to_action_str() -> dict[str, int]:
"""{'a': 0, 'd': 1}"""
return {chr(RELEVANT_KEY_1): 0, chr(RELEVANT_KEY_2): 1}
@@ -147,7 +151,7 @@ def test_play_loop_real_env():
# If apply_wrapper is true, we provide keys_to_action through the environment. If str_keys is true, the
# keys_to_action dictionary will have strings as keys
for apply_wrapper, str_keys in product([False, True], [False, True]):
for apply_wrapper, key_type in product([False, True], ["str", "int", "tuple"]):
# set of key events to inject into the play loop as callback
callback_events = [
Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
@@ -178,15 +182,28 @@ def test_play_loop_real_env():
env = gym.make(ENV, render_mode="rgb_array", disable_env_checker=True)
env.reset(seed=SEED)
keys_to_action = (
dummy_keys_to_action_str() if str_keys else dummy_keys_to_action()
)
if key_type == "tuple":
keys_to_action = dummy_keys_to_action()
elif key_type == "str":
keys_to_action = dummy_keys_to_action_str()
elif key_type == "int":
keys_to_action = dummy_keys_to_action_int()
else:
assert False
# first action is 0 because at the first iteration
# we can not inject a callback event into play()
obs, _, _, _, _ = env.step(0)
for e in keydown_events:
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
if key_type == "tuple":
action = keys_to_action[(e.key,)]
elif key_type == "str":
action = keys_to_action[chr(e.key)]
elif key_type == "int":
action = keys_to_action[e.key]
else:
assert False
obs, _, _, _, _ = env.step(action)
env_play = gym.make(ENV, render_mode="rgb_array", disable_env_checker=True)