mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
164 lines
5.5 KiB
Python
164 lines
5.5 KiB
Python
import pickle
|
|
import re
|
|
import warnings
|
|
|
|
import pytest
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.envs.registration import EnvSpec
|
|
from gymnasium.utils.env_checker import check_env, data_equivalence
|
|
from tests.envs.utils import all_testing_env_specs, all_testing_initialised_envs
|
|
|
|
|
|
# This runs a smoketest on each official registered env. We may want
|
|
# to try also running environments which are not officially registered envs.
|
|
PASSIVE_CHECK_IGNORE_WARNING = [
|
|
r"\x1b\[33mWARN: The environment (.*?) is out of date\. You should consider upgrading to version `v(\d)`\.\x1b\[0m",
|
|
]
|
|
|
|
|
|
CHECK_ENV_IGNORE_WARNINGS = [
|
|
f"\x1b[33mWARN: {message}\x1b[0m"
|
|
for message in [
|
|
"A Box observation space minimum value is -infinity. This is probably too low.",
|
|
"A Box observation space maximum value is infinity. This is probably too high.",
|
|
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
|
|
]
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spec",
|
|
all_testing_env_specs,
|
|
ids=[spec.id for spec in all_testing_env_specs],
|
|
)
|
|
def test_all_env_api(spec):
|
|
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
env = spec.make().unwrapped
|
|
|
|
check_env(env, skip_render_check=True)
|
|
|
|
env.close()
|
|
|
|
for warning in caught_warnings:
|
|
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
|
|
raise gym.error.Error(f"Unexpected warning: {warning.message}")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
|
|
)
|
|
def test_all_env_passive_env_checker(spec):
|
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
env = gym.make(spec.id)
|
|
env.reset()
|
|
env.step(env.action_space.sample())
|
|
|
|
env.close()
|
|
|
|
passive_check_pattern = re.compile("|".join(PASSIVE_CHECK_IGNORE_WARNING))
|
|
|
|
for warning in caught_warnings:
|
|
if not passive_check_pattern.search(str(warning.message)):
|
|
raise ValueError(f"Unexpected warning: {warning.message}")
|
|
|
|
|
|
# Note that this precludes running this test in multiple threads.
|
|
# However, we probably already can't do multithreading due to some environments.
|
|
SEED = 0
|
|
NUM_STEPS = 50
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_spec",
|
|
all_testing_env_specs,
|
|
ids=[env.id for env in all_testing_env_specs],
|
|
)
|
|
def test_env_determinism_rollout(env_spec: EnvSpec):
|
|
"""Run a rollout with two environments and assert equality.
|
|
|
|
This test run a rollout of NUM_STEPS steps with two environments
|
|
initialized with the same seed and assert that:
|
|
|
|
- observation after first reset are the same
|
|
- same actions are sampled by the two envs
|
|
- observations are contained in the observation space
|
|
- obs, rew, done and info are equals between the two envs
|
|
"""
|
|
# Don't check rollout equality if it's a nondeterministic environment.
|
|
if env_spec.nondeterministic is True:
|
|
pytest.skip(f"Skipping {env_spec.id} as it is non-deterministic")
|
|
|
|
env_1 = env_spec.make(disable_env_checker=True)
|
|
env_2 = env_spec.make(disable_env_checker=True)
|
|
|
|
if env_1.metadata.get("jax", False):
|
|
env_1 = gym.wrappers.JaxToNumpy(env_1)
|
|
env_2 = gym.wrappers.JaxToNumpy(env_2)
|
|
|
|
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
|
|
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
|
|
assert data_equivalence(initial_obs_1, initial_obs_2, exact=True)
|
|
|
|
env_1.action_space.seed(SEED)
|
|
|
|
for time_step in range(NUM_STEPS):
|
|
# We don't evaluate the determinism of actions
|
|
action = env_1.action_space.sample()
|
|
|
|
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
|
|
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
|
|
|
|
assert data_equivalence(
|
|
obs_1, obs_2, exact=True
|
|
), f"[{time_step}] obs_1={obs_1}, obs_2={obs_2}"
|
|
assert env_1.observation_space.contains(
|
|
obs_1
|
|
) # obs_2 verified by previous assertion
|
|
|
|
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
|
|
assert (
|
|
terminated_1 == terminated_2
|
|
), f"[{time_step}] done 1={terminated_1}, done 2={terminated_2}"
|
|
assert (
|
|
truncated_1 == truncated_2
|
|
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
|
|
assert data_equivalence(
|
|
info_1, info_2, exact=True
|
|
), f"[{time_step}] info_1={info_1}, info_2={info_2}"
|
|
|
|
if (
|
|
terminated_1 or truncated_1
|
|
): # terminated_2, truncated_2 verified by previous assertion
|
|
env_1.reset(seed=SEED)
|
|
env_2.reset(seed=SEED)
|
|
|
|
env_1.close()
|
|
env_2.close()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env",
|
|
all_testing_initialised_envs,
|
|
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
|
)
|
|
def test_pickle_env(env: gym.Env):
|
|
if env.metadata.get("jax", False):
|
|
env = gym.wrappers.JaxToNumpy(env)
|
|
|
|
action = env.action_space.sample()
|
|
|
|
env_reset = env.reset(seed=123)
|
|
env_step = env.step(action)
|
|
|
|
pickled_env = pickle.loads(pickle.dumps(env))
|
|
pickle_reset = pickled_env.reset(seed=123)
|
|
pickle_step = pickled_env.step(action)
|
|
|
|
assert data_equivalence(env_reset, pickle_reset)
|
|
assert data_equivalence(env_step, pickle_step)
|
|
|
|
env.close()
|
|
pickled_env.close()
|