Files
Gymnasium/gym/wrappers/test_frame_stack.py
2021-07-29 15:39:42 -04:00

51 lines
1.4 KiB
Python

import pytest
pytest.importorskip("atari_py")
import numpy as np
import gym
from gym.wrappers import FrameStack
try:
import lz4
except ImportError:
lz4 = None
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0", "Pong-v0"])
@pytest.mark.parametrize("num_stack", [2, 3, 4])
@pytest.mark.parametrize(
"lz4_compress",
[
pytest.param(
True,
marks=pytest.mark.skipif(
lz4 is None, reason="Need lz4 to run tests with compression"
),
),
False,
],
)
def test_frame_stack(env_id, num_stack, lz4_compress):
env = gym.make(env_id)
shape = env.observation_space.shape
env = FrameStack(env, num_stack, lz4_compress)
assert env.observation_space.shape == (num_stack,) + shape
assert env.observation_space.dtype == env.env.observation_space.dtype
obs = env.reset()
obs = np.asarray(obs)
assert obs.shape == (num_stack,) + shape
for i in range(1, num_stack):
assert np.allclose(obs[i - 1], obs[i])
obs, _, _, _ = env.step(env.action_space.sample())
obs = np.asarray(obs)
assert obs.shape == (num_stack,) + shape
for i in range(1, num_stack - 1):
assert np.allclose(obs[i - 1], obs[i])
assert not np.allclose(obs[-1], obs[-2])
obs, _, _, _ = env.step(env.action_space.sample())
assert len(obs) == num_stack