mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-31 02:06:08 +00:00
Fix space utils for Discrete with non-zero start (#2645)
* Fix flatten utils to handle Discrete.start * Fix vector space utils to handle Discrete.start * More granular dispatch in vector utils * Fix Box including the high end of the interval
This commit is contained in:
@@ -78,7 +78,7 @@ def _flatten_box_multibinary(space, x) -> np.ndarray:
|
|||||||
@flatten.register(Discrete)
|
@flatten.register(Discrete)
|
||||||
def _flatten_discrete(space, x) -> np.ndarray:
|
def _flatten_discrete(space, x) -> np.ndarray:
|
||||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||||
onehot[x] = 1
|
onehot[x - space.start] = 1
|
||||||
return onehot
|
return onehot
|
||||||
|
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.nd
|
|||||||
|
|
||||||
@unflatten.register(Discrete)
|
@unflatten.register(Discrete)
|
||||||
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
|
||||||
return int(np.nonzero(x)[0][0])
|
return int(space.start + np.nonzero(x)[0][0])
|
||||||
|
|
||||||
|
|
||||||
@unflatten.register(MultiDiscrete)
|
@unflatten.register(MultiDiscrete)
|
||||||
|
@@ -43,37 +43,44 @@ def batch_space(space, n=1):
|
|||||||
|
|
||||||
|
|
||||||
@batch_space.register(Box)
|
@batch_space.register(Box)
|
||||||
|
def _batch_space_box(space, n=1):
|
||||||
|
repeats = tuple([n] + [1] * space.low.ndim)
|
||||||
|
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
|
||||||
|
return Box(low=low, high=high, dtype=space.dtype)
|
||||||
|
|
||||||
|
|
||||||
@batch_space.register(Discrete)
|
@batch_space.register(Discrete)
|
||||||
@batch_space.register(MultiDiscrete)
|
def _batch_space_discrete(space, n=1):
|
||||||
@batch_space.register(MultiBinary)
|
if space.start == 0:
|
||||||
def batch_space_base(space, n=1):
|
|
||||||
if isinstance(space, Box):
|
|
||||||
repeats = tuple([n] + [1] * space.low.ndim)
|
|
||||||
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
|
|
||||||
return Box(low=low, high=high, dtype=space.dtype)
|
|
||||||
|
|
||||||
elif isinstance(space, Discrete):
|
|
||||||
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
|
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
|
||||||
|
|
||||||
elif isinstance(space, MultiDiscrete):
|
|
||||||
repeats = tuple([n] + [1] * space.nvec.ndim)
|
|
||||||
high = np.tile(space.nvec, repeats) - 1
|
|
||||||
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
|
|
||||||
|
|
||||||
elif isinstance(space, MultiBinary):
|
|
||||||
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Space type `{type(space)}` is not supported.")
|
return Box(
|
||||||
|
low=space.start,
|
||||||
|
high=space.start + space.n - 1,
|
||||||
|
shape=(n,),
|
||||||
|
dtype=space.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_space.register(MultiDiscrete)
|
||||||
|
def _batch_space_multidiscrete(space, n=1):
|
||||||
|
repeats = tuple([n] + [1] * space.nvec.ndim)
|
||||||
|
high = np.tile(space.nvec, repeats) - 1
|
||||||
|
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@batch_space.register(MultiBinary)
|
||||||
|
def _batch_space_multibinary(space, n=1):
|
||||||
|
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
|
||||||
|
|
||||||
|
|
||||||
@batch_space.register(Tuple)
|
@batch_space.register(Tuple)
|
||||||
def batch_space_tuple(space, n=1):
|
def _batch_space_tuple(space, n=1):
|
||||||
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
|
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
|
||||||
|
|
||||||
|
|
||||||
@batch_space.register(Dict)
|
@batch_space.register(Dict)
|
||||||
def batch_space_dict(space, n=1):
|
def _batch_space_dict(space, n=1):
|
||||||
return Dict(
|
return Dict(
|
||||||
OrderedDict(
|
OrderedDict(
|
||||||
[
|
[
|
||||||
@@ -85,7 +92,7 @@ def batch_space_dict(space, n=1):
|
|||||||
|
|
||||||
|
|
||||||
@batch_space.register(Space)
|
@batch_space.register(Space)
|
||||||
def batch_space_custom(space, n=1):
|
def _batch_space_custom(space, n=1):
|
||||||
return Tuple(tuple(space for _ in range(n)))
|
return Tuple(tuple(space for _ in range(n)))
|
||||||
|
|
||||||
|
|
||||||
@@ -130,14 +137,14 @@ def iterate(space, items):
|
|||||||
|
|
||||||
|
|
||||||
@iterate.register(Discrete)
|
@iterate.register(Discrete)
|
||||||
def iterate_discrete(space, items):
|
def _iterate_discrete(space, items):
|
||||||
raise TypeError("Unable to iterate over a space of type `Discrete`.")
|
raise TypeError("Unable to iterate over a space of type `Discrete`.")
|
||||||
|
|
||||||
|
|
||||||
@iterate.register(Box)
|
@iterate.register(Box)
|
||||||
@iterate.register(MultiDiscrete)
|
@iterate.register(MultiDiscrete)
|
||||||
@iterate.register(MultiBinary)
|
@iterate.register(MultiBinary)
|
||||||
def iterate_base(space, items):
|
def _iterate_base(space, items):
|
||||||
try:
|
try:
|
||||||
return iter(items)
|
return iter(items)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@@ -145,7 +152,7 @@ def iterate_base(space, items):
|
|||||||
|
|
||||||
|
|
||||||
@iterate.register(Tuple)
|
@iterate.register(Tuple)
|
||||||
def iterate_tuple(space, items):
|
def _iterate_tuple(space, items):
|
||||||
# If this is a tuple of custom subspaces only, then simply iterate over items
|
# If this is a tuple of custom subspaces only, then simply iterate over items
|
||||||
if all(
|
if all(
|
||||||
isinstance(subspace, Space)
|
isinstance(subspace, Space)
|
||||||
@@ -160,7 +167,7 @@ def iterate_tuple(space, items):
|
|||||||
|
|
||||||
|
|
||||||
@iterate.register(Dict)
|
@iterate.register(Dict)
|
||||||
def iterate_dict(space, items):
|
def _iterate_dict(space, items):
|
||||||
keys, values = zip(
|
keys, values = zip(
|
||||||
*[
|
*[
|
||||||
(key, iterate(subspace, items[key]))
|
(key, iterate(subspace, items[key]))
|
||||||
@@ -172,7 +179,7 @@ def iterate_dict(space, items):
|
|||||||
|
|
||||||
|
|
||||||
@iterate.register(Space)
|
@iterate.register(Space)
|
||||||
def iterate_custom(space, items):
|
def _iterate_custom(space, items):
|
||||||
raise CustomSpaceError(
|
raise CustomSpaceError(
|
||||||
f"Unable to iterate over {items}, since {space} "
|
f"Unable to iterate over {items}, since {space} "
|
||||||
"is a custom `gym.Space` instance (i.e. not one of "
|
"is a custom `gym.Space` instance (i.e. not one of "
|
||||||
|
@@ -28,9 +28,11 @@ spaces = [
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
Discrete(3, start=2),
|
||||||
|
Discrete(8, start=-5),
|
||||||
]
|
]
|
||||||
|
|
||||||
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7]
|
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
|
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
|
||||||
@@ -123,6 +125,8 @@ expected_flattened_dtypes = [
|
|||||||
np.int64,
|
np.int64,
|
||||||
np.int8,
|
np.int8,
|
||||||
np.float64,
|
np.float64,
|
||||||
|
np.int64,
|
||||||
|
np.int64,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -187,6 +191,8 @@ samples = [
|
|||||||
OrderedDict(
|
OrderedDict(
|
||||||
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
|
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
|
||||||
),
|
),
|
||||||
|
3,
|
||||||
|
-2,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -200,6 +206,8 @@ expected_flattened_samples = [
|
|||||||
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
|
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
|
||||||
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
|
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
|
||||||
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
|
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
|
||||||
|
np.array([0, 1, 0], dtype=np.int64),
|
||||||
|
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -243,6 +251,8 @@ expected_flattened_spaces = [
|
|||||||
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
|
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
|
||||||
dtype=np.float64,
|
dtype=np.float64,
|
||||||
),
|
),
|
||||||
|
Box(low=0, high=1, shape=(3,), dtype=np.int64),
|
||||||
|
Box(low=0, high=1, shape=(8,), dtype=np.int64),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@@ -26,6 +26,7 @@ expected_types = [
|
|||||||
Array("B", 1),
|
Array("B", 1),
|
||||||
Array("B", 32 * 32 * 3),
|
Array("B", 32 * 32 * 3),
|
||||||
Array("i", 1),
|
Array("i", 1),
|
||||||
|
Array("i", 1),
|
||||||
(Array("i", 1), Array("i", 1)),
|
(Array("i", 1), Array("i", 1)),
|
||||||
(Array("i", 1), Array("f", 2)),
|
(Array("i", 1), Array("f", 2)),
|
||||||
Array("B", 3),
|
Array("B", 3),
|
||||||
|
@@ -33,6 +33,7 @@ expected_batch_spaces_4 = [
|
|||||||
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
|
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
|
||||||
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
|
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
|
||||||
MultiDiscrete([2, 2, 2, 2]),
|
MultiDiscrete([2, 2, 2, 2]),
|
||||||
|
Box(low=-2, high=2, shape=(4,), dtype=np.int64),
|
||||||
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
|
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
|
||||||
Tuple(
|
Tuple(
|
||||||
(
|
(
|
||||||
|
@@ -18,6 +18,7 @@ spaces = [
|
|||||||
Box(low=0, high=255, shape=(), dtype=np.uint8),
|
Box(low=0, high=255, shape=(), dtype=np.uint8),
|
||||||
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
||||||
Discrete(2),
|
Discrete(2),
|
||||||
|
Discrete(5, start=-2),
|
||||||
Tuple((Discrete(3), Discrete(5))),
|
Tuple((Discrete(3), Discrete(5))),
|
||||||
Tuple(
|
Tuple(
|
||||||
(
|
(
|
||||||
|
Reference in New Issue
Block a user