Files
Gymnasium/tests/envs/registration/test_env_spec.py
2024-06-10 17:07:47 +01:00

243 lines
7.1 KiB
Python

"""Test for the `EnvSpec`, in particular, a full integration with `EnvSpec`."""
from __future__ import annotations
import re
from typing import Any
import dill as pickle
import pytest
import gymnasium as gym
from gymnasium.core import ObsType
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils.env_checker import check_env, data_equivalence
def test_full_integration():
# Create an environment to test with
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
# Generate the spec_stack
env_spec = env.spec
assert isinstance(env_spec, EnvSpec)
# additional_wrappers = (TimeAwareObservation, NormalizeReward)
assert len(env_spec.additional_wrappers) == 2
# env_spec.pprint()
# Serialize the spec_stack
env_spec_json = env_spec.to_json()
assert isinstance(env_spec_json, str)
# Deserialize the spec_stack
recreate_env_spec = EnvSpec.from_json(env_spec_json)
# recreate_env_spec.pprint()
assert env_spec.additional_wrappers == recreate_env_spec.additional_wrappers
assert recreate_env_spec == env_spec
# Recreate the environment using the spec_stack
recreated_env = gym.make(recreate_env_spec)
assert recreated_env.render_mode == "rgb_array"
assert isinstance(recreated_env, gym.wrappers.NormalizeReward)
assert recreated_env.gamma == 0.8
assert isinstance(recreated_env.env, gym.wrappers.TimeAwareObservation)
assert isinstance(recreated_env.unwrapped, CartPoleEnv)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
# Test the pprint of the spec_stack
spec_stack_output = env_spec.pprint(disable_print=True)
json_spec_stack_output = env_spec.pprint(disable_print=True)
assert spec_stack_output == json_spec_stack_output
@pytest.mark.parametrize(
"env_spec",
[
gym.spec("CartPole-v1"),
gym.make("CartPole-v1").unwrapped.spec,
gym.make("CartPole-v1").spec,
gym.wrappers.NormalizeReward(gym.make("CartPole-v1")).spec,
],
)
def test_env_spec_to_from_json(env_spec: EnvSpec):
json_spec = env_spec.to_json()
recreated_env_spec = EnvSpec.from_json(json_spec)
assert env_spec == recreated_env_spec
def test_pickling_env_stack():
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
pickled_env = pickle.loads(pickle.dumps(env))
obs, info = env.reset(seed=123)
pickled_obs, pickled_info = pickled_env.reset(seed=123)
assert data_equivalence(obs, pickled_obs)
assert data_equivalence(info, pickled_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
pickled_obs,
pickled_reward,
pickled_terminated,
pickled_truncated,
pickled_info,
) = pickled_env.step(action)
assert data_equivalence(obs, pickled_obs)
assert data_equivalence(reward, pickled_reward)
assert data_equivalence(terminated, pickled_terminated)
assert data_equivalence(truncated, pickled_truncated)
assert data_equivalence(info, pickled_info)
env.close()
pickled_env.close()
# flake8: noqa
def test_env_spec_pprint():
env = gym.make("CartPole-v1")
env = gym.wrappers.TimeAwareObservation(env)
env_spec = env.spec
assert env_spec is not None
output = env_spec.pprint(disable_print=True)
assert (
output
== """id=CartPole-v1
reward_threshold=475.0
max_episode_steps=500
additional_wrappers=[
name=TimeAwareObservation, kwargs={'flatten': True, 'normalize_time': False, 'dict_time_key': 'time'}
]"""
)
output = env_spec.pprint(disable_print=True, include_entry_points=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
max_episode_steps=500
additional_wrappers=[
name=TimeAwareObservation, entry_point=gymnasium.wrappers.stateful_observation:TimeAwareObservation, kwargs={'flatten': True, 'normalize_time': False, 'dict_time_key': 'time'}
]"""
)
output = env_spec.pprint(disable_print=True, print_all=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
nondeterministic=False
max_episode_steps=500
order_enforce=True
disable_env_checker=False
additional_wrappers=[
name=TimeAwareObservation, kwargs={'flatten': True, 'normalize_time': False, 'dict_time_key': 'time'}
]"""
)
env_spec.additional_wrappers = ()
output = env_spec.pprint(disable_print=True)
assert (
output
== """id=CartPole-v1
reward_threshold=475.0
max_episode_steps=500"""
)
output = env_spec.pprint(disable_print=True, print_all=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
nondeterministic=False
max_episode_steps=500
order_enforce=True
disable_env_checker=False
additional_wrappers=[]"""
)
class Unpickleable:
def __getstate__(self):
raise RuntimeError("Cannot pickle me!")
class EnvWithUnpickleableObj(gym.Env):
def __init__(self, unpickleable_obj):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Discrete(2)
self.unpickleable_obj = unpickleable_obj
def step(self, action):
return self.observation_space.sample(), 0, False, False, {}
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
super().reset(seed=seed, options=options)
if seed is not None:
self.observation_space.seed(seed)
return self.observation_space.sample(), {}
def test_spec_with_unpickleable_object():
gym.register(
id="TestEnv-v0",
entry_point=EnvWithUnpickleableObj,
kwargs={},
)
env = gym.make("TestEnv-v0", unpickleable_obj=Unpickleable())
with pytest.warns(
UserWarning,
match=re.escape(
"An exception occurred (Cannot pickle me!) while copying the environment spec="
),
):
env.spec
check_env(env, skip_render_check=True)
env.close()
del gym.registry["TestEnv-v0"]