mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 21:06:59 +00:00
Fix all warnings in tests/spaces (#1396)
This commit is contained in:
@@ -48,24 +48,25 @@ def test_dtype_errors(dtype, error, message):
|
||||
Box(low=0, high=1, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"box, expected_shape",
|
||||
[
|
||||
# Test with same 1-dim low and high shape
|
||||
(Box(low=np.zeros(2), high=np.ones(2)), (2,)),
|
||||
# Test with same multi-dim low and high shape
|
||||
(Box(low=np.zeros((2, 1)), high=np.ones((2, 1))), (2, 1)),
|
||||
# Test with scalar low high and different shape
|
||||
(Box(low=0, high=1, shape=(5, 2)), (5, 2)),
|
||||
(Box(low=0, high=1), (1,)), # Test with int and int
|
||||
(Box(low=0.0, high=1.0), (1,)), # Test with float and float
|
||||
(Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
|
||||
(Box(low=0.0, high=1), (1,)), # Test with float and int
|
||||
(Box(low=0, high=np.int32(1)), (1,)), # Test with python int and numpy int32
|
||||
(Box(low=0, high=np.ones(3)), (3,)), # Test with array and scalar
|
||||
(Box(low=np.zeros(3), high=1.0), (3,)), # Test with array and scalar
|
||||
],
|
||||
)
|
||||
def _shape_inference_params():
|
||||
# Test with same 1-dim low and high shape
|
||||
yield Box(low=np.zeros(2), high=np.ones(2), dtype=np.float64), (2,)
|
||||
# Test with same multi-dim low and high shape
|
||||
yield Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.float64), (2, 1)
|
||||
# Test with scalar low high and different shape
|
||||
yield Box(low=0, high=1, shape=(5, 2)), (5, 2)
|
||||
yield Box(low=0, high=1), (1,) # Test with int and int
|
||||
yield Box(low=0.0, high=1.0), (1,) # Test with float and float
|
||||
yield Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)
|
||||
yield Box(low=0.0, high=1), (1,) # Test with float and int
|
||||
# Test with python int and numpy int32
|
||||
yield Box(low=0, high=np.int32(1)), (1,)
|
||||
# Test with array and scalar
|
||||
yield Box(low=0, high=np.ones(3), dtype=np.float64), (3,)
|
||||
yield Box(low=np.zeros(3), high=1.0, dtype=np.float64), (3,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("box, expected_shape", _shape_inference_params())
|
||||
def test_shape_inference(box, expected_shape):
|
||||
"""Test that the shape inference is as expected."""
|
||||
assert box.shape == expected_shape
|
||||
@@ -136,7 +137,7 @@ def test_shape_inference(box, expected_shape):
|
||||
def test_shape_errors(low, high, shape, error_type, message):
|
||||
"""Test errors due to shape mismatch."""
|
||||
with pytest.raises(error_type, match=f"^{re.escape(message)}$"):
|
||||
Box(low=low, high=high, shape=shape)
|
||||
Box(low=low, high=high, shape=shape, dtype=np.float64)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -260,6 +261,7 @@ def test_invalid_low_high(low, high, dtype):
|
||||
)
|
||||
def test_valid_low_high(low, high, dtype):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always", UserWarning)
|
||||
space = Box(low=low, high=high, dtype=dtype)
|
||||
assert space.dtype == dtype
|
||||
assert space.low.dtype == dtype
|
||||
@@ -294,24 +296,32 @@ def test_contains_dtype():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lowhighshape",
|
||||
"low, high, shape",
|
||||
[
|
||||
dict(low=0, high=np.inf, shape=(2,)),
|
||||
dict(low=-np.inf, high=0, shape=(2,)),
|
||||
dict(low=-np.inf, high=np.inf, shape=(2,)),
|
||||
dict(low=0, high=np.inf, shape=(2, 3)),
|
||||
dict(low=-np.inf, high=0, shape=(2, 3)),
|
||||
dict(low=-np.inf, high=np.inf, shape=(2, 3)),
|
||||
dict(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf])),
|
||||
(0, np.inf, (2,)),
|
||||
(-np.inf, 0, (2,)),
|
||||
(-np.inf, np.inf, (2,)),
|
||||
(0, np.inf, (2, 3)),
|
||||
(-np.inf, 0, (2, 3)),
|
||||
(-np.inf, np.inf, (2, 3)),
|
||||
(np.array([-np.inf, 0]), np.array([0.0, np.inf]), None),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
|
||||
def test_infinite_space(lowhighshape, dtype):
|
||||
def test_infinite_space(low, high, shape, dtype):
|
||||
"""
|
||||
To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
|
||||
are both modified within the init, we check for infinite when we know it's not 0
|
||||
"""
|
||||
space = Box(**lowhighshape, dtype=dtype)
|
||||
# Emits a warning for lowering the last example
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always", UserWarning)
|
||||
space = Box(low=low, high=high, shape=shape, dtype=dtype)
|
||||
|
||||
# Check if only the expected precision warning is emitted
|
||||
for warn in caught_warnings:
|
||||
if "precision lowered by casting to float32" not in warn.message.args[0]:
|
||||
raise Exception(warn)
|
||||
|
||||
assert np.all(space.low < space.high)
|
||||
|
||||
|
Reference in New Issue
Block a user