2022-05-20 14:49:30 +01:00
|
|
|
"""A set of functions for checking an environment details.
|
|
|
|
|
2021-08-12 12:35:09 -05:00
|
|
|
This file is originally from the Stable Baselines3 repository hosted on GitHub
|
|
|
|
(https://github.com/DLR-RM/stable-baselines3/)
|
|
|
|
Original Author: Antonin Raffin
|
|
|
|
|
|
|
|
It also uses some warnings/assertions from the PettingZoo repository hosted on GitHub
|
|
|
|
(https://github.com/PettingZoo-Team/PettingZoo)
|
2021-12-21 14:05:40 -05:00
|
|
|
Original Author: J K Terry
|
2021-08-12 12:35:09 -05:00
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
This was rewritten and split into "env_checker.py" and "passive_env_checker.py" for invasive and passive environment checking
|
|
|
|
Original Author: Mark Towers
|
|
|
|
|
2021-08-12 12:35:09 -05:00
|
|
|
These projects are covered by the MIT License.
|
|
|
|
"""
|
|
|
|
|
2022-01-19 23:28:59 +01:00
|
|
|
import inspect
|
2022-06-06 16:21:45 +01:00
|
|
|
from copy import deepcopy
|
2021-08-12 12:35:09 -05:00
|
|
|
|
|
|
|
import numpy as np
|
2022-03-31 12:50:38 -07:00
|
|
|
|
|
|
|
import gym
|
2022-06-06 16:21:45 +01:00
|
|
|
from gym import error, logger
|
|
|
|
from gym.utils.passive_env_checker import (
|
|
|
|
check_action_space,
|
|
|
|
check_observation_space,
|
|
|
|
passive_env_reset_check,
|
|
|
|
passive_env_step_check,
|
|
|
|
)
|
2021-08-12 12:35:09 -05:00
|
|
|
|
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
def data_equivalence(data_1, data_2) -> bool:
|
|
|
|
"""Assert equality between data 1 and 2, i.e observations, actions, info.
|
2021-08-12 12:35:09 -05:00
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Args:
|
2022-06-06 16:21:45 +01:00
|
|
|
data_1: data structure 1
|
|
|
|
data_2: data structure 2
|
2021-08-12 12:35:09 -05:00
|
|
|
|
2022-05-20 14:49:30 +01:00
|
|
|
Returns:
|
2022-06-06 16:21:45 +01:00
|
|
|
If observation 1 and 2 are equivalent
|
2021-08-12 12:35:09 -05:00
|
|
|
"""
|
2022-06-06 16:21:45 +01:00
|
|
|
if type(data_1) == type(data_2):
|
|
|
|
if isinstance(data_1, dict):
|
|
|
|
return data_1.keys() == data_2.keys() and all(
|
|
|
|
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
|
|
|
)
|
|
|
|
elif isinstance(data_1, tuple):
|
|
|
|
return len(data_1) == len(data_2) and all(
|
|
|
|
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
|
|
|
)
|
|
|
|
elif isinstance(data_1, np.ndarray):
|
|
|
|
return np.all(data_1 == data_2)
|
|
|
|
else:
|
|
|
|
return data_1 == data_2
|
|
|
|
else:
|
|
|
|
return False
|
2021-08-12 12:35:09 -05:00
|
|
|
|
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
def check_reset_seed(env: gym.Env):
|
|
|
|
"""Check that the environment can be reset with a seed.
|
2022-05-20 14:49:30 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
env: The environment to check
|
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
Raises:
|
|
|
|
AssertionError: The environment cannot be reset with a random seed,
|
|
|
|
even though `seed` or `kwargs` appear in the signature.
|
2021-08-12 12:35:09 -05:00
|
|
|
"""
|
2022-06-06 16:21:45 +01:00
|
|
|
signature = inspect.signature(env.reset)
|
|
|
|
if "seed" in signature.parameters or "kwargs" in signature.parameters:
|
|
|
|
try:
|
|
|
|
obs_1 = env.reset(seed=123)
|
|
|
|
assert obs_1 in env.observation_space
|
|
|
|
obs_2 = env.reset(seed=123)
|
|
|
|
assert obs_2 in env.observation_space
|
|
|
|
assert data_equivalence(obs_1, obs_2)
|
|
|
|
seed_123_rng = deepcopy(env.unwrapped.np_random)
|
|
|
|
|
|
|
|
# Note: for some environment, they may initialise at the same state, therefore we cannot check the obs_1 != obs_3
|
|
|
|
obs_4 = env.reset(seed=None)
|
|
|
|
assert obs_4 in env.observation_space
|
|
|
|
|
|
|
|
assert (
|
|
|
|
env.unwrapped.np_random.bit_generator.state
|
|
|
|
!= seed_123_rng.bit_generator.state
|
|
|
|
)
|
|
|
|
except TypeError as e:
|
|
|
|
raise AssertionError(
|
|
|
|
"The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
|
|
|
|
"This should never happen, please report this issue. "
|
|
|
|
f"The error was: {e}"
|
|
|
|
)
|
2021-08-12 12:35:09 -05:00
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
if env.unwrapped.np_random is None:
|
|
|
|
logger.warn(
|
|
|
|
"Resetting the environment did not result in seeding its random number generator. "
|
|
|
|
"This is likely due to not calling `super().reset(seed=seed)` in the `reset` method. "
|
|
|
|
"If you do not use the python-level random number generator, this is not a problem."
|
|
|
|
)
|
2022-05-20 14:49:30 +01:00
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
seed_param = signature.parameters.get("seed")
|
|
|
|
# Check the default value is None
|
|
|
|
if seed_param is not None and seed_param.default is not None:
|
|
|
|
logger.warn(
|
|
|
|
"The default seed argument in reset should be `None`, "
|
|
|
|
"otherwise the environment will by default always be deterministic"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise error.Error(
|
|
|
|
"The `reset` method does not provide the `return_info` keyword argument"
|
2021-08-12 12:35:09 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
def check_reset_info(env: gym.Env):
|
|
|
|
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
|
2022-05-20 14:49:30 +01:00
|
|
|
|
|
|
|
Args:
|
2022-06-06 16:21:45 +01:00
|
|
|
env: The environment to check
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
AssertionError: The environment cannot be reset with `return_info=True`,
|
|
|
|
even though `return_info` or `kwargs` appear in the signature.
|
2022-05-20 14:49:30 +01:00
|
|
|
"""
|
2022-06-06 16:21:45 +01:00
|
|
|
signature = inspect.signature(env.reset)
|
|
|
|
if "return_info" in signature.parameters or "kwargs" in signature.parameters:
|
|
|
|
try:
|
|
|
|
result = env.reset(return_info=True)
|
|
|
|
assert (
|
|
|
|
len(result) == 2
|
|
|
|
), "Calling the reset method with `return_info=True` did not return a 2-tuple"
|
|
|
|
obs, info = result
|
|
|
|
assert isinstance(
|
|
|
|
info, dict
|
|
|
|
), "The second element returned by `env.reset(return_info=True)` was not a dictionary"
|
|
|
|
except TypeError as e:
|
|
|
|
raise AssertionError(
|
|
|
|
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` "
|
|
|
|
"appear in the signature. This should never happen, please report this issue. "
|
|
|
|
f"The error was: {e}"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise error.Error(
|
|
|
|
"The `reset` method does not provide the `return_info` keyword argument"
|
2021-08-12 12:35:09 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
def check_reset_options(env: gym.Env):
|
|
|
|
"""Check that the environment can be reset with options.
|
2022-05-20 14:49:30 +01:00
|
|
|
|
|
|
|
Args:
|
2022-06-06 16:21:45 +01:00
|
|
|
env: The environment to check
|
2022-05-25 14:46:41 +01:00
|
|
|
|
|
|
|
Raises:
|
2022-06-06 16:21:45 +01:00
|
|
|
AssertionError: The environment cannot be reset with options,
|
|
|
|
even though `options` or `kwargs` appear in the signature.
|
2021-08-12 12:35:09 -05:00
|
|
|
"""
|
2022-06-06 16:21:45 +01:00
|
|
|
signature = inspect.signature(env.reset)
|
|
|
|
if "options" in signature.parameters or "kwargs" in signature.parameters:
|
|
|
|
try:
|
|
|
|
env.reset(options={})
|
|
|
|
except TypeError as e:
|
|
|
|
raise AssertionError(
|
|
|
|
"The environment cannot be reset with options, even though `options` or `kwargs` appear in the signature. "
|
|
|
|
"This should never happen, please report this issue. "
|
|
|
|
f"The error was: {e}"
|
|
|
|
)
|
2021-08-12 12:35:09 -05:00
|
|
|
else:
|
2022-06-06 16:21:45 +01:00
|
|
|
raise error.Error(
|
|
|
|
"The `reset` method does not provide the `options` keyword argument"
|
|
|
|
)
|
2021-08-12 12:35:09 -05:00
|
|
|
|
|
|
|
|
|
|
|
# Check render cannot be covered by CI
|
2022-06-06 16:21:45 +01:00
|
|
|
def check_render(env: gym.Env, headless: bool = False):
|
2022-05-20 14:49:30 +01:00
|
|
|
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
env: The environment to check
|
|
|
|
headless: Whether to disable render modes that require a graphical interface. False by default.
|
2021-08-12 12:35:09 -05:00
|
|
|
"""
|
2022-02-28 15:54:03 -05:00
|
|
|
render_modes = env.metadata.get("render_modes")
|
2021-08-12 12:35:09 -05:00
|
|
|
if render_modes is None:
|
2022-06-06 16:21:45 +01:00
|
|
|
logger.warn(
|
|
|
|
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`"
|
|
|
|
)
|
2021-08-12 12:35:09 -05:00
|
|
|
|
2022-03-31 22:09:43 +02:00
|
|
|
render_fps = env.metadata.get("render_fps")
|
|
|
|
# We only require `render_fps` if rendering is actually implemented
|
2022-06-06 16:21:45 +01:00
|
|
|
if render_fps is None:
|
|
|
|
logger.warn(
|
|
|
|
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps"
|
|
|
|
)
|
2022-03-31 22:09:43 +02:00
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
if render_modes is not None:
|
2021-08-12 12:35:09 -05:00
|
|
|
# Don't check render mode that require a
|
|
|
|
# graphical interface (useful for CI)
|
|
|
|
if headless and "human" in render_modes:
|
|
|
|
render_modes.remove("human")
|
2022-06-06 16:21:45 +01:00
|
|
|
|
2021-08-12 12:35:09 -05:00
|
|
|
# Check all declared render modes
|
2022-06-06 16:21:45 +01:00
|
|
|
for mode in render_modes:
|
|
|
|
env.render(mode=mode)
|
2021-08-12 12:35:09 -05:00
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):
|
2022-05-20 14:49:30 +01:00
|
|
|
"""Check that an environment follows Gym API.
|
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
This is an invasive function that calls the environment's reset and step.
|
|
|
|
|
2021-08-12 12:35:09 -05:00
|
|
|
This is particularly useful when using a custom environment.
|
2022-06-06 16:21:45 +01:00
|
|
|
Please take a look at https://www.gymlibrary.ml/content/environment_creation/
|
2021-08-12 12:35:09 -05:00
|
|
|
for more information about the API.
|
2022-05-20 14:49:30 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
env: The Gym environment that will be checked
|
2022-06-06 16:21:45 +01:00
|
|
|
warn: Ignored
|
2022-05-20 14:49:30 +01:00
|
|
|
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
|
2021-08-12 12:35:09 -05:00
|
|
|
"""
|
2022-06-06 16:21:45 +01:00
|
|
|
if warn is not None:
|
|
|
|
logger.warn("`check_env` warn parameter is now ignored.")
|
|
|
|
|
2021-08-12 12:35:09 -05:00
|
|
|
assert isinstance(
|
|
|
|
env, gym.Env
|
2022-06-06 16:21:45 +01:00
|
|
|
), "Your environment must inherit from the gym.Env class https://www.gymlibrary.ml/content/environment_creation/"
|
2021-08-12 12:35:09 -05:00
|
|
|
|
|
|
|
# ============= Check the spaces (observation and action) ================
|
2022-06-06 16:21:45 +01:00
|
|
|
assert hasattr(
|
|
|
|
env, "action_space"
|
|
|
|
), "You must specify a action space. https://www.gymlibrary.ml/content/environment_creation/"
|
|
|
|
check_observation_space(env.action_space)
|
|
|
|
assert hasattr(
|
|
|
|
env, "observation_space"
|
|
|
|
), "You must specify an observation space. https://www.gymlibrary.ml/content/environment_creation/"
|
|
|
|
check_action_space(env.observation_space)
|
2021-08-12 12:35:09 -05:00
|
|
|
|
2022-06-06 16:21:45 +01:00
|
|
|
# ==== Check the reset method ====
|
|
|
|
check_reset_seed(env)
|
|
|
|
check_reset_options(env)
|
|
|
|
check_reset_info(env)
|
2021-08-12 12:35:09 -05:00
|
|
|
|
|
|
|
# ============ Check the returned values ===============
|
2022-06-06 16:21:45 +01:00
|
|
|
passive_env_reset_check(env)
|
|
|
|
passive_env_step_check(env, env.action_space.sample())
|
2021-08-12 12:35:09 -05:00
|
|
|
|
|
|
|
# ==== Check the render method and the declared render modes ====
|
|
|
|
if not skip_render_check:
|
2022-06-06 16:21:45 +01:00
|
|
|
check_render(env)
|