2019-05-10 23:59:32 +02:00
|
|
|
import numpy as np
|
|
|
|
import gym
|
2019-05-10 17:53:05 -07:00
|
|
|
from gym.wrappers import AtariPreprocessing
|
2019-08-23 15:45:55 -07:00
|
|
|
import pytest
|
|
|
|
pytest.importorskip('atari_py')
|
2019-05-10 23:59:32 +02:00
|
|
|
|
|
|
|
def test_atari_preprocessing():
|
2019-05-10 17:53:05 -07:00
|
|
|
import cv2
|
2019-08-23 15:45:55 -07:00
|
|
|
env_fn = lambda: gym.make('PongNoFrameskip-v4')
|
|
|
|
env1 = env_fn()
|
|
|
|
env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
|
|
|
|
env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
|
|
|
|
env1.reset()
|
|
|
|
# take these steps to imitate actions of FireReset logic
|
|
|
|
env1.step(1)
|
|
|
|
obs1 = env1.step(2)[0]
|
2019-05-10 23:59:32 +02:00
|
|
|
obs2 = env2.reset()
|
|
|
|
obs3 = env3.reset()
|
2019-08-23 15:45:55 -07:00
|
|
|
assert obs1.shape == (210, 160, 3)
|
|
|
|
assert obs2.shape == (84, 84)
|
2019-05-10 23:59:32 +02:00
|
|
|
assert obs3.shape == (84, 84, 3)
|
2019-08-23 15:45:55 -07:00
|
|
|
np.testing.assert_allclose(obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA))
|
|
|
|
obs3_gray = cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY)
|
|
|
|
# the edges of the numbers do not render quite the same in the grayscale, so we ignore them
|
|
|
|
np.testing.assert_allclose(obs2[10:], obs3_gray[10:])
|