Files
Gymnasium/tests/wrappers/test_atari_preprocessing.py

87 lines
2.3 KiB
Python
Raw Normal View History

"""Test suite for AtariProcessing wrapper."""
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing
pytest.importorskip("ale_py")
2022-07-23 15:38:52 +01:00
@pytest.mark.parametrize(
"env, expected_obs_shape",
2022-07-23 15:38:52 +01:00
[
(gym.make("ALE/Pong-v5"), (210, 160, 3)),
2022-07-23 15:38:52 +01:00
(
AtariPreprocessing(
gym.make("ALE/Pong-v5"),
2022-07-23 15:38:52 +01:00
screen_size=84,
grayscale_obs=True,
frame_skip=1,
noop_max=0,
),
(84, 84),
),
(
AtariPreprocessing(
gym.make("ALE/Pong-v5"),
2022-07-23 15:38:52 +01:00
screen_size=84,
grayscale_obs=False,
frame_skip=1,
noop_max=0,
),
(84, 84, 3),
),
(
AtariPreprocessing(
gym.make("ALE/Pong-v5"),
2022-07-23 15:38:52 +01:00
screen_size=84,
grayscale_obs=True,
frame_skip=1,
noop_max=0,
grayscale_newaxis=True,
),
(84, 84, 1),
),
],
)
def test_atari_preprocessing_grayscale(env, expected_obs_shape):
assert env.observation_space.shape == expected_obs_shape
obs, _ = env.reset(seed=0)
2022-07-23 15:38:52 +01:00
assert obs in env.observation_space
obs, _, _, _, _ = env.step(env.action_space.sample())
2022-07-23 15:38:52 +01:00
assert obs in env.observation_space
env.close()
@pytest.mark.parametrize("grayscale", [True, False])
@pytest.mark.parametrize("scaled", [True, False])
def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
# arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range
env = AtariPreprocessing(
gym.make("ALE/Pong-v5"),
2021-07-29 02:26:34 +02:00
screen_size=84,
2022-07-23 15:38:52 +01:00
grayscale_obs=grayscale,
scale_obs=scaled,
2021-07-29 02:26:34 +02:00
frame_skip=1,
noop_max=0,
)
obs, _ = env.reset()
2022-07-23 15:38:52 +01:00
max_obs = 1 if scaled else 255
assert np.all(0 <= obs) and np.all(obs <= max_obs)
terminated, truncated, step_i = False, False, 0
while not (terminated or truncated) and step_i <= max_test_steps:
obs, _, terminated, truncated, _ = env.step(env.action_space.sample())
2022-07-23 15:38:52 +01:00
assert np.all(0 <= obs) and np.all(obs <= max_obs)
step_i += 1
env.close()