Files
Gymnasium/tests/envs/test_registration.py
Andrea PIERRÉ e913bc81b8 Improve pre-commit workflow (#2602)
* feat: add `isort` to `pre-commit`

* ci: skip `__init__.py` file for `isort`

* ci: make `isort` mandatory in lint pipeline

* docs: add a section on Git hooks

* ci: check isort diff

* fix: isort from master branch

* docs: add pre-commit badge

* ci: update black + bandit versions

* feat: add PR template

* refactor: PR template

* ci: remove bandit

* docs: add Black badge

* ci: try to remove all `|| true` statements

* ci: remove lint_python job

- Remove `lint_python` CI job
- Move `pyupgrade` job to `pre-commit` workflow

* fix: avoid messing with typing

* docs: add a note on running `pre-cpmmit` manually

* ci: apply `pre-commit` to the whole codebase
2022-03-31 15:50:38 -04:00

328 lines
9.6 KiB
Python

import pytest
import gym
from gym import envs, error
from gym.envs import registration
from gym.envs.classic_control import cartpole
from gym.envs.registration import EnvSpec, EnvSpecTree
class ArgumentEnv(gym.Env):
def __init__(self, arg1, arg2, arg3):
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.test_registration:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
@pytest.fixture(scope="function")
def register_some_envs():
namespace = "MyAwesomeNamespace"
versioned_name = "MyAwesomeVersionedEnv"
unversioned_name = "MyAwesomeUnversionedEnv"
versions = [1, 3, 5]
for version in versions:
env_id = f"{namespace}/{versioned_name}-v{version}"
gym.register(
id=env_id,
entry_point="tests.envs.test_registration:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
"arg3": "arg3",
},
)
gym.register(
id=f"{namespace}/{unversioned_name}",
entry_point="tests.env.test_registration:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
"arg3": "arg3",
},
)
yield
for version in versions:
env_id = f"{namespace}/{versioned_name}-v{version}"
del gym.envs.registry.env_specs[env_id]
del gym.envs.registry.env_specs[f"{namespace}/{unversioned_name}"]
def test_make():
env = envs.make("CartPole-v1")
assert env.spec.id == "CartPole-v1"
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
@pytest.mark.parametrize(
"env_id, namespace, name, version",
[
(
"MyAwesomeNamespace/MyAwesomeEnv-v0",
"MyAwesomeNamespace",
"MyAwesomeEnv",
0,
),
("MyAwesomeEnv-v0", None, "MyAwesomeEnv", 0),
("MyAwesomeEnv", None, "MyAwesomeEnv", None),
("MyAwesomeEnv-vfinal-v0", None, "MyAwesomeEnv-vfinal", 0),
("MyAwesomeEnv-vfinal", None, "MyAwesomeEnv-vfinal", None),
("MyAwesomeEnv--", None, "MyAwesomeEnv--", None),
("MyAwesomeEnv-v", None, "MyAwesomeEnv-v", None),
],
)
def test_register(env_id, namespace, name, version):
envs.register(env_id)
assert gym.envs.spec(env_id).id == env_id
assert version in gym.envs.registry.env_specs.tree[namespace][name].keys()
del gym.envs.registry.env_specs[env_id]
@pytest.mark.parametrize(
"env_id",
[
"“Breakout-v0”",
"MyNotSoAwesomeEnv-vNone\n",
"MyNamespace///MyNotSoAwesomeEnv-vNone",
],
)
def test_register_error(env_id):
with pytest.raises(error.Error, match="Malformed environment ID"):
envs.register(env_id)
@pytest.mark.parametrize(
"env_id_input, env_id_suggested",
[
("cartpole-v1", "CartPole"),
("blackjack-v1", "Blackjack"),
("Blackjock-v1", "Blackjack"),
("mountaincarcontinuous-v0", "MountainCarContinuous"),
("taxi-v3", "Taxi"),
("taxi-v30", "Taxi"),
("MyAwesomeNamspce/MyAwesomeVersionedEnv-v1", "MyAwesomeNamespace"),
("MyAwesomeNamspce/MyAwesomeUnversionedEnv", "MyAwesomeNamespace"),
("MyAwesomeNamespace/MyAwesomeUnversioneEnv", "MyAwesomeUnversionedEnv"),
("MyAwesomeNamespace/MyAwesomeVersioneEnv", "MyAwesomeVersionedEnv"),
],
)
def test_env_suggestions(register_some_envs, env_id_input, env_id_suggested):
with pytest.raises(
error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}` ?"
):
envs.make(env_id_input)
@pytest.mark.parametrize(
"env_id_input, suggested_versions, default_version",
[
("CartPole-v12", "`v0`, `v1`", False),
("Blackjack-v10", "`v1`", False),
("MountainCarContinuous-v100", "`v0`", False),
("Taxi-v30", "`v3`", False),
("MyAwesomeNamespace/MyAwesomeVersionedEnv-v6", "`v1`, `v3`, `v5`", False),
("MyAwesomeNamespace/MyAwesomeUnversionedEnv-v6", "", True),
],
)
def test_env_version_suggestions(
register_some_envs, env_id_input, suggested_versions, default_version
):
if default_version:
match_str = "provides the default version"
with pytest.raises(
error.DeprecatedEnv,
match=match_str,
):
envs.make(env_id_input)
else:
match_str = f"versioned environments: \\[ {suggested_versions} \\]"
with pytest.raises(
error.UnregisteredEnv,
match=match_str,
):
envs.make(env_id_input)
def test_make_with_kwargs():
env = envs.make("test.ArgumentEnv-v0", arg2="override_arg2", arg3="override_arg3")
assert env.spec.id == "test.ArgumentEnv-v0"
assert isinstance(env.unwrapped, ArgumentEnv)
assert env.arg1 == "arg1"
assert env.arg2 == "override_arg2"
assert env.arg3 == "override_arg3"
@pytest.mark.filterwarnings(
"ignore:.*The environment Humanoid-v0 is out of date. You should consider upgrading to "
"version `v3` with the environment ID `Humanoid-v3`.*"
)
def test_make_deprecated():
try:
envs.make("Humanoid-v0")
except error.Error:
pass
else:
assert False
def test_spec():
spec = envs.spec("CartPole-v1")
assert spec.id == "CartPole-v1"
def test_spec_with_kwargs():
map_name_value = "8x8"
env = gym.make("FrozenLake-v1", map_name=map_name_value)
assert env.spec.kwargs["map_name"] == map_name_value
def test_missing_lookup():
registry = registration.EnvRegistry()
registry.register(id="Test-v0", entry_point=None)
registry.register(id="Test-v15", entry_point=None)
registry.register(id="Test-v9", entry_point=None)
registry.register(id="Other-v100", entry_point=None)
try:
registry.spec("Test-v1") # must match an env name but not the version above
except error.DeprecatedEnv:
pass
else:
assert False
try:
registry.spec("Test-v1000")
except error.UnregisteredEnv:
pass
else:
assert False
try:
registry.spec("Unknown-v1")
except error.UnregisteredEnv:
pass
else:
assert False
def test_malformed_lookup():
registry = registration.EnvRegistry()
try:
registry.spec("“Breakout-v0”")
except error.Error as e:
assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}"
else:
assert False
def test_versioned_lookups():
registry = registration.EnvRegistry()
registry.register("test/Test-v5")
with pytest.raises(error.VersionNotFound):
registry.spec("test/Test-v9")
with pytest.raises(error.DeprecatedEnv):
registry.spec("test/Test-v4")
assert registry.spec("test/Test-v5")
def test_default_lookups():
registry = registration.EnvRegistry()
registry.register("test/Test")
with pytest.raises(error.DeprecatedEnv):
registry.spec("test/Test-v0")
# Lookup default
registry.spec("test/Test")
def test_env_spec_tree():
spec_tree = EnvSpecTree()
# Add with namespace
spec = EnvSpec("test/Test-v0")
spec_tree["test/Test-v0"] = spec
assert spec_tree.tree.keys() == {"test"}
assert spec_tree.tree["test"].keys() == {"Test"}
assert spec_tree.tree["test"]["Test"].keys() == {0}
assert spec_tree.tree["test"]["Test"][0] == spec
assert spec_tree["test/Test-v0"] == spec
# Add without namespace
spec = EnvSpec("Test-v0")
spec_tree["Test-v0"] = spec
assert spec_tree.tree.keys() == {"test", None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {0}
assert spec_tree.tree[None]["Test"][0] == spec
# Delete last version deletes entire subtree
del spec_tree["test/Test-v0"]
assert spec_tree.tree.keys() == {None}
# Append second version for same name
spec_tree["Test-v1"] = EnvSpec("Test-v1")
assert spec_tree.tree.keys() == {None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {0, 1}
# Deleting one version leaves other
del spec_tree["Test-v0"]
assert spec_tree.tree.keys() == {None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {1}
# Add without version
myenv = "MyAwesomeEnv"
spec = EnvSpec(myenv)
spec_tree[myenv] = spec
assert spec_tree.tree.keys() == {None}
assert myenv in spec_tree.tree[None].keys()
assert spec_tree.tree[None][myenv].keys() == {None}
assert spec_tree.tree[None][myenv][None] == spec
assert spec_tree.__repr__() == "├──Test: [ v1 ]\n" + f"└──{myenv}: [ ]\n"
def test_register_versioned_unversioned():
# Register versioned then unversioned
versioned_env = "Test/MyEnv-v0"
envs.register(versioned_env)
assert gym.envs.spec(versioned_env).id == versioned_env
unversioned_env = "Test/MyEnv"
with pytest.raises(error.RegistrationError):
envs.register(unversioned_env)
# Clean everything
del gym.envs.registry.env_specs[versioned_env]
# Register unversioned then versioned
with pytest.warns(UserWarning):
envs.register(unversioned_env)
assert gym.envs.spec(unversioned_env).id == unversioned_env
with pytest.raises(error.RegistrationError):
envs.register(versioned_env)
# Clean everything
envs_list = [versioned_env, unversioned_env]
for env in envs_list:
del gym.envs.registry.env_specs[env]
def test_return_latest_versioned_env(register_some_envs):
with pytest.warns(UserWarning):
env = envs.make("MyAwesomeNamespace/MyAwesomeVersionedEnv")
assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5"