Files
Gymnasium/tests/wrappers/nested_dict_test.py
Ariel Kwiatkowski 925823661d Add options to the signature of env.reset (#2515)
* First find/replace, now tests

* Fixes to the vector env

* Make seed keyword only in wrappers

* (try to) fix the bug with old environments using new wrappers (with the seed keyword)

* black

* Change **kwargs to options, try to make it work; black

* Add OrderEnforcing wrapper to wrapper exports
Add a test for compatibility with old (pybullet-like) envs

* Add OrderEnforcing wrapper to wrapper exports
Add a test for compatibility with old (pybullet-like) envs
black

* Update the env checker

* Update the env checker

* Update the env checker to use inspect (might fail tests, let's see)

* Allow the signature to include kwargs in env_checker

* Minor fix
2022-01-19 17:28:59 -05:00

121 lines
3.9 KiB
Python

"""Tests for the filter observation wrapper."""
from typing import Optional
import pytest
import numpy as np
import gym
from gym.spaces import Dict, Box, Discrete, Tuple
from gym.wrappers import FilterObservation, FlattenObservation
class FakeEnvironment(gym.Env):
def __init__(self, observation_space):
self.observation_space = observation_space
self.obs_keys = self.observation_space.spaces.keys()
self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32)
def render(self, width=32, height=32, *args, **kwargs):
del args
del kwargs
image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation
def step(self, action):
del action
observation = self.observation_space.sample()
reward, terminal, info = 0.0, False, {}
return observation, reward, terminal, info
NESTED_DICT_TEST_CASES = (
(
Dict(
{
"key1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
"key2": Dict(
{
"subkey1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
"subkey2": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
}
),
(6,),
),
(
Dict(
{
"key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(9,),
),
(
Dict(
{
"key1": Tuple(
(
Box(shape=(2,), low=-1, high=1, dtype=np.float32),
Box(shape=(2,), low=-1, high=1, dtype=np.float32),
)
),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(7,),
),
(
Dict(
{
"key1": Tuple((Box(shape=(2,), low=-1, high=1, dtype=np.float32),)),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(5,),
),
(
Dict(
{
"key1": Tuple(
(Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),)
),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(5,),
),
)
class TestNestedDictWrapper:
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_size(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
# Make sure we are testing the right environment for the test.
observation_space = env.observation_space
assert isinstance(observation_space, Dict)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
assert wrapped_env.observation_space.shape == flat_shape
assert wrapped_env.observation_space.dtype == np.float32
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_ravel(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
obs = wrapped_env.reset()
assert obs.shape == wrapped_env.observation_space.shape