Files
Gymnasium/tests/wrappers/test_resize_observation.py
Ariel Kwiatkowski 947b857bd4 Test refactoring (#2427)
* Move tests to root with automatic PyCharm import refactoring. This will likely fail some tests

* Changed entry point for a registration test env.

* Move a stray lunar_lander test to tests/envs/...

* black

* Change the version from which importlib_metadata is replaced with importlib.metadata. Also requiring installing importlib_metadata for python 3.8 now.

???????????

* Undo last commit
2021-09-28 19:53:30 -04:00

25 lines
712 B
Python

import pytest
import gym
from gym.wrappers import ResizeObservation
pytest.importorskip("gym.envs.atari")
@pytest.mark.parametrize(
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
def test_resize_observation(env_id, shape):
env = gym.make(env_id)
env = ResizeObservation(env, shape)
assert env.observation_space.shape[-1] == 3
obs = env.reset()
if isinstance(shape, int):
assert env.observation_space.shape[:2] == (shape, shape)
assert obs.shape == (shape, shape, 3)
else:
assert env.observation_space.shape[:2] == tuple(shape)
assert obs.shape == tuple(shape) + (3,)