Fix space utils for Discrete with non-zero start (#2645)

* Fix flatten utils to handle Discrete.start

* Fix vector space utils to handle Discrete.start

* More granular dispatch in vector utils

* Fix Box including the high end of the interval
This commit is contained in:
Tristan Deleu
2022-03-04 15:17:16 -05:00
committed by GitHub
parent 108f32c743
commit e671aa168c
6 changed files with 50 additions and 30 deletions

View File

@@ -78,7 +78,7 @@ def _flatten_box_multibinary(space, x) -> np.ndarray:
@flatten.register(Discrete) @flatten.register(Discrete)
def _flatten_discrete(space, x) -> np.ndarray: def _flatten_discrete(space, x) -> np.ndarray:
onehot = np.zeros(space.n, dtype=space.dtype) onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x] = 1 onehot[x - space.start] = 1
return onehot return onehot
@@ -124,7 +124,7 @@ def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.nd
@unflatten.register(Discrete) @unflatten.register(Discrete)
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int: def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
return int(np.nonzero(x)[0][0]) return int(space.start + np.nonzero(x)[0][0])
@unflatten.register(MultiDiscrete) @unflatten.register(MultiDiscrete)

View File

@@ -43,37 +43,44 @@ def batch_space(space, n=1):
@batch_space.register(Box) @batch_space.register(Box)
@batch_space.register(Discrete) def _batch_space_box(space, n=1):
@batch_space.register(MultiDiscrete)
@batch_space.register(MultiBinary)
def batch_space_base(space, n=1):
if isinstance(space, Box):
repeats = tuple([n] + [1] * space.low.ndim) repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats) low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype) return Box(low=low, high=high, dtype=space.dtype)
elif isinstance(space, Discrete):
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
elif isinstance(space, MultiDiscrete): @batch_space.register(Discrete)
def _batch_space_discrete(space, n=1):
if space.start == 0:
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
else:
return Box(
low=space.start,
high=space.start + space.n - 1,
shape=(n,),
dtype=space.dtype,
)
@batch_space.register(MultiDiscrete)
def _batch_space_multidiscrete(space, n=1):
repeats = tuple([n] + [1] * space.nvec.ndim) repeats = tuple([n] + [1] * space.nvec.ndim)
high = np.tile(space.nvec, repeats) - 1 high = np.tile(space.nvec, repeats) - 1
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype) return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
elif isinstance(space, MultiBinary):
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
else: @batch_space.register(MultiBinary)
raise ValueError(f"Space type `{type(space)}` is not supported.") def _batch_space_multibinary(space, n=1):
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
@batch_space.register(Tuple) @batch_space.register(Tuple)
def batch_space_tuple(space, n=1): def _batch_space_tuple(space, n=1):
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces)) return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
@batch_space.register(Dict) @batch_space.register(Dict)
def batch_space_dict(space, n=1): def _batch_space_dict(space, n=1):
return Dict( return Dict(
OrderedDict( OrderedDict(
[ [
@@ -85,7 +92,7 @@ def batch_space_dict(space, n=1):
@batch_space.register(Space) @batch_space.register(Space)
def batch_space_custom(space, n=1): def _batch_space_custom(space, n=1):
return Tuple(tuple(space for _ in range(n))) return Tuple(tuple(space for _ in range(n)))
@@ -130,14 +137,14 @@ def iterate(space, items):
@iterate.register(Discrete) @iterate.register(Discrete)
def iterate_discrete(space, items): def _iterate_discrete(space, items):
raise TypeError("Unable to iterate over a space of type `Discrete`.") raise TypeError("Unable to iterate over a space of type `Discrete`.")
@iterate.register(Box) @iterate.register(Box)
@iterate.register(MultiDiscrete) @iterate.register(MultiDiscrete)
@iterate.register(MultiBinary) @iterate.register(MultiBinary)
def iterate_base(space, items): def _iterate_base(space, items):
try: try:
return iter(items) return iter(items)
except TypeError: except TypeError:
@@ -145,7 +152,7 @@ def iterate_base(space, items):
@iterate.register(Tuple) @iterate.register(Tuple)
def iterate_tuple(space, items): def _iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items # If this is a tuple of custom subspaces only, then simply iterate over items
if all( if all(
isinstance(subspace, Space) isinstance(subspace, Space)
@@ -160,7 +167,7 @@ def iterate_tuple(space, items):
@iterate.register(Dict) @iterate.register(Dict)
def iterate_dict(space, items): def _iterate_dict(space, items):
keys, values = zip( keys, values = zip(
*[ *[
(key, iterate(subspace, items[key])) (key, iterate(subspace, items[key]))
@@ -172,7 +179,7 @@ def iterate_dict(space, items):
@iterate.register(Space) @iterate.register(Space)
def iterate_custom(space, items): def _iterate_custom(space, items):
raise CustomSpaceError( raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} " f"Unable to iterate over {items}, since {space} "
"is a custom `gym.Space` instance (i.e. not one of " "is a custom `gym.Space` instance (i.e. not one of "

View File

@@ -28,9 +28,11 @@ spaces = [
), ),
} }
), ),
Discrete(3, start=2),
Discrete(8, start=-5),
] ]
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7] flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims)) @pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
@@ -123,6 +125,8 @@ expected_flattened_dtypes = [
np.int64, np.int64,
np.int8, np.int8,
np.float64, np.float64,
np.int64,
np.int64,
] ]
@@ -187,6 +191,8 @@ samples = [
OrderedDict( OrderedDict(
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))] [("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
), ),
3,
-2,
] ]
@@ -200,6 +206,8 @@ expected_flattened_samples = [
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64), np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8), np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64), np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
np.array([0, 1, 0], dtype=np.int64),
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
] ]
@@ -243,6 +251,8 @@ expected_flattened_spaces = [
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64), high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64, dtype=np.float64,
), ),
Box(low=0, high=1, shape=(3,), dtype=np.int64),
Box(low=0, high=1, shape=(8,), dtype=np.int64),
] ]

View File

@@ -26,6 +26,7 @@ expected_types = [
Array("B", 1), Array("B", 1),
Array("B", 32 * 32 * 3), Array("B", 32 * 32 * 3),
Array("i", 1), Array("i", 1),
Array("i", 1),
(Array("i", 1), Array("i", 1)), (Array("i", 1), Array("i", 1)),
(Array("i", 1), Array("f", 2)), (Array("i", 1), Array("f", 2)),
Array("B", 3), Array("B", 3),

View File

@@ -33,6 +33,7 @@ expected_batch_spaces_4 = [
Box(low=0, high=255, shape=(4,), dtype=np.uint8), Box(low=0, high=255, shape=(4,), dtype=np.uint8),
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8), Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
MultiDiscrete([2, 2, 2, 2]), MultiDiscrete([2, 2, 2, 2]),
Box(low=-2, high=2, shape=(4,), dtype=np.int64),
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))), Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
Tuple( Tuple(
( (

View File

@@ -18,6 +18,7 @@ spaces = [
Box(low=0, high=255, shape=(), dtype=np.uint8), Box(low=0, high=255, shape=(), dtype=np.uint8),
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
Discrete(2), Discrete(2),
Discrete(5, start=-2),
Tuple((Discrete(3), Discrete(5))), Tuple((Discrete(3), Discrete(5))),
Tuple( Tuple(
( (