Files
Gymnasium/tests/envs/registration/test_make.py
2023-03-13 11:10:28 +00:00

583 lines
20 KiB
Python

"""Tests that `gym.make` works as expected."""
from __future__ import annotations
import re
import warnings
import numpy as np
import pytest
import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.error import NameNotFound
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.wrappers import (
AutoResetWrapper,
HumanRendering,
OrderEnforcing,
TimeLimit,
)
from gymnasium.wrappers.env_checker import PassiveEnvChecker
from tests.envs.registration.utils_envs import ArgumentEnv
from tests.envs.utils import all_testing_env_specs
from tests.testing_env import GenericTestEnv, old_reset_func, old_step_func
from tests.wrappers.utils import has_wrapper
# Tests
# * basic example
# * parameters (equivalent for str and EnvSpec)
# 1. max_episode_steps
# 2. autoreset
# 3. apply_api_compatibility
# 4. disable_env_checker
# * rendering
# 1. render_mode
# 2. HumanRendering
# 3. RenderCollection
# * make kwargs
# * make import module
# * make env spec additional wrappers
# * env_id str errors
def test_no_arguments(env_id: str = "CartPole-v1"):
"""Test `gym.make` using str and EnvSpec with no arguments."""
env_from_id = gym.make(env_id)
assert env_from_id.spec is not None
assert env_from_id.spec.id == env_id
assert isinstance(env_from_id.unwrapped, CartPoleEnv)
env_spec = gym.spec(env_id)
env_from_spec = gym.make(env_spec)
assert env_from_spec.spec is not None
assert env_from_spec.spec.id == env_id
assert isinstance(env_from_spec.unwrapped, CartPoleEnv)
assert env_from_id.spec == env_from_spec.spec
def test_max_episode_steps(register_parameter_envs):
"""Test the `max_episode_steps` parameter in `gym.make`."""
for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
# Use the spec's value
env = gym.make(make_id)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == env_spec.max_episode_steps
# Set a custom max episode steps value
assert env_spec.max_episode_steps != 100
env = gym.make(make_id, max_episode_steps=100)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == 100, make_id
for make_id in ["NoMaxEpisodeStepsEnv-v0", gym.spec("NoMaxEpisodeStepsEnv-v0")]:
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
# env spec has no max episode steps
assert env_spec.max_episode_steps is None
env = gym.make(make_id)
assert env.spec is not None
assert env.spec.max_episode_steps is None
assert has_wrapper(env, TimeLimit) is False
# set a custom max episode steps values
env = gym.make(make_id, max_episode_steps=100)
assert env.spec is not None
assert env.spec.max_episode_steps == 100
assert has_wrapper(env, TimeLimit)
def test_autorest(register_parameter_envs):
"""Test the `autoreset` parameter in `gym.make`."""
for make_id in [
"CartPole-v1",
gym.spec("CartPole-v1"),
"AutoresetEnv-v0",
gym.spec("AutoresetEnv-v0"),
]:
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
# Use the spec's value
env = gym.make(make_id)
assert env.spec is not None
assert env.spec.autoreset == env_spec.autoreset
assert has_wrapper(env, AutoResetWrapper) is env_spec.autoreset
# Set autoreset is True
env = gym.make(make_id, autoreset=True)
assert has_wrapper(env, AutoResetWrapper)
assert env.spec is not None
assert env.spec.autoreset is True
# Set autoreset is False
env = gym.make(make_id, autoreset=False)
assert has_wrapper(env, AutoResetWrapper) is False
assert env.spec is not None
assert env.spec.autoreset is False
@pytest.mark.parametrize(
"registration_disabled, make_disabled, if_disabled",
[
[False, False, False],
[False, True, True],
[True, False, False],
[True, True, True],
[False, None, False],
[True, None, True],
],
)
def test_disable_env_checker(
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
):
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
"""
gym.register(
"DisableEnvCheckerEnv-v0",
lambda: GenericTestEnv(),
disable_env_checker=registration_disabled,
)
# Test when the registered EnvSpec.disable_env_checker = False
env = gym.make("DisableEnvCheckerEnv-v0", disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
env_spec = gym.spec("DisableEnvCheckerEnv-v0")
env = gym.make(env_spec, disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
del gym.registry["DisableEnvCheckerEnv-v0"]
def test_apply_api_compatibility(register_parameter_envs):
"""Test the `apply_api_compatibility` parameter for `gym.make`."""
# Apply the environment compatibility and check it works as intended
for make_id in ["EnabledApplyApiComp-v0", gym.spec("EnabledApplyApiComp-v0")]:
env = gym.make(make_id)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
# env has time limit of 3 enabling this test
env.reset()
assert len(env.step(env.action_space.sample())) == 5
env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
for make_id in ["DisabledApplyApiComp-v0", gym.spec("DisabledApplyApiComp-v0")]:
# Turn off the spec api compatibility
env = gym.make(make_id)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
env.reset()
with pytest.raises(
ValueError,
match=re.escape("not enough values to unpack (expected 5, got 4)"),
):
env.step(env.action_space.sample())
# Apply the environment compatibility and check it works as intended
assert env.spec is not None
assert env.spec.apply_api_compatibility is False
env = gym.make(make_id, apply_api_compatibility=True)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
env.reset()
assert len(env.step(env.action_space.sample())) == 5
env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
def test_order_enforcing(register_parameter_envs):
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
assert all(spec.order_enforce is False for spec in all_testing_env_specs)
for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
env = gym.make(make_id)
assert has_wrapper(env, OrderEnforcing)
for make_id in ["OrderlessEnv-v0", gym.spec("OrderlessEnv-v0")]:
env = gym.make(make_id)
assert has_wrapper(env, OrderEnforcing) is False
# There is no `make(..., order_enforcing=...)` so we don't test that
def test_make_with_render_mode():
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
env = gym.make("CartPole-v1", render_mode=None)
assert env.render_mode is None
env.close()
assert "rgb_array" in env.metadata["render_modes"]
env = gym.make("CartPole-v1", render_mode="rgb_array")
assert env.render_mode == "rgb_array"
env.close()
assert "no-render-mode" not in env.metadata["render_modes"]
# cartpole is special that it doesn't check the render_mode passed at initialisation
with pytest.warns(
UserWarning,
match=re.escape(
"\x1b[33mWARN: The environment is being initialised with render_mode='no-render-mode' that is not in the possible render_modes (['human', 'rgb_array']).\x1b[0m"
),
):
env = gym.make("CartPole-v1", render_mode="no-render-mode")
assert env.render_mode == "no-render-mode"
env.close()
def test_make_render_collection():
# Make sure that render_mode is applied correctly
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
assert has_wrapper(env, gym.wrappers.RenderCollection)
assert env.render_mode == "rgb_array_list"
assert env.unwrapped.render_mode == "rgb_array"
env.reset()
renders = env.render()
assert isinstance(
renders, list
) # Make sure that the `render` method does what is supposed to
assert isinstance(renders[0], np.ndarray)
env.close()
def test_make_human_rendering(register_rendering_testing_envs):
# Make sure that native rendering is used when possible
env = gym.make("CartPole-v1", render_mode="human")
assert (
has_wrapper(env, HumanRendering) is False
) # Should use native human-rendering
assert env.render_mode == "human"
env.close()
with pytest.warns(
UserWarning,
match=re.escape(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. The HumanRendering wrapper is being applied to your environment."
),
):
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
env = gym.make("NoHumanRendering-v0", render_mode="human")
assert has_wrapper(env, HumanRendering)
assert env.render_mode == "human"
env.close()
with pytest.raises(
TypeError, match=re.escape("got an unexpected keyword argument 'render_mode'")
):
gym.make(
"NoHumanRenderingOldAPI-v0",
render_mode="rgb_array_list",
)
# Make sure that an additional error is thrown a user tries to use the wrapper on an environment with old API
with warnings.catch_warnings(record=True):
with pytest.raises(
gym.error.Error,
match=re.escape(
"You passed render_mode='human' although NoHumanRenderingOldAPI-v0 doesn't implement human-rendering natively."
),
):
gym.make("NoHumanRenderingOldAPI-v0", render_mode="human")
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
# a keyword that is not `render_mode`
with pytest.raises(
TypeError,
match=re.escape("got an unexpected keyword argument 'render'"),
):
gym.make("CarRacing-v2", render="human")
# This test checks that a user can create an environment without the metadata including the render mode
with pytest.warns(
UserWarning,
match=re.escape(
"\x1b[33mWARN: The environment is being initialised with render_mode='rgb_array' that is not in the possible render_modes ([]).\x1b[0m"
),
):
gym.make("NoRenderModesMetadata-v0", render_mode="rgb_array")
def test_make_kwargs(register_kwargs_env):
env = gym.make(
"test.ArgumentEnv-v0",
arg2="override_arg2",
arg3="override_arg3",
)
assert env.spec is not None
assert env.spec.id == "test.ArgumentEnv-v0"
assert env.spec.kwargs == {
"arg1": "arg1",
"arg2": "override_arg2",
"arg3": "override_arg3",
}
assert isinstance(env.unwrapped, ArgumentEnv)
assert env.arg1 == "arg1"
assert env.arg2 == "override_arg2"
assert env.arg3 == "override_arg3"
env.close()
def test_import_module_during_make():
# Test custom environment which is registered at make
assert "RegisterDuringMake-v0" not in gym.registry
env = gym.make(
"tests.envs.registration.utils_unregistered_env:RegisterDuringMake-v0"
)
assert "RegisterDuringMake-v0" in gym.registry
from tests.envs.registration.utils_unregistered_env import RegisterDuringMakeEnv
assert isinstance(env.unwrapped, RegisterDuringMakeEnv)
env.close()
del gym.registry["RegisterDuringMake-v0"]
class NoRecordArgsWrapper(gym.ObservationWrapper):
def __init__(self, env: Env[ObsType, ActType]):
super().__init__(env)
def observation(self, observation: ObsType) -> WrapperObsType:
return self.observation_space.sample()
def test_make_with_env_spec():
# make
id_env = gym.make("CartPole-v1")
spec_env = gym.make(gym.spec("CartPole-v1"))
assert id_env.spec == spec_env.spec
# make with applied wrappers
env_2 = gym.wrappers.NormalizeReward(
gym.wrappers.TimeAwareObservation(
gym.wrappers.FlattenObservation(
gym.make("CartPole-v1", render_mode="rgb_array")
)
),
gamma=0.8,
)
env_2_recreated = gym.make(env_2.spec)
assert env_2.spec == env_2_recreated.spec
env_2.close()
env_2_recreated.close()
# make with callable entry point
gym.register("CartPole-v2", lambda: CartPoleEnv())
env_3 = gym.make("CartPole-v2")
assert isinstance(env_3.unwrapped, CartPoleEnv)
env_3.close()
# make with wrapper in env-creator
gym.register(
"CartPole-v3",
lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()),
disable_env_checker=True,
order_enforce=False,
)
env_4 = gym.make(gym.spec("CartPole-v3"))
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
assert isinstance(env_4.env, CartPoleEnv)
env_4.close()
gym.register(
"CartPole-v4",
lambda: CartPoleEnv(),
disable_env_checker=True,
order_enforce=False,
additional_wrappers=(gym.wrappers.TimeAwareObservation.wrapper_spec(),),
)
env_5 = gym.make(gym.spec("CartPole-v4"))
assert isinstance(env_5, gym.wrappers.TimeAwareObservation)
assert isinstance(env_5.env, CartPoleEnv)
env_5.close()
# make with no ezpickle wrapper
env_6 = NoRecordArgsWrapper(gym.make("CartPole-v1"))
with pytest.raises(
ValueError,
match=re.escape(
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
),
):
gym.make(env_6.spec)
# make with no ezpickle wrapper but in the entry point
gym.register(
"CartPole-v5",
entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()),
disable_env_checker=True,
order_enforce=False,
)
env_7 = gym.make(gym.spec("CartPole-v5"))
assert isinstance(env_7, NoRecordArgsWrapper)
assert isinstance(env_7.unwrapped, CartPoleEnv)
gym.register(
"CartPole-v6",
entry_point=lambda: CartPoleEnv(),
disable_env_checker=True,
order_enforce=False,
additional_wrappers=(NoRecordArgsWrapper.wrapper_spec(),),
)
del gym.registry["CartPole-v2"]
del gym.registry["CartPole-v3"]
del gym.registry["CartPole-v4"]
del gym.registry["CartPole-v5"]
del gym.registry["CartPole-v6"]
def test_make_with_env_spec_levels():
"""Test that we can recreate the environment at each 'level'."""
env = gym.wrappers.NormalizeReward(
gym.wrappers.TimeAwareObservation(
gym.wrappers.FlattenObservation(
gym.make("CartPole-v1", render_mode="rgb_array")
)
),
gamma=0.8,
)
while env is not env.unwrapped:
recreated_env = gym.make(env.spec)
assert env.spec == recreated_env.spec
env = env.env
def test_wrapped_env_entry_point():
def _create_env():
_env = gym.make("CartPole-v1", render_mode="rgb_array")
_env = gym.wrappers.FlattenObservation(_env)
return _env
gym.register("TestingEnv-v0", entry_point=_create_env)
env = gym.make("TestingEnv-v0")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
recreated_env = gym.make(env.spec)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
del gym.registry["TestingEnv-v0"]
def test_make_errors():
"""Test make with a deprecated environment (i.e., doesn't exist)."""
with warnings.catch_warnings(record=True):
with pytest.raises(
gym.error.Error,
match=re.escape(
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
),
):
gym.make("Humanoid-v0")
with pytest.raises(
NameNotFound, match=re.escape("Environment `NonExistenceEnv` doesn't exist.")
):
gym.make("NonExistenceEnv-v0")
@pytest.fixture(scope="function")
def register_parameter_envs():
gym.register(
"NoMaxEpisodeStepsEnv-v0", lambda: GenericTestEnv(), max_episode_steps=None
)
gym.register("AutoresetEnv-v0", lambda: GenericTestEnv(), autoreset=True)
gym.register(
"EnabledApplyApiComp-v0",
lambda: GenericTestEnv(step_func=old_step_func, reset_func=old_reset_func),
apply_api_compatibility=True,
max_episode_steps=3,
)
gym.register(
"DisabledApplyApiComp-v0",
lambda: GenericTestEnv(step_func=old_step_func, reset_func=old_reset_func),
apply_api_compatibility=False,
max_episode_steps=3,
)
gym.register("OrderlessEnv-v0", lambda: GenericTestEnv(), order_enforce=False)
yield
del gym.registry["NoMaxEpisodeStepsEnv-v0"]
del gym.registry["AutoresetEnv-v0"]
del gym.registry["EnabledApplyApiComp-v0"]
del gym.registry["DisabledApplyApiComp-v0"]
del gym.registry["OrderlessEnv-v0"]
@pytest.fixture(scope="function")
def register_kwargs_env():
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
@pytest.fixture(scope="function")
def register_rendering_testing_envs():
gym.register(
id="NoHumanRendering-v0",
entry_point="tests.envs.registration.utils_envs:NoHuman",
)
gym.register(
id="NoHumanRenderingOldAPI-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
)
gym.register(
id="NoHumanRenderingNoRGB-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
)
gym.register(
id="NoRenderModesMetadata-v0",
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
)
yield
del gym.envs.registration.registry["NoHumanRendering-v0"]
del gym.envs.registration.registry["NoHumanRenderingOldAPI-v0"]
del gym.envs.registration.registry["NoHumanRenderingNoRGB-v0"]
del gym.envs.registration.registry["NoRenderModesMetadata-v0"]