Bug fix, add tests for environment checker and passive environment checker wrapper (#2903)

This commit is contained in:
Mark Towers
2022-07-11 02:45:24 +01:00
committed by GitHub
parent 907b1b20dd
commit 015b31fa76
20 changed files with 1365 additions and 402 deletions

View File

@@ -327,8 +327,7 @@ class Wrapper(Env[ObsType, ActType]):
if not self.new_step_api: if not self.new_step_api:
deprecation( deprecation(
"Initializing wrapper in old step API which returns one bool instead of two. " "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
) )
def __getattr__(self, name): def __getattr__(self, name):

View File

@@ -18,6 +18,16 @@ try:
except ImportError: except ImportError:
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`") raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
try:
# As pygame is necessary for using the environment (reset and step) even without a render mode
# therefore, pygame is a necessary import for the environment.
import pygame
from pygame import gfxdraw
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gym[box2d]`"
)
STATE_W = 96 # less than Atari 160x192 STATE_W = 96 # less than Atari 160x192
STATE_H = 96 STATE_H = 96
@@ -556,15 +566,8 @@ class CarRacing(gym.Env, EzPickle):
def _render(self, mode: str = "human"): def _render(self, mode: str = "human"):
assert mode in self.metadata["render_modes"] assert mode in self.metadata["render_modes"]
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gym[box2d]`"
)
pygame.font.init() pygame.font.init()
if self.screen is None and mode == "human": if self.screen is None and mode == "human":
pygame.init() pygame.init()
pygame.display.init() pygame.display.init()
@@ -661,8 +664,6 @@ class CarRacing(gym.Env, EzPickle):
self._draw_colored_polygon(self.surf, poly, color, zoom, translation, angle) self._draw_colored_polygon(self.surf, poly, color, zoom, translation, angle)
def _render_indicators(self, W, H): def _render_indicators(self, W, H):
import pygame
s = W / 40.0 s = W / 40.0
h = H / 40.0 h = H / 40.0
color = (0, 0, 0) color = (0, 0, 0)
@@ -733,9 +734,6 @@ class CarRacing(gym.Env, EzPickle):
def _draw_colored_polygon( def _draw_colored_polygon(
self, surface, poly, color, zoom, translation, angle, clip=True self, surface, poly, color, zoom, translation, angle, clip=True
): ):
import pygame
from pygame import gfxdraw
poly = [pygame.math.Vector2(c).rotate_rad(angle) for c in poly] poly = [pygame.math.Vector2(c).rotate_rad(angle) for c in poly]
poly = [ poly = [
(c[0] * zoom + translation[0], c[1] * zoom + translation[1]) for c in poly (c[0] * zoom + translation[0], c[1] * zoom + translation[1]) for c in poly
@@ -754,8 +752,6 @@ class CarRacing(gym.Env, EzPickle):
gfxdraw.filled_polygon(self.surf, poly, color) gfxdraw.filled_polygon(self.surf, poly, color)
def _create_image_array(self, screen, size): def _create_image_array(self, screen, size):
import pygame
scaled_screen = pygame.transform.smoothscale(screen, size) scaled_screen = pygame.transform.smoothscale(screen, size)
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(scaled_screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(scaled_screen)), axes=(1, 0, 2)
@@ -763,8 +759,6 @@ class CarRacing(gym.Env, EzPickle):
def close(self): def close(self):
if self.screen is not None: if self.screen is not None:
import pygame
pygame.display.quit() pygame.display.quit()
self.isopen = False self.isopen = False
pygame.quit() pygame.quit()
@@ -772,7 +766,6 @@ class CarRacing(gym.Env, EzPickle):
if __name__ == "__main__": if __name__ == "__main__":
a = np.array([0.0, 0.0, 0.0]) a = np.array([0.0, 0.0, 0.0])
import pygame
def register_input(): def register_input():
for event in pygame.event.get(): for event in pygame.event.get():

View File

@@ -692,6 +692,7 @@ class LunarLander(gym.Env, EzPickle):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
if mode == "human": if mode == "human":
assert self.screen is not None
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])

View File

@@ -2,7 +2,7 @@
Utility functions used for classic control environments. Utility functions used for classic control environments.
""" """
from typing import Optional, SupportsFloat, Union from typing import Optional, SupportsFloat, Tuple
def verify_number_and_cast(x: SupportsFloat) -> float: def verify_number_and_cast(x: SupportsFloat) -> float:
@@ -10,35 +10,37 @@ def verify_number_and_cast(x: SupportsFloat) -> float:
try: try:
x = float(x) x = float(x)
except (ValueError, TypeError): except (ValueError, TypeError):
raise ValueError(f"Your input must support being cast to a float: {x}") raise ValueError(f"An option ({x}) could not be converted to a float.")
return x return x
def maybe_parse_reset_bounds( def maybe_parse_reset_bounds(
options: Optional[dict], default_low: float, default_high: float options: Optional[dict], default_low: float, default_high: float
) -> Union[float, float]: ) -> Tuple[float, float]:
""" """
This function can be called during a reset() to customize the sampling This function can be called during a reset() to customize the sampling
ranges for setting the initial state distributions. ranges for setting the initial state distributions.
Args: Args:
options: (Optional) options passed in to reset(). options: Options passed in to reset().
default_low: Default lower limit to use, if none specified in options. default_low: Default lower limit to use, if none specified in options.
default_high: Default upper limit to use, if none specified in options. default_high: Default upper limit to use, if none specified in options.
limit_low: Lowest allowable value for user-specified lower limit.
limit_high: Highest allowable value for user-specified higher limit.
Returns: Returns:
Lower and higher limits. Tuple of the lower and upper limits.
""" """
if options is None: if options is None:
return default_low, default_high return default_low, default_high
low = options.get("low") if "low" in options else default_low low = options.get("low") if "low" in options else default_low
high = options.get("high") if "high" in options else default_high high = options.get("high") if "high" in options else default_high
# We expect only numerical inputs. # We expect only numerical inputs.
low = verify_number_and_cast(low) low = verify_number_and_cast(low)
high = verify_number_and_cast(high) high = verify_number_and_cast(high)
if low > high: if low > high:
raise ValueError("Lower bound must be lower than higher bound.") raise ValueError(
f"Lower bound ({low}) must be lower than higher bound ({high})."
)
return low, high return low, high

View File

@@ -117,14 +117,32 @@ def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
@dataclass @dataclass
class EnvSpec: class EnvSpec:
"""A specification for creating environments with `gym.make`.
* id: The string used to create the environment with `gym.make`
* entry_point: The location of the environment to create from
* reward_threshold: The reward threshold for completing the environment.
* nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
* max_episode_steps: The max number of steps that the environment can take before truncation
* order_enforce: If to enforce the order of `reset` before `step` and `render` functions
* autoreset: If to automatically reset the environment on episode end
* disable_env_checker: If to disable the environment checker wrapper by default in `gym.make`
* kwargs: Additional keyword arguments passed to the environments through `gym.make`
"""
id: str id: str
entry_point: Optional[Union[Callable, str]] = field(default=None) entry_point: Optional[Union[Callable, str]] = field(default=None)
reward_threshold: Optional[float] = field(default=None) reward_threshold: Optional[float] = field(default=None)
nondeterministic: bool = field(default=False) nondeterministic: bool = field(default=False)
# Wrappers
max_episode_steps: Optional[int] = field(default=None) max_episode_steps: Optional[int] = field(default=None)
order_enforce: bool = field(default=True) order_enforce: bool = field(default=True)
autoreset: bool = field(default=False) autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
new_step_api: bool = field(default=False) new_step_api: bool = field(default=False)
# Environment arguments
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)
namespace: Optional[str] = field(init=False) namespace: Optional[str] = field(init=False)
@@ -530,7 +548,7 @@ def make(
max_episode_steps: Optional[int] = None, max_episode_steps: Optional[int] = None,
autoreset: bool = False, autoreset: bool = False,
new_step_api: bool = False, new_step_api: bool = False,
disable_env_checker: bool = False, disable_env_checker: Optional[bool] = None,
**kwargs, **kwargs,
) -> Env: ) -> Env:
"""Create an environment according to the given ID. """Create an environment according to the given ID.
@@ -540,7 +558,8 @@ def make(
max_episode_steps: Maximum length of an episode (TimeLimit wrapper). max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0 new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
disable_env_checker: If to disable the environment checker disable_env_checker: If to run the env checker, None will default to the environment `spec.disable_env_checker`
(that is by default True), otherwise will run according to the parameter (True = not run, False = run)
kwargs: Additional arguments to pass to the environment constructor. kwargs: Additional arguments to pass to the environment constructor.
Returns: Returns:
@@ -641,16 +660,15 @@ def make(
else: else:
raise e raise e
if apply_human_rendering:
env = HumanRendering(env)
# Copies the environment creation specification and kwargs to add to the environment specification details # Copies the environment creation specification and kwargs to add to the environment specification details
spec_ = copy.deepcopy(spec_) spec_ = copy.deepcopy(spec_)
spec_.kwargs = _kwargs spec_.kwargs = _kwargs
env.unwrapped.spec = spec_ env.unwrapped.spec = spec_
# Run the environment checker as the lowest level wrapper # Run the environment checker as the lowest level wrapper
if disable_env_checker is False: if disable_env_checker is False or (
disable_env_checker is None and spec_.disable_env_checker is False
):
env = PassiveEnvChecker(env) env = PassiveEnvChecker(env)
env = StepAPICompatibility(env, new_step_api) env = StepAPICompatibility(env, new_step_api)
@@ -669,6 +687,10 @@ def make(
if autoreset: if autoreset:
env = AutoResetWrapper(env, new_step_api) env = AutoResetWrapper(env, new_step_api)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
return env return env

View File

@@ -20,12 +20,13 @@ from copy import deepcopy
import numpy as np import numpy as np
import gym import gym
from gym import error, logger from gym import logger, spaces
from gym.utils.passive_env_checker import ( from gym.utils.passive_env_checker import (
check_action_space, check_action_space,
check_observation_space, check_observation_space,
passive_env_reset_check, env_render_passive_checker,
passive_env_step_check, env_reset_passive_checker,
env_step_passive_checker,
) )
@@ -67,47 +68,61 @@ def check_reset_seed(env: gym.Env):
even though `seed` or `kwargs` appear in the signature. even though `seed` or `kwargs` appear in the signature.
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
if "seed" in signature.parameters or "kwargs" in signature.parameters: if "seed" in signature.parameters or (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
try: try:
obs_1 = env.reset(seed=123) obs_1 = env.reset(seed=123)
assert obs_1 in env.observation_space
obs_2 = env.reset(seed=123)
assert obs_2 in env.observation_space
assert data_equivalence(obs_1, obs_2)
seed_123_rng = deepcopy(env.unwrapped.np_random)
# Note: for some environment, they may initialise at the same state, therefore we cannot check the obs_1 != obs_3
obs_4 = env.reset(seed=None)
assert obs_4 in env.observation_space
assert ( assert (
env.unwrapped.np_random.bit_generator.state obs_1 in env.observation_space
!= seed_123_rng.bit_generator.state ), "The observation returned by `env.reset(seed=123)` is not within the observation space."
assert (
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
is not None
), "Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`."
seed_123_rng = deepcopy(
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
) )
obs_2 = env.reset(seed=123)
assert (
obs_2 in env.observation_space
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
if env.spec is not None and env.spec.nondeterministic is False:
assert data_equivalence(
obs_1, obs_2
), "Using `env.reset(seed=123)` is non-deterministic as the observations are not equivalent."
assert (
env.unwrapped._np_random.bit_generator.state # pyright: ignore [reportPrivateUsage]
== seed_123_rng.bit_generator.state
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
obs_3 = env.reset(seed=456)
assert (
obs_3 in env.observation_space
), "The observation returned by `env.reset(seed=456)` is not within the observation space."
assert (
env.unwrapped._np_random.bit_generator.state # pyright: ignore [reportPrivateUsage]
!= seed_123_rng.bit_generator.state
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`."
except TypeError as e: except TypeError as e:
raise AssertionError( raise AssertionError(
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. " "The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
"This should never happen, please report this issue. " f"This should never happen, please report this issue. The error was: {e}"
f"The error was: {e}"
)
if env.unwrapped.np_random is None:
logger.warn(
"Resetting the environment did not result in seeding its random number generator. "
"This is likely due to not calling `super().reset(seed=seed)` in the `reset` method. "
"If you do not use the python-level random number generator, this is not a problem."
) )
seed_param = signature.parameters.get("seed") seed_param = signature.parameters.get("seed")
# Check the default value is None # Check the default value is None
if seed_param is not None and seed_param.default is not None: if seed_param is not None and seed_param.default is not None:
logger.warn( logger.warn(
"The default seed argument in reset should be `None`, " "The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. "
"otherwise the environment will by default always be deterministic" f"Actual default: {seed_param.default}"
) )
else: else:
raise error.Error( raise gym.error.Error(
"The `reset` method does not provide the `return_info` keyword argument" "The `reset` method does not provide a `seed` or `**kwargs` keyword argument."
) )
@@ -122,25 +137,39 @@ def check_reset_info(env: gym.Env):
even though `return_info` or `kwargs` appear in the signature. even though `return_info` or `kwargs` appear in the signature.
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
if "return_info" in signature.parameters or "kwargs" in signature.parameters: if "return_info" in signature.parameters or (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
try: try:
obs = env.reset(return_info=False)
assert (
obs in env.observation_space
), "The value returned by `env.reset(return_info=True)` is not within the observation space."
result = env.reset(return_info=True) result = env.reset(return_info=True)
assert isinstance(
result, tuple
), f"Calling the reset method with `return_info=True` did not return a tuple, actual type: {type(result)}"
assert ( assert (
len(result) == 2 len(result) == 2
), "Calling the reset method with `return_info=True` did not return a 2-tuple" ), f"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: {len(result)}"
obs, info = result obs, info = result
assert (
obs in env.observation_space
), "The first element returned by `env.reset(return_info=True)` is not within the observation space."
assert isinstance( assert isinstance(
info, dict info, dict
), "The second element returned by `env.reset(return_info=True)` was not a dictionary" ), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
except TypeError as e: except TypeError as e:
raise AssertionError( raise AssertionError(
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` " "The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` appear in the signature. "
"appear in the signature. This should never happen, please report this issue. " f"This should never happen, please report this issue. The error was: {e}"
f"The error was: {e}"
) )
else: else:
raise error.Error( raise gym.error.Error(
"The `reset` method does not provide the `return_info` keyword argument" "The `reset` method does not provide a `return_info` or `**kwargs` keyword argument."
) )
@@ -155,57 +184,62 @@ def check_reset_options(env: gym.Env):
even though `options` or `kwargs` appear in the signature. even though `options` or `kwargs` appear in the signature.
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
if "options" in signature.parameters or "kwargs" in signature.parameters: if "options" in signature.parameters or (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
try: try:
env.reset(options={}) env.reset(options={})
except TypeError as e: except TypeError as e:
raise AssertionError( raise AssertionError(
"The environment cannot be reset with options, even though `options` or `kwargs` appear in the signature. " "The environment cannot be reset with options, even though `options` or `**kwargs` appear in the signature. "
"This should never happen, please report this issue. " f"This should never happen, please report this issue. The error was: {e}"
f"The error was: {e}"
) )
else: else:
raise error.Error( raise gym.error.Error(
"The `reset` method does not provide the `options` keyword argument" "The `reset` method does not provide an `options` or `**kwargs` keyword argument."
) )
def check_render(env: gym.Env, warn: bool = True): def check_space_limit(space, space_type: str):
"""Check the declared render modes/fps of the environment. """Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
if isinstance(space, spaces.Box):
Args: if np.any(np.equal(space.low, -np.inf)):
env: The environment to check
warn: Whether to output additional warnings
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
if warn:
logger.warn( logger.warn(
"No render modes was declared in the environment " f"A Box {space_type} space minimum value is -infinity. This is probably too low."
" (env.metadata['render_modes'] is None or not defined), " )
"you may have trouble when calling `.render()`" if np.any(np.equal(space.high, np.inf)):
logger.warn(
f"A Box {space_type} space maximum value is -infinity. This is probably too high."
) )
render_fps = env.metadata.get("render_fps") # Check that the Box space is normalized
# We only require `render_fps` if rendering is actually implemented if space_type == "action":
if render_fps is None and render_modes is not None and len(render_modes) > 0: if len(space.shape) == 1: # for vector boxes
if warn: if (
logger.warn( np.any(
"No render fps was declared in the environment " np.logical_and(
" (env.metadata['render_fps'] is None or not defined), " space.low != np.zeros_like(space.low),
"rendering may occur at inconsistent fps" np.abs(space.low) != np.abs(space.high),
) )
)
if warn: or np.any(space.low < -1)
if not hasattr(env, "render_mode"): # TODO: raise an error with gym 1.0 or np.any(space.high > 1)
logger.warn("Environments must define render_mode attribute.") ):
elif env.render_mode is not None and env.render_mode not in render_modes: # todo - Add to gymlibrary.ml?
logger.warn( logger.warn(
"The environment was initialized successfully with an unsupported render mode." "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). "
) "See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information."
)
elif isinstance(space, spaces.Tuple):
for subspace in space.spaces:
check_space_limit(subspace, space_type)
elif isinstance(space, spaces.Dict):
for subspace in space.values():
check_space_limit(subspace, space_type)
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True): def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
"""Check that an environment follows Gym API. """Check that an environment follows Gym API.
This is an invasive function that calls the environment's reset and step. This is an invasive function that calls the environment's reset and step.
@@ -220,21 +254,29 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI) skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
""" """
if warn is not None: if warn is not None:
logger.warn("`check_env` warn parameter is now ignored.") logger.warn("`check_env(warn=...)` parameter is now ignored.")
assert isinstance( assert isinstance(
env, gym.Env env, gym.Env
), "Your environment must inherit from the gym.Env class https://www.gymlibrary.ml/content/environment_creation/" ), "The environment must inherit from the gym.Env class. See https://www.gymlibrary.ml/content/environment_creation/ for more info."
if env.unwrapped is not env:
logger.warn(
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
)
# ============= Check the spaces (observation and action) ================ # ============= Check the spaces (observation and action) ================
assert hasattr( assert hasattr(
env, "action_space" env, "action_space"
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/" ), "The environment must specify an action space. See https://www.gymlibrary.ml/content/environment_creation/ for more info."
check_action_space(env.action_space) check_action_space(env.action_space)
check_space_limit(env.action_space, "action")
assert hasattr( assert hasattr(
env, "observation_space" env, "observation_space"
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/" ), "The environment must specify an observation space. See https://www.gymlibrary.ml/content/environment_creation/ for more info."
check_observation_space(env.observation_space) check_observation_space(env.observation_space)
check_space_limit(env.observation_space, "observation")
# ==== Check the reset method ==== # ==== Check the reset method ====
check_reset_seed(env) check_reset_seed(env)
@@ -242,9 +284,12 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
check_reset_info(env) check_reset_info(env)
# ============ Check the returned values =============== # ============ Check the returned values ===============
passive_env_reset_check(env) env_reset_passive_checker(env)
passive_env_step_check(env, env.action_space.sample()) env_step_passive_checker(env, env.action_space.sample())
# ==== Check the render method and the declared render modes ==== # ==== Check the render method and the declared render modes ====
if not skip_render_check: if not skip_render_check:
check_render(env) if env.render_mode is not None:
env_render_passive_checker(env)
# todo: recreate the environment with a different render_mode for check that each work

View File

@@ -1,10 +1,11 @@
"""A set of functions for passively checking environment implementations.""" """A set of functions for passively checking environment implementations."""
import inspect import inspect
from functools import partial
from typing import Callable
import numpy as np import numpy as np
from gym import error, logger, spaces from gym import Space, error, logger, spaces
from gym.logger import deprecation
def _check_box_observation_space(observation_space: spaces.Box): def _check_box_observation_space(observation_space: spaces.Box):
@@ -17,47 +18,33 @@ def _check_box_observation_space(observation_space: spaces.Box):
if len(observation_space.shape) == 3: if len(observation_space.shape) == 3:
if observation_space.dtype != np.uint8: if observation_space.dtype != np.uint8:
logger.warn( logger.warn(
f"It seems that your observation space ({observation_space}) is an image but the `dtype` of your observation_space is not `np.uint8`. " f"It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: {observation_space.dtype}. "
"If your observation is not an image, we recommend you to flatten the observation to have only a 1D vector" "If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector."
) )
if np.any(observation_space.low != 0) or np.any( if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
observation_space.high != 255
): # todo np.all?
logger.warn( logger.warn(
"It seems that your observation space is an image but the upper and lower bounds are not in [0, 255]. " "It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. "
"Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not." "Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not."
) )
if len(observation_space.shape) not in [1, 3]: if len(observation_space.shape) not in [1, 3]:
logger.warn( logger.warn(
"Your observation space has an unconventional shape (neither an image, nor a 1D vector). " "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). "
"We recommend you to flatten the observation to have only a 1D vector or use a custom policy to properly process the data. " "We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. "
f"Observation space={observation_space}" f"Actual observation shape: {observation_space.shape}"
) )
if np.any(np.equal(observation_space.low, -np.inf)): assert (
logger.warn( observation_space.low.shape == observation_space.shape
"Agent's minimum observation space value is -infinity. This is probably too low." ), f"The Box observation space shape and low shape have different shapes, low shape: {observation_space.low.shape}, box shape: {observation_space.shape}"
) assert (
if np.any(np.equal(observation_space.high, np.inf)): observation_space.high.shape == observation_space.shape
logger.warn( ), f"The Box observation space shape and high shape have have different shapes, high shape: {observation_space.high.shape}, box shape: {observation_space.shape}"
"Agent's maximum observation space value is infinity. This is probably too high."
)
if np.any(np.equal(observation_space.low, observation_space.high)): if np.any(observation_space.low == observation_space.high):
logger.warn("Agent's maximum and minimum observation space values are equal") logger.warn("A Box observation space maximum and minimum values are equal.")
if np.any(np.greater(observation_space.low, observation_space.high)): elif np.any(observation_space.high < observation_space.low):
raise AssertionError( logger.warn("A Box observation space low value is greater than a high value.")
"Agent's minimum observation value is greater than it's maximum"
)
if observation_space.low.shape != observation_space.shape:
raise AssertionError(
"Agent's observation_space.low and observation_space have different shapes"
)
if observation_space.high.shape != observation_space.shape:
raise AssertionError(
"Agent's observation_space.high and observation_space have different shapes"
)
def _check_box_action_space(action_space: spaces.Box): def _check_box_action_space(action_space: spaces.Box):
@@ -66,42 +53,73 @@ def _check_box_action_space(action_space: spaces.Box):
Args: Args:
action_space: A box action space action_space: A box action space
""" """
if np.any(np.equal(action_space.low, -np.inf)): assert (
logger.warn( action_space.low.shape == action_space.shape
"Agent's minimum action space value is -infinity. This is probably too low." ), f"The Box action space shape and low shape have have different shapes, low shape: {action_space.low.shape}, box shape: {action_space.shape}"
) assert (
if np.any(np.equal(action_space.high, np.inf)): action_space.high.shape == action_space.shape
logger.warn( ), f"The Box action space shape and high shape have different shapes, high shape: {action_space.high.shape}, box shape: {action_space.shape}"
"Agent's maximum action space value is infinity. This is probably too high"
)
if np.any(np.equal(action_space.low, action_space.high)):
logger.warn("Agent's maximum and minimum action space values are equal")
if np.any(np.greater(action_space.low, action_space.high)):
raise AssertionError(
"Agent's minimum action value is greater than it's maximum"
)
if action_space.low.shape != action_space.shape:
raise AssertionError(
"Agent's action_space.low and action_space have different shapes"
)
if action_space.high.shape != action_space.shape:
raise AssertionError(
"Agent's action_space.high and action_space have different shapes"
)
# Check that the Box space is normalized if np.any(action_space.low == action_space.high):
if ( logger.warn("A Box action space maximum and minimum values are equal.")
np.any(np.abs(action_space.low) != np.abs(action_space.high)) elif np.any(action_space.high < action_space.low):
or np.any(np.abs(action_space.low) > 1) logger.warn("A Box action space low value is greater than a high value.")
or np.any(np.abs(action_space.high) > 1)
):
logger.warn(
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" # TODO Add to gymlibrary.ml?
)
def _check_obs(obs, observation_space: spaces.Space, method_name: str): def check_space(
space: Space, space_type: str, check_box_space_fn: Callable[[spaces.Box], None]
):
"""A passive check of the environment action space that should not affect the environment."""
if not isinstance(space, spaces.Space):
raise AssertionError(
f"{space_type} space does not inherit from `gym.spaces.Space`, actual type: {type(space)}"
)
elif isinstance(space, spaces.Box):
check_box_space_fn(space)
elif isinstance(space, spaces.Discrete):
assert (
0 < space.n
), f"Discrete {space_type} space's number of elements must be positive, actual number of elements: {space.n}"
assert (
space.shape == ()
), f"Discrete {space_type} space's shape should be empty, actual shape: {space.shape}"
elif isinstance(space, spaces.MultiDiscrete):
assert (
space.shape == space.nvec.shape
), f"Multi-discrete {space_type} space's shape must be equal to the nvec shape, space shape: {space.shape}, nvec shape: {space.nvec.shape}"
assert np.all(
0 < space.nvec
), f"Multi-discrete {space_type} space's all nvec elements must be greater than 0, actual nvec: {space.nvec}"
elif isinstance(space, spaces.MultiBinary):
assert np.all(
0 < np.asarray(space.shape)
), f"Multi-binary {space_type} space's all shape elements must be greater than 0, actual shape: {space.shape}"
elif isinstance(space, spaces.Tuple):
assert 0 < len(
space.spaces
), f"An empty Tuple {space_type} space is not allowed."
for subspace in space.spaces:
check_space(subspace, space_type, check_box_space_fn)
elif isinstance(space, spaces.Dict):
assert 0 < len(
space.spaces.keys()
), f"An empty Dict {space_type} space is not allowed."
for subspace in space.values():
check_space(subspace, space_type, check_box_space_fn)
check_observation_space = partial(
check_space,
space_type="observation",
check_box_space_fn=_check_box_observation_space,
)
check_action_space = partial(
check_space, space_type="action", check_box_space_fn=_check_box_action_space
)
def check_obs(obs, observation_space: spaces.Space, method_name: str):
"""Check that the observation returned by the environment correspond to the declared one. """Check that the observation returned by the environment correspond to the declared one.
Args: Args:
@@ -109,213 +127,202 @@ def _check_obs(obs, observation_space: spaces.Space, method_name: str):
observation_space: The observation space of the observation observation_space: The observation space of the observation
method_name: The method name that generated the observation method_name: The method name that generated the observation
""" """
pre = f"The observation returned by the `{method_name}()` method" pre = f"The obs returned by the `{method_name}()` method"
assert observation_space.contains(
obs
), f"{pre} is not contained with the observation space ({observation_space})"
if isinstance(observation_space, spaces.Discrete): if isinstance(observation_space, spaces.Discrete):
assert isinstance( if not isinstance(obs, (np.int64, int)):
obs, int logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
), f"The observation returned by `{method_name}()` method must be an int, actually {type(obs)}"
elif isinstance(
observation_space, (spaces.Box, spaces.MultiBinary, spaces.MultiDiscrete)
):
assert isinstance(
obs, np.ndarray
), f"The observation returned by `{method_name}()` method must be a numpy array, actually {type(obs)}"
elif isinstance(observation_space, spaces.Tuple):
assert isinstance(
obs, tuple
), f"The observation returned by the `{method_name}()` method must be a tuple, actually {type(obs)}"
for sub_obs, sub_space in zip(obs, observation_space.spaces):
_check_obs(sub_obs, sub_space, method_name)
elif isinstance(observation_space, spaces.Dict):
assert isinstance(
obs, dict
), f"The observation returned by the `{method_name}()` method must be a dict, actually {type(obs)}"
for space_key in observation_space.keys():
_check_obs(obs[space_key], observation_space[space_key], method_name)
def check_observation_space(observation_space):
"""A passive check of the environment observation space that should not affect the environment."""
if not isinstance(observation_space, spaces.Space):
raise AssertionError(
f"Observation space ({observation_space}) does not inherit from gym.spaces.Space"
)
elif isinstance(observation_space, spaces.Box): elif isinstance(observation_space, spaces.Box):
# Check if the box is an image (shape is 3 elements and the last element is 1 or 3) if observation_space.shape != ():
_check_box_observation_space(observation_space) if not isinstance(obs, np.ndarray):
elif isinstance(observation_space, spaces.Discrete): logger.warn(
assert ( f"{pre} was expecting a numpy array, actual type: {type(obs)}"
observation_space.n > 0 )
), f"There are no available discrete observations, n={observation_space.n}" elif obs.dtype != observation_space.dtype:
elif isinstance(observation_space, spaces.MultiDiscrete): logger.warn(
assert np.all( f"{pre} was expecting numpy array dtype to be {observation_space.dtype}, actual type: {obs.dtype}"
observation_space.nvec > 0 )
), f"All dimensions of multi-discrete must be greater than 0, {observation_space.nvec}" elif isinstance(observation_space, (spaces.MultiBinary, spaces.MultiDiscrete)):
elif isinstance(observation_space, spaces.MultiBinary): if not isinstance(obs, np.ndarray):
assert np.all( logger.warn(f"{pre} was expecting a numpy array, actual type: {type(obs)}")
np.asarray(observation_space.shape) > 0
), f"All dimensions of multi-binary must be greater than 0, {observation_space.shape}"
elif isinstance(observation_space, spaces.Tuple): elif isinstance(observation_space, spaces.Tuple):
assert ( if not isinstance(obs, tuple):
len(observation_space.spaces) > 0 logger.warn(f"{pre} was expecting a tuple, actual type: {type(obs)}")
), "An empty Tuple observation space is not allowed." assert len(obs) == len(
for subspace in observation_space.spaces: observation_space.spaces
check_observation_space(subspace) ), f"{pre} length is not same as the observation space length, obs length: {len(obs)}, space length: {len(observation_space.spaces)}"
for sub_obs, sub_space in zip(obs, observation_space.spaces):
check_obs(sub_obs, sub_space, method_name)
elif isinstance(observation_space, spaces.Dict): elif isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), f"{pre} must be a dict, actual type: {type(obs)}"
assert ( assert (
len(observation_space.spaces.keys()) > 0 obs.keys() == observation_space.spaces.keys()
), "An empty Dict observation space is not allowed." ), f"{pre} observation keys is not same as the observation space keys, obs keys: {list(obs.keys())}, space keys: {list(observation_space.spaces.keys())}"
for subspace in observation_space.values(): for space_key in observation_space.spaces.keys():
check_observation_space(subspace) check_obs(obs[space_key], observation_space[space_key], method_name)
try:
if obs not in observation_space:
logger.warn(f"{pre} is not within the observation space.")
except Exception as e:
logger.warn(f"{pre} is not within the observation space with exception: {e}")
def check_action_space(action_space): def env_reset_passive_checker(env, **kwargs):
"""A passive check of the environment action space that should not affect the environment."""
if not isinstance(action_space, spaces.Space):
raise AssertionError(
f"Action space ({action_space}) does not inherit from gym.spaces.Space"
)
elif isinstance(action_space, spaces.Box):
_check_box_action_space(action_space)
elif isinstance(action_space, spaces.Discrete):
assert (
action_space.n > 0
), f"There are no available discrete actions, n={action_space.n}"
elif isinstance(action_space, spaces.MultiDiscrete):
assert np.all(
action_space.nvec > 0
), f"All dimensions of multi-discrete must be greater than 0, {action_space.nvec}"
elif isinstance(action_space, spaces.MultiBinary):
assert np.all(
np.asarray(action_space.shape) > 0
), f"All dimensions of multi-binary must be greater than 0, {action_space.shape}"
elif isinstance(action_space, spaces.Tuple):
assert (
len(action_space.spaces) > 0
), "An empty Tuple action space is not allowed."
for subspace in action_space.spaces:
check_action_space(subspace)
elif isinstance(action_space, spaces.Dict):
assert (
len(action_space.spaces.keys()) > 0
), "An empty Dict action space is not allowed."
for subspace in action_space.values():
check_action_space(subspace)
def passive_env_reset_check(env, **kwargs):
"""A passive check of the `Env.reset` function investigating the returning reset information and returning the data unchanged.""" """A passive check of the `Env.reset` function investigating the returning reset information and returning the data unchanged."""
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
if "seed" not in signature.parameters or "kwargs" in signature.parameters: if "seed" not in signature.parameters and "kwargs" not in signature.parameters:
logger.warn( logger.warn(
"Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator. " "Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator."
) )
else: else:
seed_param = signature.parameters.get("seed") seed_param = signature.parameters.get("seed")
# Check the default value is None # Check the default value is None
if seed_param is not None and seed_param.default is not None: if seed_param is not None and seed_param.default is not None:
logger.warn( logger.warn(
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic" "The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. "
f"Actual default: {seed_param}"
) )
if "return_info" not in signature.parameters or "kwargs" in signature.parameters: if "return_info" not in signature.parameters and not (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
logger.warn( logger.warn(
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting." "Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
) )
if "options" not in signature.parameters or "kwargs" in signature.parameters: if "options" not in signature.parameters and "kwargs" not in signature.parameters:
logger.warn( logger.warn(
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information." "Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
) )
# Checks the result of env.reset with kwargs # Checks the result of env.reset with kwargs
result = env.reset(**kwargs) result = env.reset(**kwargs)
if "return_info" in kwargs and kwargs["return_info"] is True: if kwargs.get("return_info", False) is True:
assert isinstance(
result, tuple
), f"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: {type(result)}"
assert (
len(result) == 2
), f"The length of the result returned by `env.reset(return_info=True)` is not 2, actual length: {len(result)}"
obs, info = result obs, info = result
_check_obs(obs, env.observation_space, "reset")
assert isinstance( assert isinstance(
info, dict info, dict
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actually {type(info)}" ), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
else: else:
obs = result obs = result
_check_obs(obs, env.observation_space, "reset")
check_obs(obs, env.observation_space, "reset")
return result return result
def passive_env_step_check(env, action): def env_step_passive_checker(env, action):
"""A passive check for the environment step, investigating the returning data then returning the data unchanged.""" """A passive check for the environment step, investigating the returning data then returning the data unchanged."""
# We don't check the action as for some environments then out-of-bounds values can be given
result = env.step(action) result = env.step(action)
assert isinstance(
result, tuple
), f"Expects step result to be a tuple, actual type: {type(result)}"
if len(result) == 4: if len(result) == 4:
deprecation( logger.deprecation(
"Core environment is written in old step API which returns one bool instead of two. " "Core environment is written in old step API which returns one bool instead of two. "
"It is recommended to rewrite the environment with new step API. " "It is recommended to rewrite the environment with new step API. "
) )
obs, reward, done, info = result obs, reward, done, info = result
assert isinstance( if not isinstance(done, (bool, np.bool8)):
done, bool logger.warn(
), f"The `done` signal is of type `{type(done)}` must be a boolean" f"Expects `done` signal to be a boolean, actual type: {type(done)}"
)
elif len(result) == 5: elif len(result) == 5:
obs, reward, terminated, truncated, info = result obs, reward, terminated, truncated, info = result
assert isinstance( # np.bool is actual python bool not np boolean type, therefore bool_ or bool8
terminated, bool if not isinstance(terminated, (bool, np.bool8)):
), f"The `terminated` signal is of type `{type(terminated)}`. It must be a boolean" logger.warn(
assert isinstance( f"Expects `terminated` signal to be a boolean, actual type: {type(terminated)}"
truncated, bool )
), f"The `truncated` signal of type `{type(truncated)}`. It must be a boolean." if not isinstance(truncated, (bool, np.bool8)):
assert ( logger.warn(
terminated is False or truncated is False f"Expects `truncated` signal to be a boolean, actual type: {type(truncated)}"
), "Only `terminated` or `truncated` can be true, not both." )
else: else:
raise error.Error( raise error.Error(
f"Expected `Env.step` to return a four or five elements, actually {len(result)} elements returned." f"Expected `Env.step` to return a four or five element tuple, actual number of elements returned: {len(result)}."
) )
_check_obs(obs, env.observation_space, "step") check_obs(obs, env.observation_space, "step")
if np.any(np.isnan(obs)):
logger.warn("Encountered NaN value in observations.")
if np.any(np.isinf(obs)):
logger.warn("Encountered inf value in observations.")
assert isinstance( if not (
reward, (float, int, np.floating, np.integer) np.issubdtype(type(reward), np.integer)
), "The reward returned by `step()` must be a float" or np.issubdtype(type(reward), np.floating)
if np.any(np.isnan(reward)): ):
logger.warn("Encountered NaN value in rewards.") logger.warn(
if np.any(np.isinf(reward)): f"The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: {type(reward)}"
logger.warn("Encountered inf value in rewards.") )
else:
if np.isnan(reward):
logger.warn("The reward is a NaN value.")
if np.isinf(reward):
logger.warn("The reward is an inf value.")
assert isinstance( assert isinstance(
info, dict info, dict
), "The `info` returned by `step()` must be a python dictionary" ), f"The `info` returned by `step()` must be a python dictionary, actual type: {type(info)}"
return result return result
def passive_env_render_check(env, *args, **kwargs): def env_render_passive_checker(env, *args, **kwargs):
"""A passive check of the `Env.render` that the declared render modes/fps in the metadata of the environment is decleared.""" """A passive check of the `Env.render` that the declared render modes/fps in the metadata of the environment is declared."""
render_modes = env.metadata.get("render_modes") render_modes = env.metadata.get("render_modes")
if render_modes is None: if render_modes is None:
logger.warn( logger.warn(
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), " "No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`."
"you may have trouble when calling `.render()`"
) )
else:
if not isinstance(render_modes, (list, tuple)):
logger.warn(
f"Expects the render_modes to be a sequence (i.e. list, tuple), actual type: {type(render_modes)}"
)
elif not all(isinstance(mode, str) for mode in render_modes):
logger.warn(
f"Expects all render modes to be strings, actual types: {[type(mode) for mode in render_modes]}"
)
render_fps = env.metadata.get("render_fps") render_fps = env.metadata.get("render_fps")
# We only require `render_fps` if rendering is actually implemented # We only require `render_fps` if rendering is actually implemented
if render_fps is None: if len(render_modes) > 0:
logger.warn( if render_fps is None:
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), " logger.warn(
"rendering may occur at inconsistent fps" "No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps."
) )
else:
if not (
np.issubdtype(type(render_fps), np.integer)
or np.issubdtype(type(render_fps), np.floating)
):
logger.warn(
f"Expects the `env.metadata['render_fps']` to be an integer or a float, actual type: {type(render_fps)}"
)
else:
assert (
render_fps > 0
), f"Expects the `env.metadata['render_fps']` to be greater than zero, actual value: {render_fps}"
return env.render(*args, **kwargs) # env.render is now an attribute with default None
if len(render_modes) == 0:
assert (
env.render_mode is None
), f"With no render_modes, expects the Env.render_mode to be None, actual value: {env.render_mode}"
else:
assert env.render_mode is None or env.render_mode in render_modes, (
"The environment was initialized successfully however with an unsupported render mode. "
f"Render mode: {env.render_mode}, modes: {render_modes}"
)
result = env.render(*args, **kwargs)
# TODO: Check that the result is correct
return result

