2019-06-08 01:01:35 +02:00
|
|
|
import pytest
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
import gymnasium as gym
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium import spaces
|
|
|
|
from gymnasium.wrappers import ResizeObservation
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-06-08 01:01:35 +02:00
|
|
|
|
2022-07-23 15:38:52 +01:00
|
|
|
@pytest.mark.parametrize("env_id", ["CarRacing-v2"])
|
2021-07-29 02:26:34 +02:00
|
|
|
@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
|
2019-06-08 01:01:35 +02:00
|
|
|
def test_resize_observation(env_id, shape):
|
2022-09-16 23:41:27 +01:00
|
|
|
env = gym.make(env_id, disable_env_checker=True)
|
2019-06-08 01:01:35 +02:00
|
|
|
env = ResizeObservation(env, shape)
|
|
|
|
|
2022-07-04 18:19:25 +01:00
|
|
|
assert isinstance(env.observation_space, spaces.Box)
|
2019-06-08 01:01:35 +02:00
|
|
|
assert env.observation_space.shape[-1] == 3
|
2022-08-23 11:09:54 -04:00
|
|
|
obs, _ = env.reset()
|
2019-06-08 01:01:35 +02:00
|
|
|
if isinstance(shape, int):
|
|
|
|
assert env.observation_space.shape[:2] == (shape, shape)
|
|
|
|
assert obs.shape == (shape, shape, 3)
|
|
|
|
else:
|
2019-07-13 06:10:11 +08:00
|
|
|
assert env.observation_space.shape[:2] == tuple(shape)
|
|
|
|
assert obs.shape == tuple(shape) + (3,)
|