Files
Gymnasium/tests/wrappers/test_resize_observation.py

27 lines
818 B
Python
Raw Normal View History

import pytest
import gym
from gym import spaces
from gym.wrappers import ResizeObservation
2021-07-29 02:26:34 +02:00
pytest.importorskip("gym.envs.atari")
2021-07-29 15:39:42 -04:00
@pytest.mark.parametrize(
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
2021-07-29 02:26:34 +02:00
@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
def test_resize_observation(env_id, shape):
2022-06-16 14:29:13 +01:00
env = gym.make(env_id, disable_env_checker=True)
env = ResizeObservation(env, shape)
assert isinstance(env.observation_space, spaces.Box)
assert env.observation_space.shape[-1] == 3
obs = env.reset()
if isinstance(shape, int):
assert env.observation_space.shape[:2] == (shape, shape)
assert obs.shape == (shape, shape, 3)
else:
2019-07-13 06:10:11 +08:00
assert env.observation_space.shape[:2] == tuple(shape)
assert obs.shape == tuple(shape) + (3,)