Fix all warnings in tests/spaces (#1396)

This commit is contained in:
Martin Schuck
2025-06-08 00:55:12 +02:00
committed by GitHub
parent 433d7af1f9
commit 461f478db9
3 changed files with 81 additions and 62 deletions

View File

@@ -48,24 +48,25 @@ def test_dtype_errors(dtype, error, message):
Box(low=0, high=1, dtype=dtype)
@pytest.mark.parametrize(
"box, expected_shape",
[
# Test with same 1-dim low and high shape
(Box(low=np.zeros(2), high=np.ones(2)), (2,)),
# Test with same multi-dim low and high shape
(Box(low=np.zeros((2, 1)), high=np.ones((2, 1))), (2, 1)),
# Test with scalar low high and different shape
(Box(low=0, high=1, shape=(5, 2)), (5, 2)),
(Box(low=0, high=1), (1,)), # Test with int and int
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
(Box(low=0.0, high=1), (1,)), # Test with float and int
(Box(low=0, high=np.int32(1)), (1,)), # Test with python int and numpy int32
(Box(low=0, high=np.ones(3)), (3,)), # Test with array and scalar
(Box(low=np.zeros(3), high=1.0), (3,)), # Test with array and scalar
],
)
def _shape_inference_params():
# Test with same 1-dim low and high shape
yield Box(low=np.zeros(2), high=np.ones(2), dtype=np.float64), (2,)
# Test with same multi-dim low and high shape
yield Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.float64), (2, 1)
# Test with scalar low high and different shape
yield Box(low=0, high=1, shape=(5, 2)), (5, 2)
yield Box(low=0, high=1), (1,) # Test with int and int
yield Box(low=0.0, high=1.0), (1,) # Test with float and float
yield Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)
yield Box(low=0.0, high=1), (1,) # Test with float and int
# Test with python int and numpy int32
yield Box(low=0, high=np.int32(1)), (1,)
# Test with array and scalar
yield Box(low=0, high=np.ones(3), dtype=np.float64), (3,)
yield Box(low=np.zeros(3), high=1.0, dtype=np.float64), (3,)
@pytest.mark.parametrize("box, expected_shape", _shape_inference_params())
def test_shape_inference(box, expected_shape):
"""Test that the shape inference is as expected."""
assert box.shape == expected_shape
@@ -136,7 +137,7 @@ def test_shape_inference(box, expected_shape):
def test_shape_errors(low, high, shape, error_type, message):
"""Test errors due to shape mismatch."""
with pytest.raises(error_type, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, shape=shape)
Box(low=low, high=high, shape=shape, dtype=np.float64)
@pytest.mark.parametrize(
@@ -260,6 +261,7 @@ def test_invalid_low_high(low, high, dtype):
)
def test_valid_low_high(low, high, dtype):
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", UserWarning)
space = Box(low=low, high=high, dtype=dtype)
assert space.dtype == dtype
assert space.low.dtype == dtype
@@ -294,24 +296,32 @@ def test_contains_dtype():
@pytest.mark.parametrize(
"lowhighshape",
"low, high, shape",
[
dict(low=0, high=np.inf, shape=(2,)),
dict(low=-np.inf, high=0, shape=(2,)),
dict(low=-np.inf, high=np.inf, shape=(2,)),
dict(low=0, high=np.inf, shape=(2, 3)),
dict(low=-np.inf, high=0, shape=(2, 3)),
dict(low=-np.inf, high=np.inf, shape=(2, 3)),
dict(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf])),
(0, np.inf, (2,)),
(-np.inf, 0, (2,)),
(-np.inf, np.inf, (2,)),
(0, np.inf, (2, 3)),
(-np.inf, 0, (2, 3)),
(-np.inf, np.inf, (2, 3)),
(np.array([-np.inf, 0]), np.array([0.0, np.inf]), None),
],
)
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
def test_infinite_space(lowhighshape, dtype):
def test_infinite_space(low, high, shape, dtype):
"""
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
are both modified within the init, we check for infinite when we know it's not 0
"""
space = Box(**lowhighshape, dtype=dtype)
# Emits a warning for lowering the last example
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", UserWarning)
space = Box(low=low, high=high, shape=shape, dtype=dtype)
# Check if only the expected precision warning is emitted
for warn in caught_warnings:
if "precision lowered by casting to float32" not in warn.message.args[0]:
raise Exception(warn)
assert np.all(space.low < space.high)

