mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Bug fix, add tests for environment checker and passive environment checker wrapper (#2903)
This commit is contained in:
@@ -327,8 +327,7 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
|
|
||||||
if not self.new_step_api:
|
if not self.new_step_api:
|
||||||
deprecation(
|
deprecation(
|
||||||
"Initializing wrapper in old step API which returns one bool instead of two. "
|
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
|
||||||
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
@@ -18,6 +18,16 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
|
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# As pygame is necessary for using the environment (reset and step) even without a render mode
|
||||||
|
# therefore, pygame is a necessary import for the environment.
|
||||||
|
import pygame
|
||||||
|
from pygame import gfxdraw
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gym[box2d]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
STATE_W = 96 # less than Atari 160x192
|
STATE_W = 96 # less than Atari 160x192
|
||||||
STATE_H = 96
|
STATE_H = 96
|
||||||
@@ -556,15 +566,8 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
|
|
||||||
def _render(self, mode: str = "human"):
|
def _render(self, mode: str = "human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
try:
|
|
||||||
import pygame
|
|
||||||
except ImportError:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"pygame is not installed, run `pip install gym[box2d]`"
|
|
||||||
)
|
|
||||||
|
|
||||||
pygame.font.init()
|
pygame.font.init()
|
||||||
|
|
||||||
if self.screen is None and mode == "human":
|
if self.screen is None and mode == "human":
|
||||||
pygame.init()
|
pygame.init()
|
||||||
pygame.display.init()
|
pygame.display.init()
|
||||||
@@ -661,8 +664,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
self._draw_colored_polygon(self.surf, poly, color, zoom, translation, angle)
|
self._draw_colored_polygon(self.surf, poly, color, zoom, translation, angle)
|
||||||
|
|
||||||
def _render_indicators(self, W, H):
|
def _render_indicators(self, W, H):
|
||||||
import pygame
|
|
||||||
|
|
||||||
s = W / 40.0
|
s = W / 40.0
|
||||||
h = H / 40.0
|
h = H / 40.0
|
||||||
color = (0, 0, 0)
|
color = (0, 0, 0)
|
||||||
@@ -733,9 +734,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
def _draw_colored_polygon(
|
def _draw_colored_polygon(
|
||||||
self, surface, poly, color, zoom, translation, angle, clip=True
|
self, surface, poly, color, zoom, translation, angle, clip=True
|
||||||
):
|
):
|
||||||
import pygame
|
|
||||||
from pygame import gfxdraw
|
|
||||||
|
|
||||||
poly = [pygame.math.Vector2(c).rotate_rad(angle) for c in poly]
|
poly = [pygame.math.Vector2(c).rotate_rad(angle) for c in poly]
|
||||||
poly = [
|
poly = [
|
||||||
(c[0] * zoom + translation[0], c[1] * zoom + translation[1]) for c in poly
|
(c[0] * zoom + translation[0], c[1] * zoom + translation[1]) for c in poly
|
||||||
@@ -754,8 +752,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
gfxdraw.filled_polygon(self.surf, poly, color)
|
gfxdraw.filled_polygon(self.surf, poly, color)
|
||||||
|
|
||||||
def _create_image_array(self, screen, size):
|
def _create_image_array(self, screen, size):
|
||||||
import pygame
|
|
||||||
|
|
||||||
scaled_screen = pygame.transform.smoothscale(screen, size)
|
scaled_screen = pygame.transform.smoothscale(screen, size)
|
||||||
return np.transpose(
|
return np.transpose(
|
||||||
np.array(pygame.surfarray.pixels3d(scaled_screen)), axes=(1, 0, 2)
|
np.array(pygame.surfarray.pixels3d(scaled_screen)), axes=(1, 0, 2)
|
||||||
@@ -763,8 +759,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.screen is not None:
|
if self.screen is not None:
|
||||||
import pygame
|
|
||||||
|
|
||||||
pygame.display.quit()
|
pygame.display.quit()
|
||||||
self.isopen = False
|
self.isopen = False
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
@@ -772,7 +766,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
a = np.array([0.0, 0.0, 0.0])
|
a = np.array([0.0, 0.0, 0.0])
|
||||||
import pygame
|
|
||||||
|
|
||||||
def register_input():
|
def register_input():
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
|
@@ -692,6 +692,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||||
|
|
||||||
if mode == "human":
|
if mode == "human":
|
||||||
|
assert self.screen is not None
|
||||||
self.screen.blit(self.surf, (0, 0))
|
self.screen.blit(self.surf, (0, 0))
|
||||||
pygame.event.pump()
|
pygame.event.pump()
|
||||||
self.clock.tick(self.metadata["render_fps"])
|
self.clock.tick(self.metadata["render_fps"])
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
Utility functions used for classic control environments.
|
Utility functions used for classic control environments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, SupportsFloat, Union
|
from typing import Optional, SupportsFloat, Tuple
|
||||||
|
|
||||||
|
|
||||||
def verify_number_and_cast(x: SupportsFloat) -> float:
|
def verify_number_and_cast(x: SupportsFloat) -> float:
|
||||||
@@ -10,35 +10,37 @@ def verify_number_and_cast(x: SupportsFloat) -> float:
|
|||||||
try:
|
try:
|
||||||
x = float(x)
|
x = float(x)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
raise ValueError(f"Your input must support being cast to a float: {x}")
|
raise ValueError(f"An option ({x}) could not be converted to a float.")
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def maybe_parse_reset_bounds(
|
def maybe_parse_reset_bounds(
|
||||||
options: Optional[dict], default_low: float, default_high: float
|
options: Optional[dict], default_low: float, default_high: float
|
||||||
) -> Union[float, float]:
|
) -> Tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
This function can be called during a reset() to customize the sampling
|
This function can be called during a reset() to customize the sampling
|
||||||
ranges for setting the initial state distributions.
|
ranges for setting the initial state distributions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
options: (Optional) options passed in to reset().
|
options: Options passed in to reset().
|
||||||
default_low: Default lower limit to use, if none specified in options.
|
default_low: Default lower limit to use, if none specified in options.
|
||||||
default_high: Default upper limit to use, if none specified in options.
|
default_high: Default upper limit to use, if none specified in options.
|
||||||
limit_low: Lowest allowable value for user-specified lower limit.
|
|
||||||
limit_high: Highest allowable value for user-specified higher limit.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Lower and higher limits.
|
Tuple of the lower and upper limits.
|
||||||
"""
|
"""
|
||||||
if options is None:
|
if options is None:
|
||||||
return default_low, default_high
|
return default_low, default_high
|
||||||
|
|
||||||
low = options.get("low") if "low" in options else default_low
|
low = options.get("low") if "low" in options else default_low
|
||||||
high = options.get("high") if "high" in options else default_high
|
high = options.get("high") if "high" in options else default_high
|
||||||
|
|
||||||
# We expect only numerical inputs.
|
# We expect only numerical inputs.
|
||||||
low = verify_number_and_cast(low)
|
low = verify_number_and_cast(low)
|
||||||
high = verify_number_and_cast(high)
|
high = verify_number_and_cast(high)
|
||||||
if low > high:
|
if low > high:
|
||||||
raise ValueError("Lower bound must be lower than higher bound.")
|
raise ValueError(
|
||||||
|
f"Lower bound ({low}) must be lower than higher bound ({high})."
|
||||||
|
)
|
||||||
|
|
||||||
return low, high
|
return low, high
|
||||||
|
@@ -117,14 +117,32 @@ def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnvSpec:
|
class EnvSpec:
|
||||||
|
"""A specification for creating environments with `gym.make`.
|
||||||
|
|
||||||
|
* id: The string used to create the environment with `gym.make`
|
||||||
|
* entry_point: The location of the environment to create from
|
||||||
|
* reward_threshold: The reward threshold for completing the environment.
|
||||||
|
* nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
|
||||||
|
* max_episode_steps: The max number of steps that the environment can take before truncation
|
||||||
|
* order_enforce: If to enforce the order of `reset` before `step` and `render` functions
|
||||||
|
* autoreset: If to automatically reset the environment on episode end
|
||||||
|
* disable_env_checker: If to disable the environment checker wrapper by default in `gym.make`
|
||||||
|
* kwargs: Additional keyword arguments passed to the environments through `gym.make`
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
entry_point: Optional[Union[Callable, str]] = field(default=None)
|
entry_point: Optional[Union[Callable, str]] = field(default=None)
|
||||||
reward_threshold: Optional[float] = field(default=None)
|
reward_threshold: Optional[float] = field(default=None)
|
||||||
nondeterministic: bool = field(default=False)
|
nondeterministic: bool = field(default=False)
|
||||||
|
|
||||||
|
# Wrappers
|
||||||
max_episode_steps: Optional[int] = field(default=None)
|
max_episode_steps: Optional[int] = field(default=None)
|
||||||
order_enforce: bool = field(default=True)
|
order_enforce: bool = field(default=True)
|
||||||
autoreset: bool = field(default=False)
|
autoreset: bool = field(default=False)
|
||||||
|
disable_env_checker: bool = field(default=False)
|
||||||
new_step_api: bool = field(default=False)
|
new_step_api: bool = field(default=False)
|
||||||
|
|
||||||
|
# Environment arguments
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
namespace: Optional[str] = field(init=False)
|
namespace: Optional[str] = field(init=False)
|
||||||
@@ -530,7 +548,7 @@ def make(
|
|||||||
max_episode_steps: Optional[int] = None,
|
max_episode_steps: Optional[int] = None,
|
||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
new_step_api: bool = False,
|
new_step_api: bool = False,
|
||||||
disable_env_checker: bool = False,
|
disable_env_checker: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Env:
|
) -> Env:
|
||||||
"""Create an environment according to the given ID.
|
"""Create an environment according to the given ID.
|
||||||
@@ -540,7 +558,8 @@ def make(
|
|||||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||||
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
|
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
|
||||||
disable_env_checker: If to disable the environment checker
|
disable_env_checker: If to run the env checker, None will default to the environment `spec.disable_env_checker`
|
||||||
|
(that is by default True), otherwise will run according to the parameter (True = not run, False = run)
|
||||||
kwargs: Additional arguments to pass to the environment constructor.
|
kwargs: Additional arguments to pass to the environment constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -641,16 +660,15 @@ def make(
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if apply_human_rendering:
|
|
||||||
env = HumanRendering(env)
|
|
||||||
|
|
||||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||||
spec_ = copy.deepcopy(spec_)
|
spec_ = copy.deepcopy(spec_)
|
||||||
spec_.kwargs = _kwargs
|
spec_.kwargs = _kwargs
|
||||||
env.unwrapped.spec = spec_
|
env.unwrapped.spec = spec_
|
||||||
|
|
||||||
# Run the environment checker as the lowest level wrapper
|
# Run the environment checker as the lowest level wrapper
|
||||||
if disable_env_checker is False:
|
if disable_env_checker is False or (
|
||||||
|
disable_env_checker is None and spec_.disable_env_checker is False
|
||||||
|
):
|
||||||
env = PassiveEnvChecker(env)
|
env = PassiveEnvChecker(env)
|
||||||
|
|
||||||
env = StepAPICompatibility(env, new_step_api)
|
env = StepAPICompatibility(env, new_step_api)
|
||||||
@@ -669,6 +687,10 @@ def make(
|
|||||||
if autoreset:
|
if autoreset:
|
||||||
env = AutoResetWrapper(env, new_step_api)
|
env = AutoResetWrapper(env, new_step_api)
|
||||||
|
|
||||||
|
# Add human rendering wrapper
|
||||||
|
if apply_human_rendering:
|
||||||
|
env = HumanRendering(env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
@@ -20,12 +20,13 @@ from copy import deepcopy
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import error, logger
|
from gym import logger, spaces
|
||||||
from gym.utils.passive_env_checker import (
|
from gym.utils.passive_env_checker import (
|
||||||
check_action_space,
|
check_action_space,
|
||||||
check_observation_space,
|
check_observation_space,
|
||||||
passive_env_reset_check,
|
env_render_passive_checker,
|
||||||
passive_env_step_check,
|
env_reset_passive_checker,
|
||||||
|
env_step_passive_checker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -67,47 +68,61 @@ def check_reset_seed(env: gym.Env):
|
|||||||
even though `seed` or `kwargs` appear in the signature.
|
even though `seed` or `kwargs` appear in the signature.
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
if "seed" in signature.parameters or "kwargs" in signature.parameters:
|
if "seed" in signature.parameters or (
|
||||||
|
"kwargs" in signature.parameters
|
||||||
|
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
obs_1 = env.reset(seed=123)
|
obs_1 = env.reset(seed=123)
|
||||||
assert obs_1 in env.observation_space
|
|
||||||
obs_2 = env.reset(seed=123)
|
|
||||||
assert obs_2 in env.observation_space
|
|
||||||
assert data_equivalence(obs_1, obs_2)
|
|
||||||
seed_123_rng = deepcopy(env.unwrapped.np_random)
|
|
||||||
|
|
||||||
# Note: for some environment, they may initialise at the same state, therefore we cannot check the obs_1 != obs_3
|
|
||||||
obs_4 = env.reset(seed=None)
|
|
||||||
assert obs_4 in env.observation_space
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
env.unwrapped.np_random.bit_generator.state
|
obs_1 in env.observation_space
|
||||||
!= seed_123_rng.bit_generator.state
|
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
||||||
|
assert (
|
||||||
|
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
|
||||||
|
is not None
|
||||||
|
), "Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`."
|
||||||
|
seed_123_rng = deepcopy(
|
||||||
|
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
obs_2 = env.reset(seed=123)
|
||||||
|
assert (
|
||||||
|
obs_2 in env.observation_space
|
||||||
|
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
||||||
|
if env.spec is not None and env.spec.nondeterministic is False:
|
||||||
|
assert data_equivalence(
|
||||||
|
obs_1, obs_2
|
||||||
|
), "Using `env.reset(seed=123)` is non-deterministic as the observations are not equivalent."
|
||||||
|
assert (
|
||||||
|
env.unwrapped._np_random.bit_generator.state # pyright: ignore [reportPrivateUsage]
|
||||||
|
== seed_123_rng.bit_generator.state
|
||||||
|
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
|
||||||
|
|
||||||
|
obs_3 = env.reset(seed=456)
|
||||||
|
assert (
|
||||||
|
obs_3 in env.observation_space
|
||||||
|
), "The observation returned by `env.reset(seed=456)` is not within the observation space."
|
||||||
|
assert (
|
||||||
|
env.unwrapped._np_random.bit_generator.state # pyright: ignore [reportPrivateUsage]
|
||||||
|
!= seed_123_rng.bit_generator.state
|
||||||
|
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`."
|
||||||
|
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
|
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
|
||||||
"This should never happen, please report this issue. "
|
f"This should never happen, please report this issue. The error was: {e}"
|
||||||
f"The error was: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if env.unwrapped.np_random is None:
|
|
||||||
logger.warn(
|
|
||||||
"Resetting the environment did not result in seeding its random number generator. "
|
|
||||||
"This is likely due to not calling `super().reset(seed=seed)` in the `reset` method. "
|
|
||||||
"If you do not use the python-level random number generator, this is not a problem."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
seed_param = signature.parameters.get("seed")
|
seed_param = signature.parameters.get("seed")
|
||||||
# Check the default value is None
|
# Check the default value is None
|
||||||
if seed_param is not None and seed_param.default is not None:
|
if seed_param is not None and seed_param.default is not None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"The default seed argument in reset should be `None`, "
|
"The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. "
|
||||||
"otherwise the environment will by default always be deterministic"
|
f"Actual default: {seed_param.default}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise gym.error.Error(
|
||||||
"The `reset` method does not provide the `return_info` keyword argument"
|
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -122,25 +137,39 @@ def check_reset_info(env: gym.Env):
|
|||||||
even though `return_info` or `kwargs` appear in the signature.
|
even though `return_info` or `kwargs` appear in the signature.
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
if "return_info" in signature.parameters or "kwargs" in signature.parameters:
|
if "return_info" in signature.parameters or (
|
||||||
|
"kwargs" in signature.parameters
|
||||||
|
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
|
obs = env.reset(return_info=False)
|
||||||
|
assert (
|
||||||
|
obs in env.observation_space
|
||||||
|
), "The value returned by `env.reset(return_info=True)` is not within the observation space."
|
||||||
|
|
||||||
result = env.reset(return_info=True)
|
result = env.reset(return_info=True)
|
||||||
|
assert isinstance(
|
||||||
|
result, tuple
|
||||||
|
), f"Calling the reset method with `return_info=True` did not return a tuple, actual type: {type(result)}"
|
||||||
assert (
|
assert (
|
||||||
len(result) == 2
|
len(result) == 2
|
||||||
), "Calling the reset method with `return_info=True` did not return a 2-tuple"
|
), f"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: {len(result)}"
|
||||||
|
|
||||||
obs, info = result
|
obs, info = result
|
||||||
|
assert (
|
||||||
|
obs in env.observation_space
|
||||||
|
), "The first element returned by `env.reset(return_info=True)` is not within the observation space."
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
info, dict
|
info, dict
|
||||||
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
|
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
|
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` appear in the signature. "
|
||||||
"appear in the signature. This should never happen, please report this issue. "
|
f"This should never happen, please report this issue. The error was: {e}"
|
||||||
f"The error was: {e}"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise gym.error.Error(
|
||||||
"The `reset` method does not provide the `return_info` keyword argument"
|
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -155,57 +184,62 @@ def check_reset_options(env: gym.Env):
|
|||||||
even though `options` or `kwargs` appear in the signature.
|
even though `options` or `kwargs` appear in the signature.
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
if "options" in signature.parameters or "kwargs" in signature.parameters:
|
if "options" in signature.parameters or (
|
||||||
|
"kwargs" in signature.parameters
|
||||||
|
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
env.reset(options={})
|
env.reset(options={})
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"The environment cannot be reset with options, even though `options` or `kwargs` appear in the signature. "
|
"The environment cannot be reset with options, even though `options` or `**kwargs` appear in the signature. "
|
||||||
"This should never happen, please report this issue. "
|
f"This should never happen, please report this issue. The error was: {e}"
|
||||||
f"The error was: {e}"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise gym.error.Error(
|
||||||
"The `reset` method does not provide the `options` keyword argument"
|
"The `reset` method does not provide an `options` or `**kwargs` keyword argument."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_render(env: gym.Env, warn: bool = True):
|
def check_space_limit(space, space_type: str):
|
||||||
"""Check the declared render modes/fps of the environment.
|
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
|
||||||
|
if isinstance(space, spaces.Box):
|
||||||
Args:
|
if np.any(np.equal(space.low, -np.inf)):
|
||||||
env: The environment to check
|
|
||||||
warn: Whether to output additional warnings
|
|
||||||
"""
|
|
||||||
render_modes = env.metadata.get("render_modes")
|
|
||||||
if render_modes is None:
|
|
||||||
if warn:
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"No render modes was declared in the environment "
|
f"A Box {space_type} space minimum value is -infinity. This is probably too low."
|
||||||
" (env.metadata['render_modes'] is None or not defined), "
|
|
||||||
"you may have trouble when calling `.render()`"
|
|
||||||
)
|
)
|
||||||
|
if np.any(np.equal(space.high, np.inf)):
|
||||||
render_fps = env.metadata.get("render_fps")
|
|
||||||
# We only require `render_fps` if rendering is actually implemented
|
|
||||||
if render_fps is None and render_modes is not None and len(render_modes) > 0:
|
|
||||||
if warn:
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"No render fps was declared in the environment "
|
f"A Box {space_type} space maximum value is -infinity. This is probably too high."
|
||||||
" (env.metadata['render_fps'] is None or not defined), "
|
|
||||||
"rendering may occur at inconsistent fps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if warn:
|
# Check that the Box space is normalized
|
||||||
if not hasattr(env, "render_mode"): # TODO: raise an error with gym 1.0
|
if space_type == "action":
|
||||||
logger.warn("Environments must define render_mode attribute.")
|
if len(space.shape) == 1: # for vector boxes
|
||||||
elif env.render_mode is not None and env.render_mode not in render_modes:
|
if (
|
||||||
|
np.any(
|
||||||
|
np.logical_and(
|
||||||
|
space.low != np.zeros_like(space.low),
|
||||||
|
np.abs(space.low) != np.abs(space.high),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
or np.any(space.low < -1)
|
||||||
|
or np.any(space.high > 1)
|
||||||
|
):
|
||||||
|
# todo - Add to gymlibrary.ml?
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"The environment was initialized successfully with an unsupported render mode."
|
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). "
|
||||||
|
"See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information."
|
||||||
)
|
)
|
||||||
|
elif isinstance(space, spaces.Tuple):
|
||||||
|
for subspace in space.spaces:
|
||||||
|
check_space_limit(subspace, space_type)
|
||||||
|
elif isinstance(space, spaces.Dict):
|
||||||
|
for subspace in space.values():
|
||||||
|
check_space_limit(subspace, space_type)
|
||||||
|
|
||||||
|
|
||||||
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
|
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
|
||||||
"""Check that an environment follows Gym API.
|
"""Check that an environment follows Gym API.
|
||||||
|
|
||||||
This is an invasive function that calls the environment's reset and step.
|
This is an invasive function that calls the environment's reset and step.
|
||||||
@@ -220,21 +254,29 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
|
|||||||
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
|
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
|
||||||
"""
|
"""
|
||||||
if warn is not None:
|
if warn is not None:
|
||||||
logger.warn("`check_env` warn parameter is now ignored.")
|
logger.warn("`check_env(warn=...)` parameter is now ignored.")
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
env, gym.Env
|
env, gym.Env
|
||||||
), "Your environment must inherit from the gym.Env class https://www.gymlibrary.ml/content/environment_creation/"
|
), "The environment must inherit from the gym.Env class. See https://www.gymlibrary.ml/content/environment_creation/ for more info."
|
||||||
|
|
||||||
|
if env.unwrapped is not env:
|
||||||
|
logger.warn(
|
||||||
|
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
|
||||||
|
)
|
||||||
|
|
||||||
# ============= Check the spaces (observation and action) ================
|
# ============= Check the spaces (observation and action) ================
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
env, "action_space"
|
env, "action_space"
|
||||||
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/"
|
), "The environment must specify an action space. See https://www.gymlibrary.ml/content/environment_creation/ for more info."
|
||||||
check_action_space(env.action_space)
|
check_action_space(env.action_space)
|
||||||
|
check_space_limit(env.action_space, "action")
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
env, "observation_space"
|
env, "observation_space"
|
||||||
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
|
), "The environment must specify an observation space. See https://www.gymlibrary.ml/content/environment_creation/ for more info."
|
||||||
check_observation_space(env.observation_space)
|
check_observation_space(env.observation_space)
|
||||||
|
check_space_limit(env.observation_space, "observation")
|
||||||
|
|
||||||
# ==== Check the reset method ====
|
# ==== Check the reset method ====
|
||||||
check_reset_seed(env)
|
check_reset_seed(env)
|
||||||
@@ -242,9 +284,12 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
|
|||||||
check_reset_info(env)
|
check_reset_info(env)
|
||||||
|
|
||||||
# ============ Check the returned values ===============
|
# ============ Check the returned values ===============
|
||||||
passive_env_reset_check(env)
|
env_reset_passive_checker(env)
|
||||||
passive_env_step_check(env, env.action_space.sample())
|
env_step_passive_checker(env, env.action_space.sample())
|
||||||
|
|
||||||
# ==== Check the render method and the declared render modes ====
|
# ==== Check the render method and the declared render modes ====
|
||||||
if not skip_render_check:
|
if not skip_render_check:
|
||||||
check_render(env)
|
if env.render_mode is not None:
|
||||||
|
env_render_passive_checker(env)
|
||||||
|
|
||||||
|
# todo: recreate the environment with a different render_mode for check that each work
|
||||||
|
@@ -1,10 +1,11 @@
|
|||||||
"""A set of functions for passively checking environment implementations."""
|
"""A set of functions for passively checking environment implementations."""
|
||||||
import inspect
|
import inspect
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym import error, logger, spaces
|
from gym import Space, error, logger, spaces
|
||||||
from gym.logger import deprecation
|
|
||||||
|
|
||||||
|
|
||||||
def _check_box_observation_space(observation_space: spaces.Box):
|
def _check_box_observation_space(observation_space: spaces.Box):
|
||||||
@@ -17,47 +18,33 @@ def _check_box_observation_space(observation_space: spaces.Box):
|
|||||||
if len(observation_space.shape) == 3:
|
if len(observation_space.shape) == 3:
|
||||||
if observation_space.dtype != np.uint8:
|
if observation_space.dtype != np.uint8:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"It seems that your observation space ({observation_space}) is an image but the `dtype` of your observation_space is not `np.uint8`. "
|
f"It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: {observation_space.dtype}. "
|
||||||
"If your observation is not an image, we recommend you to flatten the observation to have only a 1D vector"
|
"If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector."
|
||||||
)
|
)
|
||||||
if np.any(observation_space.low != 0) or np.any(
|
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
|
||||||
observation_space.high != 255
|
|
||||||
): # todo np.all?
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"It seems that your observation space is an image but the upper and lower bounds are not in [0, 255]. "
|
"It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. "
|
||||||
"Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not."
|
"Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not."
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(observation_space.shape) not in [1, 3]:
|
if len(observation_space.shape) not in [1, 3]:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Your observation space has an unconventional shape (neither an image, nor a 1D vector). "
|
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). "
|
||||||
"We recommend you to flatten the observation to have only a 1D vector or use a custom policy to properly process the data. "
|
"We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. "
|
||||||
f"Observation space={observation_space}"
|
f"Actual observation shape: {observation_space.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if np.any(np.equal(observation_space.low, -np.inf)):
|
assert (
|
||||||
logger.warn(
|
observation_space.low.shape == observation_space.shape
|
||||||
"Agent's minimum observation space value is -infinity. This is probably too low."
|
), f"The Box observation space shape and low shape have different shapes, low shape: {observation_space.low.shape}, box shape: {observation_space.shape}"
|
||||||
)
|
assert (
|
||||||
if np.any(np.equal(observation_space.high, np.inf)):
|
observation_space.high.shape == observation_space.shape
|
||||||
logger.warn(
|
), f"The Box observation space shape and high shape have have different shapes, high shape: {observation_space.high.shape}, box shape: {observation_space.shape}"
|
||||||
"Agent's maximum observation space value is infinity. This is probably too high."
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.any(np.equal(observation_space.low, observation_space.high)):
|
if np.any(observation_space.low == observation_space.high):
|
||||||
logger.warn("Agent's maximum and minimum observation space values are equal")
|
logger.warn("A Box observation space maximum and minimum values are equal.")
|
||||||
if np.any(np.greater(observation_space.low, observation_space.high)):
|
elif np.any(observation_space.high < observation_space.low):
|
||||||
raise AssertionError(
|
logger.warn("A Box observation space low value is greater than a high value.")
|
||||||
"Agent's minimum observation value is greater than it's maximum"
|
|
||||||
)
|
|
||||||
if observation_space.low.shape != observation_space.shape:
|
|
||||||
raise AssertionError(
|
|
||||||
"Agent's observation_space.low and observation_space have different shapes"
|
|
||||||
)
|
|
||||||
if observation_space.high.shape != observation_space.shape:
|
|
||||||
raise AssertionError(
|
|
||||||
"Agent's observation_space.high and observation_space have different shapes"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_box_action_space(action_space: spaces.Box):
|
def _check_box_action_space(action_space: spaces.Box):
|
||||||
@@ -66,42 +53,73 @@ def _check_box_action_space(action_space: spaces.Box):
|
|||||||
Args:
|
Args:
|
||||||
action_space: A box action space
|
action_space: A box action space
|
||||||
"""
|
"""
|
||||||
if np.any(np.equal(action_space.low, -np.inf)):
|
assert (
|
||||||
logger.warn(
|
action_space.low.shape == action_space.shape
|
||||||
"Agent's minimum action space value is -infinity. This is probably too low."
|
), f"The Box action space shape and low shape have have different shapes, low shape: {action_space.low.shape}, box shape: {action_space.shape}"
|
||||||
)
|
assert (
|
||||||
if np.any(np.equal(action_space.high, np.inf)):
|
action_space.high.shape == action_space.shape
|
||||||
logger.warn(
|
), f"The Box action space shape and high shape have different shapes, high shape: {action_space.high.shape}, box shape: {action_space.shape}"
|
||||||
"Agent's maximum action space value is infinity. This is probably too high"
|
|
||||||
)
|
|
||||||
if np.any(np.equal(action_space.low, action_space.high)):
|
|
||||||
logger.warn("Agent's maximum and minimum action space values are equal")
|
|
||||||
if np.any(np.greater(action_space.low, action_space.high)):
|
|
||||||
raise AssertionError(
|
|
||||||
"Agent's minimum action value is greater than it's maximum"
|
|
||||||
)
|
|
||||||
if action_space.low.shape != action_space.shape:
|
|
||||||
raise AssertionError(
|
|
||||||
"Agent's action_space.low and action_space have different shapes"
|
|
||||||
)
|
|
||||||
if action_space.high.shape != action_space.shape:
|
|
||||||
raise AssertionError(
|
|
||||||
"Agent's action_space.high and action_space have different shapes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that the Box space is normalized
|
if np.any(action_space.low == action_space.high):
|
||||||
if (
|
logger.warn("A Box action space maximum and minimum values are equal.")
|
||||||
np.any(np.abs(action_space.low) != np.abs(action_space.high))
|
elif np.any(action_space.high < action_space.low):
|
||||||
or np.any(np.abs(action_space.low) > 1)
|
logger.warn("A Box action space low value is greater than a high value.")
|
||||||
or np.any(np.abs(action_space.high) > 1)
|
|
||||||
|
|
||||||
|
def check_space(
|
||||||
|
space: Space, space_type: str, check_box_space_fn: Callable[[spaces.Box], None]
|
||||||
):
|
):
|
||||||
logger.warn(
|
"""A passive check of the environment action space that should not affect the environment."""
|
||||||
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
|
if not isinstance(space, spaces.Space):
|
||||||
"https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" # TODO Add to gymlibrary.ml?
|
raise AssertionError(
|
||||||
|
f"{space_type} space does not inherit from `gym.spaces.Space`, actual type: {type(space)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(space, spaces.Box):
|
||||||
|
check_box_space_fn(space)
|
||||||
|
elif isinstance(space, spaces.Discrete):
|
||||||
|
assert (
|
||||||
|
0 < space.n
|
||||||
|
), f"Discrete {space_type} space's number of elements must be positive, actual number of elements: {space.n}"
|
||||||
|
assert (
|
||||||
|
space.shape == ()
|
||||||
|
), f"Discrete {space_type} space's shape should be empty, actual shape: {space.shape}"
|
||||||
|
elif isinstance(space, spaces.MultiDiscrete):
|
||||||
|
assert (
|
||||||
|
space.shape == space.nvec.shape
|
||||||
|
), f"Multi-discrete {space_type} space's shape must be equal to the nvec shape, space shape: {space.shape}, nvec shape: {space.nvec.shape}"
|
||||||
|
assert np.all(
|
||||||
|
0 < space.nvec
|
||||||
|
), f"Multi-discrete {space_type} space's all nvec elements must be greater than 0, actual nvec: {space.nvec}"
|
||||||
|
elif isinstance(space, spaces.MultiBinary):
|
||||||
|
assert np.all(
|
||||||
|
0 < np.asarray(space.shape)
|
||||||
|
), f"Multi-binary {space_type} space's all shape elements must be greater than 0, actual shape: {space.shape}"
|
||||||
|
elif isinstance(space, spaces.Tuple):
|
||||||
|
assert 0 < len(
|
||||||
|
space.spaces
|
||||||
|
), f"An empty Tuple {space_type} space is not allowed."
|
||||||
|
for subspace in space.spaces:
|
||||||
|
check_space(subspace, space_type, check_box_space_fn)
|
||||||
|
elif isinstance(space, spaces.Dict):
|
||||||
|
assert 0 < len(
|
||||||
|
space.spaces.keys()
|
||||||
|
), f"An empty Dict {space_type} space is not allowed."
|
||||||
|
for subspace in space.values():
|
||||||
|
check_space(subspace, space_type, check_box_space_fn)
|
||||||
|
|
||||||
|
|
||||||
|
check_observation_space = partial(
|
||||||
|
check_space,
|
||||||
|
space_type="observation",
|
||||||
|
check_box_space_fn=_check_box_observation_space,
|
||||||
|
)
|
||||||
|
check_action_space = partial(
|
||||||
|
check_space, space_type="action", check_box_space_fn=_check_box_action_space
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_obs(obs, observation_space: spaces.Space, method_name: str):
|
def check_obs(obs, observation_space: spaces.Space, method_name: str):
|
||||||
"""Check that the observation returned by the environment correspond to the declared one.
|
"""Check that the observation returned by the environment correspond to the declared one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -109,111 +127,50 @@ def _check_obs(obs, observation_space: spaces.Space, method_name: str):
|
|||||||
observation_space: The observation space of the observation
|
observation_space: The observation space of the observation
|
||||||
method_name: The method name that generated the observation
|
method_name: The method name that generated the observation
|
||||||
"""
|
"""
|
||||||
pre = f"The observation returned by the `{method_name}()` method"
|
pre = f"The obs returned by the `{method_name}()` method"
|
||||||
|
|
||||||
assert observation_space.contains(
|
|
||||||
obs
|
|
||||||
), f"{pre} is not contained with the observation space ({observation_space})"
|
|
||||||
|
|
||||||
if isinstance(observation_space, spaces.Discrete):
|
if isinstance(observation_space, spaces.Discrete):
|
||||||
assert isinstance(
|
if not isinstance(obs, (np.int64, int)):
|
||||||
obs, int
|
logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
|
||||||
), f"The observation returned by `{method_name}()` method must be an int, actually {type(obs)}"
|
|
||||||
elif isinstance(
|
|
||||||
observation_space, (spaces.Box, spaces.MultiBinary, spaces.MultiDiscrete)
|
|
||||||
):
|
|
||||||
assert isinstance(
|
|
||||||
obs, np.ndarray
|
|
||||||
), f"The observation returned by `{method_name}()` method must be a numpy array, actually {type(obs)}"
|
|
||||||
elif isinstance(observation_space, spaces.Tuple):
|
|
||||||
assert isinstance(
|
|
||||||
obs, tuple
|
|
||||||
), f"The observation returned by the `{method_name}()` method must be a tuple, actually {type(obs)}"
|
|
||||||
for sub_obs, sub_space in zip(obs, observation_space.spaces):
|
|
||||||
_check_obs(sub_obs, sub_space, method_name)
|
|
||||||
elif isinstance(observation_space, spaces.Dict):
|
|
||||||
assert isinstance(
|
|
||||||
obs, dict
|
|
||||||
), f"The observation returned by the `{method_name}()` method must be a dict, actually {type(obs)}"
|
|
||||||
for space_key in observation_space.keys():
|
|
||||||
_check_obs(obs[space_key], observation_space[space_key], method_name)
|
|
||||||
|
|
||||||
|
|
||||||
def check_observation_space(observation_space):
|
|
||||||
"""A passive check of the environment observation space that should not affect the environment."""
|
|
||||||
if not isinstance(observation_space, spaces.Space):
|
|
||||||
raise AssertionError(
|
|
||||||
f"Observation space ({observation_space}) does not inherit from gym.spaces.Space"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(observation_space, spaces.Box):
|
elif isinstance(observation_space, spaces.Box):
|
||||||
# Check if the box is an image (shape is 3 elements and the last element is 1 or 3)
|
if observation_space.shape != ():
|
||||||
_check_box_observation_space(observation_space)
|
if not isinstance(obs, np.ndarray):
|
||||||
elif isinstance(observation_space, spaces.Discrete):
|
logger.warn(
|
||||||
assert (
|
f"{pre} was expecting a numpy array, actual type: {type(obs)}"
|
||||||
observation_space.n > 0
|
|
||||||
), f"There are no available discrete observations, n={observation_space.n}"
|
|
||||||
elif isinstance(observation_space, spaces.MultiDiscrete):
|
|
||||||
assert np.all(
|
|
||||||
observation_space.nvec > 0
|
|
||||||
), f"All dimensions of multi-discrete must be greater than 0, {observation_space.nvec}"
|
|
||||||
elif isinstance(observation_space, spaces.MultiBinary):
|
|
||||||
assert np.all(
|
|
||||||
np.asarray(observation_space.shape) > 0
|
|
||||||
), f"All dimensions of multi-binary must be greater than 0, {observation_space.shape}"
|
|
||||||
elif isinstance(observation_space, spaces.Tuple):
|
|
||||||
assert (
|
|
||||||
len(observation_space.spaces) > 0
|
|
||||||
), "An empty Tuple observation space is not allowed."
|
|
||||||
for subspace in observation_space.spaces:
|
|
||||||
check_observation_space(subspace)
|
|
||||||
elif isinstance(observation_space, spaces.Dict):
|
|
||||||
assert (
|
|
||||||
len(observation_space.spaces.keys()) > 0
|
|
||||||
), "An empty Dict observation space is not allowed."
|
|
||||||
for subspace in observation_space.values():
|
|
||||||
check_observation_space(subspace)
|
|
||||||
|
|
||||||
|
|
||||||
def check_action_space(action_space):
|
|
||||||
"""A passive check of the environment action space that should not affect the environment."""
|
|
||||||
if not isinstance(action_space, spaces.Space):
|
|
||||||
raise AssertionError(
|
|
||||||
f"Action space ({action_space}) does not inherit from gym.spaces.Space"
|
|
||||||
)
|
)
|
||||||
|
elif obs.dtype != observation_space.dtype:
|
||||||
|
logger.warn(
|
||||||
|
f"{pre} was expecting numpy array dtype to be {observation_space.dtype}, actual type: {obs.dtype}"
|
||||||
|
)
|
||||||
|
elif isinstance(observation_space, (spaces.MultiBinary, spaces.MultiDiscrete)):
|
||||||
|
if not isinstance(obs, np.ndarray):
|
||||||
|
logger.warn(f"{pre} was expecting a numpy array, actual type: {type(obs)}")
|
||||||
|
elif isinstance(observation_space, spaces.Tuple):
|
||||||
|
if not isinstance(obs, tuple):
|
||||||
|
logger.warn(f"{pre} was expecting a tuple, actual type: {type(obs)}")
|
||||||
|
assert len(obs) == len(
|
||||||
|
observation_space.spaces
|
||||||
|
), f"{pre} length is not same as the observation space length, obs length: {len(obs)}, space length: {len(observation_space.spaces)}"
|
||||||
|
for sub_obs, sub_space in zip(obs, observation_space.spaces):
|
||||||
|
check_obs(sub_obs, sub_space, method_name)
|
||||||
|
elif isinstance(observation_space, spaces.Dict):
|
||||||
|
assert isinstance(obs, dict), f"{pre} must be a dict, actual type: {type(obs)}"
|
||||||
|
assert (
|
||||||
|
obs.keys() == observation_space.spaces.keys()
|
||||||
|
), f"{pre} observation keys is not same as the observation space keys, obs keys: {list(obs.keys())}, space keys: {list(observation_space.spaces.keys())}"
|
||||||
|
for space_key in observation_space.spaces.keys():
|
||||||
|
check_obs(obs[space_key], observation_space[space_key], method_name)
|
||||||
|
|
||||||
elif isinstance(action_space, spaces.Box):
|
try:
|
||||||
_check_box_action_space(action_space)
|
if obs not in observation_space:
|
||||||
elif isinstance(action_space, spaces.Discrete):
|
logger.warn(f"{pre} is not within the observation space.")
|
||||||
assert (
|
except Exception as e:
|
||||||
action_space.n > 0
|
logger.warn(f"{pre} is not within the observation space with exception: {e}")
|
||||||
), f"There are no available discrete actions, n={action_space.n}"
|
|
||||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
|
||||||
assert np.all(
|
|
||||||
action_space.nvec > 0
|
|
||||||
), f"All dimensions of multi-discrete must be greater than 0, {action_space.nvec}"
|
|
||||||
elif isinstance(action_space, spaces.MultiBinary):
|
|
||||||
assert np.all(
|
|
||||||
np.asarray(action_space.shape) > 0
|
|
||||||
), f"All dimensions of multi-binary must be greater than 0, {action_space.shape}"
|
|
||||||
elif isinstance(action_space, spaces.Tuple):
|
|
||||||
assert (
|
|
||||||
len(action_space.spaces) > 0
|
|
||||||
), "An empty Tuple action space is not allowed."
|
|
||||||
for subspace in action_space.spaces:
|
|
||||||
check_action_space(subspace)
|
|
||||||
elif isinstance(action_space, spaces.Dict):
|
|
||||||
assert (
|
|
||||||
len(action_space.spaces.keys()) > 0
|
|
||||||
), "An empty Dict action space is not allowed."
|
|
||||||
for subspace in action_space.values():
|
|
||||||
check_action_space(subspace)
|
|
||||||
|
|
||||||
|
|
||||||
def passive_env_reset_check(env, **kwargs):
|
def env_reset_passive_checker(env, **kwargs):
|
||||||
"""A passive check of the `Env.reset` function investigating the returning reset information and returning the data unchanged."""
|
"""A passive check of the `Env.reset` function investigating the returning reset information and returning the data unchanged."""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
if "seed" not in signature.parameters or "kwargs" in signature.parameters:
|
if "seed" not in signature.parameters and "kwargs" not in signature.parameters:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator."
|
"Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator."
|
||||||
)
|
)
|
||||||
@@ -222,100 +179,150 @@ def passive_env_reset_check(env, **kwargs):
|
|||||||
# Check the default value is None
|
# Check the default value is None
|
||||||
if seed_param is not None and seed_param.default is not None:
|
if seed_param is not None and seed_param.default is not None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic"
|
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. "
|
||||||
|
f"Actual default: {seed_param}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if "return_info" not in signature.parameters or "kwargs" in signature.parameters:
|
if "return_info" not in signature.parameters and not (
|
||||||
|
"kwargs" in signature.parameters
|
||||||
|
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||||
|
):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
|
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "options" not in signature.parameters or "kwargs" in signature.parameters:
|
if "options" not in signature.parameters and "kwargs" not in signature.parameters:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
|
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checks the result of env.reset with kwargs
|
# Checks the result of env.reset with kwargs
|
||||||
result = env.reset(**kwargs)
|
result = env.reset(**kwargs)
|
||||||
if "return_info" in kwargs and kwargs["return_info"] is True:
|
if kwargs.get("return_info", False) is True:
|
||||||
|
assert isinstance(
|
||||||
|
result, tuple
|
||||||
|
), f"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: {type(result)}"
|
||||||
|
assert (
|
||||||
|
len(result) == 2
|
||||||
|
), f"The length of the result returned by `env.reset(return_info=True)` is not 2, actual length: {len(result)}"
|
||||||
obs, info = result
|
obs, info = result
|
||||||
_check_obs(obs, env.observation_space, "reset")
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
info, dict
|
info, dict
|
||||||
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actually {type(info)}"
|
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
|
||||||
else:
|
else:
|
||||||
obs = result
|
obs = result
|
||||||
_check_obs(obs, env.observation_space, "reset")
|
|
||||||
|
|
||||||
|
check_obs(obs, env.observation_space, "reset")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def passive_env_step_check(env, action):
|
def env_step_passive_checker(env, action):
|
||||||
"""A passive check for the environment step, investigating the returning data then returning the data unchanged."""
|
"""A passive check for the environment step, investigating the returning data then returning the data unchanged."""
|
||||||
|
# We don't check the action as for some environments then out-of-bounds values can be given
|
||||||
result = env.step(action)
|
result = env.step(action)
|
||||||
|
assert isinstance(
|
||||||
|
result, tuple
|
||||||
|
), f"Expects step result to be a tuple, actual type: {type(result)}"
|
||||||
if len(result) == 4:
|
if len(result) == 4:
|
||||||
deprecation(
|
logger.deprecation(
|
||||||
"Core environment is written in old step API which returns one bool instead of two. "
|
"Core environment is written in old step API which returns one bool instead of two. "
|
||||||
"It is recommended to rewrite the environment with new step API. "
|
"It is recommended to rewrite the environment with new step API. "
|
||||||
)
|
)
|
||||||
obs, reward, done, info = result
|
obs, reward, done, info = result
|
||||||
|
|
||||||
assert isinstance(
|
if not isinstance(done, (bool, np.bool8)):
|
||||||
done, bool
|
logger.warn(
|
||||||
), f"The `done` signal is of type `{type(done)}` must be a boolean"
|
f"Expects `done` signal to be a boolean, actual type: {type(done)}"
|
||||||
|
)
|
||||||
elif len(result) == 5:
|
elif len(result) == 5:
|
||||||
obs, reward, terminated, truncated, info = result
|
obs, reward, terminated, truncated, info = result
|
||||||
|
|
||||||
assert isinstance(
|
# np.bool is actual python bool not np boolean type, therefore bool_ or bool8
|
||||||
terminated, bool
|
if not isinstance(terminated, (bool, np.bool8)):
|
||||||
), f"The `terminated` signal is of type `{type(terminated)}`. It must be a boolean"
|
logger.warn(
|
||||||
assert isinstance(
|
f"Expects `terminated` signal to be a boolean, actual type: {type(terminated)}"
|
||||||
truncated, bool
|
)
|
||||||
), f"The `truncated` signal of type `{type(truncated)}`. It must be a boolean."
|
if not isinstance(truncated, (bool, np.bool8)):
|
||||||
assert (
|
logger.warn(
|
||||||
terminated is False or truncated is False
|
f"Expects `truncated` signal to be a boolean, actual type: {type(truncated)}"
|
||||||
), "Only `terminated` or `truncated` can be true, not both."
|
)
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise error.Error(
|
||||||
f"Expected `Env.step` to return a four or five elements, actually {len(result)} elements returned."
|
f"Expected `Env.step` to return a four or five element tuple, actual number of elements returned: {len(result)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
_check_obs(obs, env.observation_space, "step")
|
check_obs(obs, env.observation_space, "step")
|
||||||
if np.any(np.isnan(obs)):
|
|
||||||
logger.warn("Encountered NaN value in observations.")
|
|
||||||
if np.any(np.isinf(obs)):
|
|
||||||
logger.warn("Encountered inf value in observations.")
|
|
||||||
|
|
||||||
assert isinstance(
|
if not (
|
||||||
reward, (float, int, np.floating, np.integer)
|
np.issubdtype(type(reward), np.integer)
|
||||||
), "The reward returned by `step()` must be a float"
|
or np.issubdtype(type(reward), np.floating)
|
||||||
if np.any(np.isnan(reward)):
|
):
|
||||||
logger.warn("Encountered NaN value in rewards.")
|
logger.warn(
|
||||||
if np.any(np.isinf(reward)):
|
f"The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: {type(reward)}"
|
||||||
logger.warn("Encountered inf value in rewards.")
|
)
|
||||||
|
else:
|
||||||
|
if np.isnan(reward):
|
||||||
|
logger.warn("The reward is a NaN value.")
|
||||||
|
if np.isinf(reward):
|
||||||
|
logger.warn("The reward is an inf value.")
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
info, dict
|
info, dict
|
||||||
), "The `info` returned by `step()` must be a python dictionary"
|
), f"The `info` returned by `step()` must be a python dictionary, actual type: {type(info)}"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def passive_env_render_check(env, *args, **kwargs):
|
def env_render_passive_checker(env, *args, **kwargs):
|
||||||
"""A passive check of the `Env.render` that the declared render modes/fps in the metadata of the environment is decleared."""
|
"""A passive check of the `Env.render` that the declared render modes/fps in the metadata of the environment is declared."""
|
||||||
render_modes = env.metadata.get("render_modes")
|
render_modes = env.metadata.get("render_modes")
|
||||||
if render_modes is None:
|
if render_modes is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), "
|
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`."
|
||||||
"you may have trouble when calling `.render()`"
|
)
|
||||||
|
else:
|
||||||
|
if not isinstance(render_modes, (list, tuple)):
|
||||||
|
logger.warn(
|
||||||
|
f"Expects the render_modes to be a sequence (i.e. list, tuple), actual type: {type(render_modes)}"
|
||||||
|
)
|
||||||
|
elif not all(isinstance(mode, str) for mode in render_modes):
|
||||||
|
logger.warn(
|
||||||
|
f"Expects all render modes to be strings, actual types: {[type(mode) for mode in render_modes]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
render_fps = env.metadata.get("render_fps")
|
render_fps = env.metadata.get("render_fps")
|
||||||
# We only require `render_fps` if rendering is actually implemented
|
# We only require `render_fps` if rendering is actually implemented
|
||||||
|
if len(render_modes) > 0:
|
||||||
if render_fps is None:
|
if render_fps is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), "
|
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps."
|
||||||
"rendering may occur at inconsistent fps"
|
)
|
||||||
|
else:
|
||||||
|
if not (
|
||||||
|
np.issubdtype(type(render_fps), np.integer)
|
||||||
|
or np.issubdtype(type(render_fps), np.floating)
|
||||||
|
):
|
||||||
|
logger.warn(
|
||||||
|
f"Expects the `env.metadata['render_fps']` to be an integer or a float, actual type: {type(render_fps)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
render_fps > 0
|
||||||
|
), f"Expects the `env.metadata['render_fps']` to be greater than zero, actual value: {render_fps}"
|
||||||
|
|
||||||
|
# env.render is now an attribute with default None
|
||||||
|
if len(render_modes) == 0:
|
||||||
|
assert (
|
||||||
|
env.render_mode is None
|
||||||
|
), f"With no render_modes, expects the Env.render_mode to be None, actual value: {env.render_mode}"
|
||||||
|
else:
|
||||||
|
assert env.render_mode is None or env.render_mode in render_modes, (
|
||||||
|
"The environment was initialized successfully however with an unsupported render mode. "
|
||||||
|
f"Render mode: {env.render_mode}, modes: {render_modes}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return env.render(*args, **kwargs)
|
result = env.render(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Check that the result is correct
|
||||||
|
|
||||||
|
return result
|
||||||
|
@@ -14,7 +14,7 @@ def make(
|
|||||||
num_envs: int = 1,
|
num_envs: int = 1,
|
||||||
asynchronous: bool = True,
|
asynchronous: bool = True,
|
||||||
wrappers: Optional[Union[callable, List[callable]]] = None,
|
wrappers: Optional[Union[callable, List[callable]]] = None,
|
||||||
disable_env_checker: bool = False,
|
disable_env_checker: Optional[bool] = None,
|
||||||
new_step_api: bool = False,
|
new_step_api: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> VectorEnv:
|
) -> VectorEnv:
|
||||||
@@ -43,8 +43,10 @@ def make(
|
|||||||
The vectorized environment.
|
The vectorized environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_env(_disable_env_checker):
|
def create_env(env_num: int):
|
||||||
"""Creates an environment that can enable or disable the environment checker."""
|
"""Creates an environment that can enable or disable the environment checker."""
|
||||||
|
# If the env_num > 0 then disable the environment checker otherwise use the parameter
|
||||||
|
_disable_env_checker = True if env_num > 0 else disable_env_checker
|
||||||
|
|
||||||
def _make_env():
|
def _make_env():
|
||||||
env = gym.envs.registration.make(
|
env = gym.envs.registration.make(
|
||||||
|
@@ -1 +1 @@
|
|||||||
VERSION = "0.24.1"
|
VERSION = "0.25.0"
|
||||||
|
@@ -4,9 +4,9 @@ from gym.core import ActType
|
|||||||
from gym.utils.passive_env_checker import (
|
from gym.utils.passive_env_checker import (
|
||||||
check_action_space,
|
check_action_space,
|
||||||
check_observation_space,
|
check_observation_space,
|
||||||
passive_env_render_check,
|
env_render_passive_checker,
|
||||||
passive_env_reset_check,
|
env_reset_passive_checker,
|
||||||
passive_env_step_check,
|
env_step_passive_checker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -19,11 +19,11 @@ class PassiveEnvChecker(gym.Wrapper):
|
|||||||
|
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
env, "action_space"
|
env, "action_space"
|
||||||
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/"
|
), "The environment must specify an action space. https://www.gymlibrary.ml/content/environment_creation/"
|
||||||
check_action_space(env.action_space)
|
check_action_space(env.action_space)
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
env, "observation_space"
|
env, "observation_space"
|
||||||
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
|
), "The environment must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
|
||||||
check_observation_space(env.observation_space)
|
check_observation_space(env.observation_space)
|
||||||
|
|
||||||
self.checked_reset = False
|
self.checked_reset = False
|
||||||
@@ -34,7 +34,7 @@ class PassiveEnvChecker(gym.Wrapper):
|
|||||||
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
||||||
if self.checked_step is False:
|
if self.checked_step is False:
|
||||||
self.checked_step = True
|
self.checked_step = True
|
||||||
return passive_env_step_check(self.env, action)
|
return env_step_passive_checker(self.env, action)
|
||||||
else:
|
else:
|
||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ class PassiveEnvChecker(gym.Wrapper):
|
|||||||
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
||||||
if self.checked_reset is False:
|
if self.checked_reset is False:
|
||||||
self.checked_reset = True
|
self.checked_reset = True
|
||||||
return passive_env_reset_check(self.env, **kwargs)
|
return env_reset_passive_checker(self.env, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
@@ -50,6 +50,6 @@ class PassiveEnvChecker(gym.Wrapper):
|
|||||||
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
|
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
|
||||||
if self.checked_render is False:
|
if self.checked_render is False:
|
||||||
self.checked_render = True
|
self.checked_render = True
|
||||||
return passive_env_render_check(self.env, *args, **kwargs)
|
return env_render_passive_checker(self.env, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.env.render(*args, **kwargs)
|
return self.env.render(*args, **kwargs)
|
||||||
|
@@ -37,8 +37,7 @@ class StepAPICompatibility(gym.Wrapper):
|
|||||||
self.new_step_api = new_step_api
|
self.new_step_api = new_step_api
|
||||||
if not self.new_step_api:
|
if not self.new_step_api:
|
||||||
deprecation(
|
deprecation(
|
||||||
"Initializing environment in old step API which returns one bool instead of two. "
|
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
|
||||||
"It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. "
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
@@ -168,17 +168,6 @@ def test_customizable_resets(env_name: str, low_high: Optional[list]):
|
|||||||
env.step(env.action_space.sample())
|
env.step(env.action_space.sample())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"env_name", ["CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("low_high", [(-10.0, -9.0), (np.array(-10.0), np.array(-9.0))])
|
|
||||||
def test_customizable_out_of_bounds_resets(env_name: str, low_high: Optional[list]):
|
|
||||||
env = gym.make(env_name)
|
|
||||||
low, high = low_high
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
env.reset(options={"low": low, "high": high})
|
|
||||||
|
|
||||||
|
|
||||||
# We test Pendulum separately, as the parameters are handled differently.
|
# We test Pendulum separately, as the parameters are handled differently.
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"low_high",
|
"low_high",
|
||||||
@@ -221,4 +210,6 @@ def test_invalid_customizable_resets(env_name: str, low_high: list):
|
|||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
low, high = low_high
|
low, high = low_high
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
# match=re.escape(f"Lower bound ({low}) must be lower than higher bound ({high}).")
|
||||||
|
# match=f"An option ({x}) could not be converted to a float."
|
||||||
env.reset(options={"low": low, "high": high})
|
env.reset(options={"low": low, "high": high})
|
||||||
|
@@ -1,24 +1,52 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import gym
|
||||||
from gym.envs.registration import EnvSpec
|
from gym.envs.registration import EnvSpec
|
||||||
from gym.utils.env_checker import check_env
|
from gym.utils.env_checker import check_env
|
||||||
from tests.envs.utils import all_testing_env_specs, assert_equals, gym_testing_env_specs
|
from tests.envs.utils import all_testing_env_specs, assert_equals, gym_testing_env_specs
|
||||||
|
|
||||||
# This runs a smoketest on each official registered env. We may want
|
# This runs a smoketest on each official registered env. We may want
|
||||||
# to try also running environments which are not officially registered
|
# to try also running environments which are not officially registered envs.
|
||||||
# envs.
|
PASSIVE_CHECK_IGNORE_WARNING = [
|
||||||
|
f"\x1b[33mWARN: {message}\x1b[0m"
|
||||||
|
for message in [
|
||||||
|
"This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).",
|
||||||
|
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
||||||
|
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
CHECK_ENV_IGNORE_WARNINGS = [
|
||||||
|
f"\x1b[33mWARN: {message}\x1b[0m"
|
||||||
|
for message in [
|
||||||
|
"This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).",
|
||||||
|
"A Box observation space minimum value is -infinity. This is probably too low.",
|
||||||
|
"A Box observation space maximum value is -infinity. This is probably too high.",
|
||||||
|
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
|
||||||
|
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
||||||
|
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_spec", gym_testing_env_specs, ids=[spec.id for spec in gym_testing_env_specs]
|
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
||||||
)
|
)
|
||||||
def test_run_env_checker(env_spec: EnvSpec):
|
def test_envs_pass_env_checker(spec):
|
||||||
"""Runs the gym environment checker on the environment spec that calls the `reset`, `step` and `render`."""
|
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
||||||
env = env_spec.make(disable_env_checker=True)
|
with pytest.warns(None) as warnings:
|
||||||
check_env(env, skip_render_check=False)
|
env = spec.make(disable_env_checker=True).unwrapped
|
||||||
|
check_env(env)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
for warning in warnings.list:
|
||||||
|
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
|
||||||
|
print()
|
||||||
|
print(warning.message.args[0])
|
||||||
|
print(CHECK_ENV_IGNORE_WARNINGS[-1])
|
||||||
|
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||||
|
|
||||||
|
|
||||||
# Note that this precludes running this test in multiple threads.
|
# Note that this precludes running this test in multiple threads.
|
||||||
# However, we probably already can't do multithreading due to some environments.
|
# However, we probably already can't do multithreading due to some environments.
|
||||||
@@ -90,3 +118,5 @@ def test_render_modes(spec):
|
|||||||
new_env.reset()
|
new_env.reset()
|
||||||
new_env.step(new_env.action_space.sample())
|
new_env.step(new_env.action_space.sample())
|
||||||
new_env.render()
|
new_env.render()
|
||||||
|
new_env.close()
|
||||||
|
env.close()
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
"""Tests that gym.make works as expected."""
|
"""Tests that gym.make works as expected."""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -9,6 +10,7 @@ import gym
|
|||||||
from gym.envs.classic_control import cartpole
|
from gym.envs.classic_control import cartpole
|
||||||
from gym.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit
|
from gym.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit
|
||||||
from gym.wrappers.env_checker import PassiveEnvChecker
|
from gym.wrappers.env_checker import PassiveEnvChecker
|
||||||
|
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
|
||||||
from tests.envs.utils import all_testing_env_specs
|
from tests.envs.utils import all_testing_env_specs
|
||||||
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
|
||||||
from tests.wrappers.utils import has_wrapper
|
from tests.wrappers.utils import has_wrapper
|
||||||
@@ -100,18 +102,49 @@ def test_gym_make_autoreset():
|
|||||||
|
|
||||||
def test_make_disable_env_checker():
|
def test_make_disable_env_checker():
|
||||||
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`."""
|
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`."""
|
||||||
env = gym.make("CartPole-v1")
|
spec = deepcopy(gym.spec("CartPole-v1"))
|
||||||
|
|
||||||
|
# Test with spec disable env checker
|
||||||
|
spec.disable_env_checker = False
|
||||||
|
env = gym.make(spec)
|
||||||
assert has_wrapper(env, PassiveEnvChecker)
|
assert has_wrapper(env, PassiveEnvChecker)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=False)
|
# Test with overwritten spec using make disable env checker
|
||||||
assert has_wrapper(env, PassiveEnvChecker)
|
assert spec.disable_env_checker is False
|
||||||
env.close()
|
env = gym.make(spec, disable_env_checker=True)
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
|
||||||
assert has_wrapper(env, PassiveEnvChecker) is False
|
assert has_wrapper(env, PassiveEnvChecker) is False
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
# Test with spec enabled disable env checker
|
||||||
|
spec.disable_env_checker = True
|
||||||
|
env = gym.make(spec)
|
||||||
|
assert has_wrapper(env, PassiveEnvChecker) is False
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
# Test with overwritten spec using make disable env checker
|
||||||
|
assert spec.disable_env_checker is True
|
||||||
|
env = gym.make(spec, disable_env_checker=False)
|
||||||
|
assert has_wrapper(env, PassiveEnvChecker)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
||||||
|
)
|
||||||
|
def test_passive_checker_wrapper_warnings(spec):
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
env = gym.make(spec) # disable_env_checker=False
|
||||||
|
env.reset()
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
for warning in warnings.list:
|
||||||
|
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
|
||||||
|
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||||
|
|
||||||
|
|
||||||
def test_make_order_enforcing():
|
def test_make_order_enforcing():
|
||||||
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
|
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
|
||||||
|
@@ -9,7 +9,12 @@ from gym.envs.registration import EnvSpec
|
|||||||
|
|
||||||
|
|
||||||
def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
|
def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
|
||||||
"""Tries to make the environment showing if it is possible. Warning the environments have no wrappers, including time limit and order enforcing."""
|
"""Tries to make the environment showing if it is possible.
|
||||||
|
|
||||||
|
Warning the environments have no wrappers, including time limit and order enforcing.
|
||||||
|
"""
|
||||||
|
# To avoid issues with registered environments during testing, we check that the spec entry points are from gym.envs.
|
||||||
|
if "gym.envs." in env_spec.entry_point:
|
||||||
try:
|
try:
|
||||||
return env_spec.make(disable_env_checker=True).unwrapped
|
return env_spec.make(disable_env_checker=True).unwrapped
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
81
tests/testing_env.py
Normal file
81
tests/testing_env.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Provides a generic testing environment for use in tests with custom reset, step and render functions."""
|
||||||
|
import types
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym import spaces
|
||||||
|
from gym.core import ActType, ObsType
|
||||||
|
from gym.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
|
def basic_reset_fn(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
|
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
||||||
|
super(GenericTestEnv, self).reset(seed=seed)
|
||||||
|
self.observation_space.seed(seed)
|
||||||
|
if return_info:
|
||||||
|
return self.observation_space.sample(), {"options": options}
|
||||||
|
else:
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def basic_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||||
|
"""A basic step function that will pass the environment check using random actions from the observation space."""
|
||||||
|
return self.observation_space.sample(), 0, False, False, {}
|
||||||
|
|
||||||
|
|
||||||
|
def basic_render_fn(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# todo: change all testing environment to this generic class
|
||||||
|
class GenericTestEnv(gym.Env):
|
||||||
|
"""A generic testing environment for use in testing with modified environments are required."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||||
|
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||||
|
reset_fn: callable = basic_reset_fn,
|
||||||
|
step_fn: callable = basic_step_fn,
|
||||||
|
render_fn: callable = basic_render_fn,
|
||||||
|
render_modes: Optional[List[str]] = None,
|
||||||
|
render_fps: Optional[int] = None,
|
||||||
|
render_mode: Optional[str] = None,
|
||||||
|
spec: EnvSpec = EnvSpec("TestingEnv-v0"),
|
||||||
|
):
|
||||||
|
self.metadata = {"render_modes": render_modes}
|
||||||
|
if render_fps:
|
||||||
|
self.metadata["render_fps"] = render_fps
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.spec = spec
|
||||||
|
|
||||||
|
if observation_space is not None:
|
||||||
|
self.observation_space = observation_space
|
||||||
|
if action_space is not None:
|
||||||
|
self.action_space = action_space
|
||||||
|
|
||||||
|
if reset_fn is not None:
|
||||||
|
self.reset = types.MethodType(reset_fn, self)
|
||||||
|
if step_fn is not None:
|
||||||
|
self.step = types.MethodType(step_fn, self)
|
||||||
|
if render_fn is not None:
|
||||||
|
self.render = types.MethodType(render_fn, self)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_info: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
|
# If you need a default working reset function, use `basic_reset_fn` above
|
||||||
|
raise NotImplementedError("TestingEnv reset_fn is not set")
|
||||||
|
|
||||||
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||||
|
raise NotImplementedError("TestingEnv step_fn is not set")
|
@@ -1,41 +1,229 @@
|
|||||||
from typing import Optional
|
"""Tests that the `env_checker` runs as expects and all errors are possible."""
|
||||||
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Box, Dict, Discrete
|
from gym import spaces
|
||||||
from gym.utils.env_checker import check_env
|
from gym.utils.env_checker import (
|
||||||
|
check_env,
|
||||||
|
check_reset_info,
|
||||||
class ActionDictTestEnv(gym.Env):
|
check_reset_options,
|
||||||
action_space = Dict({"position": Discrete(1), "velocity": Discrete(1)})
|
check_reset_seed,
|
||||||
observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
|
|
||||||
|
|
||||||
def __init__(self, render_mode: Optional[str] = None):
|
|
||||||
self.render_mode = render_mode
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
observation = np.array([1.0, 1.5, 0.5])
|
|
||||||
reward = 1
|
|
||||||
done = True
|
|
||||||
return observation, reward, done
|
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
|
||||||
super().reset(seed=seed)
|
|
||||||
return np.array([1.0, 1.5, 0.5])
|
|
||||||
|
|
||||||
def render(self, mode: Optional[str] = "human"):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_env_dict_action():
|
|
||||||
# Environment.step() only returns 3 values: obs, reward, done. Not info!
|
|
||||||
test_env = ActionDictTestEnv()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError) as errorinfo:
|
|
||||||
check_env(env=test_env, warn=True)
|
|
||||||
assert (
|
|
||||||
str(errorinfo.value)
|
|
||||||
== "The `step()` method must return four values: obs, reward, done, info"
|
|
||||||
)
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env",
|
||||||
|
[
|
||||||
|
gym.make("CartPole-v1", disable_env_checker=True).unwrapped,
|
||||||
|
gym.make("MountainCar-v0", disable_env_checker=True).unwrapped,
|
||||||
|
GenericTestEnv(
|
||||||
|
observation_space=spaces.Dict(
|
||||||
|
a=spaces.Discrete(10), b=spaces.Box(np.zeros(2), np.ones(2))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
GenericTestEnv(
|
||||||
|
observation_space=spaces.Tuple(
|
||||||
|
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
GenericTestEnv(
|
||||||
|
observation_space=spaces.Dict(
|
||||||
|
a=spaces.Tuple(
|
||||||
|
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
|
||||||
|
),
|
||||||
|
b=spaces.Box(np.zeros(2), np.ones(2)),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_no_error_warnings(env):
|
||||||
|
"""A full version of this test with all gym envs is run in tests/envs/test_envs.py."""
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
check_env(env)
|
||||||
|
|
||||||
|
assert len(warnings) == 0, [warning.message for warning in warnings]
|
||||||
|
|
||||||
|
|
||||||
|
def _no_super_reset(self, seed=None, return_info=False, options=None):
|
||||||
|
self.np_random.random() # generates a new prng
|
||||||
|
# generate seed deterministic result
|
||||||
|
self.observation_space.seed(0)
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _super_reset_fixed(self, seed=None, return_info=False, options=None):
|
||||||
|
# Call super that ignores the seed passed, use fixed seed
|
||||||
|
super(GenericTestEnv, self).reset(seed=1)
|
||||||
|
# deterministic output
|
||||||
|
self.observation_space._np_random = self.np_random
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_default_seed(
|
||||||
|
self: GenericTestEnv, seed="Error", return_info=False, options=None
|
||||||
|
):
|
||||||
|
super(GenericTestEnv, self).reset(seed=seed)
|
||||||
|
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
|
||||||
|
self.np_random
|
||||||
|
)
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,func,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
gym.error.Error,
|
||||||
|
lambda self: self.observation_space.sample(),
|
||||||
|
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
lambda self, seed, *_: self.observation_space.sample(),
|
||||||
|
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_no_super_reset,
|
||||||
|
"Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_super_reset_fixed,
|
||||||
|
"Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
_reset_default_seed,
|
||||||
|
"The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. Actual default: Error",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_reset_seed(test, func: callable, message: str):
|
||||||
|
"""Tests the check reset seed function works as expected."""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
check_reset_seed(GenericTestEnv(reset_fn=func))
|
||||||
|
else:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
check_reset_seed(GenericTestEnv(reset_fn=func))
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_var_keyword_kwargs(self, kwargs):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_return_info_type(self, seed=None, return_info=False, options=None):
|
||||||
|
if return_info:
|
||||||
|
return [1, 2]
|
||||||
|
else:
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_return_info_length(self, seed=None, return_info=False, options=None):
|
||||||
|
if return_info:
|
||||||
|
return 1, 2, 3
|
||||||
|
else:
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _return_info_obs_outside(self, seed=None, return_info=False, options=None):
|
||||||
|
if return_info:
|
||||||
|
return self.observation_space.sample() + self.observation_space.high, {}
|
||||||
|
else:
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _return_info_not_dict(self, seed=None, return_info=False, options=None):
|
||||||
|
if return_info:
|
||||||
|
return self.observation_space.sample(), ["key", "value"]
|
||||||
|
else:
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,func,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
gym.error.Error,
|
||||||
|
lambda self, *_: self.observation_space.sample(),
|
||||||
|
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
gym.error.Error,
|
||||||
|
_reset_var_keyword_kwargs,
|
||||||
|
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_reset_return_info_type,
|
||||||
|
"Calling the reset method with `return_info=True` did not return a tuple, actual type: <class 'list'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_reset_return_info_length,
|
||||||
|
"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: 3",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_return_info_obs_outside,
|
||||||
|
"The first element returned by `env.reset(return_info=True)` is not within the observation space.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_return_info_not_dict,
|
||||||
|
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'list'>",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_reset_info(test, func: callable, message: str):
|
||||||
|
"""Tests the check reset info function works as expected."""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
check_reset_info(GenericTestEnv(reset_fn=func))
|
||||||
|
else:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
check_reset_info(GenericTestEnv(reset_fn=func))
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_reset_options():
|
||||||
|
"""Tests the check_reset_options function."""
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.Error,
|
||||||
|
match=re.escape(
|
||||||
|
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
check_reset_options(GenericTestEnv(reset_fn=lambda self: 0))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"Error",
|
||||||
|
"The environment must inherit from the gym.Env class. See https://www.gymlibrary.ml/content/environment_creation/ for more info.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
GenericTestEnv(action_space=None),
|
||||||
|
"The environment must specify an action space. See https://www.gymlibrary.ml/content/environment_creation/ for more info.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
GenericTestEnv(observation_space=None),
|
||||||
|
"The environment must specify an observation space. See https://www.gymlibrary.ml/content/environment_creation/ for more info.",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_env(env: gym.Env, message: str):
|
||||||
|
"""Tests the check_env function works as expected."""
|
||||||
|
with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"):
|
||||||
|
check_env(env)
|
||||||
|
461
tests/utils/test_passive_env_checker.py
Normal file
461
tests/utils/test_passive_env_checker.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
import re
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym import spaces
|
||||||
|
from gym.utils.passive_env_checker import (
|
||||||
|
check_action_space,
|
||||||
|
check_obs,
|
||||||
|
check_observation_space,
|
||||||
|
env_render_passive_checker,
|
||||||
|
env_reset_passive_checker,
|
||||||
|
env_step_passive_checker,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def _modify_space(space: spaces.Space, attribute: str, value):
|
||||||
|
setattr(space, attribute, value)
|
||||||
|
return space
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,space,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
"error",
|
||||||
|
"observation space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
# ===== Check box observation space ====
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.zeros((5, 5, 1)), 255 * np.ones((5, 5, 1)), dtype=np.int32),
|
||||||
|
"It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: int32. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.ones((2, 2, 1)), 255 * np.ones((2, 2, 1)), dtype=np.uint8),
|
||||||
|
"It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.zeros((5, 5, 1)), np.ones((5, 5, 1)), dtype=np.uint8),
|
||||||
|
"It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.zeros((5, 5)), np.ones((5, 5))),
|
||||||
|
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (5, 5)",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.zeros(5), np.zeros(5)),
|
||||||
|
"A Box observation space maximum and minimum values are equal.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.ones(5), np.zeros(5)),
|
||||||
|
"A Box observation space low value is greater than a high value.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
|
||||||
|
"The Box observation space shape and low shape have different shapes, low shape: (3,), box shape: (2,)",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
|
||||||
|
"The Box observation space shape and high shape have have different shapes, high shape: (3,), box shape: (2,)",
|
||||||
|
],
|
||||||
|
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.Discrete(5), "n", -1),
|
||||||
|
"Discrete observation space's number of elements must be positive, actual number of elements: -1",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.MultiDiscrete([2, 2]), "nvec", np.array([2, -1])),
|
||||||
|
"Multi-discrete observation space's all nvec elements must be greater than 0, actual nvec: [ 2 -1]",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, 1, 2)),
|
||||||
|
"Multi-discrete observation space's shape must be equal to the nvec shape, space shape: (2, 1, 2), nvec shape: (2,)",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)),
|
||||||
|
"Multi-binary observation space's all shape elements must be greater than 0, actual shape: (2, -1)",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
spaces.Tuple([]),
|
||||||
|
"An empty Tuple observation space is not allowed.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
spaces.Dict(),
|
||||||
|
"An empty Dict observation space is not allowed.",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_observation_space(test, space, message: str):
|
||||||
|
"""Tests the check observation space."""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
check_observation_space(space)
|
||||||
|
else:
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
check_observation_space(space)
|
||||||
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,space,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
"error",
|
||||||
|
"action space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
# ===== Check box observation space ====
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.zeros(5), np.zeros(5)),
|
||||||
|
"A Box action space maximum and minimum values are equal.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
spaces.Box(np.ones(5), np.zeros(5)),
|
||||||
|
"A Box action space low value is greater than a high value.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
|
||||||
|
"The Box action space shape and low shape have have different shapes, low shape: (3,), box shape: (2,)",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
|
||||||
|
"The Box action space shape and high shape have different shapes, high shape: (3,), box shape: (2,)",
|
||||||
|
],
|
||||||
|
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.Discrete(5), "n", -1),
|
||||||
|
"Discrete action space's number of elements must be positive, actual number of elements: -1",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.MultiDiscrete([2, 2]), "_shape", (2, -1)),
|
||||||
|
"Multi-discrete action space's shape must be equal to the nvec shape, space shape: (2, -1), nvec shape: (2,)",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_modify_space(spaces.MultiBinary((2, 2)), "_shape", (2, -1)),
|
||||||
|
"Multi-binary action space's all shape elements must be greater than 0, actual shape: (2, -1)",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
spaces.Tuple([]),
|
||||||
|
"An empty Tuple action space is not allowed.",
|
||||||
|
],
|
||||||
|
[AssertionError, spaces.Dict(), "An empty Dict action space is not allowed."],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_action_space(
|
||||||
|
test: Union[UserWarning, type], space: spaces.Space, message: str
|
||||||
|
):
|
||||||
|
"""Tests the check action space function."""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
check_action_space(space)
|
||||||
|
else:
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
check_action_space(space)
|
||||||
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,obs,obs_space,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
3,
|
||||||
|
spaces.Discrete(2),
|
||||||
|
"The obs returned by the `testing()` method is not within the observation space.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
np.uint8(0),
|
||||||
|
spaces.Discrete(1),
|
||||||
|
"The obs returned by the `testing()` method should be an int or np.int64, actual type: <class 'numpy.uint8'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
[0, 1],
|
||||||
|
spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]),
|
||||||
|
"The obs returned by the `testing()` method was expecting a tuple, actual type: <class 'list'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
(1, 2, 3),
|
||||||
|
spaces.Tuple([spaces.Discrete(1), spaces.Discrete(2)]),
|
||||||
|
"The obs returned by the `testing()` method length is not same as the observation space length, obs length: 3, space length: 2",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
{1, 2, 3},
|
||||||
|
spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)),
|
||||||
|
"The obs returned by the `testing()` method must be a dict, actual type: <class 'set'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
{"a": 1, "c": 2},
|
||||||
|
spaces.Dict(a=spaces.Discrete(1), b=spaces.Discrete(2)),
|
||||||
|
"The obs returned by the `testing()` method observation keys is not same as the observation space keys, obs keys: ['a', 'c'], space keys: ['a', 'b']",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_obs(test, obs, obs_space: spaces.Space, message: str):
|
||||||
|
"""Tests the check observations function."""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
check_obs(obs, obs_space, "testing")
|
||||||
|
else:
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
check_obs(obs, obs_space, "testing")
|
||||||
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_no_seed(self, return_info=False, options=None):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_seed_default(self, seed="error", return_info=False, options=None):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_no_return_info(self, seed=None, options=None):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_no_option(self, seed=None, return_info=False):
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_reset_results(results):
|
||||||
|
def _reset_result(self, seed=None, return_info=False, options=None):
|
||||||
|
return results
|
||||||
|
|
||||||
|
return _reset_result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,func,message,kwargs",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
_reset_no_seed,
|
||||||
|
"Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.",
|
||||||
|
{},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
_reset_seed_default,
|
||||||
|
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
|
||||||
|
{},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
_reset_no_return_info,
|
||||||
|
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.",
|
||||||
|
{},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
_reset_no_option,
|
||||||
|
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.",
|
||||||
|
{},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_make_reset_results([0, {}]),
|
||||||
|
"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: <class 'list'>",
|
||||||
|
{"return_info": True},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
_make_reset_results((0, {1, 2})),
|
||||||
|
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'set'>",
|
||||||
|
{"return_info": True},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: Dict):
|
||||||
|
"""Tests the passive env reset check"""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
||||||
|
else:
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
||||||
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _modified_step(
|
||||||
|
self, obs=None, reward=0, terminated=False, truncated=False, info=None
|
||||||
|
):
|
||||||
|
if obs is None:
|
||||||
|
obs = self.observation_space.sample()
|
||||||
|
if info is None:
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
if truncated is None:
|
||||||
|
return obs, reward, terminated, info
|
||||||
|
else:
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,func,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
lambda self, _: "error",
|
||||||
|
"Expects step result to be a tuple, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
lambda self, _: _modified_step(self, terminated="error", truncated=None),
|
||||||
|
"Expects `done` signal to be a boolean, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
lambda self, _: _modified_step(self, terminated="error", truncated=False),
|
||||||
|
"Expects `terminated` signal to be a boolean, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
lambda self, _: _modified_step(self, truncated="error"),
|
||||||
|
"Expects `truncated` signal to be a boolean, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
gym.error.Error,
|
||||||
|
lambda self, _: (1, 2, 3),
|
||||||
|
"Expected `Env.step` to return a four or five element tuple, actual number of elements returned: 3.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
lambda self, _: _modified_step(self, reward="error"),
|
||||||
|
"The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
lambda self, _: _modified_step(self, reward=np.nan),
|
||||||
|
"The reward is a NaN value.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
lambda self, _: _modified_step(self, reward=np.inf),
|
||||||
|
"The reward is an inf value.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
lambda self, _: _modified_step(self, info="error"),
|
||||||
|
"The `info` returned by `step()` must be a python dictionary, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_passive_env_step_checker(
|
||||||
|
test: Union[UserWarning, type], func: callable, message: str
|
||||||
|
):
|
||||||
|
"""Tests the passive env step checker."""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
|
||||||
|
else:
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
env_step_passive_checker(GenericTestEnv(step_fn=func), 0)
|
||||||
|
assert len(warnings) == 0, [warning for warning in warnings.list]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test,env,message",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
GenericTestEnv(render_modes=None),
|
||||||
|
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
GenericTestEnv(render_modes="Testing mode"),
|
||||||
|
"Expects the render_modes to be a sequence (i.e. list, tuple), actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
GenericTestEnv(render_modes=["Testing mode", 1], render_fps=1),
|
||||||
|
"Expects all render modes to be strings, actual types: [<class 'str'>, <class 'int'>]",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
GenericTestEnv(
|
||||||
|
render_modes=["Testing mode"],
|
||||||
|
render_fps=None,
|
||||||
|
render_mode="Testing mode",
|
||||||
|
render_fn=lambda self: 0,
|
||||||
|
),
|
||||||
|
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
GenericTestEnv(render_modes=["Testing mode"], render_fps="fps"),
|
||||||
|
"Expects the `env.metadata['render_fps']` to be an integer or a float, actual type: <class 'str'>",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
GenericTestEnv(render_modes=[], render_fps=30, render_mode="Test"),
|
||||||
|
"With no render_modes, expects the Env.render_mode to be None, actual value: Test",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
AssertionError,
|
||||||
|
GenericTestEnv(
|
||||||
|
render_modes=["Testing mode"], render_fps=30, render_mode="Non mode"
|
||||||
|
),
|
||||||
|
"The environment was initialized successfully however with an unsupported render mode. Render mode: Non mode, modes: ['Testing mode']",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_passive_render_checker(test, env: GenericTestEnv, message: str):
|
||||||
|
"""Tests the passive render checker."""
|
||||||
|
if test is UserWarning:
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
env_render_passive_checker(env)
|
||||||
|
else:
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
|
env_render_passive_checker(env)
|
||||||
|
assert len(warnings) == 0
|
@@ -85,6 +85,10 @@ def test_create_shared_memory_custom_space(n, ctx, space):
|
|||||||
create_shared_memory(space, n=n, ctx=ctx)
|
create_shared_memory(space, n=n, ctx=ctx)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_shared_memory(space, i, shared_memory, sample):
|
||||||
|
write_to_shared_memory(space, i, sample, shared_memory)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
||||||
)
|
)
|
||||||
@@ -105,14 +109,14 @@ def test_write_to_shared_memory(space):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f"Got unknown type `{type(lhs)}`.")
|
raise TypeError(f"Got unknown type `{type(lhs)}`.")
|
||||||
|
|
||||||
def write(i, shared_memory, sample):
|
|
||||||
write_to_shared_memory(space, i, sample, shared_memory)
|
|
||||||
|
|
||||||
shared_memory_n8 = create_shared_memory(space, n=8)
|
shared_memory_n8 = create_shared_memory(space, n=8)
|
||||||
samples = [space.sample() for _ in range(8)]
|
samples = [space.sample() for _ in range(8)]
|
||||||
|
|
||||||
processes = [
|
processes = [
|
||||||
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
|
Process(
|
||||||
|
target=_write_shared_memory, args=(space, i, shared_memory_n8, samples[i])
|
||||||
|
)
|
||||||
|
for i in range(8)
|
||||||
]
|
]
|
||||||
|
|
||||||
for process in processes:
|
for process in processes:
|
||||||
|
100
tests/wrappers/test_passive_env_checker.py
Normal file
100
tests/wrappers/test_passive_env_checker.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym.wrappers.env_checker import PassiveEnvChecker
|
||||||
|
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
|
||||||
|
from tests.envs.utils import all_testing_initialised_envs
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env",
|
||||||
|
all_testing_initialised_envs,
|
||||||
|
ids=[env.spec.id for env in all_testing_initialised_envs],
|
||||||
|
)
|
||||||
|
def test_passive_checker_wrapper_warnings(env):
|
||||||
|
with pytest.warns(None) as warnings:
|
||||||
|
checker_env = PassiveEnvChecker(env)
|
||||||
|
checker_env.reset()
|
||||||
|
checker_env.step(checker_env.action_space.sample())
|
||||||
|
# todo, add check for render, bugged due to mujoco v2/3 and v4 envs
|
||||||
|
|
||||||
|
checker_env.close()
|
||||||
|
|
||||||
|
for warning in warnings.list:
|
||||||
|
if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
|
||||||
|
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env, message",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
GenericTestEnv(action_space=None),
|
||||||
|
"The environment must specify an action space. https://www.gymlibrary.ml/content/environment_creation/",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
GenericTestEnv(action_space="error"),
|
||||||
|
"action space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
GenericTestEnv(observation_space=None),
|
||||||
|
"The environment must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
GenericTestEnv(observation_space="error"),
|
||||||
|
"observation space does not inherit from `gym.spaces.Space`, actual type: <class 'str'>",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_initialise_failures(env, message):
|
||||||
|
with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"):
|
||||||
|
PassiveEnvChecker(env)
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_failure(self, seed=None, return_info=False, options=None):
|
||||||
|
return np.array([-1.0], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _step_failure(self, action):
|
||||||
|
return "error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_failures():
|
||||||
|
env = GenericTestEnv(
|
||||||
|
reset_fn=_reset_failure, step_fn=_step_failure, render_modes="error"
|
||||||
|
)
|
||||||
|
env = PassiveEnvChecker(env)
|
||||||
|
assert env.checked_reset is False
|
||||||
|
assert env.checked_step is False
|
||||||
|
assert env.checked_render is False
|
||||||
|
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning,
|
||||||
|
match=re.escape(
|
||||||
|
"The obs returned by the `reset()` method is not within the observation space"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
env.reset()
|
||||||
|
assert env.checked_reset
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError,
|
||||||
|
match="Expects step result to be a tuple, actual type: <class 'str'>",
|
||||||
|
):
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
assert env.checked_step
|
||||||
|
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning,
|
||||||
|
match=r"Expects the render_modes to be a sequence \(i\.e\. list, tuple\), actual type: <class 'str'>",
|
||||||
|
):
|
||||||
|
env.render()
|
||||||
|
assert env.checked_render
|
||||||
|
|
||||||
|
env.close()
|
Reference in New Issue
Block a user