View File

@@ -14,7 +14,7 @@ def make(
num_envs: int = 1, num_envs: int = 1,
asynchronous: bool = True, asynchronous: bool = True,
wrappers: Optional[Union[callable, List[callable]]] = None, wrappers: Optional[Union[callable, List[callable]]] = None,
disable_env_checker: bool = False, disable_env_checker: Optional[bool] = None,
new_step_api: bool = False, new_step_api: bool = False,
**kwargs, **kwargs,
) -> VectorEnv: ) -> VectorEnv:
@@ -43,8 +43,10 @@ def make(
The vectorized environment. The vectorized environment.
""" """
def create_env(_disable_env_checker): def create_env(env_num: int):
"""Creates an environment that can enable or disable the environment checker.""" """Creates an environment that can enable or disable the environment checker."""
# If the env_num > 0 then disable the environment checker otherwise use the parameter
_disable_env_checker = True if env_num > 0 else disable_env_checker
def _make_env(): def _make_env():
env = gym.envs.registration.make( env = gym.envs.registration.make(

View File

@@ -1 +1 @@
VERSION = "0.24.1" VERSION = "0.25.0"

View File

@@ -4,9 +4,9 @@ from gym.core import ActType
from gym.utils.passive_env_checker import ( from gym.utils.passive_env_checker import (
check_action_space, check_action_space,
check_observation_space, check_observation_space,
passive_env_render_check, env_render_passive_checker,
passive_env_reset_check, env_reset_passive_checker,
passive_env_step_check, env_step_passive_checker,
) )
@@ -19,11 +19,11 @@ class PassiveEnvChecker(gym.Wrapper):
assert hasattr( assert hasattr(
env, "action_space" env, "action_space"
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/" ), "The environment must specify an action space. https://www.gymlibrary.ml/content/environment_creation/"
check_action_space(env.action_space) check_action_space(env.action_space)
assert hasattr( assert hasattr(
env, "observation_space" env, "observation_space"
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/" ), "The environment must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
check_observation_space(env.observation_space) check_observation_space(env.observation_space)
self.checked_reset = False self.checked_reset = False
@@ -34,7 +34,7 @@ class PassiveEnvChecker(gym.Wrapper):
"""Steps through the environment that on the first call will run the `passive_env_step_check`.""" """Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self.checked_step is False: if self.checked_step is False:
self.checked_step = True self.checked_step = True
return passive_env_step_check(self.env, action) return env_step_passive_checker(self.env, action)
else: else:
return self.env.step(action) return self.env.step(action)
@@ -42,7 +42,7 @@ class PassiveEnvChecker(gym.Wrapper):
"""Resets the environment that on the first call will run the `passive_env_reset_check`.""" """Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self.checked_reset is False: if self.checked_reset is False:
self.checked_reset = True self.checked_reset = True
return passive_env_reset_check(self.env, **kwargs) return env_reset_passive_checker(self.env, **kwargs)
else: else:
return self.env.reset(**kwargs) return self.env.reset(**kwargs)
@@ -50,6 +50,6 @@ class PassiveEnvChecker(gym.Wrapper):
"""Renders the environment that on the first call will run the `passive_env_render_check`.""" """Renders the environment that on the first call will run the `passive_env_render_check`."""
if self.checked_render is False: if self.checked_render is False:
self.checked_render = True self.checked_render = True
return passive_env_render_check(self.env, *args, **kwargs) return env_render_passive_checker(self.env, *args, **kwargs)
else: else:
return self.env.render(*args, **kwargs) return self.env.render(*args, **kwargs)

View File

@@ -37,8 +37,7 @@ class StepAPICompatibility(gym.Wrapper):
self.new_step_api = new_step_api self.new_step_api = new_step_api
if not self.new_step_api: if not self.new_step_api:
deprecation( deprecation(
"Initializing environment in old step API which returns one bool instead of two. " "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
) )
def step(self, action): def step(self, action):

View File

@@ -168,17 +168,6 @@ def test_customizable_resets(env_name: str, low_high: Optional[list]):
env.step(env.action_space.sample()) env.step(env.action_space.sample())
@pytest.mark.parametrize(
"env_name", ["CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"]
)
@pytest.mark.parametrize("low_high", [(-10.0, -9.0), (np.array(-10.0), np.array(-9.0))])
def test_customizable_out_of_bounds_resets(env_name: str, low_high: Optional[list]):
env = gym.make(env_name)
low, high = low_high
with pytest.raises(AssertionError):
env.reset(options={"low": low, "high": high})
# We test Pendulum separately, as the parameters are handled differently. # We test Pendulum separately, as the parameters are handled differently.
@pytest.mark.parametrize( @pytest.mark.parametrize(
"low_high", "low_high",
@@ -221,4 +210,6 @@ def test_invalid_customizable_resets(env_name: str, low_high: list):
env = gym.make(env_name) env = gym.make(env_name)
low, high = low_high low, high = low_high
with pytest.raises(ValueError): with pytest.raises(ValueError):
# match=re.escape(f"Lower bound ({low}) must be lower than higher bound ({high}).")
# match=f"An option ({x}) could not be converted to a float."
env.reset(options={"low": low, "high": high}) env.reset(options={"low": low, "high": high})

View File

@@ -1,23 +1,51 @@
import pytest import pytest
import gym
from gym.envs.registration import EnvSpec from gym.envs.registration import EnvSpec
from gym.utils.env_checker import check_env from gym.utils.env_checker import check_env
from tests.envs.utils import all_testing_env_specs, assert_equals, gym_testing_env_specs from tests.envs.utils import all_testing_env_specs, assert_equals, gym_testing_env_specs
# This runs a smoketest on each official registered env. We may want # This runs a smoketest on each official registered env. We may want
# to try also running environments which are not officially registered # to try also running environments which are not officially registered envs.
# envs. PASSIVE_CHECK_IGNORE_WARNING = [
f"\x1b[33mWARN: {message}\x1b[0m"
for message in [
"This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).",
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
]
]
CHECK_ENV_IGNORE_WARNINGS = [
f"\x1b[33mWARN: {message}\x1b[0m"
for message in [
"This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).",
"A Box observation space minimum value is -infinity. This is probably too low.",
"A Box observation space maximum value is -infinity. This is probably too high.",
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
]
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_spec", gym_testing_env_specs, ids=[spec.id for spec in gym_testing_env_specs] "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
) )
def test_run_env_checker(env_spec: EnvSpec): def test_envs_pass_env_checker(spec):
"""Runs the gym environment checker on the environment spec that calls the `reset`, `step` and `render`.""" """Check that all environments pass the environment checker with no warnings other than the expected."""
env = env_spec.make(disable_env_checker=True) with pytest.warns(None) as warnings:
check_env(env, skip_render_check=False) env = spec.make(disable_env_checker=True).unwrapped
check_env(env)
env.close() env.close()
for warning in warnings.list:
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
print()
print(warning.message.args[0])
print(CHECK_ENV_IGNORE_WARNINGS[-1])
raise gym.error.Error(f"Unexpected warning: {warning.message}")
# Note that this precludes running this test in multiple threads. # Note that this precludes running this test in multiple threads.
@@ -90,3 +118,5 @@ def test_render_modes(spec):
new_env.reset() new_env.reset()
new_env.step(new_env.action_space.sample()) new_env.step(new_env.action_space.sample())
new_env.render() new_env.render()
new_env.close()
env.close()

View File

@@ -1,6 +1,7 @@
"""Tests that gym.make works as expected.""" """Tests that gym.make works as expected."""
import re import re
from copy import deepcopy
import numpy as np import numpy as np
import pytest import pytest
@@ -9,6 +10,7 @@ import gym
from gym.envs.classic_control import cartpole from gym.envs.classic_control import cartpole
from gym.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit from gym.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit
from gym.wrappers.env_checker import PassiveEnvChecker from gym.wrappers.env_checker import PassiveEnvChecker
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.utils import all_testing_env_specs from tests.envs.utils import all_testing_env_specs
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
from tests.wrappers.utils import has_wrapper from tests.wrappers.utils import has_wrapper
@@ -100,18 +102,49 @@ def test_gym_make_autoreset():
def test_make_disable_env_checker(): def test_make_disable_env_checker():
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.""" """Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`."""
env = gym.make("CartPole-v1") spec = deepcopy(gym.spec("CartPole-v1"))
# Test with spec disable env checker
spec.disable_env_checker = False
env = gym.make(spec)
assert has_wrapper(env, PassiveEnvChecker) assert has_wrapper(env, PassiveEnvChecker)
env.close() env.close()
env = gym.make("CartPole-v1", disable_env_checker=False) # Test with overwritten spec using make disable env checker
assert has_wrapper(env, PassiveEnvChecker) assert spec.disable_env_checker is False
env.close() env = gym.make(spec, disable_env_checker=True)
env = gym.make("CartPole-v1", disable_env_checker=True)
assert has_wrapper(env, PassiveEnvChecker) is False assert has_wrapper(env, PassiveEnvChecker) is False
env.close() env.close()
# Test with spec enabled disable env checker
spec.disable_env_checker = True
env = gym.make(spec)
assert has_wrapper(env, PassiveEnvChecker) is False
env.close()
# Test with overwritten spec using make disable env checker
assert spec.disable_env_checker is True
env = gym.make(spec, disable_env_checker=False)
assert has_wrapper(env, PassiveEnvChecker)
env.close()
@pytest.mark.parametrize(
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
)
def test_passive_checker_wrapper_warnings(spec):
with pytest.warns(None) as warnings:
env = gym.make(spec) # disable_env_checker=False
env.reset()
env.step(env.action_space.sample())
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
env.close()
for warning in warnings.list:
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
def test_make_order_enforcing(): def test_make_order_enforcing():
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper.""" """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""

View File

@@ -9,12 +9,17 @@ from gym.envs.registration import EnvSpec
def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]: def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
"""Tries to make the environment showing if it is possible. Warning the environments have no wrappers, including time limit and order enforcing.""" """Tries to make the environment showing if it is possible.
try:
return env_spec.make(disable_env_checker=True).unwrapped Warning the environments have no wrappers, including time limit and order enforcing.
except ImportError as e: """
logger.warn(f"Not testing {env_spec.id} due to error: {e}") # To avoid issues with registered environments during testing, we check that the spec entry points are from gym.envs.
return None if "gym.envs." in env_spec.entry_point:
try:
return env_spec.make(disable_env_checker=True).unwrapped
except ImportError as e:
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
return None
# Tries to make all environment to test with # Tries to make all environment to test with

81
tests/testing_env.py Normal file
View File

@@ -0,0 +1,81 @@
"""Provides a generic testing environment for use in tests with custom reset, step and render functions."""
import types
from typing import List, Optional, Tuple, Union
import gym
from gym import spaces
from gym.core import ActType, ObsType
from gym.envs.registration import EnvSpec
def basic_reset_fn(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
"""A basic reset function that will pass the environment check using random actions from the observation space."""
super(GenericTestEnv, self).reset(seed=seed)
self.observation_space.seed(seed)
if return_info:
return self.observation_space.sample(), {"options": options}
else:
return self.observation_space.sample()
def basic_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""A basic step function that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, False, {}
def basic_render_fn(self):
pass
# todo: change all testing environment to this generic class
class GenericTestEnv(gym.Env):
"""A generic testing environment for use in testing with modified environments are required."""
def __init__(
self,
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
reset_fn: callable = basic_reset_fn,
step_fn: callable = basic_step_fn,
render_fn: callable = basic_render_fn,
render_modes: Optional[List[str]] = None,
render_fps: Optional[int] = None,
render_mode: Optional[str] = None,
spec: EnvSpec = EnvSpec("TestingEnv-v0"),
):
self.metadata = {"render_modes": render_modes}
if render_fps:
self.metadata["render_fps"] = render_fps
self.render_mode = render_mode
self.spec = spec
if observation_space is not None:
self.observation_space = observation_space
if action_space is not None:
self.action_space = action_space
if reset_fn is not None:
self.reset = types.MethodType(reset_fn, self)
if step_fn is not None:
self.step = types.MethodType(step_fn, self)
if render_fn is not None:
self.render = types.MethodType(render_fn, self)
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
# If you need a default working reset function, use `basic_reset_fn` above
raise NotImplementedError("TestingEnv reset_fn is not set")
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
raise NotImplementedError("TestingEnv step_fn is not set")

View File

@@ -1,41 +1,229 @@
from typing import Optional """Tests that the `env_checker` runs as expects and all errors are possible."""
import re
import numpy as np import numpy as np
import pytest import pytest
import gym import gym
from gym.spaces import Box, Dict, Discrete from gym import spaces
from gym.utils.env_checker import check_env from gym.utils.env_checker import (
check_env,
check_reset_info,
check_reset_options,
check_reset_seed,
)
from tests.testing_env import GenericTestEnv
class ActionDictTestEnv(gym.Env): @pytest.mark.parametrize(
action_space = Dict({"position": Discrete(1), "velocity": Discrete(1)}) "env",
observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32) [
gym.make("CartPole-v1", disable_env_checker=True).unwrapped,
gym.make("MountainCar-v0", disable_env_checker=True).unwrapped,
GenericTestEnv(
observation_space=spaces.Dict(
a=spaces.Discrete(10), b=spaces.Box(np.zeros(2), np.ones(2))
)
),
GenericTestEnv(
observation_space=spaces.Tuple(
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
)
),
GenericTestEnv(
observation_space=spaces.Dict(
a=spaces.Tuple(
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
),
b=spaces.Box(np.zeros(2), np.ones(2)),
)
),
],
)
def test_no_error_warnings(env):
"""A full version of this test with all gym envs is run in tests/envs/test_envs.py."""
with pytest.warns(None) as warnings:
check_env(env)
def __init__(self, render_mode: Optional[str] = None): assert len(warnings) == 0, [warning.message for warning in warnings]
self.render_mode = render_mode
def step(self, action):
observation = np.array([1.0, 1.5, 0.5])
reward = 1
done = True
return observation, reward, done
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
return np.array([1.0, 1.5, 0.5])
def render(self, mode: Optional[str] = "human"):
pass
def test_check_env_dict_action(): def _no_super_reset(self, seed=None, return_info=False, options=None):
# Environment.step() only returns 3 values: obs, reward, done. Not info! self.np_random.random() # generates a new prng
test_env = ActionDictTestEnv() # generate seed deterministic result
self.observation_space.seed(0)
return self.observation_space.sample()
with pytest.raises(AssertionError) as errorinfo:
check_env(env=test_env, warn=True) def _super_reset_fixed(self, seed=None, return_info=False, options=None):
assert ( # Call super that ignores the seed passed, use fixed seed
str(errorinfo.value) super(GenericTestEnv, self).reset(seed=1)
== "The `step()` method must return four values: obs, reward, done, info" # deterministic output
) self.observation_space._np_random = self.np_random
return self.observation_space.sample()
def _reset_default_seed(
self: GenericTestEnv, seed="Error", return_info=False, options=None
):
super(GenericTestEnv, self).reset(seed=seed)
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
self.np_random
)
return self.observation_space.sample()
@pytest.mark.parametrize(
"test,func,message",
[
[
gym.error.Error,
lambda self: self.observation_space.sample(),
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
],
[
AssertionError,
lambda self, seed, *_: self.observation_space.sample(),
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
],
[
AssertionError,
_no_super_reset,
"Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`.",
],
[
AssertionError,
_super_reset_fixed,
"Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`.",
],
[
UserWarning,
_reset_default_seed,
"The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. Actual default: Error",
],
],
)
def test_check_reset_seed(test, func: callable, message: str):
"""Tests the check reset seed function works as expected."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_reset_seed(GenericTestEnv(reset_fn=func))
else:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_seed(GenericTestEnv(reset_fn=func))
def _reset_var_keyword_kwargs(self, kwargs):
return self.observation_space.sample()
def _reset_return_info_type(self, seed=None, return_info=False, options=None):
if return_info:
return [1, 2]
else:
return self.observation_space.sample()
def _reset_return_info_length(self, seed=None, return_info=False, options=None):
if return_info:
return 1, 2, 3
else:
return self.observation_space.sample()
def _return_info_obs_outside(self, seed=None, return_info=False, options=None):
if return_info:
return self.observation_space.sample() + self.observation_space.high, {}
else:
return self.observation_space.sample()
def _return_info_not_dict(self, seed=None, return_info=False, options=None):
if return_info:
return self.observation_space.sample(), ["key", "value"]
else:
return self.observation_space.sample()
@pytest.mark.parametrize(
"test,func,message",
[
[
gym.error.Error,
lambda self, *_: self.observation_space.sample(),
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
],
[
gym.error.Error,
_reset_var_keyword_kwargs,
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
],
[
AssertionError,
_reset_return_info_type,
"Calling the reset method with `return_info=True` did not return a tuple, actual type: <class 'list'>",
],
[
AssertionError,
_reset_return_info_length,
"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: 3",
],
[
AssertionError,
_return_info_obs_outside,
"The first element returned by `env.reset(return_info=True)` is not within the observation space.",
],
[
AssertionError,
_return_info_not_dict,
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'list'>",
],
],
)
def test_check_reset_info(test, func: callable, message: str):
"""Tests the check reset info function works as expected."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_reset_info(GenericTestEnv(reset_fn=func))
else:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_info(GenericTestEnv(reset_fn=func))
def test_check_reset_options():
"""Tests the check_reset_options function."""
with pytest.raises(
gym.error.Error,
match=re.escape(
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
),
):
check_reset_options(GenericTestEnv(reset_fn=lambda self: 0))
@pytest.mark.parametrize(
"env,message",
[
[
"Error",
"The environment must inherit from the gym.Env class. See https://www.gymlibrary.ml/content/environment_creation/ for more info.",
],
[
GenericTestEnv(action_space=None),
"The environment must specify an action space. See https://www.gymlibrary.ml/content/environment_creation/ for more info.",
],
[
GenericTestEnv(observation_space=None),
"The environment must specify an observation space. See https://www.gymlibrary.ml/content/environment_creation/ for more info.",
],
],
)
def test_check_env(env: gym.Env, message: str):
"""Tests the check_env function works as expected."""
with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"):
check_env(env)

View File

@@ -0,0 +1,461 @@
import re
from typing import Dict, Union
import numpy as np
import pytest
import gym
from gym import spaces
from gym.utils.passive_env_checker import (
check_action_space,
check_obs,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
from tests.testing_env import GenericTestEnv
def _modify_space(space: spaces.Space, attribute: str, value):
setattr(space, attribute, value)
return space
@pytest.mark.parametrize(
"test,space,message",
[
[
AssertionError,
"error",
"observation space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
],
# ===== Check box observation space ====
[
UserWarning,
spaces.Box(np.zeros((5, 5, 1)), 255 * np.ones((5, 5, 1)), dtype=np.int32),
"It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: int32. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector.",
],
[
UserWarning,
spaces.Box(np.ones((2, 2, 1)), 255 * np.ones((2, 2, 1)), dtype=np.uint8),
"It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.",
],
[
UserWarning,
spaces.Box(np.zeros((5, 5, 1)), np.ones((5, 5, 1)), dtype=np.uint8),
"It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.",
],
[
UserWarning,
spaces.Box(np.zeros((5, 5)), np.ones((5, 5))),
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (5, 5)",
],
[
UserWarning,
spaces.Box(np.zeros(5), np.zeros(5)),
"A Box observation space maximum and minimum values are equal.",
],
[
UserWarning,
spaces.Box(np.ones(5), np.zeros(5)),
"A Box observation space low value is greater than a high value.",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
"The Box observation space shape and low shape have different shapes, low shape: (3,), box shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
"The Box observation space shape and high shape have have different shapes, high shape: (3,), box shape: (2,)",
],
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
[
AssertionError,
_modify_space(spaces.Discrete(5), "n", -1),
"Discrete observation space's number of elements must be positive, actual number of elements: -1",
],
[
AssertionError,
_modify_space(spaces.MultiDiscrete([2, 2]), "nvec", np.array([2, -1])),
"Multi-discrete observation space's all nvec elements must be greater than 0, actual nvec: [ 2 -1]",
],
[
AssertionError,
_modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, 1, 2)),
"Multi-discrete observation space's shape must be equal to the nvec shape, space shape: (2, 1, 2), nvec shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)),
"Multi-binary observation space's all shape elements must be greater than 0, actual shape: (2, -1)",
],
[
AssertionError,
spaces.Tuple([]),
"An empty Tuple observation space is not allowed.",
],
[
AssertionError,
spaces.Dict(),
"An empty Dict observation space is not allowed.",
],
],
)
def test_check_observation_space(test, space, message: str):
"""Tests the check observation space."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_observation_space(space)
else:
with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_observation_space(space)
assert len(warnings) == 0
@pytest.mark.parametrize(
"test,space,message",
[
[
AssertionError,
"error",
"action space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
],
# ===== Check box observation space ====
[
UserWarning,
spaces.Box(np.zeros(5), np.zeros(5)),
"A Box action space maximum and minimum values are equal.",
],
[
UserWarning,
spaces.Box(np.ones(5), np.zeros(5)),
"A Box action space low value is greater than a high value.",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
"The Box action space shape and low shape have have different shapes, low shape: (3,), box shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
"The Box action space shape and high shape have different shapes, high shape: (3,), box shape: (2,)",
],
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
[
AssertionError,
_modify_space(spaces.Discrete(5), "n", -1),
"Discrete action space's number of elements must be positive, actual number of elements: -1",
],
[
AssertionError,
_modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, -1)),
"Multi-discrete action space's shape must be equal to the nvec shape, space shape: (2, -1), nvec shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)),
"Multi-binary action space's all shape elements must be greater than 0, actual shape: (2, -1)",
],
[
AssertionError,
spaces.Tuple([]),
"An empty Tuple action space is not allowed.",
],
[AssertionError, spaces.Dict(), "An empty Dict action space is not allowed."],
],
)
def test_check_action_space(
test: Union[UserWarning, type], space: spaces.Space, message: str
):
"""Tests the check action space function."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_action_space(space)
else:
with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_action_space(space)
assert len(warnings) == 0
@pytest.mark.parametrize(
"test,obs,obs_space,message",
[
[
UserWarning,
3,
spaces.Discrete(2),
"The obs returned by the `testing()` method is not within the observation space.",
],
[
UserWarning,
np.uint8(0),
spaces.Discrete(1),
"The obs returned by the `testing()` method should be an int or np.int64, actual type: <class 'numpy.uint8'>",
],
[
UserWarning,
[0, 1],
spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]),
"The obs returned by the `testing()` method was expecting a tuple, actual type: <class 'list'>",
],
[
AssertionError,
(1, 2, 3),
spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]),
"The obs returned by the `testing()` method length is not same as the observation space length, obs length: 3, space length: 2",
],
[
AssertionError,
{1, 2, 3},
spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)),
"The obs returned by the `testing()` method must be a dict, actual type: <class 'set'>",
],
[
AssertionError,
{"a": 1, "c": 2},
spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)),
"The obs returned by the `testing()` method observation keys is not same as the observation space keys, obs keys: ['a', 'c'], space keys: ['a', 'b']",
],
],
)
def test_check_obs(test, obs, obs_space: spaces.Space, message: str):
"""Tests the check observations function."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_obs(obs, obs_space, "testing")
else:
with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_obs(obs, obs_space, "testing")
assert len(warnings) == 0
def _reset_no_seed(self, return_info=False, options=None):
return self.observation_space.sample()
def _reset_seed_default(self, seed="error", return_info=False, options=None):
return self.observation_space.sample()
def _reset_no_return_info(self, seed=None, options=None):
return self.observation_space.sample()
def _reset_no_option(self, seed=None, return_info=False):
return self.observation_space.sample()
def _make_reset_results(results):
def _reset_result(self, seed=None, return_info=False, options=None):
return results
return _reset_result
@pytest.mark.parametrize(
"test,func,message,kwargs",
[
[
UserWarning,
_reset_no_seed,
"Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.",
{},
],
[
UserWarning,
_reset_seed_default,
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
{},
],
[
UserWarning,
_reset_no_return_info,
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.",
{},
],
[
UserWarning,
_reset_no_option,
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.",
{},
],
[
AssertionError,
_make_reset_results([0, {}]),
"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: <class 'list'>",
{"return_info": True},
],
[
AssertionError,
_make_reset_results((0, {1, 2})),
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'set'>",
{"return_info": True},
],
],
)
def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: Dict):
"""Tests the passive env reset check"""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
else:
with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
assert len(warnings) == 0
def _modified_step(
self, obs=None, reward=0, terminated=False, truncated=False, info=None
):
if obs is None:
obs = self.observation_space.sample()
if info is None:
info = {}
if truncated is None:
return obs, reward, terminated, info
else:
return obs, reward, terminated, truncated, info
@pytest.mark.parametrize(
"test,func,message",
[
[
AssertionError,
lambda self, _: "error",
"Expects step result to be a tuple, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, terminated="error", truncated=None),
"Expects `done` signal to be a boolean, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, terminated="error", truncated=False),
"Expects `terminated` signal to be a boolean, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, truncated="error"),
"Expects `truncated` signal to be a boolean, actual type: <class 'str'>",
],
[
gym.error.Error,
lambda self, _: (1, 2, 3),
"Expected `Env.step` to return a four or five element tuple, actual number of elements returned: 3.",
],
[
UserWarning,
lambda self, _: _modified_step(self, reward="error"),
"The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: <class 'str'>",
],
[
UserWarning,
lambda self, _: _modified_step(self, reward=np.nan),
"The reward is a NaN value.",
],
[
UserWarning,
lambda self, _: _modified_step(self, reward=np.inf),
"The reward is an inf value.",
],
[
AssertionError,
lambda self, _: _modified_step(self, info="error"),
"The `info` returned by `step()` must be a python dictionary, actual type: <class 'str'>",
],
],
)
def test_passive_env_step_checker(
test: Union[UserWarning, type], func: callable, message: str
):
"""Tests the passive env step checker."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
else:
with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
assert len(warnings) == 0, [warning for warning in warnings.list]
@pytest.mark.parametrize(
"test,env,message",
[
[
UserWarning,
GenericTestEnv(render_modes=None),
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`.",
],
[
UserWarning,
GenericTestEnv(render_modes="Testing mode"),
"Expects the render_modes to be a sequence (i.e. list, tuple), actual type: <class 'str'>",
],
[
UserWarning,
GenericTestEnv(render_modes=["Testing mode", 1], render_fps=1),
"Expects all render modes to be strings, actual types: [<class 'str'>, <class 'int'>]",
],
[
UserWarning,
GenericTestEnv(
render_modes=["Testing mode"],
render_fps=None,
render_mode="Testing mode",
render_fn=lambda self: 0,
),
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
],
[
UserWarning,
GenericTestEnv(render_modes=["Testing mode"], render_fps="fps"),
"Expects the `env.metadata['render_fps']` to be an integer or a float, actual type: <class 'str'>",
],
[
AssertionError,
GenericTestEnv(render_modes=[], render_fps=30, render_mode="Test"),
"With no render_modes, expects the Env.render_mode to be None, actual value: Test",
],
[
AssertionError,
GenericTestEnv(
render_modes=["Testing mode"], render_fps=30, render_mode="Non mode"
),
"The environment was initialized successfully however with an unsupported render mode. Render mode: Non mode, modes: ['Testing mode']",
],
],
)
def test_passive_render_checker(test, env: GenericTestEnv, message: str):
"""Tests the passive render checker."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_render_passive_checker(env)
else:
with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_render_passive_checker(env)
assert len(warnings) == 0

View File

@@ -85,6 +85,10 @@ def test_create_shared_memory_custom_space(n, ctx, space):
create_shared_memory(space, n=n, ctx=ctx) create_shared_memory(space, n=n, ctx=ctx)
def _write_shared_memory(space, i, shared_memory, sample):
write_to_shared_memory(space, i, sample, shared_memory)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces] "space", spaces, ids=[space.__class__.__name__ for space in spaces]
) )
@@ -105,14 +109,14 @@ def test_write_to_shared_memory(space):
else: else:
raise TypeError(f"Got unknown type `{type(lhs)}`.") raise TypeError(f"Got unknown type `{type(lhs)}`.")
def write(i, shared_memory, sample):
write_to_shared_memory(space, i, sample, shared_memory)
shared_memory_n8 = create_shared_memory(space, n=8) shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)] samples = [space.sample() for _ in range(8)]
processes = [ processes = [
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8) Process(
target=_write_shared_memory, args=(space, i, shared_memory_n8, samples[i])
)
for i in range(8)
] ]
for process in processes: for process in processes:

View File

@@ -0,0 +1,100 @@
import re
import numpy as np
import pytest
import gym
from gym.wrappers.env_checker import PassiveEnvChecker
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.utils import all_testing_initialised_envs
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize(
"env",
all_testing_initialised_envs,
ids=[env.spec.id for env in all_testing_initialised_envs],
)
def test_passive_checker_wrapper_warnings(env):
with pytest.warns(None) as warnings:
checker_env = PassiveEnvChecker(env)
checker_env.reset()
checker_env.step(checker_env.action_space.sample())
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
checker_env.close()
for warning in warnings.list:
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
raise gym.error.Error(f"Unexpected warning: {warning.message}")
@pytest.mark.parametrize(
"env, message",
[
(
GenericTestEnv(action_space=None),
"The environment must specify an action space. https://www.gymlibrary.ml/content/environment_creation/",
),
(
GenericTestEnv(action_space="error"),
"action space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
),
(
GenericTestEnv(observation_space=None),
"The environment must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/",
),
(
GenericTestEnv(observation_space="error"),
"observation space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
),
],
)
def test_initialise_failures(env, message):
with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"):
PassiveEnvChecker(env)
env.close()
def _reset_failure(self, seed=None, return_info=False, options=None):
return np.array([-1.0], dtype=np.float32)
def _step_failure(self, action):
return "error"
def test_api_failures():
env = GenericTestEnv(
reset_fn=_reset_failure, step_fn=_step_failure, render_modes="error"
)
env = PassiveEnvChecker(env)
assert env.checked_reset is False
assert env.checked_step is False
assert env.checked_render is False
with pytest.warns(
UserWarning,
match=re.escape(
"The obs returned by the `reset()` method is not within the observation space"
),
):
env.reset()
assert env.checked_reset
with pytest.raises(
AssertionError,
match="Expects step result to be a tuple, actual type: <class 'str'>",
):
env.step(env.action_space.sample())
assert env.checked_step
with pytest.warns(
UserWarning,
match=r"Expects the render_modes to be a sequence \(i\.e\. list, tuple\), actual type: <class 'str'>",
):
env.render()
assert env.checked_render
env.close()