Files
Gymnasium/tests/envs/test_compatibility.py

128 lines
3.8 KiB
Python
Raw Normal View History

from typing import Any, Dict, Optional, Tuple
import numpy as np
import gymnasium
2022-09-08 10:10:07 +01:00
from gymnasium.spaces import Discrete
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv
class LegacyEnvExplicit(LegacyEnv, gymnasium.Env):
"""Legacy env that explicitly implements the old API."""
observation_space = Discrete(1)
action_space = Discrete(1)
metadata = {"render.modes": ["human", "rgb_array"]}
def __init__(self):
pass
def reset(self):
return 0
def step(self, action):
return 0, 0, False, {}
def render(self, mode="human"):
if mode == "human":
return
elif mode == "rgb_array":
return np.zeros((1, 1, 3), dtype=np.uint8)
def close(self):
pass
def seed(self, seed=None):
pass
class LegacyEnvImplicit(gymnasium.Env):
"""Legacy env that implicitly implements the old API as a protocol."""
observation_space = Discrete(1)
action_space = Discrete(1)
metadata = {"render.modes": ["human", "rgb_array"]}
def __init__(self):
pass
def reset(self): # type: ignore
return 0 # type: ignore
def step(self, action: Any) -> Tuple[int, float, bool, Dict]:
return 0, 0.0, False, {}
def render(self, mode: Optional[str] = "human") -> Any:
if mode == "human":
return
elif mode == "rgb_array":
return np.zeros((1, 1, 3), dtype=np.uint8)
def close(self):
pass
def seed(self, seed: Optional[int] = None):
pass
def test_explicit():
old_env = LegacyEnvExplicit()
assert isinstance(old_env, LegacyEnv)
env = EnvCompatibility(old_env, render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
assert env.render().shape == (1, 1, 3)
env.close()
def test_implicit():
old_env = LegacyEnvImplicit()
2022-10-25 14:41:09 +01:00
assert isinstance(old_env, LegacyEnv)
env = EnvCompatibility(old_env, render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
assert env.render().shape == (1, 1, 3)
env.close()
def test_make_compatibility_in_spec():
gymnasium.register(
id="LegacyTestEnv-v0",
entry_point=LegacyEnvExplicit,
apply_api_compatibility=True,
)
env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
img = env.render()
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
def test_make_compatibility_in_make():
gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gymnasium.make(
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
)
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
img = env.render()
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]