Files
Gymnasium/tests/wrappers/test_discretize_observation.py

61 lines
2.4 KiB
Python
Raw Permalink Normal View History

"""Test suite for DiscretizeObservation wrapper."""
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import DiscretizeObservation
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
def test_discretize_observation_space_uniformity(dimensions):
"""Tests that the Box observation space is discretized uniformly."""
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
n_bins = 7
env = DiscretizeObservation(env, n_bins)
env_obs = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
env_obs = np.concatenate([o.flatten()[None] for o in env_obs], 0).T
env_obs_discretized = np.sort([env.observation(e) for e in env_obs])
assert env_obs.shape[0] == env.observation_space.n
assert np.all(env_obs_discretized == np.arange(env.observation_space.n))
@pytest.mark.parametrize(
"dimensions, bins, multidiscrete",
[
(1, 3, False),
(2, (3, 4), False),
(3, (3, 4, 5), False),
(1, 3, True),
(2, (3, 4), True),
(3, (3, 4, 5), True),
],
)
def test_revert_discretize_observation_space(dimensions, bins, multidiscrete):
"""Tests that the observation is discretized correctly within the bins."""
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
env_discrete = DiscretizeObservation(env, bins, multidiscrete)
for i in range(1000):
obs, _ = env.reset(seed=i)
obs_discrete, _ = env_discrete.reset(seed=i)
obs_reverted_min, obs_reverted_max = env_discrete.revert_observation(
obs_discrete,
)
assert np.all(obs >= obs_reverted_min) and np.all(obs <= obs_reverted_max)
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
def test_discretize_observation_bounds(high, low):
"""Tests the discretize observation wrapper with spaces that should raise an error."""
with pytest.raises((ValueError,)):
DiscretizeObservation(
GenericTestEnv(observation_space=Box(low, high, shape=(1,)))
)
def test_discretize_observation_dtype():
"""Tests the discretize observation wrapper with spaces that should raise an error."""
with pytest.raises((TypeError,)):
DiscretizeObservation(GenericTestEnv(observation_space=Discrete(10)))