Files
Gymnasium/tests/test_core.py
2025-01-17 22:46:45 +00:00

269 lines
8.7 KiB
Python

"""Checks that the core Gymnasium API is implemented as expected."""
from __future__ import annotations
import re
from typing import Any, SupportsFloat
import numpy as np
import pytest
import gymnasium as gym
from gymnasium import ActionWrapper, Env, ObservationWrapper, RewardWrapper, Wrapper
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.spaces import Box
from gymnasium.utils import seeding
from gymnasium.utils.seeding import np_random
from gymnasium.wrappers import OrderEnforcing
from tests.testing_env import GenericTestEnv
class ExampleEnv(Env):
"""Example testing environment."""
def __init__(self):
"""Constructor for example environment."""
self.observation_space = Box(0, 1)
self.action_space = Box(0, 1)
def step(
self, action: ActType
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the environment."""
return 0, 0, False, False, {}
def reset(
self,
*,
seed: int | None = None,
options: dict | None = None,
) -> tuple[ObsType, dict]:
"""Resets the environment."""
super().reset(seed=seed, options=options)
return 0, {}
@pytest.fixture
def example_env():
return ExampleEnv()
def test_example_env(example_env):
"""Tests a gymnasium environment."""
assert example_env.metadata == {"render_modes": []}
assert example_env.render_mode is None
assert example_env.spec is None
assert example_env._np_random is None # pyright: ignore [reportPrivateUsage]
class ExampleWrapper(Wrapper):
"""An example testing wrapper."""
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor that sets the reward."""
super().__init__(env)
self.new_reward = 3
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment ."""
return super().reset(seed=seed, options=options)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the environment."""
obs, reward, termination, truncation, info = self.env.step(action)
return obs, self.new_reward, termination, truncation, info
def access_hidden_np_random(self):
"""This should raise an error when called as wrappers should not access their own `_np_random` instances and should use the unwrapped environments."""
return self._np_random
def test_example_wrapper(example_env):
"""Tests the gymnasium wrapper works as expected."""
env = example_env
wrapper_env = ExampleWrapper(env)
assert env.metadata == wrapper_env.metadata
wrapper_env.metadata = {"render_modes": ["rgb_array"]}
assert env.metadata != wrapper_env.metadata
assert env.render_mode == wrapper_env.render_mode
assert env.spec == wrapper_env.spec
env.observation_space = Box(0, 1)
env.action_space = Box(0, 1)
assert env.observation_space == wrapper_env.observation_space
assert env.action_space == wrapper_env.action_space
wrapper_env.observation_space = Box(1, 2)
wrapper_env.action_space = Box(1, 2)
assert env.observation_space != wrapper_env.observation_space
assert env.action_space != wrapper_env.action_space
wrapper_env.np_random, _ = seeding.np_random()
assert (
env._np_random # pyright: ignore [reportPrivateUsage]
is env.np_random
is wrapper_env.np_random
)
assert 0 <= wrapper_env.np_random.uniform() <= 1
with pytest.raises(
AttributeError,
match=re.escape(
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
),
):
_ = wrapper_env.access_hidden_np_random()
class ExampleRewardWrapper(RewardWrapper):
"""Example reward wrapper for testing."""
def reward(self, reward: SupportsFloat) -> SupportsFloat:
"""Reward function."""
return 1
class ExampleObservationWrapper(ObservationWrapper):
"""Example observation wrapper for testing."""
def observation(self, observation: ObsType) -> ObsType:
"""Observation function."""
return np.array([1])
class ExampleActionWrapper(ActionWrapper):
"""Example action wrapper for testing."""
def action(self, action: ActType) -> ActType:
"""Action function."""
return np.array([1])
def test_reward_observation_action_wrapper():
"""Tests the observation, action and reward wrapper examples."""
env = GenericTestEnv()
reward_env = ExampleRewardWrapper(env)
reward_env.reset()
_, reward, _, _, _ = reward_env.step(0)
assert reward == 1
observation_env = ExampleObservationWrapper(env)
obs, _ = observation_env.reset()
assert obs == np.array([1])
obs, _, _, _, _ = observation_env.step(0)
assert obs == np.array([1])
env = GenericTestEnv(step_func=lambda self, action: (action, 0, False, False, {}))
action_env = ExampleActionWrapper(env)
obs, _, _, _, _ = action_env.step(0)
assert obs == np.array([1])
def test_get_set_wrapper_attr():
env = gym.make("CartPole-v1")
assert env is not env.unwrapped
# Test get_wrapper_attr
with pytest.raises(AttributeError):
env.gravity
assert env.unwrapped.gravity is not None
assert env.has_wrapper_attr("gravity")
assert env.get_wrapper_attr("gravity") is not None
with pytest.raises(AttributeError):
env.unknown_attr
assert env.has_wrapper_attr("unknown_attr") is False
with pytest.raises(AttributeError):
env.get_wrapper_attr("unknown_attr")
# Test set_wrapper_attr
env.set_wrapper_attr("gravity", 10.0)
with pytest.raises(AttributeError):
env.gravity # checks the top level wrapper hasn't been updated
assert env.unwrapped.gravity == 10.0
assert env.get_wrapper_attr("gravity") == 10.0
env.gravity = 5.0
assert env.gravity == 5.0
assert env.get_wrapper_attr("gravity") == 5.0
assert env.env.get_wrapper_attr("gravity") == 10.0
# Test with OrderEnforcing (intermediate wrapper)
assert not isinstance(env, OrderEnforcing)
# show that the base and top level objects don't contain the attribute
with pytest.raises(AttributeError):
env._disable_render_order_enforcing
with pytest.raises(AttributeError):
env.unwrapped._disable_render_order_enforcing
assert env.has_wrapper_attr("_disable_render_order_enforcing")
assert env.get_wrapper_attr("_disable_render_order_enforcing") is False
env.set_wrapper_attr("_disable_render_order_enforcing", True)
with pytest.raises(AttributeError):
env._disable_render_order_enforcing
with pytest.raises(AttributeError):
env.unwrapped._disable_render_order_enforcing
assert env.get_wrapper_attr("_disable_render_order_enforcing") is True
# Test with top-most wrapper
env.MY_ATTRIBUTE_1 = True
assert env.get_wrapper_attr("MY_ATTRIBUTE_1") is True
env.set_wrapper_attr("MY_ATTRIBUTE_1", False)
assert env.get_wrapper_attr("MY_ATTRIBUTE_1") is False
# Test with non-existing attribute
env.set_wrapper_attr("MY_ATTRIBUTE_2", True)
assert getattr(env, "MY_ATTRIBUTE_2") is True
class TestRandomSeeding:
@staticmethod
def test_nonempty_seed_retrieved_when_not_set(example_env):
assert example_env.np_random_seed is not None
assert isinstance(example_env.np_random_seed, int)
@staticmethod
def test_seed_set_at_reset_and_retrieved(example_env):
seed = 42
example_env.reset(seed=seed)
assert example_env.np_random_seed == seed
# resetting with seed=None means seed remains the same
example_env.reset(seed=None)
assert example_env.np_random_seed == seed
@staticmethod
def test_seed_cannot_be_set_directly(example_env):
with pytest.raises(AttributeError):
example_env.np_random_seed = 42
@staticmethod
def test_negative_seed_retrieved_when_seed_unknown(example_env):
rng, _ = np_random()
example_env.np_random = rng
# seed is unknown
assert example_env.np_random_seed == -1
@staticmethod
def test_seeding_works_in_wrapped_envs(example_env):
seed = 42
wrapper_env = ExampleWrapper(example_env)
wrapper_env.reset(seed=seed)
assert wrapper_env.np_random_seed == seed
# resetting with seed=None means seed remains the same
wrapper_env.reset(seed=None)
assert wrapper_env.np_random_seed == seed
# setting np_random directly makes seed unknown
rng, _ = np_random()
wrapper_env.np_random = rng
assert wrapper_env.np_random_seed == -1