Files
Gymnasium/tests/wrappers/flatten_test.py
Ariel Kwiatkowski 925823661d Add options to the signature of env.reset (#2515)
* First find/replace, now tests

* Fixes to the vector env

* Make seed keyword only in wrappers

* (try to) fix the bug with old environments using new wrappers (with the seed keyword)

* black

* Change **kwargs to options, try to make it work; black

* Add OrderEnforcing wrapper to wrapper exports
Add a test for compatibility with old (pybullet-like) envs

* Add OrderEnforcing wrapper to wrapper exports
Add a test for compatibility with old (pybullet-like) envs
black

* Update the env checker

* Update the env checker

* Update the env checker to use inspect (might fail tests, let's see)

* Allow the signature to include kwargs in env_checker

* Minor fix
2022-01-19 17:28:59 -05:00

99 lines
3.2 KiB
Python

"""Tests for the flatten observation wrapper."""
from collections import OrderedDict
from typing import Optional
import numpy as np
import pytest
import gym
from gym.spaces import Box, Dict, unflatten, flatten
from gym.wrappers import FlattenObservation
class FakeEnvironment(gym.Env):
def __init__(self, observation_space):
self.observation_space = observation_space
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
self.observation = self.observation_space.sample()
return self.observation
OBSERVATION_SPACES = (
(
Dict(
OrderedDict(
[
("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)),
("key2", Box(shape=(), low=1, high=1, dtype=np.float32)),
("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)),
]
)
),
True,
),
(
Dict(
OrderedDict(
[
("key2", Box(shape=(), low=0, high=0, dtype=np.float32)),
("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)),
("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)),
]
)
),
True,
),
(
Dict(
{
"key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
False,
),
)
class TestFlattenEnvironment:
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
def test_flattened_environment(self, observation_space, ordered_values):
"""
make sure that flattened observations occur in the order expected
"""
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(env)
flattened = wrapped_env.reset()
unflattened = unflatten(env.observation_space, flattened)
original = env.observation
self._check_observations(original, flattened, unflattened, ordered_values)
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
def test_flatten_unflatten(self, observation_space, ordered_values):
"""
test flatten and unflatten functions directly
"""
original = observation_space.sample()
flattened = flatten(observation_space, original)
unflattened = unflatten(observation_space, flattened)
self._check_observations(original, flattened, unflattened, ordered_values)
def _check_observations(self, original, flattened, unflattened, ordered_values):
# make sure that unflatten(flatten(original)) == original
assert set(unflattened.keys()) == set(original.keys())
for k, v in original.items():
np.testing.assert_allclose(unflattened[k], v)
if ordered_values:
# make sure that the values were flattened in the order they appeared in the
# OrderedDict
np.testing.assert_allclose(sorted(flattened), flattened)