mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 06:16:32 +00:00
Update the type hinting for core.py (#39)
This commit is contained in:
@@ -1,16 +1,7 @@
|
||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
SupportsFloat,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -55,20 +46,22 @@ class Env(Generic[ObsType, ActType]):
|
||||
"""
|
||||
|
||||
# Set this in SOME subclasses
|
||||
metadata: Dict[str, Any] = {"render_modes": []}
|
||||
metadata: dict[str, Any] = {"render_modes": []}
|
||||
# define render_mode if your environment supports rendering
|
||||
render_mode: Optional[str] = None
|
||||
render_mode: str | None = None
|
||||
reward_range = (-float("inf"), float("inf"))
|
||||
spec: "EnvSpec" = None
|
||||
spec: EnvSpec | None = None
|
||||
|
||||
# Set these in ALL subclasses
|
||||
action_space: spaces.Space[ActType]
|
||||
observation_space: spaces.Space[ObsType]
|
||||
|
||||
# Created
|
||||
_np_random: Optional[np.random.Generator] = None
|
||||
_np_random: np.random.Generator | None = None
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Run one timestep of the environment's dynamics using the agent actions.
|
||||
|
||||
When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to
|
||||
@@ -86,7 +79,7 @@ class Env(Generic[ObsType, ActType]):
|
||||
Returns:
|
||||
observation (ObsType): An element of the environment's :attr:`observation_space` as the next observation due to the agent actions.
|
||||
An example is a numpy array containing the positions and velocities of the pole in CartPole.
|
||||
reward (float): The reward as a result of taking the action.
|
||||
reward (SupportsFloat): The reward as a result of taking the action.
|
||||
terminated (bool): Whether the agent reaches the terminal state (as defined under the MDP of the task)
|
||||
which can be positive or negative. An example is reaching the goal state or moving into the lava from
|
||||
the Sutton and Barton, Gridworld. If true, the user needs to call :meth:`reset`.
|
||||
@@ -109,9 +102,9 @@ class Env(Generic[ObsType, ActType]):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> Tuple[ObsType, dict]:
|
||||
seed: int | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
|
||||
"""Resets the environment to an initial internal state, returning an initial observation and info.
|
||||
|
||||
This method generates a new starting state often with some randomness to ensure that the agent explores the
|
||||
@@ -149,7 +142,7 @@ class Env(Generic[ObsType, ActType]):
|
||||
if seed is not None:
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
|
||||
def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment.
|
||||
|
||||
The environment's :attr:`metadata` render modes (`env.metadata["render_modes"]`) should contain the possible
|
||||
@@ -191,8 +184,8 @@ class Env(Generic[ObsType, ActType]):
|
||||
pass
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> "Env":
|
||||
"""Returns the base non-wrapped environment (i.e., removes all wrappers).
|
||||
def unwrapped(self) -> Env[ObsType, ActType]:
|
||||
"""Returns the base non-wrapped environment.
|
||||
|
||||
Returns:
|
||||
Env: The base non-wrapped :class:`gymnasium.Env` instance
|
||||
@@ -229,14 +222,18 @@ class Env(Generic[ObsType, ActType]):
|
||||
"""Support with-statement for the environment."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args: Any):
|
||||
"""Support with-statement for the environment and closes the environment."""
|
||||
self.close()
|
||||
# propagate exception
|
||||
return False
|
||||
|
||||
|
||||
class Wrapper(Env[ObsType, ActType]):
|
||||
WrapperObsType = TypeVar("WrapperObsType")
|
||||
WrapperActType = TypeVar("WrapperActType")
|
||||
|
||||
|
||||
class Wrapper(Env[WrapperObsType, WrapperActType]):
|
||||
"""Wraps a :class:`gymnasium.Env` to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||
|
||||
This class is the base class of all wrappers to change the behavior of the underlying environment allowing
|
||||
@@ -293,7 +290,7 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
Don't forget to call ``super().__init__(env)``
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env):
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||
|
||||
Args:
|
||||
@@ -301,73 +298,81 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
"""
|
||||
self.env = env
|
||||
|
||||
self._action_space: Optional[spaces.Space] = None
|
||||
self._observation_space: Optional[spaces.Space] = None
|
||||
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
|
||||
self._metadata: Optional[dict] = None
|
||||
self._action_space: spaces.Space[WrapperActType] | None = None
|
||||
self._observation_space: spaces.Space[WrapperObsType] | None = None
|
||||
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
|
||||
self._metadata: dict[str, Any] | None = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str):
|
||||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
||||
if name.startswith("_"):
|
||||
if name == "_np_random":
|
||||
raise AttributeError(
|
||||
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
|
||||
)
|
||||
elif name.startswith("_"):
|
||||
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
|
||||
return getattr(self.env, name)
|
||||
|
||||
@property
|
||||
def spec(self):
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Returns the :attr:`Env` :attr:`spec` attribute."""
|
||||
return self.env.spec
|
||||
|
||||
@classmethod
|
||||
def class_name(cls):
|
||||
def class_name(cls) -> str:
|
||||
"""Returns the class name of the wrapper."""
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space[ActType]:
|
||||
def action_space(
|
||||
self,
|
||||
) -> spaces.Space[ActType] | spaces.Space[WrapperActType]:
|
||||
"""Return the :attr:`Env` :attr:`action_space` unless overwritten then the wrapper :attr:`action_space` is used."""
|
||||
if self._action_space is None:
|
||||
return self.env.action_space
|
||||
return self._action_space
|
||||
|
||||
@action_space.setter
|
||||
def action_space(self, space: spaces.Space):
|
||||
def action_space(self, space: spaces.Space[WrapperActType]):
|
||||
self._action_space = space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
def observation_space(
|
||||
self,
|
||||
) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]:
|
||||
"""Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used."""
|
||||
if self._observation_space is None:
|
||||
return self.env.observation_space
|
||||
return self._observation_space
|
||||
|
||||
@observation_space.setter
|
||||
def observation_space(self, space: spaces.Space):
|
||||
def observation_space(self, space: spaces.Space[WrapperObsType]):
|
||||
self._observation_space = space
|
||||
|
||||
@property
|
||||
def reward_range(self) -> Tuple[SupportsFloat, SupportsFloat]:
|
||||
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
|
||||
"""Return the :attr:`Env` :attr:`reward_range` unless overwritten then the wrapper :attr:`reward_range` is used."""
|
||||
if self._reward_range is None:
|
||||
return self.env.reward_range
|
||||
return self._reward_range
|
||||
|
||||
@reward_range.setter
|
||||
def reward_range(self, value: Tuple[SupportsFloat, SupportsFloat]):
|
||||
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
|
||||
self._reward_range = value
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict:
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
"""Returns the :attr:`Env` :attr:`metadata`."""
|
||||
if self._metadata is None:
|
||||
return self.env.metadata
|
||||
return self._metadata
|
||||
|
||||
@metadata.setter
|
||||
def metadata(self, value):
|
||||
def metadata(self, value: dict[str, Any]):
|
||||
self._metadata = value
|
||||
|
||||
@property
|
||||
def render_mode(self) -> Optional[str]:
|
||||
def render_mode(self) -> str | None:
|
||||
"""Returns the :attr:`Env` :attr:`render_mode`."""
|
||||
return self.env.render_mode
|
||||
|
||||
@@ -377,28 +382,34 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
return self.env.np_random
|
||||
|
||||
@np_random.setter
|
||||
def np_random(self, value):
|
||||
def np_random(self, value: np.random.Generator):
|
||||
self.env.np_random = value
|
||||
|
||||
@property
|
||||
def _np_random(self):
|
||||
"""This code will never be run due to __getattr__ being called prior this.
|
||||
|
||||
It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable.
|
||||
"""
|
||||
raise AttributeError(
|
||||
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
|
||||
)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self, **kwargs) -> Tuple[ObsType, dict]:
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.reset(**kwargs)
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
def render(
|
||||
self, *args, **kwargs
|
||||
) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.render(*args, **kwargs)
|
||||
return self.env.render()
|
||||
|
||||
def close(self):
|
||||
"""Closes the wrapper and :attr:`env`."""
|
||||
@@ -413,12 +424,12 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
return str(self)
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> Env:
|
||||
def unwrapped(self) -> Env[ObsType, ActType]:
|
||||
"""Returns the base environment of the wrapper."""
|
||||
return self.env.unwrapped
|
||||
|
||||
|
||||
class ObservationWrapper(Wrapper):
|
||||
class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
|
||||
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`.
|
||||
|
||||
If you would like to apply a function to only the observation before
|
||||
@@ -433,7 +444,7 @@ class ObservationWrapper(Wrapper):
|
||||
``observation["target_position"] - observation["agent_position"]``. For this, you could implement an
|
||||
observation wrapper like this::
|
||||
|
||||
class RelativePosition(gymnasium.ObservationWrapper):
|
||||
class RelativePosition(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.observation_space = Box(shape=(2,), low=-np.inf, high=np.inf)
|
||||
@@ -445,17 +456,25 @@ class ObservationWrapper(Wrapper):
|
||||
index of the timestep to the observation.
|
||||
"""
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Constructor for the observation wrapper."""
|
||||
super().__init__(env)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Modifies the :attr:`env` after calling :meth:`reset`, returning a modified observation using :meth:`self.observation`."""
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
return self.observation(obs), info
|
||||
|
||||
def step(self, action):
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations."""
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self.observation(observation), reward, terminated, truncated, info
|
||||
|
||||
def observation(self, observation):
|
||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||
"""Returns a modified observation.
|
||||
|
||||
Args:
|
||||
@@ -467,7 +486,7 @@ class ObservationWrapper(Wrapper):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RewardWrapper(Wrapper):
|
||||
class RewardWrapper(Wrapper[ObsType, ActType]):
|
||||
"""Superclass of wrappers that can modify the returning reward from a step.
|
||||
|
||||
If you would like to apply a function to the reward that is returned by the base environment before
|
||||
@@ -480,23 +499,29 @@ class RewardWrapper(Wrapper):
|
||||
because it is intrinsic), we want to clip the reward to a range to gain some numerical stability.
|
||||
To do that, we could, for instance, implement the following wrapper::
|
||||
|
||||
class ClipReward(gymnasium.RewardWrapper):
|
||||
class ClipReward(gym.RewardWrapper):
|
||||
def __init__(self, env, min_reward, max_reward):
|
||||
super().__init__(env)
|
||||
self.min_reward = min_reward
|
||||
self.max_reward = max_reward
|
||||
self.reward_range = (min_reward, max_reward)
|
||||
|
||||
def reward(self, reward):
|
||||
return np.clip(reward, self.min_reward, self.max_reward)
|
||||
def reward(self, r: SupportsFloat) -> SupportsFloat:
|
||||
return np.clip(r, self.min_reward, self.max_reward)
|
||||
"""
|
||||
|
||||
def step(self, action):
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Constructor for the Reward wrapper."""
|
||||
super().__init__(env)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`."""
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
return observation, self.reward(reward), terminated, truncated, info
|
||||
|
||||
def reward(self, reward):
|
||||
def reward(self, reward: SupportsFloat) -> SupportsFloat:
|
||||
"""Returns a modified environment ``reward``.
|
||||
|
||||
Args:
|
||||
@@ -508,7 +533,7 @@ class RewardWrapper(Wrapper):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ActionWrapper(Wrapper):
|
||||
class ActionWrapper(Wrapper[ObsType, WrapperActType]):
|
||||
"""Superclass of wrappers that can modify the action before :meth:`env.step`.
|
||||
|
||||
If you would like to apply a function to the action before passing it to the base environment,
|
||||
@@ -521,7 +546,7 @@ class ActionWrapper(Wrapper):
|
||||
Let’s say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like
|
||||
to use a finite subset of actions. Then, you might want to implement the following wrapper::
|
||||
|
||||
class DiscreteActions(gymnasium.ActionWrapper):
|
||||
class DiscreteActions(gym.ActionWrapper):
|
||||
def __init__(self, env, disc_to_cont):
|
||||
super().__init__(env)
|
||||
self.disc_to_cont = disc_to_cont
|
||||
@@ -531,7 +556,7 @@ class ActionWrapper(Wrapper):
|
||||
return self.disc_to_cont[act]
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = gymnasium.make("LunarLanderContinuous-v2")
|
||||
env = gym.make("LunarLanderContinuous-v2")
|
||||
wrapped_env = DiscreteActions(env, [np.array([1,0]), np.array([-1,0]),
|
||||
np.array([0,1]), np.array([0,-1])])
|
||||
print(wrapped_env.action_space) #Discrete(4)
|
||||
@@ -539,11 +564,17 @@ class ActionWrapper(Wrapper):
|
||||
Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction` for clipping and rescaling actions.
|
||||
"""
|
||||
|
||||
def step(self, action):
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Constructor for the action wrapper."""
|
||||
super().__init__(env)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Runs the :attr:`env` :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
|
||||
return self.env.step(self.action(action))
|
||||
|
||||
def action(self, action):
|
||||
def action(self, action: WrapperActType) -> ActType:
|
||||
"""Returns a modified action before :meth:`env.step` is called.
|
||||
|
||||
Args:
|
||||
|
@@ -609,6 +609,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -577,6 +577,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -606,6 +606,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -282,6 +282,7 @@ class AcrobotEnv(Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -208,6 +208,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -194,6 +194,7 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -169,6 +169,7 @@ class MountainCarEnv(gym.Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -168,6 +168,7 @@ class PendulumEnv(gym.Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -194,6 +194,7 @@ class BlackjackEnv(gym.Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -165,6 +165,7 @@ class CliffWalkingEnv(Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -270,6 +270,7 @@ class FrozenLakeEnv(Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -281,6 +281,7 @@ class TaxiEnv(Env):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -72,6 +72,7 @@ class PlayableGame:
|
||||
elif hasattr(self.env.unwrapped, "get_keys_to_action"):
|
||||
keys_to_action = self.env.unwrapped.get_keys_to_action()
|
||||
else:
|
||||
assert self.env.spec is not None
|
||||
raise MissingKeysToAction(
|
||||
f"{self.env.spec.id} does not have explicit key to action mapping, "
|
||||
"please specify one manually"
|
||||
@@ -230,6 +231,7 @@ def play(
|
||||
elif hasattr(env.unwrapped, "get_keys_to_action"):
|
||||
keys_to_action = env.unwrapped.get_keys_to_action()
|
||||
else:
|
||||
assert env.spec is not None
|
||||
raise MissingKeysToAction(
|
||||
f"{env.spec.id} does not have explicit key to action mapping, "
|
||||
"please specify one manually"
|
||||
|
@@ -69,7 +69,8 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
assert noop_max >= 0
|
||||
if frame_skip > 1:
|
||||
if (
|
||||
"NoFrameskip" not in env.spec.id
|
||||
env.spec is not None
|
||||
and "NoFrameskip" not in env.spec.id
|
||||
and getattr(env.unwrapped, "_frameskip", None) != 1
|
||||
):
|
||||
raise ValueError(
|
||||
|
@@ -30,6 +30,7 @@ class TimeLimit(gym.Wrapper):
|
||||
"""
|
||||
super().__init__(env)
|
||||
if max_episode_steps is None and self.env.spec is not None:
|
||||
assert env.spec is not None
|
||||
max_episode_steps = env.spec.max_episode_steps
|
||||
if self.env.spec is not None:
|
||||
self.env.spec.max_episode_steps = max_episode_steps
|
||||
|
@@ -57,7 +57,9 @@ DISCRETE_ENVS = list(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env", DISCRETE_ENVS, ids=[env.spec.id for env in DISCRETE_ENVS]
|
||||
"env",
|
||||
DISCRETE_ENVS,
|
||||
ids=[env.spec.id for env in DISCRETE_ENVS if env.spec is not None],
|
||||
)
|
||||
def test_discrete_actions_out_of_bound(env: gym.Env):
|
||||
"""Test out of bound actions in Discrete action_space.
|
||||
@@ -87,7 +89,9 @@ BOX_ENVS = list(
|
||||
OOB_VALUE = 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS])
|
||||
@pytest.mark.parametrize(
|
||||
"env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS if env.spec is not None]
|
||||
)
|
||||
def test_box_actions_out_of_bound(env: gym.Env):
|
||||
"""Test out of bound actions in Box action_space.
|
||||
|
||||
@@ -100,6 +104,7 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
||||
"""
|
||||
env.reset(seed=42)
|
||||
|
||||
assert env.spec is not None
|
||||
oob_env = gym.make(env.spec.id, disable_env_checker=True)
|
||||
oob_env.reset(seed=42)
|
||||
|
||||
|
@@ -192,7 +192,7 @@ def test_render_modes(spec):
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
all_testing_initialised_envs,
|
||||
ids=[env.spec.id for env in all_testing_initialised_envs],
|
||||
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
||||
)
|
||||
def test_pickle_env(env: gym.Env):
|
||||
pickled_env = pickle.loads(pickle.dumps(env))
|
||||
|
@@ -53,6 +53,7 @@ gym.register(
|
||||
|
||||
def test_make():
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
assert env.spec is not None
|
||||
assert env.spec.id == "CartPole-v1"
|
||||
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
|
||||
env.close()
|
||||
@@ -73,6 +74,7 @@ def test_make_max_episode_steps():
|
||||
# Default, uses the spec's
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
assert has_wrapper(env, TimeLimit)
|
||||
assert env.spec is not None
|
||||
assert (
|
||||
env.spec.max_episode_steps == gym.envs.registry["CartPole-v1"].max_episode_steps
|
||||
)
|
||||
@@ -81,6 +83,7 @@ def test_make_max_episode_steps():
|
||||
# Custom max episode steps
|
||||
env = gym.make("CartPole-v1", max_episode_steps=100, disable_env_checker=True)
|
||||
assert has_wrapper(env, TimeLimit)
|
||||
assert env.spec is not None
|
||||
assert env.spec.max_episode_steps == 100
|
||||
env.close()
|
||||
|
||||
@@ -297,6 +300,7 @@ def test_make_kwargs():
|
||||
arg3="override_arg3",
|
||||
disable_env_checker=True,
|
||||
)
|
||||
assert env.spec is not None
|
||||
assert env.spec.id == "test.ArgumentEnv-v0"
|
||||
assert isinstance(env.unwrapped, ArgumentEnv)
|
||||
assert env.arg1 == "arg1"
|
||||
|
@@ -183,6 +183,7 @@ def test_make_latest_versioned_env(register_testing_envs):
|
||||
env = gym.make(
|
||||
"MyAwesomeNamespace/MyAwesomeVersionedEnv", disable_env_checker=True
|
||||
)
|
||||
assert env.spec is not None
|
||||
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"
|
||||
|
||||
|
||||
|
@@ -16,6 +16,7 @@ def test_spec():
|
||||
def test_spec_kwargs():
|
||||
map_name_value = "8x8"
|
||||
env = gym.make("FrozenLake-v1", map_name=map_name_value)
|
||||
assert env.spec is not None
|
||||
assert env.spec.kwargs["map_name"] == map_name_value
|
||||
|
||||
|
||||
|
@@ -1,13 +1,27 @@
|
||||
from typing import Optional
|
||||
"""Checks that the core Gymnasium API is implemented as expected."""
|
||||
import re
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium import core, spaces
|
||||
from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper, spaces
|
||||
from gymnasium.core import (
|
||||
ActionWrapper,
|
||||
ActType,
|
||||
ObsType,
|
||||
WrapperActType,
|
||||
WrapperObsType,
|
||||
)
|
||||
from gymnasium.spaces import Box
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.wrappers import OrderEnforcing, TimeLimit
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
# ==== Old testing code
|
||||
|
||||
|
||||
class ArgumentEnv(core.Env):
|
||||
class ArgumentEnv(Env):
|
||||
observation_space = spaces.Box(low=0, high=1, shape=(1,))
|
||||
action_space = spaces.Box(low=0, high=1, shape=(1,))
|
||||
calls = 0
|
||||
@@ -17,7 +31,7 @@ class ArgumentEnv(core.Env):
|
||||
self.arg = arg
|
||||
|
||||
|
||||
class UnittestEnv(core.Env):
|
||||
class UnittestEnv(Env):
|
||||
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
||||
action_space = spaces.Discrete(3)
|
||||
|
||||
@@ -30,7 +44,7 @@ class UnittestEnv(core.Env):
|
||||
return (observation, 0.0, False, {})
|
||||
|
||||
|
||||
class UnknownSpacesEnv(core.Env):
|
||||
class UnknownSpacesEnv(Env):
|
||||
"""This environment defines its observation & action spaces only
|
||||
after the first call to reset. Although this pattern is sometimes
|
||||
necessary when implementing a new environment (e.g. if it depends
|
||||
@@ -50,7 +64,7 @@ class UnknownSpacesEnv(core.Env):
|
||||
return (observation, 0.0, False, {})
|
||||
|
||||
|
||||
class OldStyleEnv(core.Env):
|
||||
class OldStyleEnv(Env):
|
||||
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -64,7 +78,7 @@ class OldStyleEnv(core.Env):
|
||||
return 0, 0, False, {}
|
||||
|
||||
|
||||
class NewPropertyWrapper(core.Wrapper):
|
||||
class NewPropertyWrapper(Wrapper):
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
@@ -137,3 +151,133 @@ def test_compatibility_with_old_style_env():
|
||||
env = TimeLimit(env)
|
||||
obs = env.reset()
|
||||
assert obs == 0
|
||||
|
||||
|
||||
# ==== New testing code
|
||||
|
||||
|
||||
class ExampleEnv(Env):
|
||||
def __init__(self):
|
||||
self.observation_space = Box(0, 1)
|
||||
self.action_space = Box(0, 1)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
|
||||
return 0, 0, False, False, {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> Tuple[ObsType, dict]:
|
||||
return 0, {}
|
||||
|
||||
|
||||
def test_gymnasium_env():
|
||||
env = ExampleEnv()
|
||||
|
||||
assert env.metadata == {"render_modes": []}
|
||||
assert env.render_mode is None
|
||||
assert env.reward_range == (-float("inf"), float("inf"))
|
||||
assert env.spec is None
|
||||
assert env._np_random is None # pyright: ignore [reportPrivateUsage]
|
||||
|
||||
|
||||
class ExampleWrapper(Wrapper):
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
super().__init__(env)
|
||||
|
||||
self.new_reward = 3
|
||||
|
||||
def reset(
|
||||
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[WrapperObsType, Dict[str, Any]]:
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]:
|
||||
obs, reward, termination, truncation, info = self.env.step(action)
|
||||
return obs, self.new_reward, termination, truncation, info
|
||||
|
||||
def access_hidden_np_random(self):
|
||||
"""This should raise an error when called as wrappers should not access their own `_np_random` instances and should use the unwrapped environments."""
|
||||
return self._np_random
|
||||
|
||||
|
||||
def test_gymnasium_wrapper():
|
||||
env = ExampleEnv()
|
||||
wrapper_env = ExampleWrapper(env)
|
||||
|
||||
assert env.metadata == wrapper_env.metadata
|
||||
wrapper_env.metadata = {"render_modes": ["rgb_array"]}
|
||||
assert env.metadata != wrapper_env.metadata
|
||||
|
||||
assert env.render_mode == wrapper_env.render_mode
|
||||
|
||||
assert env.reward_range == wrapper_env.reward_range
|
||||
wrapper_env.reward_range = (-1.0, 1.0)
|
||||
assert env.reward_range != wrapper_env.reward_range
|
||||
|
||||
assert env.spec == wrapper_env.spec
|
||||
|
||||
env.observation_space = Box(0, 1)
|
||||
env.action_space = Box(0, 1)
|
||||
assert env.observation_space == wrapper_env.observation_space
|
||||
assert env.action_space == wrapper_env.action_space
|
||||
wrapper_env.observation_space = Box(1, 2)
|
||||
wrapper_env.action_space = Box(1, 2)
|
||||
assert env.observation_space != wrapper_env.observation_space
|
||||
assert env.action_space != wrapper_env.action_space
|
||||
|
||||
wrapper_env.np_random, _ = seeding.np_random()
|
||||
assert (
|
||||
env._np_random # pyright: ignore [reportPrivateUsage]
|
||||
is env.np_random
|
||||
is wrapper_env.np_random
|
||||
)
|
||||
assert 0 <= wrapper_env.np_random.uniform() <= 1
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=re.escape(
|
||||
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
|
||||
),
|
||||
):
|
||||
print(wrapper_env.access_hidden_np_random())
|
||||
|
||||
|
||||
class ExampleRewardWrapper(RewardWrapper):
|
||||
def reward(self, reward: SupportsFloat) -> SupportsFloat:
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleObservationWrapper(ObservationWrapper):
|
||||
def observation(self, observation: ObsType) -> ObsType:
|
||||
return np.array([1])
|
||||
|
||||
|
||||
class ExampleActionWrapper(ActionWrapper):
|
||||
def action(self, action: ActType) -> ActType:
|
||||
return np.array([1])
|
||||
|
||||
|
||||
def test_wrapper_types():
|
||||
env = GenericTestEnv()
|
||||
|
||||
reward_env = ExampleRewardWrapper(env)
|
||||
reward_env.reset()
|
||||
_, reward, _, _, _ = reward_env.step(0)
|
||||
assert reward == 1
|
||||
|
||||
observation_env = ExampleObservationWrapper(env)
|
||||
obs, _ = observation_env.reset()
|
||||
assert obs == np.array([1])
|
||||
obs, _, _, _, _ = observation_env.step(0)
|
||||
assert obs == np.array([1])
|
||||
|
||||
env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {}))
|
||||
action_env = ExampleActionWrapper(env)
|
||||
obs, _, _, _, _ = action_env.step(0)
|
||||
assert obs == np.array([1])
|
||||
|
@@ -38,6 +38,7 @@ def test_vector_make_wrappers():
|
||||
|
||||
sub_env = env.envs[0]
|
||||
assert isinstance(sub_env, gym.Env)
|
||||
assert sub_env.spec is not None
|
||||
if sub_env.spec.order_enforce:
|
||||
assert has_wrapper(sub_env, OrderEnforcing)
|
||||
if sub_env.spec.max_episode_steps is not None:
|
||||
|
@@ -14,7 +14,7 @@ from tests.testing_env import GenericTestEnv
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
all_testing_initialised_envs,
|
||||
ids=[env.spec.id for env in all_testing_initialised_envs],
|
||||
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
||||
)
|
||||
def test_passive_checker_wrapper_warnings(env):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
|
@@ -16,6 +16,7 @@ def test_record_episode_statistics(env_id, deque_size):
|
||||
assert env.episode_returns is not None and env.episode_lengths is not None
|
||||
assert env.episode_returns[0] == 0.0
|
||||
assert env.episode_lengths[0] == 0
|
||||
assert env.spec is not None
|
||||
for t in range(env.spec.max_episode_steps):
|
||||
_, _, terminated, truncated, info = env.step(env.action_space.sample())
|
||||
if terminated or truncated:
|
||||
|
Reference in New Issue
Block a user