Files
Gymnasium/tests/wrappers/test_atari_preprocessing.py
2022-12-05 19:14:56 +00:00

127 lines
4.0 KiB
Python

import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility
from tests.testing_env import GenericTestEnv, old_step_func
class AleTesting:
"""A testing implementation for the ALE object in atari games."""
grayscale_obs_space = Box(low=0, high=255, shape=(210, 160), dtype=np.uint8, seed=1)
rgb_obs_space = Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1)
def lives(self) -> int:
"""Returns the number of lives in the atari game."""
return 1
def getScreenGrayscale(self, buffer: np.ndarray):
"""Updates the buffer with a random grayscale observation."""
buffer[...] = self.grayscale_obs_space.sample()
def getScreenRGB(self, buffer: np.ndarray):
"""Updates the buffer with a random rgb observation."""
buffer[...] = self.rgb_obs_space.sample()
class AtariTestingEnv(GenericTestEnv):
"""A testing environment to replicate the atari (ale-py) environments."""
def __init__(self):
super().__init__(
observation_space=Box(
low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1
),
action_space=Discrete(3, seed=1),
step_func=old_step_func,
)
self.ale = AleTesting()
def get_action_meanings(self):
"""Returns the meanings of each of the actions available to the agent. First index must be 'NOOP'."""
return ["NOOP", "UP", "DOWN"]
@pytest.mark.parametrize(
"env, obs_shape",
[
(AtariTestingEnv(), (210, 160, 3)),
(
AtariPreprocessing(
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
screen_size=84,
grayscale_obs=True,
frame_skip=1,
noop_max=0,
),
(84, 84),
),
(
AtariPreprocessing(
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
screen_size=84,
grayscale_obs=False,
frame_skip=1,
noop_max=0,
),
(84, 84, 3),
),
(
AtariPreprocessing(
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
screen_size=84,
grayscale_obs=True,
frame_skip=1,
noop_max=0,
grayscale_newaxis=True,
),
(84, 84, 1),
),
],
)
def test_atari_preprocessing_grayscale(env, obs_shape):
assert env.observation_space.shape == obs_shape
# It is not possible to test the outputs as we are not using actual observations.
# todo: update when ale-py is compatible with the ci
env = StepAPICompatibility(
env, output_truncation_bool=True
) # using compatibility wrapper since ale-py uses old step API
obs, _ = env.reset(seed=0)
assert obs in env.observation_space
obs, _, _, _, _ = env.step(env.action_space.sample())
assert obs in env.observation_space
env.close()
@pytest.mark.parametrize("grayscale", [True, False])
@pytest.mark.parametrize("scaled", [True, False])
def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
# arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range
env = AtariPreprocessing(
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
screen_size=84,
grayscale_obs=grayscale,
scale_obs=scaled,
frame_skip=1,
noop_max=0,
)
obs, _ = env.reset()
max_obs = 1 if scaled else 255
assert np.all(0 <= obs) and np.all(obs <= max_obs)
terminated, truncated, step_i = False, False, 0
while not (terminated or truncated) and step_i <= max_test_steps:
obs, _, terminated, truncated, _ = env.step(env.action_space.sample())
assert np.all(0 <= obs) and np.all(obs <= max_obs)
step_i += 1
env.close()