mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Add explicit error messages when unflatten discrete and multidiscrete fail (#267)
This commit is contained in:
@@ -271,7 +271,13 @@ def _unflatten_box_multibinary(
|
||||
|
||||
@unflatten.register(Discrete)
|
||||
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
|
||||
return space.start + np.nonzero(x)[0][0]
|
||||
nonzero = np.nonzero(x)
|
||||
if len(nonzero[0]) == 0:
|
||||
raise ValueError(
|
||||
f"{x} is not a valid one-hot encoded vector and can not be unflattened to space {space}. "
|
||||
"Not all valid samples in a flattened space can be unflattened."
|
||||
)
|
||||
return space.start + nonzero[0][0]
|
||||
|
||||
|
||||
@unflatten.register(MultiDiscrete)
|
||||
@@ -280,8 +286,13 @@ def _unflatten_multidiscrete(
|
||||
) -> NDArray[np.integer[Any]]:
|
||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||
|
||||
(indices,) = cast(type(offsets[:-1]), np.nonzero(x))
|
||||
nonzero = np.nonzero(x)
|
||||
if len(nonzero[0]) == 0:
|
||||
raise ValueError(
|
||||
f"{x} is not a concatenation of one-hot encoded vectors and can not be unflattened to space {space}. "
|
||||
"Not all valid samples in a flattened space can be unflattened."
|
||||
)
|
||||
(indices,) = cast(type(offsets[:-1]), nonzero)
|
||||
return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape)
|
||||
|
||||
|
||||
|
@@ -135,3 +135,15 @@ def test_flatten_roundtripping(space):
|
||||
|
||||
for original, roundtripped in zip(samples, unflattened_samples):
|
||||
assert data_equivalence(original, roundtripped)
|
||||
|
||||
|
||||
def test_unflatten_discrete_error():
|
||||
value = np.array([0])
|
||||
with pytest.raises(ValueError):
|
||||
utils.unflatten(gym.spaces.Discrete(1), value)
|
||||
|
||||
|
||||
def test_unflatten_multidiscrete_error():
|
||||
value = np.array([0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)
|
||||
|
Reference in New Issue
Block a user