Files
Gymnasium/tests/wrappers/test_frame_stack.py
Andrea PIERRÉ e913bc81b8 Improve pre-commit workflow (#2602)
* feat: add `isort` to `pre-commit`

* ci: skip `__init__.py` file for `isort`

* ci: make `isort` mandatory in lint pipeline

* docs: add a section on Git hooks

* ci: check isort diff

* fix: isort from master branch

* docs: add pre-commit badge

* ci: update black + bandit versions

* feat: add PR template

* refactor: PR template

* ci: remove bandit

* docs: add Black badge

* ci: try to remove all `|| true` statements

* ci: remove lint_python job

- Remove `lint_python` CI job
- Move `pyupgrade` job to `pre-commit` workflow

* fix: avoid messing with typing

* docs: add a note on running `pre-cpmmit` manually

* ci: apply `pre-commit` to the whole codebase
2022-03-31 15:50:38 -04:00

50 lines
1.2 KiB
Python

import pytest
pytest.importorskip("gym.envs.atari")
import numpy as np
import gym
from gym.wrappers import FrameStack
try:
import lz4
except ImportError:
lz4 = None
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "Pong-v0"])
@pytest.mark.parametrize("num_stack", [2, 3, 4])
@pytest.mark.parametrize(
"lz4_compress",
[
pytest.param(
True,
marks=pytest.mark.skipif(
lz4 is None, reason="Need lz4 to run tests with compression"
),
),
False,
],
)
def test_frame_stack(env_id, num_stack, lz4_compress):
env = gym.make(env_id)
shape = env.observation_space.shape
env = FrameStack(env, num_stack, lz4_compress)
assert env.observation_space.shape == (num_stack,) + shape
assert env.observation_space.dtype == env.env.observation_space.dtype
dup = gym.make(env_id)
obs = env.reset(seed=0)
dup_obs = dup.reset(seed=0)
assert np.allclose(obs[-1], dup_obs)
for _ in range(num_stack**2):
action = env.action_space.sample()
dup_obs, _, _, _ = dup.step(action)
obs, _, _, _ = env.step(action)
assert np.allclose(obs[-1], dup_obs)
assert len(obs) == num_stack