mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Fix space utils for Discrete with non-zero start (#2645)
* Fix flatten utils to handle Discrete.start * Fix vector space utils to handle Discrete.start * More granular dispatch in vector utils * Fix Box including the high end of the interval
This commit is contained in:
@@ -78,7 +78,7 @@ def _flatten_box_multibinary(space, x) -> np.ndarray:
|
||||
@flatten.register(Discrete)
|
||||
def _flatten_discrete(space, x) -> np.ndarray:
|
||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||
onehot[x] = 1
|
||||
onehot[x - space.start] = 1
|
||||
return onehot
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.nd
|
||||
|
||||
@unflatten.register(Discrete)
|
||||
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
||||
return int(np.nonzero(x)[0][0])
|
||||
return int(space.start + np.nonzero(x)[0][0])
|
||||
|
||||
|
||||
@unflatten.register(MultiDiscrete)
|
||||
|
@@ -43,37 +43,44 @@ def batch_space(space, n=1):
|
||||
|
||||
|
||||
@batch_space.register(Box)
|
||||
@batch_space.register(Discrete)
|
||||
@batch_space.register(MultiDiscrete)
|
||||
@batch_space.register(MultiBinary)
|
||||
def batch_space_base(space, n=1):
|
||||
if isinstance(space, Box):
|
||||
def _batch_space_box(space, n=1):
|
||||
repeats = tuple([n] + [1] * space.low.ndim)
|
||||
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
|
||||
return Box(low=low, high=high, dtype=space.dtype)
|
||||
|
||||
elif isinstance(space, Discrete):
|
||||
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
|
||||
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
@batch_space.register(Discrete)
|
||||
def _batch_space_discrete(space, n=1):
|
||||
if space.start == 0:
|
||||
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
|
||||
else:
|
||||
return Box(
|
||||
low=space.start,
|
||||
high=space.start + space.n - 1,
|
||||
shape=(n,),
|
||||
dtype=space.dtype,
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(MultiDiscrete)
|
||||
def _batch_space_multidiscrete(space, n=1):
|
||||
repeats = tuple([n] + [1] * space.nvec.ndim)
|
||||
high = np.tile(space.nvec, repeats) - 1
|
||||
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
|
||||
|
||||
elif isinstance(space, MultiBinary):
|
||||
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Space type `{type(space)}` is not supported.")
|
||||
@batch_space.register(MultiBinary)
|
||||
def _batch_space_multibinary(space, n=1):
|
||||
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
|
||||
|
||||
|
||||
@batch_space.register(Tuple)
|
||||
def batch_space_tuple(space, n=1):
|
||||
def _batch_space_tuple(space, n=1):
|
||||
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
|
||||
|
||||
|
||||
@batch_space.register(Dict)
|
||||
def batch_space_dict(space, n=1):
|
||||
def _batch_space_dict(space, n=1):
|
||||
return Dict(
|
||||
OrderedDict(
|
||||
[
|
||||
@@ -85,7 +92,7 @@ def batch_space_dict(space, n=1):
|
||||
|
||||
|
||||
@batch_space.register(Space)
|
||||
def batch_space_custom(space, n=1):
|
||||
def _batch_space_custom(space, n=1):
|
||||
return Tuple(tuple(space for _ in range(n)))
|
||||
|
||||
|
||||
@@ -130,14 +137,14 @@ def iterate(space, items):
|
||||
|
||||
|
||||
@iterate.register(Discrete)
|
||||
def iterate_discrete(space, items):
|
||||
def _iterate_discrete(space, items):
|
||||
raise TypeError("Unable to iterate over a space of type `Discrete`.")
|
||||
|
||||
|
||||
@iterate.register(Box)
|
||||
@iterate.register(MultiDiscrete)
|
||||
@iterate.register(MultiBinary)
|
||||
def iterate_base(space, items):
|
||||
def _iterate_base(space, items):
|
||||
try:
|
||||
return iter(items)
|
||||
except TypeError:
|
||||
@@ -145,7 +152,7 @@ def iterate_base(space, items):
|
||||
|
||||
|
||||
@iterate.register(Tuple)
|
||||
def iterate_tuple(space, items):
|
||||
def _iterate_tuple(space, items):
|
||||
# If this is a tuple of custom subspaces only, then simply iterate over items
|
||||
if all(
|
||||
isinstance(subspace, Space)
|
||||
@@ -160,7 +167,7 @@ def iterate_tuple(space, items):
|
||||
|
||||
|
||||
@iterate.register(Dict)
|
||||
def iterate_dict(space, items):
|
||||
def _iterate_dict(space, items):
|
||||
keys, values = zip(
|
||||
*[
|
||||
(key, iterate(subspace, items[key]))
|
||||
@@ -172,7 +179,7 @@ def iterate_dict(space, items):
|
||||
|
||||
|
||||
@iterate.register(Space)
|
||||
def iterate_custom(space, items):
|
||||
def _iterate_custom(space, items):
|
||||
raise CustomSpaceError(
|
||||
f"Unable to iterate over {items}, since {space} "
|
||||
"is a custom `gym.Space` instance (i.e. not one of "
|
||||
|
@@ -28,9 +28,11 @@ spaces = [
|
||||
),
|
||||
}
|
||||
),
|
||||
Discrete(3, start=2),
|
||||
Discrete(8, start=-5),
|
||||
]
|
||||
|
||||
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7]
|
||||
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
|
||||
@@ -123,6 +125,8 @@ expected_flattened_dtypes = [
|
||||
np.int64,
|
||||
np.int8,
|
||||
np.float64,
|
||||
np.int64,
|
||||
np.int64,
|
||||
]
|
||||
|
||||
|
||||
@@ -187,6 +191,8 @@ samples = [
|
||||
OrderedDict(
|
||||
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
|
||||
),
|
||||
3,
|
||||
-2,
|
||||
]
|
||||
|
||||
|
||||
@@ -200,6 +206,8 @@ expected_flattened_samples = [
|
||||
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
|
||||
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
|
||||
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
|
||||
np.array([0, 1, 0], dtype=np.int64),
|
||||
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
|
||||
]
|
||||
|
||||
|
||||
@@ -243,6 +251,8 @@ expected_flattened_spaces = [
|
||||
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
|
||||
dtype=np.float64,
|
||||
),
|
||||
Box(low=0, high=1, shape=(3,), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(8,), dtype=np.int64),
|
||||
]
|
||||
|
||||
|
||||
|
@@ -26,6 +26,7 @@ expected_types = [
|
||||
Array("B", 1),
|
||||
Array("B", 32 * 32 * 3),
|
||||
Array("i", 1),
|
||||
Array("i", 1),
|
||||
(Array("i", 1), Array("i", 1)),
|
||||
(Array("i", 1), Array("f", 2)),
|
||||
Array("B", 3),
|
||||
|
@@ -33,6 +33,7 @@ expected_batch_spaces_4 = [
|
||||
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
|
||||
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
|
||||
MultiDiscrete([2, 2, 2, 2]),
|
||||
Box(low=-2, high=2, shape=(4,), dtype=np.int64),
|
||||
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
|
||||
Tuple(
|
||||
(
|
||||
|
@@ -18,6 +18,7 @@ spaces = [
|
||||
Box(low=0, high=255, shape=(), dtype=np.uint8),
|
||||
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
||||
Discrete(2),
|
||||
Discrete(5, start=-2),
|
||||
Tuple((Discrete(3), Discrete(5))),
|
||||
Tuple(
|
||||
(
|
||||
|
Reference in New Issue
Block a user