Files
Gymnasium/tests/wrappers/nested_dict_test.py
Mark Towers bf093c6890 Update the flake8 pre-commit ignores (#2778)
* Remove additional ignores from flake8

* Remove all unused imports

* Remove all unused imports

* Update flake8 and pyupgrade

* F841, removed unused variables

* E731, removed lambda assignment to variables

* Remove E731, F403, F405, F524

* Remove E722, bare exceptions

* Remove E712, compare variable == True or == False to is True or is False

* Remove E402, module level import not at top of file

* Added --pre-file-ignores

* Add --per-file-ignores removing E741, E302 and E704

* Add E741, do not use variables named ‘l’, ‘O’, or ‘I’ to ignore issues in classic control

* Fixed issues for pytest==6.2

* Remove unnecessary # noqa

* Edit comment with the removal of E302

* Added warnings and declared module, attr for pyright type hinting

* Remove unused import

* Removed flake8 E302

* Updated flake8 from 3.9.2 to 4.0.1

* Remove unused variable
2022-04-26 11:18:37 -04:00

121 lines
3.9 KiB
Python

"""Tests for the filter observation wrapper."""
from typing import Optional
import numpy as np
import pytest
import gym
from gym.spaces import Box, Dict, Tuple
from gym.wrappers import FilterObservation, FlattenObservation
class FakeEnvironment(gym.Env):
def __init__(self, observation_space):
self.observation_space = observation_space
self.obs_keys = self.observation_space.spaces.keys()
self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32)
def render(self, width=32, height=32, *args, **kwargs):
del args
del kwargs
image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation
def step(self, action):
del action
observation = self.observation_space.sample()
reward, terminal, info = 0.0, False, {}
return observation, reward, terminal, info
NESTED_DICT_TEST_CASES = (
(
Dict(
{
"key1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
"key2": Dict(
{
"subkey1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
"subkey2": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
}
),
(6,),
),
(
Dict(
{
"key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(9,),
),
(
Dict(
{
"key1": Tuple(
(
Box(shape=(2,), low=-1, high=1, dtype=np.float32),
Box(shape=(2,), low=-1, high=1, dtype=np.float32),
)
),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(7,),
),
(
Dict(
{
"key1": Tuple((Box(shape=(2,), low=-1, high=1, dtype=np.float32),)),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(5,),
),
(
Dict(
{
"key1": Tuple(
(Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),)
),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(5,),
),
)
class TestNestedDictWrapper:
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_size(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
# Make sure we are testing the right environment for the test.
observation_space = env.observation_space
assert isinstance(observation_space, Dict)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
assert wrapped_env.observation_space.shape == flat_shape
assert wrapped_env.observation_space.dtype == np.float32
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_ravel(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
obs = wrapped_env.reset()
assert obs.shape == wrapped_env.observation_space.shape