mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-26 00:07:41 +00:00
Add explicit error messages when unflatten discrete and multidiscrete fail (#267)
This commit is contained in:
@@ -271,7 +271,13 @@ def _unflatten_box_multibinary(
|
|||||||
|
|
||||||
@unflatten.register(Discrete)
|
@unflatten.register(Discrete)
|
||||||
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
|
def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
|
||||||
return space.start + np.nonzero(x)[0][0]
|
nonzero = np.nonzero(x)
|
||||||
|
if len(nonzero[0]) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"{x} is not a valid one-hot encoded vector and can not be unflattened to space {space}. "
|
||||||
|
"Not all valid samples in a flattened space can be unflattened."
|
||||||
|
)
|
||||||
|
return space.start + nonzero[0][0]
|
||||||
|
|
||||||
|
|
||||||
@unflatten.register(MultiDiscrete)
|
@unflatten.register(MultiDiscrete)
|
||||||
@@ -280,8 +286,13 @@ def _unflatten_multidiscrete(
|
|||||||
) -> NDArray[np.integer[Any]]:
|
) -> NDArray[np.integer[Any]]:
|
||||||
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
|
||||||
offsets[1:] = np.cumsum(space.nvec.flatten())
|
offsets[1:] = np.cumsum(space.nvec.flatten())
|
||||||
|
nonzero = np.nonzero(x)
|
||||||
(indices,) = cast(type(offsets[:-1]), np.nonzero(x))
|
if len(nonzero[0]) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"{x} is not a concatenation of one-hot encoded vectors and can not be unflattened to space {space}. "
|
||||||
|
"Not all valid samples in a flattened space can be unflattened."
|
||||||
|
)
|
||||||
|
(indices,) = cast(type(offsets[:-1]), nonzero)
|
||||||
return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape)
|
return np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -135,3 +135,15 @@ def test_flatten_roundtripping(space):
|
|||||||
|
|
||||||
for original, roundtripped in zip(samples, unflattened_samples):
|
for original, roundtripped in zip(samples, unflattened_samples):
|
||||||
assert data_equivalence(original, roundtripped)
|
assert data_equivalence(original, roundtripped)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unflatten_discrete_error():
|
||||||
|
value = np.array([0])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
utils.unflatten(gym.spaces.Discrete(1), value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unflatten_multidiscrete_error():
|
||||||
|
value = np.array([0, 0])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value)
|
||||||
|
Reference in New Issue
Block a user