Files
Gymnasium/tests/wrappers/test_discretize_action.py

57 lines
2.3 KiB
Python
Raw Permalink Normal View History

"""Test suite for DiscretizeAction wrapper."""
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import DiscretizeAction
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize("dimensions", [1, 2, 3, 5])
def test_discretize_action_space_uniformity(dimensions):
"""Tests that the Box action space is discretized uniformly."""
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
n_bins = 7
env = DiscretizeAction(env, n_bins)
env_act = np.meshgrid(*(np.linspace(0, 99, n_bins) for _ in range(dimensions)))
env_act = np.concatenate([o.flatten()[None] for o in env_act], 0).T
env_act_discretized = np.sort([env.revert_action(a) for a in env_act])
assert env_act.shape[0] == env.action_space.n
assert np.all(env_act_discretized == np.arange(env.action_space.n))
@pytest.mark.parametrize(
"dimensions, bins, multidiscrete",
[
(1, 3, False),
(2, (3, 4), False),
(3, (3, 4, 5), False),
(1, 3, True),
(2, (3, 4), True),
(3, (3, 4, 5), True),
],
)
def test_revert_discretize_action_space(dimensions, bins, multidiscrete):
"""Tests that the action is discretized correctly within the bins."""
env = GenericTestEnv(action_space=Box(0, 99, shape=(dimensions,)))
env_discrete = DiscretizeAction(env, bins, multidiscrete)
for i in range(1000):
act_discrete = env_discrete.action_space.sample()
act_continuous = env_discrete.action(act_discrete)
assert env.action_space.contains(act_continuous)
assert np.all(env_discrete.revert_action(act_continuous) == act_discrete)
@pytest.mark.parametrize("high, low", [(0, np.inf), (-np.inf, np.inf), (-np.inf, 0)])
def test_discretize_action_bounds(high, low):
"""Tests the discretize action wrapper with spaces that should raise an error."""
with pytest.raises((ValueError,)):
DiscretizeAction(GenericTestEnv(action_space=Box(low, high, shape=(1,))))
def test_discretize_action_dtype():
"""Tests the discretize action wrapper with spaces that should raise an error."""
with pytest.raises((TypeError,)):
DiscretizeAction(GenericTestEnv(action_space=Discrete(10)))