mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 23:12:46 +00:00
re-enable and fix atari preprocessing wrappers pixel test (#1652)
This commit is contained in:
@@ -94,8 +94,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
# NoopReset
|
# NoopReset
|
||||||
self.env.reset(**kwargs)
|
self.env.reset(**kwargs)
|
||||||
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
|
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
|
||||||
assert noops > 0
|
|
||||||
for _ in range(noops):
|
for _ in range(noops):
|
||||||
_, _, done, _ = self.env.step(0)
|
_, _, done, _ = self.env.step(0)
|
||||||
if done:
|
if done:
|
||||||
|
@@ -1,24 +1,25 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
import pytest
|
|
||||||
from gym.wrappers import AtariPreprocessing
|
from gym.wrappers import AtariPreprocessing
|
||||||
try:
|
import pytest
|
||||||
import atari_py
|
pytest.importorskip('atari_py')
|
||||||
except ImportError:
|
|
||||||
atari_py = None
|
|
||||||
|
|
||||||
@pytest.mark.skipif(atari_py is None, reason='Only run this test when atari_py is installed')
|
|
||||||
def test_atari_preprocessing():
|
def test_atari_preprocessing():
|
||||||
import cv2
|
import cv2
|
||||||
env1 = gym.make('PongNoFrameskip-v0')
|
env_fn = lambda: gym.make('PongNoFrameskip-v4')
|
||||||
env2 = AtariPreprocessing(env1, screen_size=84, grayscale_obs=True)
|
env1 = env_fn()
|
||||||
env3 = AtariPreprocessing(env1, screen_size=84, grayscale_obs=False)
|
env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
|
||||||
obs1 = env1.reset()
|
env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
|
||||||
assert obs1.shape == (210, 160, 3)
|
env1.reset()
|
||||||
|
# take these steps to imitate actions of FireReset logic
|
||||||
|
env1.step(1)
|
||||||
|
obs1 = env1.step(2)[0]
|
||||||
obs2 = env2.reset()
|
obs2 = env2.reset()
|
||||||
assert obs2.shape == (84, 84)
|
|
||||||
obs3 = env3.reset()
|
obs3 = env3.reset()
|
||||||
|
assert obs1.shape == (210, 160, 3)
|
||||||
|
assert obs2.shape == (84, 84)
|
||||||
assert obs3.shape == (84, 84, 3)
|
assert obs3.shape == (84, 84, 3)
|
||||||
# TODO peterz - figure out why assertions below are faliing and fix
|
np.testing.assert_allclose(obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA))
|
||||||
# np.testing.assert_allclose(obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA))
|
obs3_gray = cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY)
|
||||||
# np.testing.assert_allclose(obs2, cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY))
|
# the edges of the numbers do not render quite the same in the grayscale, so we ignore them
|
||||||
|
np.testing.assert_allclose(obs2[10:], obs3_gray[10:])
|
||||||
|
Reference in New Issue
Block a user