Fix all warnings in tests/spaces (#1396)

This commit is contained in:
Martin Schuck
2025-06-08 00:55:12 +02:00
committed by GitHub
parent 433d7af1f9
commit 461f478db9
3 changed files with 81 additions and 62 deletions

View File

@@ -48,24 +48,25 @@ def test_dtype_errors(dtype, error, message):
Box(low=0, high=1, dtype=dtype)
@pytest.mark.parametrize(
"box, expected_shape",
[
# Test with same 1-dim low and high shape
(Box(low=np.zeros(2), high=np.ones(2)), (2,)),
# Test with same multi-dim low and high shape
(Box(low=np.zeros((2, 1)), high=np.ones((2, 1))), (2, 1)),
# Test with scalar low high and different shape
(Box(low=0, high=1, shape=(5, 2)), (5, 2)),
(Box(low=0, high=1), (1,)), # Test with int and int
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
(Box(low=0.0, high=1), (1,)), # Test with float and int
(Box(low=0, high=np.int32(1)), (1,)), # Test with python int and numpy int32
(Box(low=0, high=np.ones(3)), (3,)), # Test with array and scalar
(Box(low=np.zeros(3), high=1.0), (3,)), # Test with array and scalar
],
)
def _shape_inference_params():
# Test with same 1-dim low and high shape
yield Box(low=np.zeros(2), high=np.ones(2), dtype=np.float64), (2,)
# Test with same multi-dim low and high shape
yield Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.float64), (2, 1)
# Test with scalar low high and different shape
yield Box(low=0, high=1, shape=(5, 2)), (5, 2)
yield Box(low=0, high=1), (1,) # Test with int and int
yield Box(low=0.0, high=1.0), (1,) # Test with float and float
yield Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)
yield Box(low=0.0, high=1), (1,) # Test with float and int
# Test with python int and numpy int32
yield Box(low=0, high=np.int32(1)), (1,)
# Test with array and scalar
yield Box(low=0, high=np.ones(3), dtype=np.float64), (3,)
yield Box(low=np.zeros(3), high=1.0, dtype=np.float64), (3,)
@pytest.mark.parametrize("box, expected_shape", _shape_inference_params())
def test_shape_inference(box, expected_shape):
"""Test that the shape inference is as expected."""
assert box.shape == expected_shape
@@ -136,7 +137,7 @@ def test_shape_inference(box, expected_shape):
def test_shape_errors(low, high, shape, error_type, message):
"""Test errors due to shape mismatch."""
with pytest.raises(error_type, match=f"^{re.escape(message)}$"):
Box(low=low, high=high, shape=shape)
Box(low=low, high=high, shape=shape, dtype=np.float64)
@pytest.mark.parametrize(
@@ -260,6 +261,7 @@ def test_invalid_low_high(low, high, dtype):
)
def test_valid_low_high(low, high, dtype):
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", UserWarning)
space = Box(low=low, high=high, dtype=dtype)
assert space.dtype == dtype
assert space.low.dtype == dtype
@@ -294,24 +296,32 @@ def test_contains_dtype():
@pytest.mark.parametrize(
"lowhighshape",
"low, high, shape",
[
dict(low=0, high=np.inf, shape=(2,)),
dict(low=-np.inf, high=0, shape=(2,)),
dict(low=-np.inf, high=np.inf, shape=(2,)),
dict(low=0, high=np.inf, shape=(2, 3)),
dict(low=-np.inf, high=0, shape=(2, 3)),
dict(low=-np.inf, high=np.inf, shape=(2, 3)),
dict(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf])),
(0, np.inf, (2,)),
(-np.inf, 0, (2,)),
(-np.inf, np.inf, (2,)),
(0, np.inf, (2, 3)),
(-np.inf, 0, (2, 3)),
(-np.inf, np.inf, (2, 3)),
(np.array([-np.inf, 0]), np.array([0.0, np.inf]), None),
],
)
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
def test_infinite_space(lowhighshape, dtype):
def test_infinite_space(low, high, shape, dtype):
"""
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
are both modified within the init, we check for infinite when we know it's not 0
"""
space = Box(**lowhighshape, dtype=dtype)
# Emits a warning for lowering the last example
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", UserWarning)
space = Box(low=low, high=high, shape=shape, dtype=dtype)
# Check if only the expected precision warning is emitted
for warn in caught_warnings:
if "precision lowered by casting to float32" not in warn.message.args[0]:
raise Exception(warn)
assert np.all(space.low < space.high)