Files
Gymnasium/tests/wrappers/test_gray_scale_observation.py

27 lines
914 B
Python
Raw Normal View History

import pytest
import gym
from gym import spaces
2022-07-23 15:38:52 +01:00
from gym.wrappers import GrayScaleObservation
2021-07-29 02:26:34 +02:00
2022-07-23 15:38:52 +01:00
@pytest.mark.parametrize("env_id", ["CarRacing-v2"])
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("keep_dim", [True, False])
def test_gray_scale_observation(env_id, keep_dim):
2022-07-23 15:38:52 +01:00
rgb_env = gym.make(env_id, disable_env_checker=True)
assert isinstance(rgb_env.observation_space, spaces.Box)
2022-07-23 15:38:52 +01:00
assert len(rgb_env.observation_space.shape) == 3
assert rgb_env.observation_space.shape[-1] == 3
2022-07-23 15:38:52 +01:00
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
assert isinstance(wrapped_env.observation_space, spaces.Box)
if keep_dim:
2022-07-23 15:38:52 +01:00
assert len(wrapped_env.observation_space.shape) == 3
assert wrapped_env.observation_space.shape[-1] == 1
else:
assert len(wrapped_env.observation_space.shape) == 2
2022-07-23 15:38:52 +01:00
wrapped_obs = wrapped_env.reset()
assert wrapped_obs in wrapped_env.observation_space