2019-08-23 23:04:11 +02:00
|
|
|
import pytest
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2021-09-22 17:11:21 -06:00
|
|
|
pytest.importorskip("gym.envs.atari")
|
2019-08-23 23:04:11 +02:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import gym
|
|
|
|
from gym.wrappers import FrameStack
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-08-23 23:04:11 +02:00
|
|
|
try:
|
|
|
|
import lz4
|
|
|
|
except ImportError:
|
|
|
|
lz4 = None
|
|
|
|
|
|
|
|
|
2021-09-25 20:00:28 +02:00
|
|
|
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "Pong-v0"])
|
2021-07-29 02:26:34 +02:00
|
|
|
@pytest.mark.parametrize("num_stack", [2, 3, 4])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"lz4_compress",
|
|
|
|
[
|
|
|
|
pytest.param(
|
|
|
|
True,
|
2021-07-29 15:39:42 -04:00
|
|
|
marks=pytest.mark.skipif(
|
|
|
|
lz4 is None, reason="Need lz4 to run tests with compression"
|
|
|
|
),
|
2021-07-29 02:26:34 +02:00
|
|
|
),
|
|
|
|
False,
|
|
|
|
],
|
|
|
|
)
|
2019-08-23 23:04:11 +02:00
|
|
|
def test_frame_stack(env_id, num_stack, lz4_compress):
|
|
|
|
env = gym.make(env_id)
|
|
|
|
shape = env.observation_space.shape
|
|
|
|
env = FrameStack(env, num_stack, lz4_compress)
|
|
|
|
assert env.observation_space.shape == (num_stack,) + shape
|
2020-06-20 07:16:02 +10:00
|
|
|
assert env.observation_space.dtype == env.env.observation_space.dtype
|
2019-08-23 23:04:11 +02:00
|
|
|
|
2021-09-11 11:04:41 -06:00
|
|
|
dup = gym.make(env_id)
|
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
obs = env.reset(seed=0)
|
|
|
|
dup_obs = dup.reset(seed=0)
|
2021-09-11 11:04:41 -06:00
|
|
|
assert np.allclose(obs[-1], dup_obs)
|
|
|
|
|
|
|
|
for _ in range(num_stack ** 2):
|
|
|
|
action = env.action_space.sample()
|
|
|
|
dup_obs, _, _, _ = dup.step(action)
|
|
|
|
obs, _, _, _ = env.step(action)
|
|
|
|
assert np.allclose(obs[-1], dup_obs)
|
|
|
|
|
2020-06-20 07:16:02 +10:00
|
|
|
assert len(obs) == num_stack
|