Add explicit error messages when unflatten discrete and multidiscrete fail (#267)

This commit is contained in:
Pierre Mardon
2023-01-18 18:32:54 +01:00
committed by GitHub
parent bb368fe75f
commit 6ba886abce
2 changed files with 26 additions and 3 deletions

View File

@@ -271,7 +271,13 @@ def _unflatten_box_multibinary(
@unflatten.register(Discrete)
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
return space.start + np.nonzero(x)[0][0]
nonzero = np.nonzero(x)
if len(nonzero[0]) == 0:
raise ValueError(
f"{x} is not a valid one-hot encoded vector and can not be unflattened to space {space}. "
"Not all valid samples in a flattened space can be unflattened."
)
return space.start + nonzero[0][0]
@unflatten.register(MultiDiscrete)
@@ -280,8 +286,13 @@ def _unflatten_multidiscrete(
) -> NDArray[np.integer[Any]]:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten())
(indices,) = cast(type(offsets[:-1]), np.nonzero(x))
nonzero = np.nonzero(x)
if len(nonzero[0]) == 0:
raise ValueError(
f"{x} is not a concatenation of one-hot encoded vectors and can not be unflattened to space {space}. "
"Not all valid samples in a flattened space can be unflattened."
)
(indices,) = cast(type(offsets[:-1]), nonzero)
return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape)

View File

@@ -135,3 +135,15 @@ def test_flatten_roundtripping(space):
for original, roundtripped in zip(samples, unflattened_samples):
assert data_equivalence(original, roundtripped)
def test_unflatten_discrete_error():
value = np.array([0])
with pytest.raises(ValueError):
utils.unflatten(gym.spaces.Discrete(1), value)
def test_unflatten_multidiscrete_error():
value = np.array([0, 0])
with pytest.raises(ValueError):
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)