mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 22:11:25 +00:00
Updates the environment checker (#2864)
* Updated testing requirements based off extra["testing"] * Updated setup to check the version is valid, added testing and all dependency groups and collects the requirements from requirements.txt to keep everything standardized. * Updated requirements.txt based on the current minimum gym requirements.txt to work * Updated requirements.txt based on the current minimum gym requirements.txt to work * Updated test_requirements.txt based on the current gym full testing requirements * Pre-commit updates * Add integer check for the `n` parameter * The type of self.spaces is an Iterable which is absorbed by the tuple. * Simplifies the environment checker to two files, env_checker.py and passive_env_checker.py with a new wrapper env_checker.py * Adds the passive environment checker on `gym.make` * Ignore the `check_env` warn parameter * Ignore the `check_env` warn parameter * Use the `data_equivalence` function * Revert rewrite setup.py changes * Remove smart formatting for 3.6 support * Fixed `check_action_space` and `check_observation_space` * Added disable_env_checker to vector.make such that env_checker would only run on the first environment created. * Removing check that different seeds would produce different initialising states * Use the unwrapped environment np_random * Fixed vector environment creator
This commit is contained in:
@@ -23,8 +23,8 @@ from typing import (
|
||||
import numpy as np
|
||||
|
||||
from gym.envs.__relocated__ import internal_env_relocation_map
|
||||
from gym.utils.env_checker import check_env
|
||||
from gym.wrappers import AutoResetWrapper, OrderEnforcing, TimeLimit
|
||||
from gym.wrappers.env_checker import PassiveEnvChecker
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
import importlib_metadata as metadata # type: ignore
|
||||
@@ -43,7 +43,15 @@ ENV_ID_RE = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def load(name: str) -> type:
|
||||
def load(name: str) -> callable:
|
||||
"""Loads an environment with name and returns an environment creation function
|
||||
|
||||
Args:
|
||||
name: The environment name
|
||||
|
||||
Returns:
|
||||
Calls the environment constructor
|
||||
"""
|
||||
mod_name, attr_name = name.split(":")
|
||||
mod = importlib.import_module(mod_name)
|
||||
fn = getattr(mod, attr_name)
|
||||
@@ -519,11 +527,6 @@ def make(
|
||||
) -> Env:
|
||||
"""Create an environment according to the given ID.
|
||||
|
||||
Warnings:
|
||||
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
|
||||
This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid
|
||||
if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`.
|
||||
|
||||
Args:
|
||||
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||
@@ -578,49 +581,44 @@ def make(
|
||||
_kwargs = spec_.kwargs.copy()
|
||||
_kwargs.update(kwargs)
|
||||
|
||||
# TODO: add a minimal env checker on initialization
|
||||
if spec_.entry_point is None:
|
||||
raise error.Error(f"{spec_.id} registered but entry_point is not specified")
|
||||
elif callable(spec_.entry_point):
|
||||
cls = spec_.entry_point
|
||||
env_creator = spec_.entry_point
|
||||
else:
|
||||
# Assume it's a string
|
||||
cls = load(spec_.entry_point)
|
||||
env_creator = load(spec_.entry_point)
|
||||
|
||||
env = cls(**_kwargs)
|
||||
env = env_creator(**_kwargs)
|
||||
|
||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||
spec_ = copy.deepcopy(spec_)
|
||||
spec_.kwargs = _kwargs
|
||||
|
||||
env.unwrapped.spec = spec_
|
||||
|
||||
# Run the environment checker as the lowest level wrapper
|
||||
if disable_env_checker is False:
|
||||
env = PassiveEnvChecker(env)
|
||||
|
||||
# Add the order enforcing wrapper
|
||||
if spec_.order_enforce:
|
||||
env = OrderEnforcing(env)
|
||||
|
||||
# Add the time limit wrapper
|
||||
if max_episode_steps is not None:
|
||||
env = TimeLimit(env, max_episode_steps)
|
||||
elif spec_.max_episode_steps is not None:
|
||||
env = TimeLimit(env, spec_.max_episode_steps)
|
||||
|
||||
# Add the autoreset wrapper
|
||||
if autoreset:
|
||||
env = AutoResetWrapper(env)
|
||||
|
||||
if not disable_env_checker:
|
||||
try:
|
||||
check_env(env)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f"Env check failed with the following message: {e}\n"
|
||||
f"You can set `disable_env_checker=True` to disable this check."
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def spec(env_id: str) -> EnvSpec:
|
||||
"""
|
||||
Retrieve the spec for the given environment from the global registry.
|
||||
"""
|
||||
"""Retrieve the spec for the given environment from the global registry."""
|
||||
spec_ = registry.get(env_id)
|
||||
if spec_ is None:
|
||||
ns, name, version = parse_env_id(env_id)
|
||||
|
@@ -33,6 +33,7 @@ class Discrete(Space[int]):
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
|
||||
start (int): The smallest element of this space.
|
||||
"""
|
||||
assert isinstance(n, (int, np.integer))
|
||||
assert n > 0, "n (counts) have to be positive"
|
||||
assert isinstance(start, (int, np.integer))
|
||||
self.n = int(n)
|
||||
|
@@ -33,9 +33,8 @@ class Tuple(Space[tuple], Sequence):
|
||||
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
|
||||
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
|
||||
"""
|
||||
spaces = tuple(spaces)
|
||||
self.spaces = spaces
|
||||
for space in spaces:
|
||||
self.spaces = tuple(spaces)
|
||||
for space in self.spaces:
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), "Elements of the tuple must be instances of gym.Space"
|
||||
|
@@ -8,367 +8,110 @@ It also uses some warnings/assertions from the PettingZoo repository hosted on G
|
||||
(https://github.com/PettingZoo-Team/PettingZoo)
|
||||
Original Author: J K Terry
|
||||
|
||||
This was rewritten and split into "env_checker.py" and "passive_env_checker.py" for invasive and passive environment checking
|
||||
Original Author: Mark Towers
|
||||
|
||||
These projects are covered by the MIT License.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Optional, Union
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym import logger
|
||||
from gym.spaces import Box, Dict, Discrete, Space, Tuple
|
||||
from gym import error, logger
|
||||
from gym.utils.passive_env_checker import (
|
||||
check_action_space,
|
||||
check_observation_space,
|
||||
passive_env_reset_check,
|
||||
passive_env_step_check,
|
||||
)
|
||||
|
||||
|
||||
def _is_numpy_array_space(space: Space) -> bool:
|
||||
"""Checks if a space can be represented as a single numpy array (e.g. Dict and Tuple spaces return False).
|
||||
def data_equivalence(data_1, data_2) -> bool:
|
||||
"""Assert equality between data 1 and 2, i.e observations, actions, info.
|
||||
|
||||
Args:
|
||||
space: The space to check
|
||||
data_1: data structure 1
|
||||
data_2: data structure 2
|
||||
|
||||
Returns:
|
||||
Returns False if the provided space is not representable as a single numpy array
|
||||
If observation 1 and 2 are equivalent
|
||||
"""
|
||||
return not isinstance(space, (Dict, Tuple))
|
||||
|
||||
|
||||
def _check_image_input(observation_space: Box, key: str = ""):
|
||||
"""Check whether an observation space of type :class:`Box` adheres to general standards for spaces that represent images.
|
||||
|
||||
It will check that:
|
||||
- The datatype is ``np.uint8``
|
||||
- The lower bound is 0 across all dimensions
|
||||
- The upper bound is 255 across all dimensions
|
||||
|
||||
Args:
|
||||
observation_space: The observation space to check
|
||||
key: The observation shape key for warning
|
||||
"""
|
||||
if observation_space.dtype != np.uint8:
|
||||
logger.warn(
|
||||
f"It seems that your observation {key} is an image but the `dtype` "
|
||||
"of your observation_space is not `np.uint8`. "
|
||||
"If your observation is not an image, we recommend you to flatten the observation "
|
||||
"to have only a 1D vector"
|
||||
)
|
||||
|
||||
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
|
||||
logger.warn(
|
||||
f"It seems that your observation space {key} is an image but the "
|
||||
"upper and lower bounds are not in [0, 255]. "
|
||||
"Generally, CNN policies assume observations are within that range, "
|
||||
"so you may encounter an issue if the observation values are not."
|
||||
)
|
||||
|
||||
|
||||
def _check_nan(env: gym.Env, check_inf: bool = True):
|
||||
"""Check if the environment observation, reward are NaN and Inf.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
check_inf: Checks if the observation is infinity
|
||||
"""
|
||||
for _ in range(10):
|
||||
action = env.action_space.sample()
|
||||
observation, reward, done, _ = env.step(action)
|
||||
|
||||
if done:
|
||||
env.reset()
|
||||
|
||||
if np.any(np.isnan(observation)):
|
||||
logger.warn("Encountered NaN value in observations.")
|
||||
if np.any(np.isnan(reward)):
|
||||
logger.warn("Encountered NaN value in rewards.")
|
||||
if check_inf and np.any(np.isinf(observation)):
|
||||
logger.warn("Encountered inf value in observations.")
|
||||
if check_inf and np.any(np.isinf(reward)):
|
||||
logger.warn("Encountered inf value in rewards.")
|
||||
|
||||
|
||||
def _check_obs(
|
||||
obs: Union[tuple, dict, np.ndarray, int],
|
||||
observation_space: Space,
|
||||
method_name: str,
|
||||
):
|
||||
"""Check that the observation returned by the environment correspond to the declared one.
|
||||
|
||||
Args:
|
||||
obs: The observation to check
|
||||
observation_space: The observation space of the observation
|
||||
method_name: The method name that generated the observation
|
||||
"""
|
||||
if not isinstance(observation_space, Tuple):
|
||||
assert not isinstance(
|
||||
obs, tuple
|
||||
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
|
||||
|
||||
if isinstance(observation_space, Discrete):
|
||||
assert isinstance(
|
||||
obs, int
|
||||
), f"The observation returned by `{method_name}()` method must be an int"
|
||||
elif _is_numpy_array_space(observation_space):
|
||||
assert isinstance(
|
||||
obs, np.ndarray
|
||||
), f"The observation returned by `{method_name}()` method must be a numpy array"
|
||||
|
||||
assert observation_space.contains(
|
||||
obs
|
||||
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
|
||||
|
||||
|
||||
def _check_box_obs(observation_space: Box, key: str = ""):
|
||||
"""Check that the observation space is correctly formatted when dealing with a :class:`Box` space.
|
||||
|
||||
In particular, it checks:
|
||||
- that the dimensions are big enough when it is an image, and that the type matches
|
||||
- that the observation has an expected shape (warn the user if not)
|
||||
|
||||
Args:
|
||||
observation_space: Checks if the Box observation space
|
||||
key: The observation key
|
||||
"""
|
||||
# If image, check the low and high values, the type and the number of channels
|
||||
# and the shape (minimal value)
|
||||
if len(observation_space.shape) == 3:
|
||||
_check_image_input(observation_space)
|
||||
|
||||
if len(observation_space.shape) not in [1, 3]:
|
||||
logger.warn(
|
||||
f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). "
|
||||
"We recommend you to flatten the observation "
|
||||
"to have only a 1D vector or use a custom policy to properly process the data."
|
||||
)
|
||||
|
||||
if np.any(np.equal(observation_space.low, -np.inf)):
|
||||
logger.warn(
|
||||
"Agent's minimum observation space value is -infinity. This is probably too low."
|
||||
)
|
||||
if np.any(np.equal(observation_space.high, np.inf)):
|
||||
logger.warn(
|
||||
"Agent's maxmimum observation space value is infinity. This is probably too high"
|
||||
)
|
||||
if np.any(np.equal(observation_space.low, observation_space.high)):
|
||||
logger.warn("Agent's maximum and minimum observation space values are equal")
|
||||
if np.any(np.greater(observation_space.low, observation_space.high)):
|
||||
assert False, "Agent's minimum observation value is greater than it's maximum"
|
||||
if observation_space.low.shape != observation_space.shape:
|
||||
assert (
|
||||
False
|
||||
), "Agent's observation_space.low and observation_space have different shapes"
|
||||
if observation_space.high.shape != observation_space.shape:
|
||||
assert (
|
||||
False
|
||||
), "Agent's observation_space.high and observation_space have different shapes"
|
||||
|
||||
|
||||
def _check_box_action(action_space: Box):
|
||||
"""Checks that a :class:`Box` action space is defined in a sensible way.
|
||||
|
||||
Args:
|
||||
action_space: A box action space
|
||||
"""
|
||||
if np.any(np.equal(action_space.low, -np.inf)):
|
||||
logger.warn(
|
||||
"Agent's minimum action space value is -infinity. This is probably too low."
|
||||
)
|
||||
if np.any(np.equal(action_space.high, np.inf)):
|
||||
logger.warn(
|
||||
"Agent's maximum action space value is infinity. This is probably too high"
|
||||
)
|
||||
if np.any(np.equal(action_space.low, action_space.high)):
|
||||
logger.warn("Agent's maximum and minimum action space values are equal")
|
||||
if np.any(np.greater(action_space.low, action_space.high)):
|
||||
assert False, "Agent's minimum action value is greater than it's maximum"
|
||||
if action_space.low.shape != action_space.shape:
|
||||
assert False, "Agent's action_space.low and action_space have different shapes"
|
||||
if action_space.high.shape != action_space.shape:
|
||||
assert False, "Agent's action_space.high and action_space have different shapes"
|
||||
|
||||
|
||||
def _check_normalized_action(action_space: Box):
|
||||
"""Checks that a box action space is normalized.
|
||||
|
||||
Args:
|
||||
action_space: A box action space
|
||||
"""
|
||||
if (
|
||||
np.any(np.abs(action_space.low) != np.abs(action_space.high))
|
||||
or np.any(np.abs(action_space.low) > 1)
|
||||
or np.any(np.abs(action_space.high) > 1)
|
||||
):
|
||||
logger.warn(
|
||||
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
|
||||
"cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
|
||||
)
|
||||
|
||||
|
||||
def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space):
|
||||
"""Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods.
|
||||
|
||||
Args:
|
||||
env: The environment
|
||||
observation_space: The environment's observation space
|
||||
action_space: The environment's action space
|
||||
|
||||
Raises:
|
||||
AssertionError: If the ``observation_space`` is :class:`Dict` and
|
||||
keys from :meth:`Env.reset` are not in the observation space
|
||||
"""
|
||||
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
|
||||
obs = env.reset()
|
||||
|
||||
if isinstance(observation_space, Dict):
|
||||
assert isinstance(
|
||||
obs, dict
|
||||
), "The observation returned by `reset()` must be a dictionary"
|
||||
for key in observation_space.spaces.keys():
|
||||
try:
|
||||
_check_obs(obs[key], observation_space.spaces[key], "reset")
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"Error while checking key={key}: " + str(e))
|
||||
else:
|
||||
_check_obs(obs, observation_space, "reset")
|
||||
|
||||
# Sample a random action
|
||||
action = action_space.sample()
|
||||
data = env.step(action)
|
||||
|
||||
assert (
|
||||
len(data) == 4
|
||||
), "The `step()` method must return four values: obs, reward, done, info"
|
||||
|
||||
# Unpack
|
||||
obs, reward, done, info = data
|
||||
|
||||
if isinstance(observation_space, Dict):
|
||||
assert isinstance(
|
||||
obs, dict
|
||||
), "The observation returned by `step()` must be a dictionary"
|
||||
for key in observation_space.spaces.keys():
|
||||
try:
|
||||
_check_obs(obs[key], observation_space.spaces[key], "step")
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"Error while checking key={key}: " + str(e))
|
||||
|
||||
else:
|
||||
_check_obs(obs, observation_space, "step")
|
||||
|
||||
# We also allow int because the reward will be cast to float
|
||||
assert isinstance(
|
||||
reward, (float, int, np.float32)
|
||||
), "The reward returned by `step()` must be a float"
|
||||
assert isinstance(done, bool), "The `done` signal must be a boolean"
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), "The `info` returned by `step()` must be a python dictionary"
|
||||
|
||||
|
||||
def _check_spaces(env: gym.Env):
|
||||
"""Check that the observation and action spaces are defined and inherit from :class:`gym.spaces.Space`.
|
||||
|
||||
Args:
|
||||
env: The environment's observation and action space to check
|
||||
"""
|
||||
# Helper to link to the code, because gym has no proper documentation
|
||||
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
|
||||
|
||||
assert hasattr(env, "observation_space"), (
|
||||
"You must specify an observation space (cf gym.spaces)" + gym_spaces
|
||||
)
|
||||
assert hasattr(env, "action_space"), (
|
||||
"You must specify an action space (cf gym.spaces)" + gym_spaces
|
||||
)
|
||||
|
||||
assert isinstance(env.observation_space, Space), (
|
||||
"The observation space must inherit from gym.spaces" + gym_spaces
|
||||
)
|
||||
assert isinstance(env.action_space, Space), (
|
||||
"The action space must inherit from gym.spaces" + gym_spaces
|
||||
)
|
||||
|
||||
|
||||
# Check render cannot be covered by CI
|
||||
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False):
|
||||
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
warn: Whether to output additional warnings
|
||||
headless: Whether to disable render modes that require a graphical interface. False by default.
|
||||
"""
|
||||
render_modes = env.metadata.get("render_modes")
|
||||
if render_modes is None:
|
||||
if warn:
|
||||
logger.warn(
|
||||
"No render modes was declared in the environment "
|
||||
" (env.metadata['render_modes'] is None or not defined), "
|
||||
"you may have trouble when calling `.render()`"
|
||||
if type(data_1) == type(data_2):
|
||||
if isinstance(data_1, dict):
|
||||
return data_1.keys() == data_2.keys() and all(
|
||||
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
||||
)
|
||||
|
||||
render_fps = env.metadata.get("render_fps")
|
||||
# We only require `render_fps` if rendering is actually implemented
|
||||
if render_fps is None and render_modes is not None and len(render_modes) > 0:
|
||||
if warn:
|
||||
logger.warn(
|
||||
"No render fps was declared in the environment "
|
||||
" (env.metadata['render_fps'] is None or not defined), "
|
||||
"rendering may occur at inconsistent fps"
|
||||
elif isinstance(data_1, tuple):
|
||||
return len(data_1) == len(data_2) and all(
|
||||
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||
)
|
||||
|
||||
elif isinstance(data_1, np.ndarray):
|
||||
return np.all(data_1 == data_2)
|
||||
else:
|
||||
return data_1 == data_2
|
||||
else:
|
||||
# Don't check render mode that require a
|
||||
# graphical interface (useful for CI)
|
||||
if headless and "human" in render_modes:
|
||||
render_modes.remove("human")
|
||||
# Check all declared render modes
|
||||
for render_mode in render_modes:
|
||||
env.render(mode=render_mode)
|
||||
env.close()
|
||||
return False
|
||||
|
||||
|
||||
def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
|
||||
def check_reset_seed(env: gym.Env):
|
||||
"""Check that the environment can be reset with a seed.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
seed: The optional seed to use
|
||||
|
||||
Raises:
|
||||
AssertionError: The environment cannot be reset with a random seed,
|
||||
even though `seed` or `kwargs` appear in the signature.
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
assert (
|
||||
"seed" in signature.parameters or "kwargs" in signature.parameters
|
||||
), "The environment cannot be reset with a random seed. This behavior will be deprecated in the future."
|
||||
if "seed" in signature.parameters or "kwargs" in signature.parameters:
|
||||
try:
|
||||
obs_1 = env.reset(seed=123)
|
||||
assert obs_1 in env.observation_space
|
||||
obs_2 = env.reset(seed=123)
|
||||
assert obs_2 in env.observation_space
|
||||
assert data_equivalence(obs_1, obs_2)
|
||||
seed_123_rng = deepcopy(env.unwrapped.np_random)
|
||||
|
||||
try:
|
||||
env.reset(seed=seed)
|
||||
except TypeError as e:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` "
|
||||
"appear in the signature. This should never happen, please report this issue. "
|
||||
f"The error was: {e}"
|
||||
)
|
||||
# Note: for some environment, they may initialise at the same state, therefore we cannot check the obs_1 != obs_3
|
||||
obs_4 = env.reset(seed=None)
|
||||
assert obs_4 in env.observation_space
|
||||
|
||||
if env.unwrapped.np_random is None:
|
||||
logger.warn(
|
||||
"Resetting the environment did not result in seeding its random number generator. "
|
||||
"This is likely due to not calling `super().reset(seed=seed)` in the `reset` method. "
|
||||
"If you do not use the python-level random number generator, this is not a problem."
|
||||
)
|
||||
assert (
|
||||
env.unwrapped.np_random.bit_generator.state
|
||||
!= seed_123_rng.bit_generator.state
|
||||
)
|
||||
except TypeError as e:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
|
||||
"This should never happen, please report this issue. "
|
||||
f"The error was: {e}"
|
||||
)
|
||||
|
||||
seed_param = signature.parameters.get("seed")
|
||||
# Check the default value is None
|
||||
if seed_param is not None and seed_param.default is not None:
|
||||
logger.warn(
|
||||
"The default seed argument in reset should be `None`, "
|
||||
"otherwise the environment will by default always be deterministic"
|
||||
if env.unwrapped.np_random is None:
|
||||
logger.warn(
|
||||
"Resetting the environment did not result in seeding its random number generator. "
|
||||
"This is likely due to not calling `super().reset(seed=seed)` in the `reset` method. "
|
||||
"If you do not use the python-level random number generator, this is not a problem."
|
||||
)
|
||||
|
||||
seed_param = signature.parameters.get("seed")
|
||||
# Check the default value is None
|
||||
if seed_param is not None and seed_param.default is not None:
|
||||
logger.warn(
|
||||
"The default seed argument in reset should be `None`, "
|
||||
"otherwise the environment will by default always be deterministic"
|
||||
)
|
||||
else:
|
||||
raise error.Error(
|
||||
"The `reset` method does not provide the `return_info` keyword argument"
|
||||
)
|
||||
|
||||
|
||||
def _check_reset_info(env: gym.Env):
|
||||
def check_reset_info(env: gym.Env):
|
||||
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
|
||||
|
||||
Args:
|
||||
@@ -379,29 +122,29 @@ def _check_reset_info(env: gym.Env):
|
||||
even though `return_info` or `kwargs` appear in the signature.
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
assert (
|
||||
"return_info" in signature.parameters or "kwargs" in signature.parameters
|
||||
), "The `reset` method does not provide the `return_info` keyword argument"
|
||||
|
||||
try:
|
||||
result = env.reset(return_info=True)
|
||||
except TypeError as e:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
|
||||
"appear in the signature. This should never happen, please report this issue. "
|
||||
f"The error was: {e}"
|
||||
if "return_info" in signature.parameters or "kwargs" in signature.parameters:
|
||||
try:
|
||||
result = env.reset(return_info=True)
|
||||
assert (
|
||||
len(result) == 2
|
||||
), "Calling the reset method with `return_info=True` did not return a 2-tuple"
|
||||
obs, info = result
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
|
||||
except TypeError as e:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
|
||||
"appear in the signature. This should never happen, please report this issue. "
|
||||
f"The error was: {e}"
|
||||
)
|
||||
else:
|
||||
raise error.Error(
|
||||
"The `reset` method does not provide the `return_info` keyword argument"
|
||||
)
|
||||
assert (
|
||||
len(result) == 2
|
||||
), "Calling the reset method with `return_info=True` did not return a 2-tuple"
|
||||
|
||||
obs, info = result
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
|
||||
|
||||
|
||||
def _check_reset_options(env: gym.Env):
|
||||
def check_reset_options(env: gym.Env):
|
||||
"""Check that the environment can be reset with options.
|
||||
|
||||
Args:
|
||||
@@ -412,73 +155,94 @@ def _check_reset_options(env: gym.Env):
|
||||
even though `options` or `kwargs` appear in the signature.
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
assert (
|
||||
"options" in signature.parameters or "kwargs" in signature.parameters
|
||||
), "The environment cannot be reset with options. This behavior will be deprecated in the future."
|
||||
|
||||
try:
|
||||
env.reset(options={})
|
||||
except TypeError as e:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with options, even though `options` or `kwargs` "
|
||||
"appear in the signature. This should never happen, please report this issue. "
|
||||
f"The error was: {e}"
|
||||
if "options" in signature.parameters or "kwargs" in signature.parameters:
|
||||
try:
|
||||
env.reset(options={})
|
||||
except TypeError as e:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with options, even though `options` or `kwargs` appear in the signature. "
|
||||
"This should never happen, please report this issue. "
|
||||
f"The error was: {e}"
|
||||
)
|
||||
else:
|
||||
raise error.Error(
|
||||
"The `reset` method does not provide the `options` keyword argument"
|
||||
)
|
||||
|
||||
|
||||
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True):
|
||||
# Check render cannot be covered by CI
|
||||
def check_render(env: gym.Env, headless: bool = False):
|
||||
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
headless: Whether to disable render modes that require a graphical interface. False by default.
|
||||
"""
|
||||
render_modes = env.metadata.get("render_modes")
|
||||
if render_modes is None:
|
||||
logger.warn(
|
||||
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`"
|
||||
)
|
||||
|
||||
render_fps = env.metadata.get("render_fps")
|
||||
# We only require `render_fps` if rendering is actually implemented
|
||||
if render_fps is None:
|
||||
logger.warn(
|
||||
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps"
|
||||
)
|
||||
|
||||
if render_modes is not None:
|
||||
# Don't check render mode that require a
|
||||
# graphical interface (useful for CI)
|
||||
if headless and "human" in render_modes:
|
||||
render_modes.remove("human")
|
||||
|
||||
# Check all declared render modes
|
||||
for mode in render_modes:
|
||||
env.render(mode=mode)
|
||||
env.close()
|
||||
|
||||
|
||||
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
|
||||
"""Check that an environment follows Gym API.
|
||||
|
||||
This is an invasive function that calls the environment's reset and step.
|
||||
|
||||
This is particularly useful when using a custom environment.
|
||||
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
|
||||
Please take a look at https://www.gymlibrary.ml/content/environment_creation/
|
||||
for more information about the API.
|
||||
It also optionally check that the environment is compatible with Stable-Baselines.
|
||||
|
||||
Args:
|
||||
env: The Gym environment that will be checked
|
||||
warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines
|
||||
warn: Ignored
|
||||
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
|
||||
"""
|
||||
if warn is not None:
|
||||
logger.warn("`check_env` warn parameter is now ignored.")
|
||||
|
||||
assert isinstance(
|
||||
env, gym.Env
|
||||
), "Your environment must inherit from the gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py"
|
||||
), "Your environment must inherit from the gym.Env class https://www.gymlibrary.ml/content/environment_creation/"
|
||||
|
||||
# ============= Check the spaces (observation and action) ================
|
||||
_check_spaces(env)
|
||||
# Define aliases for convenience
|
||||
observation_space = env.observation_space
|
||||
action_space = env.action_space
|
||||
assert hasattr(
|
||||
env, "action_space"
|
||||
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/"
|
||||
check_observation_space(env.action_space)
|
||||
assert hasattr(
|
||||
env, "observation_space"
|
||||
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
|
||||
check_action_space(env.observation_space)
|
||||
|
||||
# Warn the user if needed.
|
||||
# A warning means that the environment may run but not work properly with popular RL libraries.
|
||||
if warn:
|
||||
obs_spaces = (
|
||||
observation_space.spaces
|
||||
if isinstance(observation_space, Dict)
|
||||
else {"": observation_space}
|
||||
)
|
||||
for key, space in obs_spaces.items():
|
||||
if isinstance(space, Box):
|
||||
_check_box_obs(space, key)
|
||||
|
||||
# Check for the action space, it may lead to hard-to-debug issues
|
||||
if isinstance(action_space, Box):
|
||||
_check_box_action(action_space)
|
||||
_check_normalized_action(action_space)
|
||||
# ==== Check the reset method ====
|
||||
check_reset_seed(env)
|
||||
check_reset_options(env)
|
||||
check_reset_info(env)
|
||||
|
||||
# ============ Check the returned values ===============
|
||||
_check_returned_values(env, observation_space, action_space)
|
||||
passive_env_reset_check(env)
|
||||
passive_env_step_check(env, env.action_space.sample())
|
||||
|
||||
# ==== Check the render method and the declared render modes ====
|
||||
if not skip_render_check:
|
||||
_check_render(env, warn=warn) # pragma: no cover
|
||||
|
||||
# The check only works with numpy arrays
|
||||
if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space):
|
||||
_check_nan(env)
|
||||
|
||||
# ==== Check the reset method ====
|
||||
_check_reset_seed(env)
|
||||
_check_reset_seed(env, seed=0)
|
||||
_check_reset_options(env)
|
||||
_check_reset_info(env)
|
||||
check_render(env)
|
||||
|
310
gym/utils/passive_env_checker.py
Normal file
310
gym/utils/passive_env_checker.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""A set of functions for passively checking environment implementations."""
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym import error, logger, spaces
|
||||
|
||||
|
||||
def _check_box_observation_space(observation_space: spaces.Box):
|
||||
"""Checks that a :class:`Box` observation space is defined in a sensible way.
|
||||
|
||||
Args:
|
||||
observation_space: A box observation space
|
||||
"""
|
||||
# Check if the box is an image
|
||||
if len(observation_space.shape) == 3:
|
||||
if observation_space.dtype != np.uint8:
|
||||
logger.warn(
|
||||
f"It seems that your observation space ({observation_space}) is an image but the `dtype` of your observation_space is not `np.uint8`. "
|
||||
"If your observation is not an image, we recommend you to flatten the observation to have only a 1D vector"
|
||||
)
|
||||
if np.any(observation_space.low != 0) or np.any(
|
||||
observation_space.high != 255
|
||||
): # todo np.all?
|
||||
logger.warn(
|
||||
"It seems that your observation space is an image but the upper and lower bounds are not in [0, 255]. "
|
||||
"Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not."
|
||||
)
|
||||
|
||||
if len(observation_space.shape) not in [1, 3]:
|
||||
logger.warn(
|
||||
"Your observation space has an unconventional shape (neither an image, nor a 1D vector). "
|
||||
"We recommend you to flatten the observation to have only a 1D vector or use a custom policy to properly process the data. "
|
||||
f"Observation space={observation_space}"
|
||||
)
|
||||
|
||||
if np.any(np.equal(observation_space.low, -np.inf)):
|
||||
logger.warn(
|
||||
"Agent's minimum observation space value is -infinity. This is probably too low."
|
||||
)
|
||||
if np.any(np.equal(observation_space.high, np.inf)):
|
||||
logger.warn(
|
||||
"Agent's maximum observation space value is infinity. This is probably too high."
|
||||
)
|
||||
|
||||
if np.any(np.equal(observation_space.low, observation_space.high)):
|
||||
logger.warn("Agent's maximum and minimum observation space values are equal")
|
||||
if np.any(np.greater(observation_space.low, observation_space.high)):
|
||||
raise AssertionError(
|
||||
"Agent's minimum observation value is greater than it's maximum"
|
||||
)
|
||||
if observation_space.low.shape != observation_space.shape:
|
||||
raise AssertionError(
|
||||
"Agent's observation_space.low and observation_space have different shapes"
|
||||
)
|
||||
if observation_space.high.shape != observation_space.shape:
|
||||
raise AssertionError(
|
||||
"Agent's observation_space.high and observation_space have different shapes"
|
||||
)
|
||||
|
||||
|
||||
def _check_box_action_space(action_space: spaces.Box):
|
||||
"""Checks that a :class:`Box` action space is defined in a sensible way.
|
||||
|
||||
Args:
|
||||
action_space: A box action space
|
||||
"""
|
||||
if np.any(np.equal(action_space.low, -np.inf)):
|
||||
logger.warn(
|
||||
"Agent's minimum action space value is -infinity. This is probably too low."
|
||||
)
|
||||
if np.any(np.equal(action_space.high, np.inf)):
|
||||
logger.warn(
|
||||
"Agent's maximum action space value is infinity. This is probably too high"
|
||||
)
|
||||
if np.any(np.equal(action_space.low, action_space.high)):
|
||||
logger.warn("Agent's maximum and minimum action space values are equal")
|
||||
if np.any(np.greater(action_space.low, action_space.high)):
|
||||
raise AssertionError(
|
||||
"Agent's minimum action value is greater than it's maximum"
|
||||
)
|
||||
if action_space.low.shape != action_space.shape:
|
||||
raise AssertionError(
|
||||
"Agent's action_space.low and action_space have different shapes"
|
||||
)
|
||||
if action_space.high.shape != action_space.shape:
|
||||
raise AssertionError(
|
||||
"Agent's action_space.high and action_space have different shapes"
|
||||
)
|
||||
|
||||
# Check that the Box space is normalized
|
||||
if (
|
||||
np.any(np.abs(action_space.low) != np.abs(action_space.high))
|
||||
or np.any(np.abs(action_space.low) > 1)
|
||||
or np.any(np.abs(action_space.high) > 1)
|
||||
):
|
||||
logger.warn(
|
||||
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
|
||||
"https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" # TODO Add to gymlibrary.ml?
|
||||
)
|
||||
|
||||
|
||||
def _check_obs(obs, observation_space: spaces.Space, method_name: str):
|
||||
"""Check that the observation returned by the environment correspond to the declared one.
|
||||
|
||||
Args:
|
||||
obs: The observation to check
|
||||
observation_space: The observation space of the observation
|
||||
method_name: The method name that generated the observation
|
||||
"""
|
||||
pre = f"The observation returned by the `{method_name}()` method"
|
||||
|
||||
assert observation_space.contains(
|
||||
obs
|
||||
), f"{pre} is not contained with the observation space ({observation_space})"
|
||||
|
||||
if isinstance(observation_space, spaces.Discrete):
|
||||
assert isinstance(
|
||||
obs, int
|
||||
), f"The observation returned by `{method_name}()` method must be an int, actually {type(obs)}"
|
||||
elif isinstance(
|
||||
observation_space, (spaces.Box, spaces.MultiBinary, spaces.MultiDiscrete)
|
||||
):
|
||||
assert isinstance(
|
||||
obs, np.ndarray
|
||||
), f"The observation returned by `{method_name}()` method must be a numpy array, actually {type(obs)}"
|
||||
elif isinstance(observation_space, spaces.Tuple):
|
||||
assert isinstance(
|
||||
obs, tuple
|
||||
), f"The observation returned by the `{method_name}()` method must be a tuple, actually {type(obs)}"
|
||||
for sub_obs, sub_space in zip(obs, observation_space.spaces):
|
||||
_check_obs(sub_obs, sub_space, method_name)
|
||||
elif isinstance(observation_space, spaces.Dict):
|
||||
assert isinstance(
|
||||
obs, dict
|
||||
), f"The observation returned by the `{method_name}()` method must be a dict, actually {type(obs)}"
|
||||
for space_key in observation_space.keys():
|
||||
_check_obs(obs[space_key], observation_space[space_key], method_name)
|
||||
|
||||
|
||||
def check_observation_space(observation_space):
|
||||
"""A passive check of the environment observation space that should not affect the environment."""
|
||||
if not isinstance(observation_space, spaces.Space):
|
||||
raise AssertionError(
|
||||
f"Observation space ({observation_space}) does not inherit from gym.spaces.Space"
|
||||
)
|
||||
|
||||
elif isinstance(observation_space, spaces.Box):
|
||||
# Check if the box is an image (shape is 3 elements and the last element is 1 or 3)
|
||||
_check_box_observation_space(observation_space)
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
assert (
|
||||
observation_space.n > 0
|
||||
), f"There are no available discrete observations, n={observation_space.n}"
|
||||
elif isinstance(observation_space, spaces.MultiDiscrete):
|
||||
assert np.all(
|
||||
observation_space.nvec > 0
|
||||
), f"All dimensions of multi-discrete must be greater than 0, {observation_space.nvec}"
|
||||
elif isinstance(observation_space, spaces.MultiBinary):
|
||||
assert np.all(
|
||||
np.asarray(observation_space.shape) > 0
|
||||
), f"All dimensions of multi-binary must be greater than 0, {observation_space.shape}"
|
||||
elif isinstance(observation_space, spaces.Tuple):
|
||||
assert (
|
||||
len(observation_space.spaces) > 0
|
||||
), "An empty Tuple observation space is not allowed."
|
||||
for subspace in observation_space.spaces:
|
||||
check_observation_space(subspace)
|
||||
elif isinstance(observation_space, spaces.Dict):
|
||||
assert (
|
||||
len(observation_space.spaces.keys()) > 0
|
||||
), "An empty Dict observation space is not allowed."
|
||||
for subspace in observation_space.values():
|
||||
check_observation_space(subspace)
|
||||
|
||||
|
||||
def check_action_space(action_space):
|
||||
"""A passive check of the environment action space that should not affect the environment."""
|
||||
if not isinstance(action_space, spaces.Space):
|
||||
raise AssertionError(
|
||||
f"Action space ({action_space}) does not inherit from gym.spaces.Space"
|
||||
)
|
||||
|
||||
elif isinstance(action_space, spaces.Box):
|
||||
_check_box_action_space(action_space)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
assert (
|
||||
action_space.n > 0
|
||||
), f"There are no available discrete actions, n={action_space.n}"
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
assert np.all(
|
||||
action_space.nvec > 0
|
||||
), f"All dimensions of multi-discrete must be greater than 0, {action_space.nvec}"
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
assert np.all(
|
||||
np.asarray(action_space.shape) > 0
|
||||
), f"All dimensions of multi-binary must be greater than 0, {action_space.shape}"
|
||||
elif isinstance(action_space, spaces.Tuple):
|
||||
assert (
|
||||
len(action_space.spaces) > 0
|
||||
), "An empty Tuple action space is not allowed."
|
||||
for subspace in action_space.spaces:
|
||||
check_action_space(subspace)
|
||||
elif isinstance(action_space, spaces.Dict):
|
||||
assert (
|
||||
len(action_space.spaces.keys()) > 0
|
||||
), "An empty Dict action space is not allowed."
|
||||
for subspace in action_space.values():
|
||||
check_action_space(subspace)
|
||||
|
||||
|
||||
def passive_env_reset_check(env, **kwargs):
|
||||
"""A passive check of the `Env.reset` function investigating the returning reset information and returning the data unchanged."""
|
||||
signature = inspect.signature(env.reset)
|
||||
if "seed" not in signature.parameters or "kwargs" in signature.parameters:
|
||||
logger.warn(
|
||||
"Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator. "
|
||||
)
|
||||
else:
|
||||
seed_param = signature.parameters.get("seed")
|
||||
# Check the default value is None
|
||||
if seed_param is not None and seed_param.default is not None:
|
||||
logger.warn(
|
||||
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic"
|
||||
)
|
||||
|
||||
if "return_info" not in signature.parameters or "kwargs" in signature.parameters:
|
||||
logger.warn(
|
||||
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
|
||||
)
|
||||
|
||||
if "options" not in signature.parameters or "kwargs" in signature.parameters:
|
||||
logger.warn(
|
||||
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
|
||||
)
|
||||
|
||||
# Checks the result of env.reset with kwargs
|
||||
result = env.reset(**kwargs)
|
||||
if "return_info" in kwargs and kwargs["return_info"] is True:
|
||||
obs, info = result
|
||||
_check_obs(obs, env.observation_space, "reset")
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actually {type(info)}"
|
||||
else:
|
||||
obs = result
|
||||
_check_obs(obs, env.observation_space, "reset")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def passive_env_step_check(env, action):
|
||||
"""A passive check for the environment step, investigating the returning data then returning the data unchanged."""
|
||||
result = env.step(action)
|
||||
if len(result) == 4:
|
||||
obs, reward, done, info = result
|
||||
|
||||
assert isinstance(done, bool), "The `done` signal must be a boolean"
|
||||
elif len(result) == 5:
|
||||
obs, reward, terminated, truncated, info = result
|
||||
|
||||
assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
|
||||
assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
|
||||
assert (
|
||||
terminated is False or truncated is False
|
||||
), "Only `terminated` or `truncated` can be true, not both."
|
||||
else:
|
||||
raise error.Error(
|
||||
f"Expected `Env.step` to return a four or five elements, actually {len(result)} elements returned."
|
||||
)
|
||||
|
||||
_check_obs(obs, env.observation_space, "step")
|
||||
if np.any(np.isnan(obs)):
|
||||
logger.warn("Encountered NaN value in observations.")
|
||||
if np.any(np.isinf(obs)):
|
||||
logger.warn("Encountered inf value in observations.")
|
||||
|
||||
assert isinstance(
|
||||
reward, (float, int, np.float32)
|
||||
), "The reward returned by `step()` must be a float"
|
||||
if np.any(np.isnan(reward)):
|
||||
logger.warn("Encountered NaN value in rewards.")
|
||||
if np.any(np.isinf(reward)):
|
||||
logger.warn("Encountered inf value in rewards.")
|
||||
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), "The `info` returned by `step()` must be a python dictionary"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def passive_env_render_check(env, **kwargs):
|
||||
"""A passive check of the `Env.render` that the declared render modes/fps in the metadata of the environment is decleared."""
|
||||
render_modes = env.metadata.get("render_modes")
|
||||
if render_modes is None:
|
||||
logger.warn(
|
||||
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), "
|
||||
"you may have trouble when calling `.render()`"
|
||||
)
|
||||
|
||||
render_fps = env.metadata.get("render_fps")
|
||||
# We only require `render_fps` if rendering is actually implemented
|
||||
if render_fps is None:
|
||||
logger.warn(
|
||||
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), "
|
||||
"rendering may occur at inconsistent fps"
|
||||
)
|
||||
|
||||
return env.render(**kwargs)
|
@@ -1,6 +1,7 @@
|
||||
"""Module for vector environments."""
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
import gym
|
||||
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||
from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
|
||||
@@ -13,6 +14,7 @@ def make(
|
||||
num_envs: int = 1,
|
||||
asynchronous: bool = True,
|
||||
wrappers: Optional[Union[callable, List[callable]]] = None,
|
||||
disable_env_checker: bool = False,
|
||||
**kwargs,
|
||||
) -> VectorEnv:
|
||||
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||
@@ -32,26 +34,36 @@ def make(
|
||||
num_envs: Number of copies of the environment.
|
||||
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
|
||||
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
|
||||
disable_env_checker: If to disable the env checker, if True it will only run on the first environment created.
|
||||
**kwargs: Keywords arguments applied during gym.make
|
||||
|
||||
Returns:
|
||||
The vectorized environment.
|
||||
"""
|
||||
from gym.envs import make as make_
|
||||
|
||||
def _make_env():
|
||||
env = make_(id, **kwargs)
|
||||
if wrappers is not None:
|
||||
if callable(wrappers):
|
||||
env = wrappers(env)
|
||||
elif isinstance(wrappers, Iterable) and all(
|
||||
[callable(w) for w in wrappers]
|
||||
):
|
||||
for wrapper in wrappers:
|
||||
env = wrapper(env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
def create_env(_disable_env_checker):
|
||||
"""Creates an environment that can enable or disable the environment checker."""
|
||||
|
||||
env_fns = [_make_env for _ in range(num_envs)]
|
||||
def _make_env():
|
||||
env = gym.envs.registration.make(
|
||||
id, disable_env_checker=_disable_env_checker, **kwargs
|
||||
)
|
||||
if wrappers is not None:
|
||||
if callable(wrappers):
|
||||
env = wrappers(env)
|
||||
elif isinstance(wrappers, Iterable) and all(
|
||||
[callable(w) for w in wrappers]
|
||||
):
|
||||
for wrapper in wrappers:
|
||||
env = wrapper(env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
return _make_env
|
||||
|
||||
env_fns = [
|
||||
create_env(env_num == 0 and disable_env_checker is False)
|
||||
for env_num in range(num_envs)
|
||||
]
|
||||
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
|
||||
|
57
gym/wrappers/env_checker.py
Normal file
57
gym/wrappers/env_checker.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
|
||||
from typing import Tuple, Union
|
||||
|
||||
import gym
|
||||
from gym.core import ActType, ObsType
|
||||
from gym.utils.passive_env_checker import (
|
||||
check_action_space,
|
||||
check_observation_space,
|
||||
passive_env_render_check,
|
||||
passive_env_reset_check,
|
||||
passive_env_step_check,
|
||||
)
|
||||
|
||||
|
||||
class PassiveEnvChecker(gym.Wrapper):
|
||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gym API."""
|
||||
|
||||
def __init__(self, env):
|
||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||
super().__init__(env)
|
||||
|
||||
assert hasattr(
|
||||
env, "action_space"
|
||||
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/"
|
||||
check_observation_space(env.action_space)
|
||||
assert hasattr(
|
||||
env, "observation_space"
|
||||
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
|
||||
check_action_space(env.observation_space)
|
||||
|
||||
self.checked_reset = False
|
||||
self.checked_step = False
|
||||
self.checked_render = False
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
||||
if self.checked_step is False:
|
||||
self.checked_step = True
|
||||
return passive_env_step_check(self.env, action)
|
||||
else:
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
||||
if self.checked_reset is False:
|
||||
self.checked_reset = True
|
||||
return passive_env_reset_check(self.env, **kwargs)
|
||||
else:
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def render(self, **kwargs):
|
||||
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
|
||||
if self.checked_render is False:
|
||||
self.checked_render = True
|
||||
return passive_env_render_check(self.env, **kwargs)
|
||||
else:
|
||||
return self.env.render(**kwargs)
|
@@ -1,8 +1,8 @@
|
||||
"""Test environment determinism by performing a rollout."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.utils.env_checker import data_equivalence
|
||||
from tests.envs.spec_list import spec_list
|
||||
|
||||
|
||||
@@ -36,23 +36,25 @@ def test_env(spec):
|
||||
env1.action_space.seed(SEED)
|
||||
env2.action_space.seed(SEED)
|
||||
|
||||
assert_equals(initial_observation1, initial_observation2)
|
||||
assert data_equivalence(
|
||||
initial_observation1, initial_observation2
|
||||
), f"Initial Observations 1 and 2 are not equivalent. initial obs 1={initial_observation1}, initial obs 2={initial_observation2}"
|
||||
|
||||
for i in range(NUM_STEPS):
|
||||
action1 = env1.action_space.sample()
|
||||
action2 = env2.action_space.sample()
|
||||
|
||||
try:
|
||||
assert_equals(action1, action2)
|
||||
assert data_equivalence(
|
||||
action1, action2
|
||||
), f"Action 1 and 2 are not equivalent. action 1={action1}, action 2={action2}"
|
||||
except AssertionError:
|
||||
print("env1.action_space=", env1.action_space)
|
||||
print("env2.action_space=", env2.action_space)
|
||||
print("action_samples1=", action1)
|
||||
print("action_samples2=", action2)
|
||||
print(f"[{i}] action_sample1: {action1}, action_sample2: {action2}")
|
||||
print(f"env 1 action space={env1.action_space}")
|
||||
print(f"env 2 action space={env2.action_space}")
|
||||
print(f"[{i}] action sample 1={action1}, action sample 2={action2}")
|
||||
raise
|
||||
|
||||
# Don't check rollout equality if it's a a nondeterministic
|
||||
# Don't check rollout equality if it's a nondeterministic
|
||||
# environment.
|
||||
if spec.nondeterministic:
|
||||
return
|
||||
@@ -60,14 +62,18 @@ def test_env(spec):
|
||||
obs1, rew1, done1, info1 = env1.step(action1)
|
||||
obs2, rew2, done2, info2 = env2.step(action2)
|
||||
|
||||
assert_equals(obs1, obs2, f"[{i}] ")
|
||||
assert data_equivalence(
|
||||
obs1, obs2
|
||||
), f"Observation 1 and 2 are not equivalent. obs 1={obs1}, obs 2={obs2}"
|
||||
|
||||
assert env1.observation_space.contains(obs1)
|
||||
assert env2.observation_space.contains(obs2)
|
||||
|
||||
assert rew1 == rew2, f"[{i}] r1: {rew1}, r2: {rew2}"
|
||||
assert done1 == done2, f"[{i}] d1: {done1}, d2: {done2}"
|
||||
assert_equals(info1, info2, f"[{i}] ")
|
||||
assert rew1 == rew2, f"[{i}] reward1: {rew1}, reward2: {rew2}"
|
||||
assert done1 == done2, f"[{i}] done1: {done1}, done2: {done2}"
|
||||
assert data_equivalence(
|
||||
info1, info2
|
||||
), f"Info 1 and 2 are not equivalent. info 1={info1}, info 2={info2}"
|
||||
|
||||
if done1: # done2 verified in previous assertion
|
||||
env1.reset(seed=SEED)
|
||||
@@ -75,29 +81,3 @@ def test_env(spec):
|
||||
|
||||
env1.close()
|
||||
env2.close()
|
||||
|
||||
|
||||
def assert_equals(a, b, prefix=None):
|
||||
"""Assert equality of data structures `a` and `b`.
|
||||
|
||||
Args:
|
||||
a: first data structure
|
||||
b: second data structure
|
||||
prefix: prefix for failed assertion message for types and dicts
|
||||
|
||||
"""
|
||||
assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
|
||||
if isinstance(a, dict):
|
||||
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
|
||||
|
||||
for k in a.keys():
|
||||
v_a = a[k]
|
||||
v_b = b[k]
|
||||
assert_equals(v_a, v_b)
|
||||
elif isinstance(a, np.ndarray):
|
||||
np.testing.assert_array_equal(a, b)
|
||||
elif isinstance(a, tuple):
|
||||
for elem_from_a, elem_from_b in zip(a, b):
|
||||
assert_equals(elem_from_a, elem_from_b)
|
||||
else:
|
||||
assert a == b
|
||||
|
@@ -22,7 +22,7 @@ def test_env(spec):
|
||||
env = spec.make()
|
||||
|
||||
# Test if env adheres to Gym API
|
||||
check_env(env, warn=True, skip_render_check=True)
|
||||
check_env(env, skip_render_check=True)
|
||||
|
||||
# Check that dtype is explicitly declared for gym.Box spaces
|
||||
for warning_msg in warnings:
|
||||
|
Reference in New Issue
Block a user