2022-03-31 12:50:38 -07:00
|
|
|
import pytest
|
2019-08-23 23:47:07 +02:00
|
|
|
|
2022-09-08 10:10:07 +01:00
|
|
|
import gymnasium
|
|
|
|
from gymnasium import spaces
|
|
|
|
from gymnasium.wrappers import GrayScaleObservation
|
2019-08-23 23:47:07 +02:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-07-23 15:38:52 +01:00
|
|
|
@pytest.mark.parametrize("env_id", ["CarRacing-v2"])
|
2021-07-29 02:26:34 +02:00
|
|
|
@pytest.mark.parametrize("keep_dim", [True, False])
|
2019-08-23 23:47:07 +02:00
|
|
|
def test_gray_scale_observation(env_id, keep_dim):
|
2022-09-08 10:10:07 +01:00
|
|
|
rgb_env = gymnasium.make(env_id, disable_env_checker=True)
|
2022-07-04 18:19:25 +01:00
|
|
|
|
|
|
|
assert isinstance(rgb_env.observation_space, spaces.Box)
|
2022-07-23 15:38:52 +01:00
|
|
|
assert len(rgb_env.observation_space.shape) == 3
|
2019-08-23 23:47:07 +02:00
|
|
|
assert rgb_env.observation_space.shape[-1] == 3
|
|
|
|
|
2022-07-23 15:38:52 +01:00
|
|
|
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
|
2022-07-04 18:19:25 +01:00
|
|
|
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
2019-08-23 23:47:07 +02:00
|
|
|
if keep_dim:
|
2022-07-23 15:38:52 +01:00
|
|
|
assert len(wrapped_env.observation_space.shape) == 3
|
2019-08-23 23:47:07 +02:00
|
|
|
assert wrapped_env.observation_space.shape[-1] == 1
|
|
|
|
else:
|
|
|
|
assert len(wrapped_env.observation_space.shape) == 2
|
|
|
|
|
2022-08-23 11:09:54 -04:00
|
|
|
wrapped_obs, info = wrapped_env.reset()
|
2022-07-23 15:38:52 +01:00
|
|
|
assert wrapped_obs in wrapped_env.observation_space
|