Updates the environment checker (#2864)

* Updated testing requirements based off extra["testing"]

* Updated setup to check the version is valid, added testing and all dependency groups and collects the requirements from requirements.txt to keep everything standardized.

* Updated requirements.txt based on the current minimum gym requirements.txt to work

* Updated requirements.txt based on the current minimum gym requirements.txt to work

* Updated test_requirements.txt based on the current gym full testing requirements

* Pre-commit updates

* Add integer check for the `n` parameter

* The type of self.spaces is an Iterable which is absorbed by the tuple.

* Simplifies the environment checker to two files, env_checker.py and passive_env_checker.py with a new wrapper env_checker.py

* Adds the passive environment checker on `gym.make`

* Ignore the `check_env` warn parameter

* Ignore the `check_env` warn parameter

* Use the `data_equivalence` function

* Revert rewrite setup.py changes

* Remove smart formatting for 3.6 support

* Fixed `check_action_space` and `check_observation_space`

* Added disable_env_checker to vector.make such that env_checker would only run on the first environment created.

* Removing check that different seeds would produce different initialising states

* Use the unwrapped environment np_random

* Fixed vector environment creator
This commit is contained in:
Mark Towers
2022-06-06 16:21:45 +01:00
committed by GitHub
parent 9fa7ede1e3
commit 134de4a713
9 changed files with 596 additions and 475 deletions

View File

@@ -23,8 +23,8 @@ from typing import (
import numpy as np
from gym.envs.__relocated__ import internal_env_relocation_map
from gym.utils.env_checker import check_env
from gym.wrappers import AutoResetWrapper, OrderEnforcing, TimeLimit
from gym.wrappers.env_checker import PassiveEnvChecker
if sys.version_info < (3, 10):
import importlib_metadata as metadata # type: ignore
@@ -43,7 +43,15 @@ ENV_ID_RE = re.compile(
)
def load(name: str) -> type:
def load(name: str) -> callable:
"""Loads an environment with name and returns an environment creation function
Args:
name: The environment name
Returns:
Calls the environment constructor
"""
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
@@ -519,11 +527,6 @@ def make(
) -> Env:
"""Create an environment according to the given ID.
Warnings:
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid
if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`.
Args:
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
@@ -578,49 +581,44 @@ def make(
_kwargs = spec_.kwargs.copy()
_kwargs.update(kwargs)
# TODO: add a minimal env checker on initialization
if spec_.entry_point is None:
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
elif callable(spec_.entry_point):
cls = spec_.entry_point
env_creator = spec_.entry_point
else:
# Assume it's a string
cls = load(spec_.entry_point)
env_creator = load(spec_.entry_point)
env = cls(**_kwargs)
env = env_creator(**_kwargs)
# Copies the environment creation specification and kwargs to add to the environment specification details
spec_ = copy.deepcopy(spec_)
spec_.kwargs = _kwargs
env.unwrapped.spec = spec_
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False:
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if spec_.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif spec_.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps)
# Add the autoreset wrapper
if autoreset:
env = AutoResetWrapper(env)
if not disable_env_checker:
try:
check_env(env)
except Exception as e:
logger.warn(
f"Env check failed with the following message: {e}\n"
f"You can set `disable_env_checker=True` to disable this check."
)
return env
def spec(env_id: str) -> EnvSpec:
"""
Retrieve the spec for the given environment from the global registry.
"""
"""Retrieve the spec for the given environment from the global registry."""
spec_ = registry.get(env_id)
if spec_ is None:
ns, name, version = parse_env_id(env_id)

View File

@@ -33,6 +33,7 @@ class Discrete(Space[int]):
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
start (int): The smallest element of this space.
"""
assert isinstance(n, (int, np.integer))
assert n > 0, "n (counts) have to be positive"
assert isinstance(start, (int, np.integer))
self.n = int(n)

View File

@@ -33,9 +33,8 @@ class Tuple(Space[tuple], Sequence):
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
spaces = tuple(spaces)
self.spaces = spaces
for space in spaces:
self.spaces = tuple(spaces)
for space in self.spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"

View File

@@ -8,367 +8,110 @@ It also uses some warnings/assertions from the PettingZoo repository hosted on G
(https://github.com/PettingZoo-Team/PettingZoo)
Original Author: J K Terry
This was rewritten and split into "env_checker.py" and "passive_env_checker.py" for invasive and passive environment checking
Original Author: Mark Towers
These projects are covered by the MIT License.
"""
import inspect
from typing import Optional, Union
from copy import deepcopy
import numpy as np
import gym
from gym import logger
from gym.spaces import Box, Dict, Discrete, Space, Tuple
from gym import error, logger
from gym.utils.passive_env_checker import (
check_action_space,
check_observation_space,
passive_env_reset_check,
passive_env_step_check,
)
def _is_numpy_array_space(space: Space) -> bool:
"""Checks if a space can be represented as a single numpy array (e.g. Dict and Tuple spaces return False).
def data_equivalence(data_1, data_2) -> bool:
"""Assert equality between data 1 and 2, i.e observations, actions, info.
Args:
space: The space to check
data_1: data structure 1
data_2: data structure 2
Returns:
Returns False if the provided space is not representable as a single numpy array
If observation 1 and 2 are equivalent
"""
return not isinstance(space, (Dict, Tuple))
def _check_image_input(observation_space: Box, key: str = ""):
"""Check whether an observation space of type :class:`Box` adheres to general standards for spaces that represent images.
It will check that:
- The datatype is ``np.uint8``
- The lower bound is 0 across all dimensions
- The upper bound is 255 across all dimensions
Args:
observation_space: The observation space to check
key: The observation shape key for warning
"""
if observation_space.dtype != np.uint8:
logger.warn(
f"It seems that your observation {key} is an image but the `dtype` "
"of your observation_space is not `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation "
"to have only a 1D vector"
)
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
logger.warn(
f"It seems that your observation space {key} is an image but the "
"upper and lower bounds are not in [0, 255]. "
"Generally, CNN policies assume observations are within that range, "
"so you may encounter an issue if the observation values are not."
)
def _check_nan(env: gym.Env, check_inf: bool = True):
"""Check if the environment observation, reward are NaN and Inf.
Args:
env: The environment to check
check_inf: Checks if the observation is infinity
"""
for _ in range(10):
action = env.action_space.sample()
observation, reward, done, _ = env.step(action)
if done:
env.reset()
if np.any(np.isnan(observation)):
logger.warn("Encountered NaN value in observations.")
if np.any(np.isnan(reward)):
logger.warn("Encountered NaN value in rewards.")
if check_inf and np.any(np.isinf(observation)):
logger.warn("Encountered inf value in observations.")
if check_inf and np.any(np.isinf(reward)):
logger.warn("Encountered inf value in rewards.")
def _check_obs(
obs: Union[tuple, dict, np.ndarray, int],
observation_space: Space,
method_name: str,
):
"""Check that the observation returned by the environment correspond to the declared one.
Args:
obs: The observation to check
observation_space: The observation space of the observation
method_name: The method name that generated the observation
"""
if not isinstance(observation_space, Tuple):
assert not isinstance(
obs, tuple
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
if isinstance(observation_space, Discrete):
assert isinstance(
obs, int
), f"The observation returned by `{method_name}()` method must be an int"
elif _is_numpy_array_space(observation_space):
assert isinstance(
obs, np.ndarray
), f"The observation returned by `{method_name}()` method must be a numpy array"
assert observation_space.contains(
obs
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
def _check_box_obs(observation_space: Box, key: str = ""):
"""Check that the observation space is correctly formatted when dealing with a :class:`Box` space.
In particular, it checks:
- that the dimensions are big enough when it is an image, and that the type matches
- that the observation has an expected shape (warn the user if not)
Args:
observation_space: Checks if the Box observation space
key: The observation key
"""
# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
if len(observation_space.shape) == 3:
_check_image_input(observation_space)
if len(observation_space.shape) not in [1, 3]:
logger.warn(
f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). "
"We recommend you to flatten the observation "
"to have only a 1D vector or use a custom policy to properly process the data."
)
if np.any(np.equal(observation_space.low, -np.inf)):
logger.warn(
"Agent's minimum observation space value is -infinity. This is probably too low."
)
if np.any(np.equal(observation_space.high, np.inf)):
logger.warn(
"Agent's maxmimum observation space value is infinity. This is probably too high"
)
if np.any(np.equal(observation_space.low, observation_space.high)):
logger.warn("Agent's maximum and minimum observation space values are equal")
if np.any(np.greater(observation_space.low, observation_space.high)):
assert False, "Agent's minimum observation value is greater than it's maximum"
if observation_space.low.shape != observation_space.shape:
assert (
False
), "Agent's observation_space.low and observation_space have different shapes"
if observation_space.high.shape != observation_space.shape:
assert (
False
), "Agent's observation_space.high and observation_space have different shapes"
def _check_box_action(action_space: Box):
"""Checks that a :class:`Box` action space is defined in a sensible way.
Args:
action_space: A box action space
"""
if np.any(np.equal(action_space.low, -np.inf)):
logger.warn(
"Agent's minimum action space value is -infinity. This is probably too low."
)
if np.any(np.equal(action_space.high, np.inf)):
logger.warn(
"Agent's maximum action space value is infinity. This is probably too high"
)
if np.any(np.equal(action_space.low, action_space.high)):
logger.warn("Agent's maximum and minimum action space values are equal")
if np.any(np.greater(action_space.low, action_space.high)):
assert False, "Agent's minimum action value is greater than it's maximum"
if action_space.low.shape != action_space.shape:
assert False, "Agent's action_space.low and action_space have different shapes"
if action_space.high.shape != action_space.shape:
assert False, "Agent's action_space.high and action_space have different shapes"
def _check_normalized_action(action_space: Box):
"""Checks that a box action space is normalized.
Args:
action_space: A box action space
"""
if (
np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(np.abs(action_space.low) > 1)
or np.any(np.abs(action_space.high) > 1)
):
logger.warn(
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
)
def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space):
"""Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods.
Args:
env: The environment
observation_space: The environment's observation space
action_space: The environment's action space
Raises:
AssertionError: If the ``observation_space`` is :class:`Dict` and
keys from :meth:`Env.reset` are not in the observation space
"""
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
obs = env.reset()
if isinstance(observation_space, Dict):
assert isinstance(
obs, dict
), "The observation returned by `reset()` must be a dictionary"
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "reset")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e))
else:
_check_obs(obs, observation_space, "reset")
# Sample a random action
action = action_space.sample()
data = env.step(action)
assert (
len(data) == 4
), "The `step()` method must return four values: obs, reward, done, info"
# Unpack
obs, reward, done, info = data
if isinstance(observation_space, Dict):
assert isinstance(
obs, dict
), "The observation returned by `step()` must be a dictionary"
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "step")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e))
else:
_check_obs(obs, observation_space, "step")
# We also allow int because the reward will be cast to float
assert isinstance(
reward, (float, int, np.float32)
), "The reward returned by `step()` must be a float"
assert isinstance(done, bool), "The `done` signal must be a boolean"
assert isinstance(
info, dict
), "The `info` returned by `step()` must be a python dictionary"
def _check_spaces(env: gym.Env):
"""Check that the observation and action spaces are defined and inherit from :class:`gym.spaces.Space`.
Args:
env: The environment's observation and action space to check
"""
# Helper to link to the code, because gym has no proper documentation
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
assert hasattr(env, "observation_space"), (
"You must specify an observation space (cf gym.spaces)" + gym_spaces
)
assert hasattr(env, "action_space"), (
"You must specify an action space (cf gym.spaces)" + gym_spaces
)
assert isinstance(env.observation_space, Space), (
"The observation space must inherit from gym.spaces" + gym_spaces
)
assert isinstance(env.action_space, Space), (
"The action space must inherit from gym.spaces" + gym_spaces
)
# Check render cannot be covered by CI
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False):
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
Args:
env: The environment to check
warn: Whether to output additional warnings
headless: Whether to disable render modes that require a graphical interface. False by default.
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
if warn:
logger.warn(
"No render modes was declared in the environment "
" (env.metadata['render_modes'] is None or not defined), "
"you may have trouble when calling `.render()`"
if type(data_1) == type(data_2):
if isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all(
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
render_fps = env.metadata.get("render_fps")
# We only require `render_fps` if rendering is actually implemented
if render_fps is None and render_modes is not None and len(render_modes) > 0:
if warn:
logger.warn(
"No render fps was declared in the environment "
" (env.metadata['render_fps'] is None or not defined), "
"rendering may occur at inconsistent fps"
elif isinstance(data_1, tuple):
return len(data_1) == len(data_2) and all(
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
elif isinstance(data_1, np.ndarray):
return np.all(data_1 == data_2)
else:
return data_1 == data_2
else:
# Don't check render mode that require a
# graphical interface (useful for CI)
if headless and "human" in render_modes:
render_modes.remove("human")
# Check all declared render modes
for render_mode in render_modes:
env.render(mode=render_mode)
env.close()
return False
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
def check_reset_seed(env: gym.Env):
"""Check that the environment can be reset with a seed.
Args:
env: The environment to check
seed: The optional seed to use
Raises:
AssertionError: The environment cannot be reset with a random seed,
even though `seed` or `kwargs` appear in the signature.
"""
signature = inspect.signature(env.reset)
assert (
"seed" in signature.parameters or "kwargs" in signature.parameters
), "The environment cannot be reset with a random seed. This behavior will be deprecated in the future."
if "seed" in signature.parameters or "kwargs" in signature.parameters:
try:
obs_1 = env.reset(seed=123)
assert obs_1 in env.observation_space
obs_2 = env.reset(seed=123)
assert obs_2 in env.observation_space
assert data_equivalence(obs_1, obs_2)
seed_123_rng = deepcopy(env.unwrapped.np_random)
try:
env.reset(seed=seed)
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. "
f"The error was: {e}"
)
# Note: for some environment, they may initialise at the same state, therefore we cannot check the obs_1 != obs_3
obs_4 = env.reset(seed=None)
assert obs_4 in env.observation_space
if env.unwrapped.np_random is None:
logger.warn(
"Resetting the environment did not result in seeding its random number generator. "
"This is likely due to not calling `super().reset(seed=seed)` in the `reset` method. "
"If you do not use the python-level random number generator, this is not a problem."
)
assert (
env.unwrapped.np_random.bit_generator.state
!= seed_123_rng.bit_generator.state
)
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
"This should never happen, please report this issue. "
f"The error was: {e}"
)
seed_param = signature.parameters.get("seed")
# Check the default value is None
if seed_param is not None and seed_param.default is not None:
logger.warn(
"The default seed argument in reset should be `None`, "
"otherwise the environment will by default always be deterministic"
if env.unwrapped.np_random is None:
logger.warn(
"Resetting the environment did not result in seeding its random number generator. "
"This is likely due to not calling `super().reset(seed=seed)` in the `reset` method. "
"If you do not use the python-level random number generator, this is not a problem."
)
seed_param = signature.parameters.get("seed")
# Check the default value is None
if seed_param is not None and seed_param.default is not None:
logger.warn(
"The default seed argument in reset should be `None`, "
"otherwise the environment will by default always be deterministic"
)
else:
raise error.Error(
"The `reset` method does not provide the `return_info` keyword argument"
)
def _check_reset_info(env: gym.Env):
def check_reset_info(env: gym.Env):
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
Args:
@@ -379,29 +122,29 @@ def _check_reset_info(env: gym.Env):
even though `return_info` or `kwargs` appear in the signature.
"""
signature = inspect.signature(env.reset)
assert (
"return_info" in signature.parameters or "kwargs" in signature.parameters
), "The `reset` method does not provide the `return_info` keyword argument"
try:
result = env.reset(return_info=True)
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. "
f"The error was: {e}"
if "return_info" in signature.parameters or "kwargs" in signature.parameters:
try:
result = env.reset(return_info=True)
assert (
len(result) == 2
), "Calling the reset method with `return_info=True` did not return a 2-tuple"
obs, info = result
assert isinstance(
info, dict
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. "
f"The error was: {e}"
)
else:
raise error.Error(
"The `reset` method does not provide the `return_info` keyword argument"
)
assert (
len(result) == 2
), "Calling the reset method with `return_info=True` did not return a 2-tuple"
obs, info = result
assert isinstance(
info, dict
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
def _check_reset_options(env: gym.Env):
def check_reset_options(env: gym.Env):
"""Check that the environment can be reset with options.
Args:
@@ -412,73 +155,94 @@ def _check_reset_options(env: gym.Env):
even though `options` or `kwargs` appear in the signature.
"""
signature = inspect.signature(env.reset)
assert (
"options" in signature.parameters or "kwargs" in signature.parameters
), "The environment cannot be reset with options. This behavior will be deprecated in the future."
try:
env.reset(options={})
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with options, even though `options` or `kwargs` "
"appear in the signature. This should never happen, please report this issue. "
f"The error was: {e}"
if "options" in signature.parameters or "kwargs" in signature.parameters:
try:
env.reset(options={})
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with options, even though `options` or `kwargs` appear in the signature. "
"This should never happen, please report this issue. "
f"The error was: {e}"
)
else:
raise error.Error(
"The `reset` method does not provide the `options` keyword argument"
)
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True):
# Check render cannot be covered by CI
def check_render(env: gym.Env, headless: bool = False):
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
Args:
env: The environment to check
headless: Whether to disable render modes that require a graphical interface. False by default.
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
logger.warn(
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`"
)
render_fps = env.metadata.get("render_fps")
# We only require `render_fps` if rendering is actually implemented
if render_fps is None:
logger.warn(
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps"
)
if render_modes is not None:
# Don't check render mode that require a
# graphical interface (useful for CI)
if headless and "human" in render_modes:
render_modes.remove("human")
# Check all declared render modes
for mode in render_modes:
env.render(mode=mode)
env.close()
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
"""Check that an environment follows Gym API.
This is an invasive function that calls the environment's reset and step.
This is particularly useful when using a custom environment.
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
Please take a look at https://www.gymlibrary.ml/content/environment_creation/
for more information about the API.
It also optionally check that the environment is compatible with Stable-Baselines.
Args:
env: The Gym environment that will be checked
warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines
warn: Ignored
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
"""
if warn is not None:
logger.warn("`check_env` warn parameter is now ignored.")
assert isinstance(
env, gym.Env
), "Your environment must inherit from the gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py"
), "Your environment must inherit from the gym.Env class https://www.gymlibrary.ml/content/environment_creation/"
# ============= Check the spaces (observation and action) ================
_check_spaces(env)
# Define aliases for convenience
observation_space = env.observation_space
action_space = env.action_space
assert hasattr(
env, "action_space"
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/"
check_observation_space(env.action_space)
assert hasattr(
env, "observation_space"
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
check_action_space(env.observation_space)
# Warn the user if needed.
# A warning means that the environment may run but not work properly with popular RL libraries.
if warn:
obs_spaces = (
observation_space.spaces
if isinstance(observation_space, Dict)
else {"": observation_space}
)
for key, space in obs_spaces.items():
if isinstance(space, Box):
_check_box_obs(space, key)
# Check for the action space, it may lead to hard-to-debug issues
if isinstance(action_space, Box):
_check_box_action(action_space)
_check_normalized_action(action_space)
# ==== Check the reset method ====
check_reset_seed(env)
check_reset_options(env)
check_reset_info(env)
# ============ Check the returned values ===============
_check_returned_values(env, observation_space, action_space)
passive_env_reset_check(env)
passive_env_step_check(env, env.action_space.sample())
# ==== Check the render method and the declared render modes ====
if not skip_render_check:
_check_render(env, warn=warn) # pragma: no cover
# The check only works with numpy arrays
if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space):
_check_nan(env)
# ==== Check the reset method ====
_check_reset_seed(env)
_check_reset_seed(env, seed=0)
_check_reset_options(env)
_check_reset_info(env)
check_render(env)

View File

@@ -0,0 +1,310 @@
"""A set of functions for passively checking environment implementations."""
import inspect
import numpy as np
from gym import error, logger, spaces
def _check_box_observation_space(observation_space: spaces.Box):
"""Checks that a :class:`Box` observation space is defined in a sensible way.
Args:
observation_space: A box observation space
"""
# Check if the box is an image
if len(observation_space.shape) == 3:
if observation_space.dtype != np.uint8:
logger.warn(
f"It seems that your observation space ({observation_space}) is an image but the `dtype` of your observation_space is not `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation to have only a 1D vector"
)
if np.any(observation_space.low != 0) or np.any(
observation_space.high != 255
): # todo np.all?
logger.warn(
"It seems that your observation space is an image but the upper and lower bounds are not in [0, 255]. "
"Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not."
)
if len(observation_space.shape) not in [1, 3]:
logger.warn(
"Your observation space has an unconventional shape (neither an image, nor a 1D vector). "
"We recommend you to flatten the observation to have only a 1D vector or use a custom policy to properly process the data. "
f"Observation space={observation_space}"
)
if np.any(np.equal(observation_space.low, -np.inf)):
logger.warn(
"Agent's minimum observation space value is -infinity. This is probably too low."
)
if np.any(np.equal(observation_space.high, np.inf)):
logger.warn(
"Agent's maximum observation space value is infinity. This is probably too high."
)
if np.any(np.equal(observation_space.low, observation_space.high)):
logger.warn("Agent's maximum and minimum observation space values are equal")
if np.any(np.greater(observation_space.low, observation_space.high)):
raise AssertionError(
"Agent's minimum observation value is greater than it's maximum"
)
if observation_space.low.shape != observation_space.shape:
raise AssertionError(
"Agent's observation_space.low and observation_space have different shapes"
)
if observation_space.high.shape != observation_space.shape:
raise AssertionError(
"Agent's observation_space.high and observation_space have different shapes"
)
def _check_box_action_space(action_space: spaces.Box):
"""Checks that a :class:`Box` action space is defined in a sensible way.
Args:
action_space: A box action space
"""
if np.any(np.equal(action_space.low, -np.inf)):
logger.warn(
"Agent's minimum action space value is -infinity. This is probably too low."
)
if np.any(np.equal(action_space.high, np.inf)):
logger.warn(
"Agent's maximum action space value is infinity. This is probably too high"
)
if np.any(np.equal(action_space.low, action_space.high)):
logger.warn("Agent's maximum and minimum action space values are equal")
if np.any(np.greater(action_space.low, action_space.high)):
raise AssertionError(
"Agent's minimum action value is greater than it's maximum"
)
if action_space.low.shape != action_space.shape:
raise AssertionError(
"Agent's action_space.low and action_space have different shapes"
)
if action_space.high.shape != action_space.shape:
raise AssertionError(
"Agent's action_space.high and action_space have different shapes"
)
# Check that the Box space is normalized
if (
np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(np.abs(action_space.low) > 1)
or np.any(np.abs(action_space.high) > 1)
):
logger.warn(
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" # TODO Add to gymlibrary.ml?
)
def _check_obs(obs, observation_space: spaces.Space, method_name: str):
"""Check that the observation returned by the environment correspond to the declared one.
Args:
obs: The observation to check
observation_space: The observation space of the observation
method_name: The method name that generated the observation
"""
pre = f"The observation returned by the `{method_name}()` method"
assert observation_space.contains(
obs
), f"{pre} is not contained with the observation space ({observation_space})"
if isinstance(observation_space, spaces.Discrete):
assert isinstance(
obs, int
), f"The observation returned by `{method_name}()` method must be an int, actually {type(obs)}"
elif isinstance(
observation_space, (spaces.Box, spaces.MultiBinary, spaces.MultiDiscrete)
):
assert isinstance(
obs, np.ndarray
), f"The observation returned by `{method_name}()` method must be a numpy array, actually {type(obs)}"
elif isinstance(observation_space, spaces.Tuple):
assert isinstance(
obs, tuple
), f"The observation returned by the `{method_name}()` method must be a tuple, actually {type(obs)}"
for sub_obs, sub_space in zip(obs, observation_space.spaces):
_check_obs(sub_obs, sub_space, method_name)
elif isinstance(observation_space, spaces.Dict):
assert isinstance(
obs, dict
), f"The observation returned by the `{method_name}()` method must be a dict, actually {type(obs)}"
for space_key in observation_space.keys():
_check_obs(obs[space_key], observation_space[space_key], method_name)
def check_observation_space(observation_space):
"""A passive check of the environment observation space that should not affect the environment."""
if not isinstance(observation_space, spaces.Space):
raise AssertionError(
f"Observation space ({observation_space}) does not inherit from gym.spaces.Space"
)
elif isinstance(observation_space, spaces.Box):
# Check if the box is an image (shape is 3 elements and the last element is 1 or 3)
_check_box_observation_space(observation_space)
elif isinstance(observation_space, spaces.Discrete):
assert (
observation_space.n > 0
), f"There are no available discrete observations, n={observation_space.n}"
elif isinstance(observation_space, spaces.MultiDiscrete):
assert np.all(
observation_space.nvec > 0
), f"All dimensions of multi-discrete must be greater than 0, {observation_space.nvec}"
elif isinstance(observation_space, spaces.MultiBinary):
assert np.all(
np.asarray(observation_space.shape) > 0
), f"All dimensions of multi-binary must be greater than 0, {observation_space.shape}"
elif isinstance(observation_space, spaces.Tuple):
assert (
len(observation_space.spaces) > 0
), "An empty Tuple observation space is not allowed."
for subspace in observation_space.spaces:
check_observation_space(subspace)
elif isinstance(observation_space, spaces.Dict):
assert (
len(observation_space.spaces.keys()) > 0
), "An empty Dict observation space is not allowed."
for subspace in observation_space.values():
check_observation_space(subspace)
def check_action_space(action_space):
"""A passive check of the environment action space that should not affect the environment."""
if not isinstance(action_space, spaces.Space):
raise AssertionError(
f"Action space ({action_space}) does not inherit from gym.spaces.Space"
)
elif isinstance(action_space, spaces.Box):
_check_box_action_space(action_space)
elif isinstance(action_space, spaces.Discrete):
assert (
action_space.n > 0
), f"There are no available discrete actions, n={action_space.n}"
elif isinstance(action_space, spaces.MultiDiscrete):
assert np.all(
action_space.nvec > 0
), f"All dimensions of multi-discrete must be greater than 0, {action_space.nvec}"
elif isinstance(action_space, spaces.MultiBinary):
assert np.all(
np.asarray(action_space.shape) > 0
), f"All dimensions of multi-binary must be greater than 0, {action_space.shape}"
elif isinstance(action_space, spaces.Tuple):
assert (
len(action_space.spaces) > 0
), "An empty Tuple action space is not allowed."
for subspace in action_space.spaces:
check_action_space(subspace)
elif isinstance(action_space, spaces.Dict):
assert (
len(action_space.spaces.keys()) > 0
), "An empty Dict action space is not allowed."
for subspace in action_space.values():
check_action_space(subspace)
def passive_env_reset_check(env, **kwargs):
"""A passive check of the `Env.reset` function investigating the returning reset information and returning the data unchanged."""
signature = inspect.signature(env.reset)
if "seed" not in signature.parameters or "kwargs" in signature.parameters:
logger.warn(
"Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator. "
)
else:
seed_param = signature.parameters.get("seed")
# Check the default value is None
if seed_param is not None and seed_param.default is not None:
logger.warn(
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic"
)
if "return_info" not in signature.parameters or "kwargs" in signature.parameters:
logger.warn(
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
)
if "options" not in signature.parameters or "kwargs" in signature.parameters:
logger.warn(
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
)
# Checks the result of env.reset with kwargs
result = env.reset(**kwargs)
if "return_info" in kwargs and kwargs["return_info"] is True:
obs, info = result
_check_obs(obs, env.observation_space, "reset")
assert isinstance(
info, dict
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actually {type(info)}"
else:
obs = result
_check_obs(obs, env.observation_space, "reset")
return result
def passive_env_step_check(env, action):
"""A passive check for the environment step, investigating the returning data then returning the data unchanged."""
result = env.step(action)
if len(result) == 4:
obs, reward, done, info = result
assert isinstance(done, bool), "The `done` signal must be a boolean"
elif len(result) == 5:
obs, reward, terminated, truncated, info = result
assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
assert (
terminated is False or truncated is False
), "Only `terminated` or `truncated` can be true, not both."
else:
raise error.Error(
f"Expected `Env.step` to return a four or five elements, actually {len(result)} elements returned."
)
_check_obs(obs, env.observation_space, "step")
if np.any(np.isnan(obs)):
logger.warn("Encountered NaN value in observations.")
if np.any(np.isinf(obs)):
logger.warn("Encountered inf value in observations.")
assert isinstance(
reward, (float, int, np.float32)
), "The reward returned by `step()` must be a float"
if np.any(np.isnan(reward)):
logger.warn("Encountered NaN value in rewards.")
if np.any(np.isinf(reward)):
logger.warn("Encountered inf value in rewards.")
assert isinstance(
info, dict
), "The `info` returned by `step()` must be a python dictionary"
return result
def passive_env_render_check(env, **kwargs):
"""A passive check of the `Env.render` that the declared render modes/fps in the metadata of the environment is decleared."""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
logger.warn(
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), "
"you may have trouble when calling `.render()`"
)
render_fps = env.metadata.get("render_fps")
# We only require `render_fps` if rendering is actually implemented
if render_fps is None:
logger.warn(
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), "
"rendering may occur at inconsistent fps"
)
return env.render(**kwargs)

View File

@@ -1,6 +1,7 @@
"""Module for vector environments."""
from typing import Iterable, List, Optional, Union
import gym
from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv
from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
@@ -13,6 +14,7 @@ def make(
num_envs: int = 1,
asynchronous: bool = True,
wrappers: Optional[Union[callable, List[callable]]] = None,
disable_env_checker: bool = False,
**kwargs,
) -> VectorEnv:
"""Create a vectorized environment from multiple copies of an environment, from its id.
@@ -32,26 +34,36 @@ def make(
num_envs: Number of copies of the environment.
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
disable_env_checker: If to disable the env checker, if True it will only run on the first environment created.
**kwargs: Keywords arguments applied during gym.make
Returns:
The vectorized environment.
"""
from gym.envs import make as make_
def _make_env():
env = make_(id, **kwargs)
if wrappers is not None:
if callable(wrappers):
env = wrappers(env)
elif isinstance(wrappers, Iterable) and all(
[callable(w) for w in wrappers]
):
for wrapper in wrappers:
env = wrapper(env)
else:
raise NotImplementedError
return env
def create_env(_disable_env_checker):
"""Creates an environment that can enable or disable the environment checker."""
env_fns = [_make_env for _ in range(num_envs)]
def _make_env():
env = gym.envs.registration.make(
id, disable_env_checker=_disable_env_checker, **kwargs
)
if wrappers is not None:
if callable(wrappers):
env = wrappers(env)
elif isinstance(wrappers, Iterable) and all(
[callable(w) for w in wrappers]
):
for wrapper in wrappers:
env = wrapper(env)
else:
raise NotImplementedError
return env
return _make_env
env_fns = [
create_env(env_num == 0 and disable_env_checker is False)
for env_num in range(num_envs)
]
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)

View File

@@ -0,0 +1,57 @@
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
from typing import Tuple, Union
import gym
from gym.core import ActType, ObsType
from gym.utils.passive_env_checker import (
check_action_space,
check_observation_space,
passive_env_render_check,
passive_env_reset_check,
passive_env_step_check,
)
class PassiveEnvChecker(gym.Wrapper):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gym API."""
def __init__(self, env):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
super().__init__(env)
assert hasattr(
env, "action_space"
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/"
check_observation_space(env.action_space)
assert hasattr(
env, "observation_space"
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
check_action_space(env.observation_space)
self.checked_reset = False
self.checked_step = False
self.checked_render = False
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self.checked_step is False:
self.checked_step = True
return passive_env_step_check(self.env, action)
else:
return self.env.step(action)
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self.checked_reset is False:
self.checked_reset = True
return passive_env_reset_check(self.env, **kwargs)
else:
return self.env.reset(**kwargs)
def render(self, **kwargs):
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
if self.checked_render is False:
self.checked_render = True
return passive_env_render_check(self.env, **kwargs)
else:
return self.env.render(**kwargs)

View File

@@ -1,8 +1,8 @@
"""Test environment determinism by performing a rollout."""
import numpy as np
import pytest
from gym.utils.env_checker import data_equivalence
from tests.envs.spec_list import spec_list
@@ -36,23 +36,25 @@ def test_env(spec):
env1.action_space.seed(SEED)
env2.action_space.seed(SEED)
assert_equals(initial_observation1, initial_observation2)
assert data_equivalence(
initial_observation1, initial_observation2
), f"Initial Observations 1 and 2 are not equivalent. initial obs 1={initial_observation1}, initial obs 2={initial_observation2}"
for i in range(NUM_STEPS):
action1 = env1.action_space.sample()
action2 = env2.action_space.sample()
try:
assert_equals(action1, action2)
assert data_equivalence(
action1, action2
), f"Action 1 and 2 are not equivalent. action 1={action1}, action 2={action2}"
except AssertionError:
print("env1.action_space=", env1.action_space)
print("env2.action_space=", env2.action_space)
print("action_samples1=", action1)
print("action_samples2=", action2)
print(f"[{i}] action_sample1: {action1}, action_sample2: {action2}")
print(f"env 1 action space={env1.action_space}")
print(f"env 2 action space={env2.action_space}")
print(f"[{i}] action sample 1={action1}, action sample 2={action2}")
raise
# Don't check rollout equality if it's a a nondeterministic
# Don't check rollout equality if it's a nondeterministic
# environment.
if spec.nondeterministic:
return
@@ -60,14 +62,18 @@ def test_env(spec):
obs1, rew1, done1, info1 = env1.step(action1)
obs2, rew2, done2, info2 = env2.step(action2)
assert_equals(obs1, obs2, f"[{i}] ")
assert data_equivalence(
obs1, obs2
), f"Observation 1 and 2 are not equivalent. obs 1={obs1}, obs 2={obs2}"
assert env1.observation_space.contains(obs1)
assert env2.observation_space.contains(obs2)
assert rew1 == rew2, f"[{i}] r1: {rew1}, r2: {rew2}"
assert done1 == done2, f"[{i}] d1: {done1}, d2: {done2}"
assert_equals(info1, info2, f"[{i}] ")
assert rew1 == rew2, f"[{i}] reward1: {rew1}, reward2: {rew2}"
assert done1 == done2, f"[{i}] done1: {done1}, done2: {done2}"
assert data_equivalence(
info1, info2
), f"Info 1 and 2 are not equivalent. info 1={info1}, info 2={info2}"
if done1: # done2 verified in previous assertion
env1.reset(seed=SEED)
@@ -75,29 +81,3 @@ def test_env(spec):
env1.close()
env2.close()
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
Args:
a: first data structure
b: second data structure
prefix: prefix for failed assertion message for types and dicts
"""
assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
for k in a.keys():
v_a = a[k]
v_b = b[k]
assert_equals(v_a, v_b)
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
elif isinstance(a, tuple):
for elem_from_a, elem_from_b in zip(a, b):
assert_equals(elem_from_a, elem_from_b)
else:
assert a == b

View File

@@ -22,7 +22,7 @@ def test_env(spec):
env = spec.make()
# Test if env adheres to Gym API
check_env(env, warn=True, skip_render_check=True)
check_env(env, skip_render_check=True)
# Check that dtype is explicitly declared for gym.Box spaces
for warning_msg in warnings: