2019-08-23 23:04:11 +02:00
|
|
|
import pytest
|
|
|
|
pytest.importorskip("atari_py")
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import gym
|
|
|
|
from gym.wrappers import FrameStack
|
|
|
|
try:
|
|
|
|
import lz4
|
|
|
|
except ImportError:
|
|
|
|
lz4 = None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('env_id', ['CartPole-v1', 'Pendulum-v0', 'Pong-v0'])
|
|
|
|
@pytest.mark.parametrize('num_stack', [2, 3, 4])
|
|
|
|
@pytest.mark.parametrize('lz4_compress', [
|
|
|
|
pytest.param(True, marks=pytest.mark.skipif(lz4 is None, reason="Need lz4 to run tests with compression")),
|
|
|
|
False
|
|
|
|
])
|
|
|
|
def test_frame_stack(env_id, num_stack, lz4_compress):
|
|
|
|
env = gym.make(env_id)
|
|
|
|
shape = env.observation_space.shape
|
|
|
|
env = FrameStack(env, num_stack, lz4_compress)
|
|
|
|
assert env.observation_space.shape == (num_stack,) + shape
|
2020-06-20 07:16:02 +10:00
|
|
|
assert env.observation_space.dtype == env.env.observation_space.dtype
|
2019-08-23 23:04:11 +02:00
|
|
|
|
|
|
|
obs = env.reset()
|
|
|
|
obs = np.asarray(obs)
|
|
|
|
assert obs.shape == (num_stack,) + shape
|
|
|
|
for i in range(1, num_stack):
|
|
|
|
assert np.allclose(obs[i - 1], obs[i])
|
|
|
|
|
|
|
|
obs, _, _, _ = env.step(env.action_space.sample())
|
|
|
|
obs = np.asarray(obs)
|
|
|
|
assert obs.shape == (num_stack,) + shape
|
|
|
|
for i in range(1, num_stack - 1):
|
|
|
|
assert np.allclose(obs[i - 1], obs[i])
|
|
|
|
assert not np.allclose(obs[-1], obs[-2])
|
2020-06-20 07:16:02 +10:00
|
|
|
|
|
|
|
obs, _, _, _ = env.step(env.action_space.sample())
|
|
|
|
assert len(obs) == num_stack
|