Update vector space utility functions for all spaces (#223)

This commit is contained in:
Mark Towers
2023-02-22 15:05:58 +00:00
committed by GitHub
parent 761bb2e033
commit a7d9146b1d
8 changed files with 293 additions and 408 deletions

View File

@@ -0,0 +1,12 @@
---
title: Vector Utility functions
---
# Utility functions for vectorisation
```{eval-rst}
.. autofunction:: gymnasium.experimental.vector.utils.batch_space
.. autofunction:: gymnasium.experimental.vector.utils.concatenate
.. autofunction:: gymnasium.experimental.vector.utils.iterate
.. autofunction:: gymnasium.experimental.vector.utils.create_empty_array
```

View File

@@ -1,30 +1,42 @@
"""Utility functions for gymnasium spaces: `batch_space` and `iterator`."""
"""Space-based utility functions for vector environments.
- ``batch_space``: Create a (batched) space, containing multiple copies of a single space.
- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
- ``Iterate``: Iterate over the elements of a (batched) space and items.
- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
"""
from __future__ import annotations
from collections import OrderedDict
from copy import deepcopy
from functools import singledispatch
from typing import Callable, Iterable, Iterator
from typing import Any, Iterable, Iterator
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.logger import warn
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
)
from gymnasium.spaces.space import T_cov
__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"]
@singledispatch
def batch_space(space: Space, n: int = 1) -> Space:
def batch_space(space: Space[Any], n: int = 1) -> Space[Any]:
"""Create a (batched) space, containing multiple copies of a single space.
Args:
@@ -35,11 +47,11 @@ def batch_space(space: Space, n: int = 1) -> Space:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Raises:
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
ValueError: Cannot batch space does not have a registered function.
Example::
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
@@ -47,20 +59,20 @@ def batch_space(space: Space, n: int = 1) -> Space:
>>> batch_space(space, n=5)
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
"""
raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gymnasium.Space` instance."
raise TypeError(
f"The space provided to `batch_space` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@batch_space.register(Box)
def _batch_space_box(space, n=1):
def _batch_space_box(space: Box, n: int = 1):
repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random))
@batch_space.register(Discrete)
def _batch_space_discrete(space, n=1):
def _batch_space_discrete(space: Discrete, n=1):
if space.start == 0:
return MultiDiscrete(
np.full((n,), space.n, dtype=space.dtype),
@@ -78,7 +90,7 @@ def _batch_space_discrete(space, n=1):
@batch_space.register(MultiDiscrete)
def _batch_space_multidiscrete(space, n=1):
def _batch_space_multidiscrete(space: MultiDiscrete, n=1):
repeats = tuple([n] + [1] * space.nvec.ndim)
high = np.tile(space.nvec, repeats) - 1
return Box(
@@ -90,7 +102,7 @@ def _batch_space_multidiscrete(space, n=1):
@batch_space.register(MultiBinary)
def _batch_space_multibinary(space, n=1):
def _batch_space_multibinary(space: MultiBinary, n=1):
return Box(
low=0,
high=1,
@@ -101,7 +113,7 @@ def _batch_space_multibinary(space, n=1):
@batch_space.register(Tuple)
def _batch_space_tuple(space, n=1):
def _batch_space_tuple(space: Tuple, n=1):
return Tuple(
tuple(batch_space(subspace, n=n) for subspace in space.spaces),
seed=deepcopy(space.np_random),
@@ -109,32 +121,31 @@ def _batch_space_tuple(space, n=1):
@batch_space.register(Dict)
def _batch_space_dict(space, n=1):
def _batch_space_dict(space: Dict, n: int = 1):
return Dict(
OrderedDict(
[
(key, batch_space(subspace, n=n))
for (key, subspace) in space.spaces.items()
]
),
{key: batch_space(subspace, n=n) for key, subspace in space.items()},
seed=deepcopy(space.np_random),
)
@batch_space.register(Graph)
@batch_space.register(Text)
@batch_space.register(Sequence)
@batch_space.register(Space)
def _batch_space_custom(space, n=1):
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random)
)
new_seeds = list(map(int, batched_space.np_random.integers(0, 1e8, n)))
space_rng = deepcopy(space.np_random)
new_seeds = list(map(int, space_rng.integers(0, 1e8, n)))
batched_space.seed(new_seeds)
return batched_space
@singledispatch
def iterate(space: Space, items) -> Iterator:
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
"""Iterate over the elements of a (batched) space.
Args:
@@ -164,9 +175,14 @@ def iterate(space: Space, items) -> Iterator:
...
StopIteration
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `iterate` function. Register `{type(space)}` for `iterate` to support it."
)
else:
raise TypeError(
f"The space provided to `iterate` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@iterate.register(Discrete)
@@ -177,7 +193,7 @@ def _iterate_discrete(space, items):
@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def _iterate_base(space, items):
def _iterate_base(space: Box | MultiDiscrete | MultiBinary, items: np.ndarray):
try:
return iter(items)
except TypeError as e:
@@ -187,22 +203,26 @@ def _iterate_base(space, items):
@iterate.register(Tuple)
def _iterate_tuple(space, items):
def _iterate_tuple(space: Tuple, items: tuple[Any, ...]):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
and (not isinstance(subspace, (Box, Discrete, MultiDiscrete, Tuple, Dict)))
for subspace in space.spaces
):
return iter(items)
if all(type(subspace) in iterate.registry for subspace in space):
return zip(*[iterate(subspace, items[i]) for i, subspace in enumerate(space)])
return zip(
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
)
try:
return iter(items)
except Exception as e:
unregistered_spaces = [
type(subspace)
for subspace in space
if type(subspace) not in iterate.registry
]
raise CustomSpaceError(
f"Could not iterate through {space} as no custom iterate function is registered for {unregistered_spaces} and `iter(items)` raised the following error: {e}."
) from e
@iterate.register(Dict)
def _iterate_dict(space, items):
def _iterate_dict(space: Dict, items: dict[str, Any]):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
@@ -210,22 +230,13 @@ def _iterate_dict(space, items):
]
)
for item in zip(*values):
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
@iterate.register(Space)
def _iterate_custom(space, items):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custom `gymnasium.Space` instance (i.e. not one of "
"`Box`, `Dict`, etc...)."
)
yield OrderedDict({key: value for key, value in zip(keys, item)})
@singledispatch
def concatenate(
space: Space, items: Iterable, out: tuple | dict | np.ndarray
) -> tuple | dict | np.ndarray:
space: Space, items: Iterable, out: tuple[Any, ...] | dict[str, Any] | np.ndarray
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Concatenate multiple samples from space into a single object.
Args:
@@ -237,7 +248,7 @@ def concatenate(
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
ValueError: Space
Example:
>>> from gymnasium.spaces import Box
@@ -249,8 +260,8 @@ def concatenate(
array([[0.77395606, 0.43887845, 0.85859793],
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
raise TypeError(
f"The space provided to `concatenate` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@@ -258,12 +269,18 @@ def concatenate(
@concatenate.register(Discrete)
@concatenate.register(MultiDiscrete)
@concatenate.register(MultiBinary)
def _concatenate_base(space, items, out):
def _concatenate_base(
space: Box | Discrete | MultiDiscrete | MultiBinary,
items: Iterable,
out: np.ndarray,
) -> np.ndarray:
return np.stack(items, axis=0, out=out)
@concatenate.register(Tuple)
def _concatenate_tuple(space, items, out):
def _concatenate_tuple(
space: Tuple, items: Iterable, out: tuple[Any, ...]
) -> tuple[Any, ...]:
return tuple(
concatenate(subspace, [item[i] for item in items], out[i])
for (i, subspace) in enumerate(space.spaces)
@@ -271,25 +288,36 @@ def _concatenate_tuple(space, items, out):
@concatenate.register(Dict)
def _concatenate_dict(space, items, out):
def _concatenate_dict(
space: Dict, items: Iterable, out: dict[str, Any]
) -> dict[str, Any]:
return OrderedDict(
[
(key, concatenate(subspace, [item[key] for item in items], out[key]))
for (key, subspace) in space.spaces.items()
]
{
key: concatenate(subspace, [item[key] for item in items], out[key])
for key, subspace in space.items()
}
)
@concatenate.register(Graph)
@concatenate.register(Text)
@concatenate.register(Sequence)
@concatenate.register(Space)
def _concatenate_custom(space, items, out):
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
if out is not None:
warn(
f"For `vector.utils.concatenate({type(space)}, ...)`, `out` is not None ({out}) however the value is ignored."
)
return tuple(items)
@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros
) -> tuple | dict | np.ndarray:
"""Create an empty (possibly nested) numpy array.
space: Space, n: int = 1, fn: callable = np.zeros
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
In most cases, the array will be contained within the batched space, however, this is not guaranteed.
Args:
space: Observation space of a single environment in the vectorized environment.
@@ -313,8 +341,8 @@ def create_empty_array(
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
raise TypeError(
f"The space provided to `create_empty_array` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@@ -322,26 +350,66 @@ def create_empty_array(
@create_empty_array.register(Discrete)
@create_empty_array.register(MultiDiscrete)
@create_empty_array.register(MultiBinary)
def _create_empty_array_base(space, n=1, fn=np.zeros):
shape = space.shape if (n is None) else (n,) + space.shape
return fn(shape, dtype=space.dtype)
def _create_empty_array_multi(space: Box, n: int = 1, fn=np.zeros) -> np.ndarray:
return fn((n,) + space.shape, dtype=space.dtype)
@create_empty_array.register(Tuple)
def _create_empty_array_tuple(space, n=1, fn=np.zeros):
def _create_empty_array_tuple(space: Tuple, n: int = 1, fn=np.zeros) -> tuple[Any, ...]:
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
@create_empty_array.register(Dict)
def _create_empty_array_dict(space, n=1, fn=np.zeros):
def _create_empty_array_dict(space: Dict, n: int = 1, fn=np.zeros) -> dict[str, Any]:
return OrderedDict(
[
(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()
]
{
key: create_empty_array(subspace, n=n, fn=fn)
for key, subspace in space.items()
}
)
@create_empty_array.register(Graph)
def _create_empty_array_graph(
space: Graph, n: int = 1, fn=np.zeros
) -> tuple[GraphInstance, ...]:
if space.edge_space is not None:
return tuple(
GraphInstance(
nodes=fn((1,) + space.node_space.shape, dtype=space.node_space.dtype),
edges=fn((1,) + space.edge_space.shape, dtype=space.edge_space.dtype),
edge_links=fn((1, 2), dtype=np.int64),
)
for _ in range(n)
)
else:
return tuple(
GraphInstance(
nodes=fn((1,) + space.node_space.shape, dtype=space.node_space.dtype),
edges=None,
edge_links=None,
)
for _ in range(n)
)
@create_empty_array.register(Text)
def _create_empty_array_text(space: Text, n: int = 1, fn=np.zeros) -> tuple[str, ...]:
return tuple(space.characters[0] * space.min_length for _ in range(n))
@create_empty_array.register(Sequence)
def _create_empty_array_sequence(
space: Sequence, n: int = 1, fn=np.zeros
) -> tuple[Any, ...]:
if space.stack:
return tuple(
create_empty_array(space.feature_space, n=1, fn=fn) for _ in range(n)
)
else:
return tuple(tuple() for _ in range(n))
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None

View File

@@ -46,7 +46,7 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
self.feature_space = space
self.stack = stack
if self.stack:
self.batched_feature_space: Space = gym.vector.utils.batch_space(
self.stacked_feature_space: Space = gym.vector.utils.batch_space(
self.feature_space, 1
)
@@ -141,7 +141,7 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
if self.stack:
return all(
item in self.feature_space
for item in gym.vector.utils.iterate(self.batched_feature_space, x)
for item in gym.vector.utils.iterate(self.stacked_feature_space, x)
)
else:
return isinstance(x, tuple) and all(
@@ -157,14 +157,14 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
if self.stack:
return self.batched_feature_space.to_jsonable(sample_n)
return self.stacked_feature_space.to_jsonable(sample_n)
else:
return [self.feature_space.to_jsonable(sample) for sample in sample_n]
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...] | Any]:
"""Convert a JSONable data type to a batch of samples from this space."""
if self.stack:
return self.batched_feature_space.from_jsonable(sample_n)
return self.stacked_feature_space.from_jsonable(sample_n)
else:
return [
tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n

View File

@@ -243,7 +243,7 @@ def _flatten_sequence(
space: Sequence, x: tuple[Any, ...] | Any
) -> tuple[Any, ...] | Any:
if space.stack:
samples_iters = gym.vector.utils.iterate(space.batched_feature_space, x)
samples_iters = gym.vector.utils.iterate(space.stacked_feature_space, x)
flattened_samples = [
flatten(space.feature_space, sample) for sample in samples_iters
]

View File

@@ -1,341 +1,85 @@
"""Testing `gymnasium.experimental.vector.utils.space_utils` functions."""
import copy
from collections import OrderedDict
import re
from typing import Iterable
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from gymnasium import Space
from gymnasium.error import CustomSpaceError
from gymnasium.experimental.vector.utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
from gymnasium.spaces import Box, Dict, MultiDiscrete, Space, Tuple
from tests.experimental.vector.testing_utils import (
BaseGymSpaces,
CustomSpace,
assert_rng_equal,
custom_spaces,
spaces,
)
from gymnasium.spaces import Tuple
from gymnasium.utils.env_checker import data_equivalence
from tests.experimental.vector.utils.utils import is_rng_equal
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_concatenate(space):
"""Tests the `concatenate` functions with list of spaces."""
def assert_type(lhs, rhs, n):
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
if np.isscalar(rhs[0]):
assert isinstance(lhs, np.ndarray)
assert all([np.isscalar(rhs[i]) for i in range(n)])
else:
assert all([isinstance(rhs[i], type(lhs)) for i in range(n)])
def assert_nested_equal(lhs, rhs, n):
assert isinstance(rhs, list)
assert (n > 0) and (len(rhs) == n)
assert_type(lhs, rhs, n)
if isinstance(lhs, np.ndarray):
assert lhs.shape[0] == n
for i in range(n):
assert np.all(lhs[i] == rhs[i])
elif isinstance(lhs, tuple):
for i in range(len(lhs)):
rhs_T_i = [rhs[j][i] for j in range(n)]
assert_nested_equal(lhs[i], rhs_T_i, n)
elif isinstance(lhs, OrderedDict):
for key in lhs.keys():
rhs_T_key = [rhs[j][key] for j in range(n)]
assert_nested_equal(lhs[key], rhs_T_key, n)
else:
raise TypeError(f"Got unknown type `{type(lhs)}`.")
samples = [space.sample() for _ in range(8)]
array = create_empty_array(space, n=8)
concatenated = concatenate(space, samples, array)
assert np.all(concatenated == array)
assert_nested_equal(array, samples, n=8)
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array(space, n):
"""Test `create_empty_array` function with list of spaces and different `n` values."""
def assert_nested_type(arr, space, n):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == (n,) + space.shape
elif isinstance(space, Tuple):
assert isinstance(arr, tuple)
assert len(arr) == len(space.spaces)
for i in range(len(arr)):
assert_nested_type(arr[i], space.spaces[i], n)
elif isinstance(space, Dict):
assert isinstance(arr, OrderedDict)
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
for key in arr.keys():
assert_nested_type(arr[key], space.spaces[key], n)
else:
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=n, fn=np.empty)
assert_nested_type(array, space, n=n)
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_zeros(space, n):
"""Test `create_empty_array` with a list of spaces and different `n`."""
def assert_nested_type(arr, space, n):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == (n,) + space.shape
assert np.all(arr == 0)
elif isinstance(space, Tuple):
assert isinstance(arr, tuple)
assert len(arr) == len(space.spaces)
for i in range(len(arr)):
assert_nested_type(arr[i], space.spaces[i], n)
elif isinstance(space, Dict):
assert isinstance(arr, OrderedDict)
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
for key in arr.keys():
assert_nested_type(arr[key], space.spaces[key], n)
else:
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=n, fn=np.zeros)
assert_nested_type(array, space, n=n)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_none_shape_ones(space):
"""Tests `create_empty_array` with ``None`` space."""
def assert_nested_type(arr, space):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == space.shape
assert np.all(arr == 1)
elif isinstance(space, Tuple):
assert isinstance(arr, tuple)
assert len(arr) == len(space.spaces)
for i in range(len(arr)):
assert_nested_type(arr[i], space.spaces[i])
elif isinstance(space, Dict):
assert isinstance(arr, OrderedDict)
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
for key in arr.keys():
assert_nested_type(arr[key], space.spaces[key])
else:
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=None, fn=np.ones)
assert_nested_type(array, space)
expected_batch_spaces_4 = [
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float64),
Box(
low=np.array(
[[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]
),
high=np.array(
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
),
dtype=np.float64,
),
Box(
low=np.array(
[
[[-1.0, 0.0], [0.0, -1.0]],
[[-1.0, 0.0], [0.0, -1.0]],
[[-1.0, 0.0], [0.0, -1]],
[[-1.0, 0.0], [0.0, -1.0]],
]
),
high=np.ones((4, 2, 2)),
dtype=np.float64,
),
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
MultiDiscrete([2, 2, 2, 2]),
Box(low=-2, high=2, shape=(4,), dtype=np.int64),
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
Tuple(
(
MultiDiscrete([7, 7, 7, 7]),
Box(
low=np.array([[0.0, -1.0], [0.0, -1.0], [0.0, -1.0], [0.0, -1]]),
high=np.array([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
dtype=np.float64,
),
)
),
Box(
low=np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]),
high=np.array([[10, 12, 16], [10, 12, 16], [10, 12, 16], [10, 12, 16]]),
dtype=np.int64,
),
Box(low=0, high=1, shape=(4, 19), dtype=np.int8),
Dict(
{
"position": MultiDiscrete([23, 23, 23, 23]),
"velocity": Box(low=0.0, high=1.0, shape=(4, 1), dtype=np.float64),
}
),
Dict(
{
"position": Dict(
{
"x": MultiDiscrete([29, 29, 29, 29]),
"y": MultiDiscrete([31, 31, 31, 31]),
}
),
"velocity": Tuple(
(
MultiDiscrete([37, 37, 37, 37]),
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
)
),
}
),
]
expected_custom_batch_spaces_4 = [
Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())),
Tuple(
(
Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())),
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
)
),
]
@pytest.mark.parametrize(
"space,expected_batch_space_4",
list(zip(spaces, expected_batch_spaces_4)),
ids=[space.__class__.__name__ for space in spaces],
)
def test_batch_space(space, expected_batch_space_4):
"""Tests `batch_space` with the expected spaces."""
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4
@pytest.mark.parametrize(
"space,expected_batch_space_4",
list(zip(custom_spaces, expected_custom_batch_spaces_4)),
ids=[space.__class__.__name__ for space in custom_spaces],
)
def test_batch_space_custom_space(space, expected_batch_space_4):
"""Tests `batch_space` for custom spaces with the expected batch spaces."""
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4
@pytest.mark.parametrize(
"space,batched_space",
list(zip(spaces, expected_batch_spaces_4)),
ids=[space.__class__.__name__ for space in spaces],
)
def test_iterate(space, batched_space):
"""Test `iterate` function with list of spaces and expected batch space."""
items = batched_space.sample()
iterator = iterate(batched_space, items)
i = 0
for i, item in enumerate(iterator):
assert item in space
assert i == 3
@pytest.mark.parametrize(
"space,batched_space",
list(zip(custom_spaces, expected_custom_batch_spaces_4)),
ids=[space.__class__.__name__ for space in custom_spaces],
)
def test_iterate_custom_space(space, batched_space):
"""Test iterating over a custom space."""
items = batched_space.sample()
iterator = iterate(batched_space, items)
i = 0
for i, item in enumerate(iterator):
assert item in space
assert i == 3
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_rng_different_at_each_index(space: Space, n: int, base_seed: int):
"""Tests that the rng values produced at each index are different to prevent if the rng is copied for each subspace."""
space.seed(base_seed)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [1, 4], ids=[f"n={n}" for n in [1, 4]])
def test_batch_space_concatenate_iterate_create_empty_array(space: Space, n: int):
"""Test all space_utils functions using them together."""
# Batch the space and create a sample
batched_space = batch_space(space, n)
assert space.np_random is not batched_space.np_random
assert_rng_equal(space.np_random, batched_space.np_random)
assert isinstance(batched_space, Space)
batched_sample = batched_space.sample()
sample = list(iterate(batched_space, batched_sample))
assert not all(np.all(element == sample[0]) for element in sample), sample
assert batched_sample in batched_space
# Check the batched samples are within the original space
iterated_samples = iterate(batched_space, batched_sample)
assert isinstance(iterated_samples, Iterable)
unbatched_samples = list(iterated_samples)
assert len(unbatched_samples) == n
assert all(item in space for item in unbatched_samples)
# Create an empty array and check that space is within the batch space
array = create_empty_array(space, n)
# We do not check that the generated array is within the batched_space.
# assert array in batched_space
unbatched_array = list(iterate(batched_space, array))
assert len(unbatched_array) == n
# assert all(item in space for item in unbatched_array)
# Generate samples from the original space and concatenate using array into a single object
space_samples = [space.sample() for _ in range(n)]
assert all(item in space for item in space_samples)
concatenated_samples_array = concatenate(space, space_samples, array)
# `concatenate` does not necessarily use the out object as the returned object
# assert out is concatenated_samples_array
assert concatenated_samples_array in batched_space
# Iterate over the samples and check that the concatenated samples == original samples
iterated_samples = iterate(batched_space, concatenated_samples_array)
assert isinstance(iterated_samples, Iterable)
unbatched_samples = list(iterated_samples)
assert len(unbatched_samples) == n
for unbatched_sample, original_sample in zip(unbatched_samples, space_samples):
assert data_equivalence(unbatched_sample, original_sample)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_deterministic(space: Space, n: int, base_seed: int):
def test_batch_space_deterministic(space: Space, n: int, base_seed: int):
"""Tests the batched spaces are deterministic by using a copied version."""
# Copy the spaces and check that the np_random are not reference equal
space_a = space
space_a.seed(base_seed)
space_b = copy.deepcopy(space_a)
assert_rng_equal(space_a.np_random, space_b.np_random)
is_rng_equal(space_a.np_random, space_b.np_random)
assert space_a.np_random is not space_b.np_random
# Batch the spaces and check that the np_random are not reference equal
space_a_batched = batch_space(space_a, n)
space_b_batched = batch_space(space_b, n)
assert_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
is_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
assert space_a_batched.np_random is not space_b_batched.np_random
# Create that the batched space is not reference equal to the origin spaces
assert space_a.np_random is not space_a_batched.np_random
@@ -348,9 +92,65 @@ def test_deterministic(space: Space, n: int, base_seed: int):
iterate(space_a_batched, space_a_batched_sample),
iterate(space_b_batched, space_b_batched_sample),
):
if isinstance(a_sample, tuple):
assert len(a_sample) == len(b_sample)
for a_subsample, b_subsample in zip(a_sample, b_sample):
assert_array_equal(a_subsample, b_subsample)
else:
assert_array_equal(a_sample, b_sample)
assert data_equivalence(a_sample, b_sample)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_batch_space_different_samples(space: Space, n: int, base_seed: int):
"""Tests that the rng values produced at each index are different to prevent if the rng is copied for each subspace."""
space.seed(base_seed)
batched_space = batch_space(space, n)
assert space.np_random is not batched_space.np_random
is_rng_equal(space.np_random, batched_space.np_random)
batched_sample = batched_space.sample()
unbatched_samples = list(iterate(batched_space, batched_sample))
assert len(unbatched_samples) == n
assert all(item in space for item in unbatched_samples)
assert not all(
data_equivalence(element, unbatched_samples[0]) for element in unbatched_samples
), unbatched_samples
@pytest.mark.parametrize(
"func, n_args",
[(batch_space, 1), (concatenate, 2), (iterate, 1), (create_empty_array, 2)],
)
def test_non_space(func, n_args):
"""Test spaces for vector utility functions on the error produced with unknown spaces."""
args = [None for _ in range(n_args)]
func_name = func.__name__
with pytest.raises(
TypeError,
match=re.escape(
f"The space provided to `{func_name}` is not a gymnasium Space instance, type: <class 'str'>, space"
),
):
func("space", *args)
def test_custom_space():
"""Test custom spaces with space util functions."""
custom_space = CustomSpace()
batched_space = batch_space(custom_space, n=2)
assert batched_space == Tuple([custom_space, custom_space])
with pytest.raises(
CustomSpaceError,
match=re.escape(
"Space of type `<class 'tests.spaces.utils.CustomSpace'>` doesn't have an registered `iterate` function. Register `<class 'tests.spaces.utils.CustomSpace'>` for `iterate` to support it."
),
):
iterate(custom_space, None)
concatenated_items = concatenate(custom_space, (None, None), out=None)
assert concatenated_items == (None, None)
empty_array = create_empty_array(custom_space)
assert empty_array is None

View File

@@ -0,0 +1,7 @@
"""Utility functions for testing the vector utility functions."""
import numpy as np
def is_rng_equal(rng_1: np.random.Generator, rng_2: np.random.Generator):
"""Asserts that two random number generates are equivalent."""
return rng_1.bit_generator.state == rng_2.bit_generator.state

View File

@@ -2,22 +2,19 @@ from functools import partial
import pytest
from gymnasium import Space
from gymnasium.spaces import utils
TESTING_SPACE = Space()
from tests.spaces.utils import TESTING_CUSTOM_SPACE
@pytest.mark.parametrize(
"func",
[
TESTING_SPACE.sample,
partial(TESTING_SPACE.contains, None),
partial(utils.flatdim, TESTING_SPACE),
partial(utils.flatten, TESTING_SPACE, None),
partial(utils.flatten_space, TESTING_SPACE),
partial(utils.unflatten, TESTING_SPACE, None),
TESTING_CUSTOM_SPACE.sample,
partial(TESTING_CUSTOM_SPACE.contains, None),
partial(utils.flatdim, TESTING_CUSTOM_SPACE),
partial(utils.flatten, TESTING_CUSTOM_SPACE, None),
partial(utils.flatten_space, TESTING_CUSTOM_SPACE),
partial(utils.unflatten, TESTING_CUSTOM_SPACE, None),
],
)
def test_not_implemented_errors(func):

View File

@@ -112,9 +112,10 @@ TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES + TESTING_COMPOSITE_SPACES
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS + TESTING_COMPOSITE_SPACES_IDS
CUSTOM_SPACES = [
Space(),
Tuple([Space(), Space(), Space()]),
Dict(a=Space(), b=Space()),
]
CUSTOM_SPACES_IDS = [f"{space}" for space in CUSTOM_SPACES]
class CustomSpace(Space):
def __eq__(self, o: object) -> bool:
return isinstance(o, CustomSpace)
TESTING_CUSTOM_SPACE = CustomSpace()