Update the type hinting for core.py (#39)

This commit is contained in:
Mark Towers
2022-11-12 10:21:24 +00:00
committed by GitHub
parent f8ea4df0b8
commit 31025e391b
25 changed files with 286 additions and 82 deletions

View File

@@ -1,16 +1,7 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
SupportsFloat,
Tuple,
TypeVar,
Union,
)
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
import numpy as np
@@ -55,20 +46,22 @@ class Env(Generic[ObsType, ActType]):
"""
# Set this in SOME subclasses
metadata: Dict[str, Any] = {"render_modes": []}
metadata: dict[str, Any] = {"render_modes": []}
# define render_mode if your environment supports rendering
render_mode: Optional[str] = None
render_mode: str | None = None
reward_range = (-float("inf"), float("inf"))
spec: "EnvSpec" = None
spec: EnvSpec | None = None
# Set these in ALL subclasses
action_space: spaces.Space[ActType]
observation_space: spaces.Space[ObsType]
# Created
_np_random: Optional[np.random.Generator] = None
_np_random: np.random.Generator | None = None
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Run one timestep of the environment's dynamics using the agent actions.
When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to
@@ -86,7 +79,7 @@ class Env(Generic[ObsType, ActType]):
Returns:
observation (ObsType): An element of the environment's :attr:`observation_space` as the next observation due to the agent actions.
An example is a numpy array containing the positions and velocities of the pole in CartPole.
reward (float): The reward as a result of taking the action.
reward (SupportsFloat): The reward as a result of taking the action.
terminated (bool): Whether the agent reaches the terminal state (as defined under the MDP of the task)
which can be positive or negative. An example is reaching the goal state or moving into the lava from
the Sutton and Barton, Gridworld. If true, the user needs to call :meth:`reset`.
@@ -109,9 +102,9 @@ class Env(Generic[ObsType, ActType]):
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
"""Resets the environment to an initial internal state, returning an initial observation and info.
This method generates a new starting state often with some randomness to ensure that the agent explores the
@@ -149,7 +142,7 @@ class Env(Generic[ObsType, ActType]):
if seed is not None:
self._np_random, seed = seeding.np_random(seed)
def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment.
The environment's :attr:`metadata` render modes (`env.metadata["render_modes"]`) should contain the possible
@@ -191,8 +184,8 @@ class Env(Generic[ObsType, ActType]):
pass
@property
def unwrapped(self) -> "Env":
"""Returns the base non-wrapped environment (i.e., removes all wrappers).
def unwrapped(self) -> Env[ObsType, ActType]:
"""Returns the base non-wrapped environment.
Returns:
Env: The base non-wrapped :class:`gymnasium.Env` instance
@@ -229,14 +222,18 @@ class Env(Generic[ObsType, ActType]):
"""Support with-statement for the environment."""
return self
def __exit__(self, *args):
def __exit__(self, *args: Any):
"""Support with-statement for the environment and closes the environment."""
self.close()
# propagate exception
return False
class Wrapper(Env[ObsType, ActType]):
WrapperObsType = TypeVar("WrapperObsType")
WrapperActType = TypeVar("WrapperActType")
class Wrapper(Env[WrapperObsType, WrapperActType]):
"""Wraps a :class:`gymnasium.Env` to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
This class is the base class of all wrappers to change the behavior of the underlying environment allowing
@@ -293,7 +290,7 @@ class Wrapper(Env[ObsType, ActType]):
Don't forget to call ``super().__init__(env)``
"""
def __init__(self, env: Env):
def __init__(self, env: Env[ObsType, ActType]):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
Args:
@@ -301,73 +298,81 @@ class Wrapper(Env[ObsType, ActType]):
"""
self.env = env
self._action_space: Optional[spaces.Space] = None
self._observation_space: Optional[spaces.Space] = None
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
self._metadata: Optional[dict] = None
self._action_space: spaces.Space[WrapperActType] | None = None
self._observation_space: spaces.Space[WrapperObsType] | None = None
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
self._metadata: dict[str, Any] | None = None
def __getattr__(self, name):
def __getattr__(self, name: str):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_"):
if name == "_np_random":
raise AttributeError(
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
)
elif name.startswith("_"):
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)
@property
def spec(self):
def spec(self) -> EnvSpec | None:
"""Returns the :attr:`Env` :attr:`spec` attribute."""
return self.env.spec
@classmethod
def class_name(cls):
def class_name(cls) -> str:
"""Returns the class name of the wrapper."""
return cls.__name__
@property
def action_space(self) -> spaces.Space[ActType]:
def action_space(
self,
) -> spaces.Space[ActType] | spaces.Space[WrapperActType]:
"""Return the :attr:`Env` :attr:`action_space` unless overwritten then the wrapper :attr:`action_space` is used."""
if self._action_space is None:
return self.env.action_space
return self._action_space
@action_space.setter
def action_space(self, space: spaces.Space):
def action_space(self, space: spaces.Space[WrapperActType]):
self._action_space = space
@property
def observation_space(self) -> spaces.Space:
def observation_space(
self,
) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]:
"""Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used."""
if self._observation_space is None:
return self.env.observation_space
return self._observation_space
@observation_space.setter
def observation_space(self, space: spaces.Space):
def observation_space(self, space: spaces.Space[WrapperObsType]):
self._observation_space = space
@property
def reward_range(self) -> Tuple[SupportsFloat, SupportsFloat]:
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
"""Return the :attr:`Env` :attr:`reward_range` unless overwritten then the wrapper :attr:`reward_range` is used."""
if self._reward_range is None:
return self.env.reward_range
return self._reward_range
@reward_range.setter
def reward_range(self, value: Tuple[SupportsFloat, SupportsFloat]):
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
self._reward_range = value
@property
def metadata(self) -> dict:
def metadata(self) -> dict[str, Any]:
"""Returns the :attr:`Env` :attr:`metadata`."""
if self._metadata is None:
return self.env.metadata
return self._metadata
@metadata.setter
def metadata(self, value):
def metadata(self, value: dict[str, Any]):
self._metadata = value
@property
def render_mode(self) -> Optional[str]:
def render_mode(self) -> str | None:
"""Returns the :attr:`Env` :attr:`render_mode`."""
return self.env.render_mode
@@ -377,28 +382,34 @@ class Wrapper(Env[ObsType, ActType]):
return self.env.np_random
@np_random.setter
def np_random(self, value):
def np_random(self, value: np.random.Generator):
self.env.np_random = value
@property
def _np_random(self):
"""This code will never be run due to __getattr__ being called prior this.
It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable.
"""
raise AttributeError(
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.step(action)
def reset(self, **kwargs) -> Tuple[ObsType, dict]:
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.reset(**kwargs)
return self.env.reset(seed=seed, options=options)
def render(
self, *args, **kwargs
) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.render(*args, **kwargs)
return self.env.render()
def close(self):
"""Closes the wrapper and :attr:`env`."""
@@ -413,12 +424,12 @@ class Wrapper(Env[ObsType, ActType]):
return str(self)
@property
def unwrapped(self) -> Env:
def unwrapped(self) -> Env[ObsType, ActType]:
"""Returns the base environment of the wrapper."""
return self.env.unwrapped
class ObservationWrapper(Wrapper):
class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`.
If you would like to apply a function to only the observation before
@@ -433,7 +444,7 @@ class ObservationWrapper(Wrapper):
``observation["target_position"] - observation["agent_position"]``. For this, you could implement an
observation wrapper like this::
class RelativePosition(gymnasium.ObservationWrapper):
class RelativePosition(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = Box(shape=(2,), low=-np.inf, high=np.inf)
@@ -445,17 +456,25 @@ class ObservationWrapper(Wrapper):
index of the timestep to the observation.
"""
def reset(self, **kwargs):
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the observation wrapper."""
super().__init__(env)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Modifies the :attr:`env` after calling :meth:`reset`, returning a modified observation using :meth:`self.observation`."""
obs, info = self.env.reset(**kwargs)
obs, info = self.env.reset(seed=seed, options=options)
return self.observation(obs), info
def step(self, action):
def step(
self, action: ActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations."""
observation, reward, terminated, truncated, info = self.env.step(action)
return self.observation(observation), reward, terminated, truncated, info
def observation(self, observation):
def observation(self, observation: ObsType) -> WrapperObsType:
"""Returns a modified observation.
Args:
@@ -467,7 +486,7 @@ class ObservationWrapper(Wrapper):
raise NotImplementedError
class RewardWrapper(Wrapper):
class RewardWrapper(Wrapper[ObsType, ActType]):
"""Superclass of wrappers that can modify the returning reward from a step.
If you would like to apply a function to the reward that is returned by the base environment before
@@ -480,23 +499,29 @@ class RewardWrapper(Wrapper):
because it is intrinsic), we want to clip the reward to a range to gain some numerical stability.
To do that, we could, for instance, implement the following wrapper::
class ClipReward(gymnasium.RewardWrapper):
class ClipReward(gym.RewardWrapper):
def __init__(self, env, min_reward, max_reward):
super().__init__(env)
self.min_reward = min_reward
self.max_reward = max_reward
self.reward_range = (min_reward, max_reward)
def reward(self, reward):
return np.clip(reward, self.min_reward, self.max_reward)
def reward(self, r: SupportsFloat) -> SupportsFloat:
return np.clip(r, self.min_reward, self.max_reward)
"""
def step(self, action):
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the Reward wrapper."""
super().__init__(env)
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`."""
observation, reward, terminated, truncated, info = self.env.step(action)
return observation, self.reward(reward), terminated, truncated, info
def reward(self, reward):
def reward(self, reward: SupportsFloat) -> SupportsFloat:
"""Returns a modified environment ``reward``.
Args:
@@ -508,7 +533,7 @@ class RewardWrapper(Wrapper):
raise NotImplementedError
class ActionWrapper(Wrapper):
class ActionWrapper(Wrapper[ObsType, WrapperActType]):
"""Superclass of wrappers that can modify the action before :meth:`env.step`.
If you would like to apply a function to the action before passing it to the base environment,
@@ -521,7 +546,7 @@ class ActionWrapper(Wrapper):
Lets say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like
to use a finite subset of actions. Then, you might want to implement the following wrapper::
class DiscreteActions(gymnasium.ActionWrapper):
class DiscreteActions(gym.ActionWrapper):
def __init__(self, env, disc_to_cont):
super().__init__(env)
self.disc_to_cont = disc_to_cont
@@ -531,7 +556,7 @@ class ActionWrapper(Wrapper):
return self.disc_to_cont[act]
if __name__ == "__main__":
env = gymnasium.make("LunarLanderContinuous-v2")
env = gym.make("LunarLanderContinuous-v2")
wrapped_env = DiscreteActions(env, [np.array([1,0]), np.array([-1,0]),
np.array([0,1]), np.array([0,-1])])
print(wrapped_env.action_space) #Discrete(4)
@@ -539,11 +564,17 @@ class ActionWrapper(Wrapper):
Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction` for clipping and rescaling actions.
"""
def step(self, action):
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the action wrapper."""
super().__init__(env)
def step(
self, action: WrapperActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Runs the :attr:`env` :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
return self.env.step(self.action(action))
def action(self, action):
def action(self, action: WrapperActType) -> ActType:
"""Returns a modified action before :meth:`env.step` is called.
Args:

View File

@@ -609,6 +609,7 @@ class BipedalWalker(gym.Env, EzPickle):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -577,6 +577,7 @@ class CarRacing(gym.Env, EzPickle):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -606,6 +606,7 @@ class LunarLander(gym.Env, EzPickle):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -282,6 +282,7 @@ class AcrobotEnv(Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -208,6 +208,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -194,6 +194,7 @@ class Continuous_MountainCarEnv(gym.Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -169,6 +169,7 @@ class MountainCarEnv(gym.Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -168,6 +168,7 @@ class PendulumEnv(gym.Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -194,6 +194,7 @@ class BlackjackEnv(gym.Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -165,6 +165,7 @@ class CliffWalkingEnv(Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -270,6 +270,7 @@ class FrozenLakeEnv(Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -281,6 +281,7 @@ class TaxiEnv(Env):
def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "

View File

@@ -72,6 +72,7 @@ class PlayableGame:
elif hasattr(self.env.unwrapped, "get_keys_to_action"):
keys_to_action = self.env.unwrapped.get_keys_to_action()
else:
assert self.env.spec is not None
raise MissingKeysToAction(
f"{self.env.spec.id} does not have explicit key to action mapping, "
"please specify one manually"
@@ -230,6 +231,7 @@ def play(
elif hasattr(env.unwrapped, "get_keys_to_action"):
keys_to_action = env.unwrapped.get_keys_to_action()
else:
assert env.spec is not None
raise MissingKeysToAction(
f"{env.spec.id} does not have explicit key to action mapping, "
"please specify one manually"

View File

@@ -69,7 +69,8 @@ class AtariPreprocessing(gym.Wrapper):
assert noop_max >= 0
if frame_skip > 1:
if (
"NoFrameskip" not in env.spec.id
env.spec is not None
and "NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(

View File

@@ -30,6 +30,7 @@ class TimeLimit(gym.Wrapper):
"""
super().__init__(env)
if max_episode_steps is None and self.env.spec is not None:
assert env.spec is not None
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps

View File

@@ -57,7 +57,9 @@ DISCRETE_ENVS = list(
@pytest.mark.parametrize(
"env", DISCRETE_ENVS, ids=[env.spec.id for env in DISCRETE_ENVS]
"env",
DISCRETE_ENVS,
ids=[env.spec.id for env in DISCRETE_ENVS if env.spec is not None],
)
def test_discrete_actions_out_of_bound(env: gym.Env):
"""Test out of bound actions in Discrete action_space.
@@ -87,7 +89,9 @@ BOX_ENVS = list(
OOB_VALUE = 100
@pytest.mark.parametrize("env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS])
@pytest.mark.parametrize(
"env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS if env.spec is not None]
)
def test_box_actions_out_of_bound(env: gym.Env):
"""Test out of bound actions in Box action_space.
@@ -100,6 +104,7 @@ def test_box_actions_out_of_bound(env: gym.Env):
"""
env.reset(seed=42)
assert env.spec is not None
oob_env = gym.make(env.spec.id, disable_env_checker=True)
oob_env.reset(seed=42)

View File

@@ -192,7 +192,7 @@ def test_render_modes(spec):
@pytest.mark.parametrize(
"env",
all_testing_initialised_envs,
ids=[env.spec.id for env in all_testing_initialised_envs],
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
)
def test_pickle_env(env: gym.Env):
pickled_env = pickle.loads(pickle.dumps(env))

View File

@@ -53,6 +53,7 @@ gym.register(
def test_make():
env = gym.make("CartPole-v1", disable_env_checker=True)
assert env.spec is not None
assert env.spec.id == "CartPole-v1"
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
env.close()
@@ -73,6 +74,7 @@ def test_make_max_episode_steps():
# Default, uses the spec's
env = gym.make("CartPole-v1", disable_env_checker=True)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert (
env.spec.max_episode_steps == gym.envs.registry["CartPole-v1"].max_episode_steps
)
@@ -81,6 +83,7 @@ def test_make_max_episode_steps():
# Custom max episode steps
env = gym.make("CartPole-v1", max_episode_steps=100, disable_env_checker=True)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == 100
env.close()
@@ -297,6 +300,7 @@ def test_make_kwargs():
arg3="override_arg3",
disable_env_checker=True,
)
assert env.spec is not None
assert env.spec.id == "test.ArgumentEnv-v0"
assert isinstance(env.unwrapped, ArgumentEnv)
assert env.arg1 == "arg1"

View File

@@ -183,6 +183,7 @@ def test_make_latest_versioned_env(register_testing_envs):
env = gym.make(
"MyAwesomeNamespace/MyAwesomeVersionedEnv", disable_env_checker=True
)
assert env.spec is not None
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"

View File

@@ -16,6 +16,7 @@ def test_spec():
def test_spec_kwargs():
map_name_value = "8x8"
env = gym.make("FrozenLake-v1", map_name=map_name_value)
assert env.spec is not None
assert env.spec.kwargs["map_name"] == map_name_value

View File

@@ -1,13 +1,27 @@
from typing import Optional
"""Checks that the core Gymnasium API is implemented as expected."""
import re
from typing import Any, Dict, Optional, SupportsFloat, Tuple
import numpy as np
import pytest
from gymnasium import core, spaces
from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper, spaces
from gymnasium.core import (
ActionWrapper,
ActType,
ObsType,
WrapperActType,
WrapperObsType,
)
from gymnasium.spaces import Box
from gymnasium.utils import seeding
from gymnasium.wrappers import OrderEnforcing, TimeLimit
from tests.testing_env import GenericTestEnv
# ==== Old testing code
class ArgumentEnv(core.Env):
class ArgumentEnv(Env):
observation_space = spaces.Box(low=0, high=1, shape=(1,))
action_space = spaces.Box(low=0, high=1, shape=(1,))
calls = 0
@@ -17,7 +31,7 @@ class ArgumentEnv(core.Env):
self.arg = arg
class UnittestEnv(core.Env):
class UnittestEnv(Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3)
@@ -30,7 +44,7 @@ class UnittestEnv(core.Env):
return (observation, 0.0, False, {})
class UnknownSpacesEnv(core.Env):
class UnknownSpacesEnv(Env):
"""This environment defines its observation & action spaces only
after the first call to reset. Although this pattern is sometimes
necessary when implementing a new environment (e.g. if it depends
@@ -50,7 +64,7 @@ class UnknownSpacesEnv(core.Env):
return (observation, 0.0, False, {})
class OldStyleEnv(core.Env):
class OldStyleEnv(Env):
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)"""
def __init__(self):
@@ -64,7 +78,7 @@ class OldStyleEnv(core.Env):
return 0, 0, False, {}
class NewPropertyWrapper(core.Wrapper):
class NewPropertyWrapper(Wrapper):
def __init__(
self,
env,
@@ -137,3 +151,133 @@ def test_compatibility_with_old_style_env():
env = TimeLimit(env)
obs = env.reset()
assert obs == 0
# ==== New testing code
class ExampleEnv(Env):
def __init__(self):
self.observation_space = Box(0, 1)
self.action_space = Box(0, 1)
def step(
self, action: ActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
return 0, 0, False, False, {}
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
return 0, {}
def test_gymnasium_env():
env = ExampleEnv()
assert env.metadata == {"render_modes": []}
assert env.render_mode is None
assert env.reward_range == (-float("inf"), float("inf"))
assert env.spec is None
assert env._np_random is None # pyright: ignore [reportPrivateUsage]
class ExampleWrapper(Wrapper):
def __init__(self, env: Env[ObsType, ActType]):
super().__init__(env)
self.new_reward = 3
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[WrapperObsType, Dict[str, Any]]:
return super().reset(seed=seed, options=options)
def step(
self, action: WrapperActType
) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]:
obs, reward, termination, truncation, info = self.env.step(action)
return obs, self.new_reward, termination, truncation, info
def access_hidden_np_random(self):
"""This should raise an error when called as wrappers should not access their own `_np_random` instances and should use the unwrapped environments."""
return self._np_random
def test_gymnasium_wrapper():
env = ExampleEnv()
wrapper_env = ExampleWrapper(env)
assert env.metadata == wrapper_env.metadata
wrapper_env.metadata = {"render_modes": ["rgb_array"]}
assert env.metadata != wrapper_env.metadata
assert env.render_mode == wrapper_env.render_mode
assert env.reward_range == wrapper_env.reward_range
wrapper_env.reward_range = (-1.0, 1.0)
assert env.reward_range != wrapper_env.reward_range
assert env.spec == wrapper_env.spec
env.observation_space = Box(0, 1)
env.action_space = Box(0, 1)
assert env.observation_space == wrapper_env.observation_space
assert env.action_space == wrapper_env.action_space
wrapper_env.observation_space = Box(1, 2)
wrapper_env.action_space = Box(1, 2)
assert env.observation_space != wrapper_env.observation_space
assert env.action_space != wrapper_env.action_space
wrapper_env.np_random, _ = seeding.np_random()
assert (
env._np_random # pyright: ignore [reportPrivateUsage]
is env.np_random
is wrapper_env.np_random
)
assert 0 <= wrapper_env.np_random.uniform() <= 1
with pytest.raises(
AttributeError,
match=re.escape(
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
),
):
print(wrapper_env.access_hidden_np_random())
class ExampleRewardWrapper(RewardWrapper):
def reward(self, reward: SupportsFloat) -> SupportsFloat:
return 1
class ExampleObservationWrapper(ObservationWrapper):
def observation(self, observation: ObsType) -> ObsType:
return np.array([1])
class ExampleActionWrapper(ActionWrapper):
def action(self, action: ActType) -> ActType:
return np.array([1])
def test_wrapper_types():
env = GenericTestEnv()
reward_env = ExampleRewardWrapper(env)
reward_env.reset()
_, reward, _, _, _ = reward_env.step(0)
assert reward == 1
observation_env = ExampleObservationWrapper(env)
obs, _ = observation_env.reset()
assert obs == np.array([1])
obs, _, _, _, _ = observation_env.step(0)
assert obs == np.array([1])
env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {}))
action_env = ExampleActionWrapper(env)
obs, _, _, _, _ = action_env.step(0)
assert obs == np.array([1])

View File

@@ -38,6 +38,7 @@ def test_vector_make_wrappers():
sub_env = env.envs[0]
assert isinstance(sub_env, gym.Env)
assert sub_env.spec is not None
if sub_env.spec.order_enforce:
assert has_wrapper(sub_env, OrderEnforcing)
if sub_env.spec.max_episode_steps is not None:

View File

@@ -14,7 +14,7 @@ from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize(
"env",
all_testing_initialised_envs,
ids=[env.spec.id for env in all_testing_initialised_envs],
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
)
def test_passive_checker_wrapper_warnings(env):
with warnings.catch_warnings(record=True) as caught_warnings:

View File

@@ -16,6 +16,7 @@ def test_record_episode_statistics(env_id, deque_size):
assert env.episode_returns is not None and env.episode_lengths is not None
assert env.episode_returns[0] == 0.0
assert env.episode_lengths[0] == 0
assert env.spec is not None
for t in range(env.spec.max_episode_steps):
_, _, terminated, truncated, info = env.step(env.action_space.sample())
if terminated or truncated: