mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
61 lines
2.4 KiB
Python
61 lines
2.4 KiB
Python
"""Test suite for DiscretizeObservation wrapper."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gymnasium.spaces import Box, Discrete
|
|
from gymnasium.wrappers import DiscretizeObservation
|
|
from tests.testing_env import GenericTestEnv
|
|
|
|
|
|
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
|
|
def test_discretize_observation_space_uniformity(dimensions):
|
|
"""Tests that the Box observation space is discretized uniformly."""
|
|
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
|
|
n_bins = 7
|
|
env = DiscretizeObservation(env, n_bins)
|
|
env_obs = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
|
|
env_obs = np.concatenate([o.flatten()[None] for o in env_obs], 0).T
|
|
env_obs_discretized = np.sort([env.observation(e) for e in env_obs])
|
|
assert env_obs.shape[0] == env.observation_space.n
|
|
assert np.all(env_obs_discretized == np.arange(env.observation_space.n))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dimensions, bins, multidiscrete",
|
|
[
|
|
(1, 3, False),
|
|
(2, (3, 4), False),
|
|
(3, (3, 4, 5), False),
|
|
(1, 3, True),
|
|
(2, (3, 4), True),
|
|
(3, (3, 4, 5), True),
|
|
],
|
|
)
|
|
def test_revert_discretize_observation_space(dimensions, bins, multidiscrete):
|
|
"""Tests that the observation is discretized correctly within the bins."""
|
|
env = GenericTestEnv(observation_space=Box(0, 99, shape=(dimensions,)))
|
|
env_discrete = DiscretizeObservation(env, bins, multidiscrete)
|
|
for i in range(1000):
|
|
obs, _ = env.reset(seed=i)
|
|
obs_discrete, _ = env_discrete.reset(seed=i)
|
|
obs_reverted_min, obs_reverted_max = env_discrete.revert_observation(
|
|
obs_discrete,
|
|
)
|
|
assert np.all(obs >= obs_reverted_min) and np.all(obs <= obs_reverted_max)
|
|
|
|
|
|
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
|
|
def test_discretize_observation_bounds(high, low):
|
|
"""Tests the discretize observation wrapper with spaces that should raise an error."""
|
|
with pytest.raises((ValueError,)):
|
|
DiscretizeObservation(
|
|
GenericTestEnv(observation_space=Box(low, high, shape=(1,)))
|
|
)
|
|
|
|
|
|
def test_discretize_observation_dtype():
|
|
"""Tests the discretize observation wrapper with spaces that should raise an error."""
|
|
with pytest.raises((TypeError,)):
|
|
DiscretizeObservation(GenericTestEnv(observation_space=Discrete(10)))
|