Add explicit error messages when unflatten discrete and multidiscrete fail (#267)

This commit is contained in:
Pierre Mardon
2023-01-18 18:32:54 +01:00
committed by GitHub
parent bb368fe75f
commit 6ba886abce
2 changed files with 26 additions and 3 deletions

View File

@@ -135,3 +135,15 @@ def test_flatten_roundtripping(space):
for original, roundtripped in zip(samples, unflattened_samples):
assert data_equivalence(original, roundtripped)
def test_unflatten_discrete_error():
value = np.array([0])
with pytest.raises(ValueError):
utils.unflatten(gym.spaces.Discrete(1), value)
def test_unflatten_multidiscrete_error():
value = np.array([0, 0])
with pytest.raises(ValueError):
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)