Files
Gymnasium/tests/envs/test_registration.py

136 lines
3.6 KiB
Python
Raw Normal View History

2016-04-27 08:00:58 -07:00
# -*- coding: utf-8 -*-
import gym
2016-04-27 08:00:58 -07:00
from gym import error, envs
from gym.envs import registration
from gym.envs.classic_control import cartpole
from gym.envs.registration import EnvSpec, EnvSpecTree
2016-04-27 08:00:58 -07:00
2021-07-29 02:26:34 +02:00
class ArgumentEnv(gym.Env):
def __init__(self, arg1, arg2, arg3):
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
2021-07-29 02:26:34 +02:00
gym.register(
2021-07-29 02:26:34 +02:00
id="test.ArgumentEnv-v0",
entry_point="tests.envs.test_registration:ArgumentEnv",
kwargs={
2021-07-29 02:26:34 +02:00
"arg1": "arg1",
"arg2": "arg2",
},
)
2021-07-29 02:26:34 +02:00
2016-04-27 08:00:58 -07:00
def test_make():
2021-07-29 02:26:34 +02:00
env = envs.make("CartPole-v0")
assert env.spec.id == "CartPole-v0"
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
2016-04-27 08:00:58 -07:00
2021-07-29 02:26:34 +02:00
def test_make_with_kwargs():
2021-07-29 02:26:34 +02:00
env = envs.make("test.ArgumentEnv-v0", arg2="override_arg2", arg3="override_arg3")
assert env.spec.id == "test.ArgumentEnv-v0"
assert isinstance(env.unwrapped, ArgumentEnv)
2021-07-29 02:26:34 +02:00
assert env.arg1 == "arg1"
assert env.arg2 == "override_arg2"
assert env.arg3 == "override_arg3"
2016-05-06 22:26:40 -07:00
def test_make_deprecated():
try:
2021-07-29 02:26:34 +02:00
envs.make("Humanoid-v0")
2016-05-06 22:26:40 -07:00
except error.Error:
pass
else:
assert False
2021-07-29 02:26:34 +02:00
2016-04-27 08:00:58 -07:00
def test_spec():
2021-07-29 02:26:34 +02:00
spec = envs.spec("CartPole-v0")
assert spec.id == "CartPole-v0"
2016-04-27 08:00:58 -07:00
def test_spec_with_kwargs():
2021-07-29 02:26:34 +02:00
map_name_value = "8x8"
env = gym.make("FrozenLake-v1", map_name=map_name_value)
2021-07-29 02:26:34 +02:00
assert env.spec._kwargs["map_name"] == map_name_value
2016-04-27 08:00:58 -07:00
def test_missing_lookup():
registry = registration.EnvRegistry()
2021-07-29 02:26:34 +02:00
registry.register(id="Test-v0", entry_point=None)
registry.register(id="Test-v15", entry_point=None)
registry.register(id="Test-v9", entry_point=None)
registry.register(id="Other-v100", entry_point=None)
2016-04-27 08:00:58 -07:00
try:
2021-07-29 02:26:34 +02:00
registry.spec("Test-v1") # must match an env name but not the version above
except error.DeprecatedEnv:
pass
else:
assert False
try:
registry.spec("Test-v1000")
except error.UnregisteredEnv:
pass
else:
assert False
try:
2021-07-29 02:26:34 +02:00
registry.spec("Unknown-v1")
2016-04-27 08:00:58 -07:00
except error.UnregisteredEnv:
pass
else:
assert False
2021-07-29 02:26:34 +02:00
2016-04-27 08:00:58 -07:00
def test_malformed_lookup():
registry = registration.EnvRegistry()
try:
2021-07-29 02:26:34 +02:00
registry.spec(u"“Breakout-v0”")
2016-04-27 08:00:58 -07:00
except error.Error as e:
2021-07-29 15:39:42 -04:00
assert "malformed environment ID" in "{}".format(
e
), "Unexpected message: {}".format(e)
2016-04-27 08:00:58 -07:00
else:
assert False
def test_env_spec_tree():
spec_tree = EnvSpecTree()
# Add with namespace
spec = EnvSpec("test/Test-v0")
spec_tree["test/Test-v0"] = spec
assert spec_tree.tree.keys() == {"test"}
assert spec_tree.tree["test"].keys() == {"Test"}
assert spec_tree.tree["test"]["Test"].keys() == {"0"}
assert spec_tree.tree["test"]["Test"]["0"] == spec
assert spec_tree["test/Test-v0"] == spec
# Add without namespace
spec = EnvSpec("Test-v0")
spec_tree["Test-v0"] = spec
assert spec_tree.tree.keys() == {"test", None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {"0"}
assert spec_tree.tree[None]["Test"]["0"] == spec
# Delete last version deletes entire subtree
del spec_tree["test/Test-v0"]
assert spec_tree.tree.keys() == {None}
# Append second version for same name
spec_tree["Test-v1"] = EnvSpec("Test-v1")
assert spec_tree.tree.keys() == {None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {"0", "1"}
# Deleting one version leaves other
del spec_tree["Test-v0"]
assert spec_tree.tree.keys() == {None}
assert spec_tree.tree[None].keys() == {"Test"}
assert spec_tree.tree[None]["Test"].keys() == {"1"}