re-enable and fix atari preprocessing wrappers pixel test (#1652)

This commit is contained in:
pzhokhov
2019-08-23 15:45:55 -07:00
committed by GitHub
parent a0264410f2
commit fd5314780d
2 changed files with 17 additions and 17 deletions

View File

@@ -94,8 +94,7 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
# NoopReset # NoopReset
self.env.reset(**kwargs) self.env.reset(**kwargs)
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
assert noops > 0
for _ in range(noops): for _ in range(noops):
_, _, done, _ = self.env.step(0) _, _, done, _ = self.env.step(0)
if done: if done:

View File

@@ -1,24 +1,25 @@
import numpy as np import numpy as np
import gym import gym
import pytest
from gym.wrappers import AtariPreprocessing from gym.wrappers import AtariPreprocessing
try: import pytest
import atari_py pytest.importorskip('atari_py')
except ImportError:
atari_py = None
@pytest.mark.skipif(atari_py is None, reason='Only run this test when atari_py is installed')
def test_atari_preprocessing(): def test_atari_preprocessing():
import cv2 import cv2
env1 = gym.make('PongNoFrameskip-v0') env_fn = lambda: gym.make('PongNoFrameskip-v4')
env2 = AtariPreprocessing(env1, screen_size=84, grayscale_obs=True) env1 = env_fn()
env3 = AtariPreprocessing(env1, screen_size=84, grayscale_obs=False) env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
obs1 = env1.reset() env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
assert obs1.shape == (210, 160, 3) env1.reset()
# take these steps to imitate actions of FireReset logic
env1.step(1)
obs1 = env1.step(2)[0]
obs2 = env2.reset() obs2 = env2.reset()
assert obs2.shape == (84, 84)
obs3 = env3.reset() obs3 = env3.reset()
assert obs1.shape == (210, 160, 3)
assert obs2.shape == (84, 84)
assert obs3.shape == (84, 84, 3) assert obs3.shape == (84, 84, 3)
# TODO peterz - figure out why assertions below are faliing and fix np.testing.assert_allclose(obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA))
# np.testing.assert_allclose(obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA)) obs3_gray = cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY)
# np.testing.assert_allclose(obs2, cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY)) # the edges of the numbers do not render quite the same in the grayscale, so we ignore them
np.testing.assert_allclose(obs2[10:], obs3_gray[10:])