mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
"""Test suite for RenderObservation wrapper."""
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gymnasium import spaces
|
|
from gymnasium.wrappers import AddRenderObservation
|
|
from tests.testing_env import GenericTestEnv
|
|
|
|
|
|
STATE_KEY = "state"
|
|
|
|
|
|
def image_render_func(self):
|
|
return np.zeros((32, 32, 3), dtype=np.uint8)
|
|
|
|
|
|
@pytest.mark.parametrize("pixels_only", (True, False))
|
|
def test_dict_observation(pixels_only, pixel_key="rgb"):
|
|
env = GenericTestEnv(
|
|
observation_space=spaces.Dict(
|
|
state=spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
|
|
),
|
|
render_mode="rgb_array",
|
|
render_func=image_render_func,
|
|
)
|
|
|
|
# Make sure we are testing the right environment for the test.
|
|
assert isinstance(env.observation_space, spaces.Dict)
|
|
|
|
# width, height = (320, 240)
|
|
|
|
# The wrapper should only add one observation.
|
|
wrapped_env = AddRenderObservation(
|
|
env,
|
|
render_key=pixel_key,
|
|
render_only=pixels_only,
|
|
# render_kwargs={pixel_key: {"width": width, "height": height}},
|
|
)
|
|
obs, info = wrapped_env.reset()
|
|
if pixels_only:
|
|
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
|
assert isinstance(obs, np.ndarray)
|
|
|
|
rendered_obs = obs
|
|
else:
|
|
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
|
|
|
expected_keys = [pixel_key] + list(env.observation_space.spaces.keys())
|
|
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
|
|
|
assert isinstance(obs, dict)
|
|
rendered_obs = obs[pixel_key]
|
|
|
|
# Check that the added space item is consistent with the added observation.
|
|
# assert rendered_obs.shape == (height, width, 3)
|
|
assert rendered_obs.ndim == 3
|
|
assert rendered_obs.dtype == np.uint8
|
|
|
|
|
|
@pytest.mark.parametrize("pixels_only", (True, False))
|
|
def test_single_array_observation(pixels_only):
|
|
pixel_key = "depth"
|
|
|
|
env = GenericTestEnv(
|
|
observation_space=spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32),
|
|
render_mode="rgb_array",
|
|
render_func=image_render_func,
|
|
)
|
|
assert isinstance(env.observation_space, spaces.Box)
|
|
|
|
# The wrapper should only add one observation.
|
|
wrapped_env = AddRenderObservation(
|
|
env,
|
|
render_key=pixel_key,
|
|
render_only=pixels_only,
|
|
# render_kwargs={pixel_key: {"width": width, "height": height}},
|
|
)
|
|
obs, info = wrapped_env.reset()
|
|
if pixels_only:
|
|
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
|
assert isinstance(obs, np.ndarray)
|
|
|
|
rendered_obs = obs
|
|
else:
|
|
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
|
|
|
expected_keys = [pixel_key, "state"]
|
|
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
|
|
|
assert isinstance(obs, dict)
|
|
rendered_obs = obs[pixel_key]
|
|
|
|
# Check that the added space item is consistent with the added observation.
|
|
# assert rendered_obs.shape == (height, width, 3)
|
|
assert rendered_obs.ndim == 3
|
|
assert rendered_obs.dtype == np.uint8
|