View File

@@ -32,31 +32,19 @@ CHECK_ENV_IGNORE_WARNINGS = [
]
@pytest.mark.parametrize(
"env",
[
gym.make("CartPole-v1", disable_env_checker=True).unwrapped,
gym.make("MountainCar-v0", disable_env_checker=True).unwrapped,
GenericTestEnv(
observation_space=spaces.Dict(
a=spaces.Discrete(10), b=spaces.Box(np.zeros(2), np.ones(2))
)
),
GenericTestEnv(
observation_space=spaces.Tuple(
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
)
),
GenericTestEnv(
observation_space=spaces.Dict(
a=spaces.Tuple(
[spaces.Discrete(10), spaces.Box(np.zeros(2), np.ones(2))]
),
b=spaces.Box(np.zeros(2), np.ones(2)),
)
),
],
)
def _no_error_warnings_envs():
yield gym.make("CartPole-v1", disable_env_checker=True).unwrapped
yield gym.make("MountainCar-v0", disable_env_checker=True).unwrapped
space_a = spaces.Discrete(10)
space_b = spaces.Box(np.zeros(2, np.float32), np.ones(2, np.float32))
yield GenericTestEnv(observation_space=spaces.Dict(a=space_a, b=space_b))
yield GenericTestEnv(observation_space=spaces.Tuple([space_a, space_b]))
yield GenericTestEnv(
observation_space=spaces.Dict(a=spaces.Tuple([space_a, space_b]), b=space_b)
)
@pytest.mark.parametrize("env", _no_error_warnings_envs())
def test_no_error_warnings(env):
"""A full version of this test with all gymnasium envs is run in tests/envs/test_envs.py."""
with warnings.catch_warnings(record=True) as caught_warnings:

View File

@@ -34,17 +34,25 @@ def _modify_space(space: spaces.Space, attribute: str, value):
# ===== Check box observation space ====
[
UserWarning,
spaces.Box(np.zeros(5), np.zeros(5)),
spaces.Box(np.zeros(5, np.float32), np.zeros(5, np.float32)),
"A Box observation space maximum and minimum values are equal.",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
_modify_space(
spaces.Box(np.zeros(2, np.float32), np.ones(2, np.float32)),
"low",
np.zeros(3, np.float32),
),
"The Box observation space shape and low shape have different shapes, low shape: (3,), box shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
_modify_space(
spaces.Box(np.zeros(2, np.float32), np.ones(2, np.float32)),
"high",
np.ones(3, np.float32),
),
"The Box observation space shape and high shape have have different shapes, high shape: (3,), box shape: (2,)",
],
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
@@ -105,17 +113,25 @@ def test_check_observation_space(test, space, message: str):
# ===== Check box observation space ====
[
UserWarning,
spaces.Box(np.zeros(5), np.zeros(5)),
spaces.Box(np.zeros(5, np.float32), np.zeros(5, np.float32)),
"A Box action space maximum and minimum values are equal.",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),
_modify_space(
spaces.Box(np.zeros(2, np.float32), np.ones(2, np.float32)),
"low",
np.zeros(3, np.float32),
),
"The Box action space shape and low shape have have different shapes, low shape: (3,), box shape: (2,)",
],
[
AssertionError,
_modify_space(spaces.Box(np.zeros(2), np.ones(2)), "high", np.ones(3)),
_modify_space(
spaces.Box(np.zeros(2, np.float32), np.ones(2, np.float32)),
"high",
np.ones(3, np.float32),
),
"The Box action space shape and high shape have different shapes, high shape: (3,), box shape: (2,)",
],
# ==== Other observation spaces (Discrete, MultiDiscrete, MultiBinary, Tuple, Dict)
@@ -359,7 +375,12 @@ def test_passive_env_step_checker(
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="^\\x1b\\[33mWARN: Core environment is written in old step API *",
)
env_step_passive_checker(GenericTestEnv(step_func=func), 0)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
@@ -377,7 +398,7 @@ def test_passive_env_step_checker(
],
[
UserWarning,
GenericTestEnv(metadata={"render_modes": "Testing mode"}),
GenericTestEnv(metadata={"render_modes": "Testing mode", "render_fps": 1}),
"Expects the render_modes to be a sequence (i.e. list, tuple), actual type: <class 'str'>",
],
[