Files
Gymnasium/tests/envs/test_registration.py
Ariel Kwiatkowski 00a60e6cc8 Rewriting of the registration mechanism (#2748)
* First version of the new registration

* Almost done

* Hopefully final commit

* Minor fixes

* Missing error

* Type fixes

* Type fixes

* Add some type hinting stuff

* Fix an error?

* Fix literal import

* Add a comment

* Add some docstrings

Remove old tests

* Add some docstrings, rename helper functions

* Rename a function

* Registration check fix

* Consistently use `register` instead of `envs.register` in tests

* Fix the malformed registration error message to not use a write-only format

* Change an error back to a warning when double-registering an environment
2022-04-21 14:41:15 -04:00

276 lines
7.6 KiB
Python

import pytest
import gym
from gym import envs, error
from gym.envs import register, registration, registry, spec
from gym.envs.classic_control import cartpole
from gym.envs.registration import EnvSpec
class ArgumentEnv(gym.Env):
def __init__(self, arg1, arg2, arg3):
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.test_registration:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
@pytest.fixture(scope="function")
def register_some_envs():
namespace = "MyAwesomeNamespace"
versioned_name = "MyAwesomeVersionedEnv"
unversioned_name = "MyAwesomeUnversionedEnv"
versions = [1, 3, 5]
for version in versions:
env_id = f"{namespace}/{versioned_name}-v{version}"
gym.register(
id=env_id,
entry_point="tests.envs.test_registration:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
"arg3": "arg3",
},
)
gym.register(
id=f"{namespace}/{unversioned_name}",
entry_point="tests.env.test_registration:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
"arg3": "arg3",
},
)
yield
for version in versions:
env_id = f"{namespace}/{versioned_name}-v{version}"
del gym.envs.registry[env_id]
del gym.envs.registry[f"{namespace}/{unversioned_name}"]
def test_make():
env = envs.make("CartPole-v1")
assert env.spec.id == "CartPole-v1"
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
@pytest.mark.parametrize(
"env_id, namespace, name, version",
[
(
"MyAwesomeNamespace/MyAwesomeEnv-v0",
"MyAwesomeNamespace",
"MyAwesomeEnv",
0,
),
("MyAwesomeEnv-v0", None, "MyAwesomeEnv", 0),
("MyAwesomeEnv", None, "MyAwesomeEnv", None),
("MyAwesomeEnv-vfinal-v0", None, "MyAwesomeEnv-vfinal", 0),
("MyAwesomeEnv-vfinal", None, "MyAwesomeEnv-vfinal", None),
("MyAwesomeEnv--", None, "MyAwesomeEnv--", None),
("MyAwesomeEnv-v", None, "MyAwesomeEnv-v", None),
],
)
def test_register(env_id, namespace, name, version):
register(env_id)
assert gym.envs.spec(env_id).id == env_id
full_name = f"{name}"
if namespace:
full_name = f"{namespace}/{full_name}"
if version is not None:
full_name = f"{full_name}-v{version}"
assert full_name in gym.envs.registry.keys()
del gym.envs.registry[env_id]
@pytest.mark.parametrize(
"env_id",
[
"“Breakout-v0”",
"MyNotSoAwesomeEnv-vNone\n",
"MyNamespace///MyNotSoAwesomeEnv-vNone",
],
)
def test_register_error(env_id):
with pytest.raises(error.Error, match="Malformed environment ID"):
register(env_id)
@pytest.mark.parametrize(
"env_id_input, env_id_suggested",
[
("cartpole-v1", "CartPole"),
("blackjack-v1", "Blackjack"),
("Blackjock-v1", "Blackjack"),
("mountaincarcontinuous-v0", "MountainCarContinuous"),
("taxi-v3", "Taxi"),
("taxi-v30", "Taxi"),
("MyAwesomeNamspce/MyAwesomeVersionedEnv-v1", "MyAwesomeNamespace"),
("MyAwesomeNamspce/MyAwesomeUnversionedEnv", "MyAwesomeNamespace"),
("MyAwesomeNamespace/MyAwesomeUnversioneEnv", "MyAwesomeUnversionedEnv"),
("MyAwesomeNamespace/MyAwesomeVersioneEnv", "MyAwesomeVersionedEnv"),
],
)
def test_env_suggestions(register_some_envs, env_id_input, env_id_suggested):
with pytest.raises(
error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}` ?"
):
envs.make(env_id_input)
@pytest.mark.parametrize(
"env_id_input, suggested_versions, default_version",
[
("CartPole-v12", "`v0`, `v1`", False),
("Blackjack-v10", "`v1`", False),
("MountainCarContinuous-v100", "`v0`", False),
("Taxi-v30", "`v3`", False),
("MyAwesomeNamespace/MyAwesomeVersionedEnv-v6", "`v1`, `v3`, `v5`", False),
("MyAwesomeNamespace/MyAwesomeUnversionedEnv-v6", "", True),
],
)
def test_env_version_suggestions(
register_some_envs, env_id_input, suggested_versions, default_version
):
if default_version:
match_str = "provides the default version"
with pytest.raises(
error.DeprecatedEnv,
match=match_str,
):
envs.make(env_id_input)
else:
match_str = f"versioned environments: \\[ {suggested_versions} \\]"
with pytest.raises(
error.UnregisteredEnv,
match=match_str,
):
envs.make(env_id_input)
def test_make_with_kwargs():
env = envs.make("test.ArgumentEnv-v0", arg2="override_arg2", arg3="override_arg3")
assert env.spec.id == "test.ArgumentEnv-v0"
assert isinstance(env.unwrapped, ArgumentEnv)
assert env.arg1 == "arg1"
assert env.arg2 == "override_arg2"
assert env.arg3 == "override_arg3"
@pytest.mark.filterwarnings(
"ignore:.*The environment Humanoid-v0 is out of date. You should consider upgrading to "
"version `v3` with the environment ID `Humanoid-v3`.*"
)
def test_make_deprecated():
try:
envs.make("Humanoid-v0")
except error.Error:
pass
else:
assert False
def test_spec():
spec = envs.spec("CartPole-v1")
assert spec.id == "CartPole-v1"
def test_spec_with_kwargs():
map_name_value = "8x8"
env = gym.make("FrozenLake-v1", map_name=map_name_value)
assert env.spec.kwargs["map_name"] == map_name_value
def test_missing_lookup():
register(id="Test1-v0", entry_point=None)
register(id="Test1-v15", entry_point=None)
register(id="Test1-v9", entry_point=None)
register(id="Other1-v100", entry_point=None)
with pytest.raises(error.DeprecatedEnv):
spec("Test1-v1")
try:
spec("Test1-v1000")
except error.UnregisteredEnv:
pass
else:
assert False
try:
spec("Unknown1-v1")
except error.UnregisteredEnv:
pass
else:
assert False
def test_malformed_lookup():
try:
spec("“Breakout-v0”")
except error.Error as e:
assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}"
else:
assert False
def test_versioned_lookups():
register("test/Test2-v5")
with pytest.raises(error.VersionNotFound):
spec("test/Test2-v9")
with pytest.raises(error.DeprecatedEnv):
spec("test/Test2-v4")
assert spec("test/Test2-v5")
def test_default_lookups():
register("test/Test3")
with pytest.raises(error.DeprecatedEnv):
spec("test/Test3-v0")
# Lookup default
spec("test/Test3")
def test_register_versioned_unversioned():
# Register versioned then unversioned
versioned_env = "Test/MyEnv-v0"
register(versioned_env)
assert gym.envs.spec(versioned_env).id == versioned_env
unversioned_env = "Test/MyEnv"
with pytest.raises(error.RegistrationError):
register(unversioned_env)
# Clean everything
del gym.envs.registry[versioned_env]
# Register unversioned then versioned
register(unversioned_env)
assert gym.envs.spec(unversioned_env).id == unversioned_env
with pytest.raises(error.RegistrationError):
register(versioned_env)
# Clean everything
del gym.envs.registry[unversioned_env]
def test_return_latest_versioned_env(register_some_envs):
with pytest.warns(UserWarning):
env = envs.make("MyAwesomeNamespace/MyAwesomeVersionedEnv")
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"