mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 14:26:33 +00:00
Re-enable reportGeneralTypeIssues
typing rule (#1391)
This commit is contained in:
committed by
GitHub
parent
f20e3f4845
commit
428493e584
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeAlias
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -10,17 +10,20 @@ import jax.random as jrng
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, ObsType, StateType
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.vector import AutoresetMode
|
||||
from gymnasium.vector.utils import batch_space
|
||||
|
||||
|
||||
class FunctionalJaxEnv(gym.Env):
|
||||
PRNGKeyType: TypeAlias = jax.Array
|
||||
|
||||
|
||||
class FunctionalJaxEnv(gym.Env, Generic[StateType]):
|
||||
"""A conversion layer for jax-based environments."""
|
||||
|
||||
state: StateType
|
||||
rng: jrng.PRNGKey
|
||||
rng: PRNGKeyType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -98,15 +101,17 @@ class FunctionalJaxEnv(gym.Env):
|
||||
self.render_state = None
|
||||
|
||||
|
||||
class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
class FunctionalJaxVectorEnv(
|
||||
gym.vector.VectorEnv[ObsType, ActType, Any], Generic[ObsType, ActType, StateType]
|
||||
):
|
||||
"""A vector env implementation for functional Jax envs."""
|
||||
|
||||
state: StateType
|
||||
rng: jrng.PRNGKey
|
||||
rng: PRNGKeyType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func_env: FuncEnv,
|
||||
func_env: FuncEnv[StateType, ObsType, ActType, Any, Any, Any, Any],
|
||||
num_envs: int,
|
||||
max_episode_steps: int = 0,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
|
@@ -2,22 +2,23 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Tuple, TypeAlias
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax import struct
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv
|
||||
from gymnasium.utils import EzPickle
|
||||
from gymnasium.vector import AutoresetMode
|
||||
|
||||
|
||||
PRNGKeyType: TypeAlias = jax.Array
|
||||
StateType: TypeAlias = jax.Array
|
||||
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
|
||||
|
||||
|
||||
@@ -43,14 +44,16 @@ class CartPoleParams:
|
||||
|
||||
|
||||
class CartPoleFunctional(
|
||||
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
|
||||
FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
|
||||
):
|
||||
"""Cartpole but in jax and functional."""
|
||||
|
||||
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
|
||||
action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def initial(self, rng: PRNGKey, params: CartPoleParams = CartPoleParams):
|
||||
def initial(
|
||||
self, rng: PRNGKeyType, params: CartPoleParams = CartPoleParams
|
||||
) -> StateType:
|
||||
"""Initial state generation."""
|
||||
return jax.random.uniform(
|
||||
key=rng, minval=-params.x_init, maxval=params.x_init, shape=(4,)
|
||||
@@ -58,7 +61,7 @@ class CartPoleFunctional(
|
||||
|
||||
def transition(
|
||||
self,
|
||||
state: jax.Array,
|
||||
state: StateType,
|
||||
action: int | jax.Array,
|
||||
rng: None = None,
|
||||
params: CartPoleParams = CartPoleParams,
|
||||
@@ -90,13 +93,13 @@ class CartPoleFunctional(
|
||||
return state
|
||||
|
||||
def observation(
|
||||
self, state: jax.Array, rng: Any, params: CartPoleParams = CartPoleParams
|
||||
self, state: StateType, rng: Any, params: CartPoleParams = CartPoleParams
|
||||
) -> jax.Array:
|
||||
"""Cartpole observation."""
|
||||
return state
|
||||
|
||||
def terminal(
|
||||
self, state: jax.Array, rng: Any, params: CartPoleParams = CartPoleParams
|
||||
self, state: StateType, rng: Any, params: CartPoleParams = CartPoleParams
|
||||
) -> jax.Array:
|
||||
"""Checks if the state is terminal."""
|
||||
x, _, theta, _ = state
|
||||
|
@@ -3,22 +3,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from os import path
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Optional, Tuple, TypeAlias
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax import struct
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv
|
||||
from gymnasium.utils import EzPickle
|
||||
from gymnasium.vector import AutoresetMode
|
||||
|
||||
|
||||
PRNGKeyType: TypeAlias = jax.Array
|
||||
StateType: TypeAlias = jax.Array
|
||||
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
|
||||
|
||||
|
||||
@@ -37,7 +38,7 @@ class PendulumParams:
|
||||
|
||||
|
||||
class PendulumFunctional(
|
||||
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, PendulumParams]
|
||||
FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, PendulumParams]
|
||||
):
|
||||
"""Pendulum but in jax and functional structure."""
|
||||
|
||||
@@ -46,18 +47,20 @@ class PendulumFunctional(
|
||||
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
|
||||
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
|
||||
|
||||
def initial(self, rng: PRNGKey, params: PendulumParams = PendulumParams):
|
||||
def initial(
|
||||
self, rng: PRNGKeyType, params: PendulumParams = PendulumParams
|
||||
) -> StateType:
|
||||
"""Initial state generation."""
|
||||
high = jnp.array([params.high_x, params.high_y])
|
||||
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
|
||||
|
||||
def transition(
|
||||
self,
|
||||
state: jax.Array,
|
||||
state: StateType,
|
||||
action: int | jax.Array,
|
||||
rng: None = None,
|
||||
params: PendulumParams = PendulumParams,
|
||||
) -> jax.Array:
|
||||
) -> StateType:
|
||||
"""Pendulum transition."""
|
||||
th, thdot = state # th := theta
|
||||
u = action
|
||||
@@ -77,7 +80,7 @@ class PendulumFunctional(
|
||||
return new_state
|
||||
|
||||
def observation(
|
||||
self, state: jax.Array, rng: Any, params: PendulumParams = PendulumParams
|
||||
self, state: StateType, rng: Any, params: PendulumParams = PendulumParams
|
||||
) -> jax.Array:
|
||||
"""Generates an observation based on the state."""
|
||||
theta, thetadot = state
|
||||
|
@@ -2,14 +2,13 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import NamedTuple, Optional, Tuple, Union
|
||||
from typing import NamedTuple, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax import struct
|
||||
from jax import random
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
|
||||
@@ -20,6 +19,7 @@ from gymnasium.vector import AutoresetMode
|
||||
from gymnasium.wrappers import HumanRendering
|
||||
|
||||
|
||||
PRNGKeyType: TypeAlias = jax.Array
|
||||
RenderStateType = Tuple["pygame.Surface", str, int] # type: ignore # noqa: F821
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ class BlackJackParams:
|
||||
|
||||
|
||||
class BlackjackFunctional(
|
||||
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, BlackJackParams]
|
||||
FuncEnv[EnvState, jax.Array, int, float, bool, RenderStateType, BlackJackParams]
|
||||
):
|
||||
"""Blackjack is a card game where the goal is to beat the dealer by obtaining cards that sum to closer to 21 (without going over 21) than the dealers cards.
|
||||
|
||||
@@ -247,9 +247,9 @@ class BlackjackFunctional(
|
||||
self,
|
||||
state: EnvState,
|
||||
action: Union[int, jax.Array],
|
||||
key: PRNGKey,
|
||||
key: PRNGKeyType,
|
||||
params: BlackJackParams = BlackJackParams,
|
||||
):
|
||||
) -> EnvState:
|
||||
"""The blackjack environment's state transition function."""
|
||||
env_state = jax.lax.cond(action, take, notake, (state, key))
|
||||
|
||||
@@ -273,7 +273,9 @@ class BlackjackFunctional(
|
||||
|
||||
return new_state
|
||||
|
||||
def initial(self, rng: PRNGKey, params: BlackJackParams = BlackJackParams):
|
||||
def initial(
|
||||
self, rng: PRNGKeyType, params: BlackJackParams = BlackJackParams
|
||||
) -> EnvState:
|
||||
"""Blackjack initial observataion function."""
|
||||
player_hand = jnp.zeros(21)
|
||||
dealer_hand = jnp.zeros(21)
|
||||
@@ -293,7 +295,10 @@ class BlackjackFunctional(
|
||||
return state
|
||||
|
||||
def observation(
|
||||
self, state: EnvState, rng: PRNGKey, params: BlackJackParams = BlackJackParams
|
||||
self,
|
||||
state: EnvState,
|
||||
rng: PRNGKeyType,
|
||||
params: BlackJackParams = BlackJackParams,
|
||||
) -> jax.Array:
|
||||
"""Blackjack observation."""
|
||||
return jnp.array(
|
||||
@@ -306,7 +311,10 @@ class BlackjackFunctional(
|
||||
)
|
||||
|
||||
def terminal(
|
||||
self, state: EnvState, rng: PRNGKey, params: BlackJackParams = BlackJackParams
|
||||
self,
|
||||
state: EnvState,
|
||||
rng: PRNGKeyType,
|
||||
params: BlackJackParams = BlackJackParams,
|
||||
) -> jax.Array:
|
||||
"""Determines if a particular Blackjack observation is terminal."""
|
||||
return (state.done) > 0
|
||||
@@ -315,8 +323,8 @@ class BlackjackFunctional(
|
||||
self,
|
||||
state: EnvState,
|
||||
action: ActType,
|
||||
next_state: StateType,
|
||||
rng: PRNGKey,
|
||||
next_state: EnvState,
|
||||
rng: PRNGKeyType,
|
||||
params: BlackJackParams = BlackJackParams,
|
||||
) -> jax.Array:
|
||||
"""Calculates reward from a state."""
|
||||
|
@@ -3,17 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from os import path
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv
|
||||
from gymnasium.utils import EzPickle
|
||||
from gymnasium.vector import AutoresetMode
|
||||
from gymnasium.wrappers import HumanRendering
|
||||
@@ -26,7 +25,7 @@ if TYPE_CHECKING:
|
||||
class RenderStateType(NamedTuple):
|
||||
"""A named tuple which contains the full render state of the Cliffwalking Env. This is static during the episode."""
|
||||
|
||||
screen: pygame.surface
|
||||
screen: pygame.Surface
|
||||
shape: tuple[int, int]
|
||||
nS: int
|
||||
cell_size: tuple[int, int]
|
||||
@@ -47,11 +46,14 @@ class RenderStateType(NamedTuple):
|
||||
class EnvState(NamedTuple):
|
||||
"""A named tuple which contains the full state of the Cliffwalking game."""
|
||||
|
||||
player_position: jnp.array
|
||||
player_position: jax.Array
|
||||
last_action: int
|
||||
fallen: bool
|
||||
|
||||
|
||||
PRNGKeyType: TypeAlias = jax.Array
|
||||
|
||||
|
||||
def fell_off(player_position):
|
||||
"""Checks to see if the player_position means the player has fallen of the cliff."""
|
||||
return (
|
||||
@@ -62,7 +64,7 @@ def fell_off(player_position):
|
||||
|
||||
|
||||
class CliffWalkingFunctional(
|
||||
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType, None]
|
||||
FuncEnv[EnvState, jax.Array, int, float, bool, RenderStateType, None]
|
||||
):
|
||||
"""Cliff walking involves crossing a gridworld from start to goal while avoiding falling off a cliff.
|
||||
|
||||
@@ -144,9 +146,9 @@ class CliffWalkingFunctional(
|
||||
self,
|
||||
state: EnvState,
|
||||
action: int | jax.Array,
|
||||
key: PRNGKey,
|
||||
key: PRNGKeyType,
|
||||
params: None = None,
|
||||
):
|
||||
) -> EnvState:
|
||||
"""The Cliffwalking environment's state transition function."""
|
||||
new_position = state.player_position
|
||||
|
||||
@@ -182,14 +184,14 @@ class CliffWalkingFunctional(
|
||||
|
||||
return new_state
|
||||
|
||||
def initial(self, rng: PRNGKey, params: None = None) -> EnvState:
|
||||
def initial(self, rng: PRNGKeyType, params: None = None) -> EnvState:
|
||||
"""Cliffwalking initial observation function."""
|
||||
player_position = jnp.array([3, 0])
|
||||
|
||||
state = EnvState(player_position=player_position, last_action=-1, fallen=False)
|
||||
return state
|
||||
|
||||
def observation(self, state: EnvState, params: None = None) -> int:
|
||||
def observation(self, state: EnvState, params: None = None) -> jax.Array:
|
||||
"""Cliffwalking observation."""
|
||||
return jnp.array(
|
||||
state.player_position[0] * 12 + state.player_position[1]
|
||||
@@ -203,7 +205,7 @@ class CliffWalkingFunctional(
|
||||
self,
|
||||
state: EnvState,
|
||||
action: ActType,
|
||||
next_state: StateType,
|
||||
next_state: EnvState,
|
||||
params: None = None,
|
||||
) -> jax.Array:
|
||||
"""Calculates reward from a state."""
|
||||
@@ -296,7 +298,7 @@ class CliffWalkingFunctional(
|
||||
)
|
||||
|
||||
def render_image(
|
||||
self, state: StateType, render_state: RenderStateType, params: None = None
|
||||
self, state: EnvState, render_state: RenderStateType, params: None = None
|
||||
) -> tuple[RenderStateType, np.ndarray]:
|
||||
"""Renders an image from a state."""
|
||||
try:
|
||||
|
@@ -39,13 +39,17 @@ class MultiBinary(Space[NDArray[np.int8]]):
|
||||
or some sort of sequence (tuple, list or np.ndarray) if there are multiple axes.
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||
"""
|
||||
if isinstance(n, (Sequence, np.ndarray)):
|
||||
self.n = input_n = tuple(int(i) for i in n)
|
||||
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
|
||||
else:
|
||||
if isinstance(n, int):
|
||||
self.n = n = int(n)
|
||||
input_n = (n,)
|
||||
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
|
||||
elif isinstance(n, (Sequence, np.ndarray)):
|
||||
self.n = input_n = tuple(int(i) for i in n)
|
||||
assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected n to be an int or a sequence of ints, actual type: {type(n)}"
|
||||
)
|
||||
|
||||
super().__init__(input_n, np.int8, seed)
|
||||
|
||||
|
@@ -271,6 +271,8 @@ def play(
|
||||
|
||||
key_code_to_action = {}
|
||||
for key_combination, action in keys_to_action.items():
|
||||
if isinstance(key_combination, int):
|
||||
key_combination = (key_combination,)
|
||||
key_code = tuple(
|
||||
sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
|
||||
)
|
||||
|
@@ -250,7 +250,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
seed: int | list[int | None] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets all sub-environments in parallel and return a batch of concatenated observations and info.
|
||||
@@ -267,7 +267,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: int | list[int] | None = None,
|
||||
seed: int | list[int | None] | None = None,
|
||||
options: dict | None = None,
|
||||
):
|
||||
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
||||
@@ -719,7 +719,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
def _async_worker(
|
||||
index: int,
|
||||
env_fn: callable,
|
||||
env_fn: Callable,
|
||||
pipe: Connection,
|
||||
parent_pipe: Connection,
|
||||
shared_memory: SynchronizedArray | dict[str, Any] | tuple[Any, ...],
|
||||
|
@@ -164,7 +164,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
seed: int | list[int | None] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets each of the sub-environments and concatenate the results together.
|
||||
|
@@ -10,6 +10,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any, Iterable, Iterator
|
||||
@@ -428,7 +429,7 @@ def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any,
|
||||
|
||||
@singledispatch
|
||||
def create_empty_array(
|
||||
space: Space, n: int = 1, fn: callable = np.zeros
|
||||
space: Space, n: int = 1, fn: Callable = np.zeros
|
||||
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
|
||||
"""Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
|
||||
|
||||
|
@@ -33,9 +33,9 @@ __all__ = [
|
||||
class AutoresetMode(Enum):
|
||||
"""Enum representing the different autoreset modes, next step, same step and disabled."""
|
||||
|
||||
NEXT_STEP: str = "NextStep"
|
||||
SAME_STEP: str = "SameStep"
|
||||
DISABLED: str = "Disabled"
|
||||
NEXT_STEP = "NextStep"
|
||||
SAME_STEP = "SameStep"
|
||||
DISABLED = "Disabled"
|
||||
|
||||
|
||||
class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
|
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, SupportsFloat
|
||||
from typing import Any, Callable, Generic, List, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -32,7 +32,9 @@ __all__ = [
|
||||
|
||||
|
||||
class RenderCollection(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType],
|
||||
Generic[ObsType, ActType, RenderFrame],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.
|
||||
|
||||
@@ -159,7 +161,9 @@ class RenderCollection(
|
||||
|
||||
|
||||
class RecordVideo(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType],
|
||||
Generic[ObsType, ActType, RenderFrame],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Records videos of environment episodes using the environment's render function.
|
||||
|
||||
|
@@ -91,7 +91,7 @@ class NormalizeReward(
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.return_rms = RunningMeanStd(shape=())
|
||||
self.discounted_reward: np.array = np.array([0.0])
|
||||
self.discounted_reward = np.array([0.0])
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self._update_running_mean = True
|
||||
|
@@ -82,7 +82,7 @@ class NormalizeReward(VectorWrapper, gym.utils.RecordConstructorArgs):
|
||||
VectorWrapper.__init__(self, env)
|
||||
|
||||
self.return_rms = RunningMeanStd(shape=())
|
||||
self.accumulated_reward: np.array = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.accumulated_reward = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self._update_running_mean = True
|
||||
|
@@ -151,7 +151,6 @@ reportMissingTypeStubs = false
|
||||
# For warning and error, will raise an error when
|
||||
reportInvalidTypeVarUse = "none"
|
||||
|
||||
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
|
||||
reportAttributeAccessIssue = "none" # pyright provides false positives
|
||||
reportArgumentType = "none" # pyright provides false positives
|
||||
|
||||
|
@@ -48,11 +48,15 @@ class PlayStatus:
|
||||
self.last_observation = obs_tp1
|
||||
|
||||
|
||||
def dummy_keys_to_action():
|
||||
def dummy_keys_to_action() -> dict[tuple[int], int]:
|
||||
return {(RELEVANT_KEY_1,): 0, (RELEVANT_KEY_2,): 1}
|
||||
|
||||
|
||||
def dummy_keys_to_action_str():
|
||||
def dummy_keys_to_action_int() -> dict[int, int]:
|
||||
return {RELEVANT_KEY_1: 0, RELEVANT_KEY_2: 1}
|
||||
|
||||
|
||||
def dummy_keys_to_action_str() -> dict[str, int]:
|
||||
"""{'a': 0, 'd': 1}"""
|
||||
return {chr(RELEVANT_KEY_1): 0, chr(RELEVANT_KEY_2): 1}
|
||||
|
||||
@@ -147,7 +151,7 @@ def test_play_loop_real_env():
|
||||
|
||||
# If apply_wrapper is true, we provide keys_to_action through the environment. If str_keys is true, the
|
||||
# keys_to_action dictionary will have strings as keys
|
||||
for apply_wrapper, str_keys in product([False, True], [False, True]):
|
||||
for apply_wrapper, key_type in product([False, True], ["str", "int", "tuple"]):
|
||||
# set of key events to inject into the play loop as callback
|
||||
callback_events = [
|
||||
Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
|
||||
@@ -178,15 +182,28 @@ def test_play_loop_real_env():
|
||||
|
||||
env = gym.make(ENV, render_mode="rgb_array", disable_env_checker=True)
|
||||
env.reset(seed=SEED)
|
||||
keys_to_action = (
|
||||
dummy_keys_to_action_str() if str_keys else dummy_keys_to_action()
|
||||
)
|
||||
|
||||
if key_type == "tuple":
|
||||
keys_to_action = dummy_keys_to_action()
|
||||
elif key_type == "str":
|
||||
keys_to_action = dummy_keys_to_action_str()
|
||||
elif key_type == "int":
|
||||
keys_to_action = dummy_keys_to_action_int()
|
||||
else:
|
||||
assert False
|
||||
|
||||
# first action is 0 because at the first iteration
|
||||
# we can not inject a callback event into play()
|
||||
obs, _, _, _, _ = env.step(0)
|
||||
for e in keydown_events:
|
||||
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
||||
if key_type == "tuple":
|
||||
action = keys_to_action[(e.key,)]
|
||||
elif key_type == "str":
|
||||
action = keys_to_action[chr(e.key)]
|
||||
elif key_type == "int":
|
||||
action = keys_to_action[e.key]
|
||||
else:
|
||||
assert False
|
||||
obs, _, _, _, _ = env.step(action)
|
||||
|
||||
env_play = gym.make(ENV, render_mode="rgb_array", disable_env_checker=True)
|
||||
|
Reference in New Issue
Block a user