Files
Gymnasium/tests/wrappers/flatten_test.py
Ariel Kwiatkowski 947b857bd4 Test refactoring (#2427)
* Move tests to root with automatic PyCharm import refactoring. This will likely fail some tests

* Changed entry point for a registration test env.

* Move a stray lunar_lander test to tests/envs/...

* black

* Change the version from which importlib_metadata is replaced with importlib.metadata. Also requiring installing importlib_metadata for python 3.8 now.

???????????

* Undo last commit
2021-09-28 19:53:30 -04:00

97 lines
3.1 KiB
Python

"""Tests for the flatten observation wrapper."""
from collections import OrderedDict
import numpy as np
import pytest
import gym
from gym.spaces import Box, Dict, unflatten, flatten
from gym.wrappers import FlattenObservation
class FakeEnvironment(gym.Env):
def __init__(self, observation_space):
self.observation_space = observation_space
def reset(self):
self.observation = self.observation_space.sample()
return self.observation
OBSERVATION_SPACES = (
(
Dict(
OrderedDict(
[
("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)),
("key2", Box(shape=(), low=1, high=1, dtype=np.float32)),
("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)),
]
)
),
True,
),
(
Dict(
OrderedDict(
[
("key2", Box(shape=(), low=0, high=0, dtype=np.float32)),
("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)),
("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)),
]
)
),
True,
),
(
Dict(
{
"key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
False,
),
)
class TestFlattenEnvironment(object):
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
def test_flattened_environment(self, observation_space, ordered_values):
"""
make sure that flattened observations occur in the order expected
"""
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(env)
flattened = wrapped_env.reset()
unflattened = unflatten(env.observation_space, flattened)
original = env.observation
self._check_observations(original, flattened, unflattened, ordered_values)
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
def test_flatten_unflatten(self, observation_space, ordered_values):
"""
test flatten and unflatten functions directly
"""
original = observation_space.sample()
flattened = flatten(observation_space, original)
unflattened = unflatten(observation_space, flattened)
self._check_observations(original, flattened, unflattened, ordered_values)
def _check_observations(self, original, flattened, unflattened, ordered_values):
# make sure that unflatten(flatten(original)) == original
assert set(unflattened.keys()) == set(original.keys())
for k, v in original.items():
np.testing.assert_allclose(unflattened[k], v)
if ordered_values:
# make sure that the values were flattened in the order they appeared in the
# OrderedDict
np.testing.assert_allclose(sorted(flattened), flattened)