mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 22:11:25 +00:00
Add a OneOf space for exclusive unions (#812)
Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
fd4ae52045
commit
2b2e853536
@@ -66,7 +66,8 @@ Often environment spaces require joining fundamental spaces together for vectori
|
||||
* :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
|
||||
* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
|
||||
* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
|
||||
* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values.
|
||||
* :class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values
|
||||
* :class:`OneOf` - Supports optional action spaces such that an action can be one of N possible subspaces
|
||||
```
|
||||
|
||||
## Utility functions
|
||||
|
@@ -21,4 +21,9 @@
|
||||
|
||||
.. automethod:: gymnasium.spaces.Graph.sample
|
||||
.. automethod:: gymnasium.spaces.Graph.seed
|
||||
|
||||
.. autoclass:: gymnasium.spaces.OneOf
|
||||
|
||||
.. automethod:: gymnasium.spaces.OneOf.sample
|
||||
.. automethod:: gymnasium.spaces.OneOf.seed
|
||||
```
|
||||
|
@@ -16,6 +16,7 @@ from gymnasium.spaces.discrete import Discrete
|
||||
from gymnasium.spaces.graph import Graph, GraphInstance
|
||||
from gymnasium.spaces.multi_binary import MultiBinary
|
||||
from gymnasium.spaces.multi_discrete import MultiDiscrete
|
||||
from gymnasium.spaces.oneof import OneOf
|
||||
from gymnasium.spaces.sequence import Sequence
|
||||
from gymnasium.spaces.space import Space
|
||||
from gymnasium.spaces.text import Text
|
||||
@@ -38,6 +39,7 @@ __all__ = [
|
||||
"Tuple",
|
||||
"Sequence",
|
||||
"Dict",
|
||||
"OneOf",
|
||||
# util functions (there are more utility functions in vector/utils/spaces.py)
|
||||
"flatdim",
|
||||
"flatten_space",
|
||||
|
158
gymnasium/spaces/oneof.py
Normal file
158
gymnasium/spaces/oneof.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Implementation of a space that represents the cartesian product of other spaces."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.spaces.space import Space
|
||||
|
||||
|
||||
class OneOf(Space[Any]):
|
||||
"""An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.
|
||||
|
||||
Elements of this space are elements of one of the constituent spaces.
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import OneOf, Box, Discrete
|
||||
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
|
||||
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
|
||||
(1, array([-0.3991573 , 0.21649833], dtype=float32))
|
||||
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
|
||||
(0, 0)
|
||||
>>> observation_space[0]
|
||||
Discrete(2)
|
||||
>>> observation_space[1]
|
||||
Box(-1.0, 1.0, (2,), float32)
|
||||
>>> len(observation_space)
|
||||
2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spaces: Iterable[Space[Any]],
|
||||
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
|
||||
):
|
||||
r"""Constructor of :class:`OneOf` space.
|
||||
|
||||
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
|
||||
|
||||
Args:
|
||||
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
|
||||
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
|
||||
"""
|
||||
self.spaces = tuple(spaces)
|
||||
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
|
||||
for space in self.spaces:
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
|
||||
super().__init__(None, None, seed)
|
||||
|
||||
@property
|
||||
def is_np_flattenable(self):
|
||||
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
|
||||
return all(space.is_np_flattenable for space in self.spaces)
|
||||
|
||||
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
|
||||
"""Seed the PRNG of this space and all subspaces.
|
||||
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
|
||||
* ``None`` - All the subspaces will use a random initial seed
|
||||
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
||||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
"""
|
||||
if isinstance(seed, collections.abc.Sequence):
|
||||
assert (
|
||||
len(seed) == len(self.spaces) + 1
|
||||
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
|
||||
seeds = super().seed(seed[0])
|
||||
for subseed, space in zip(seed, self.spaces):
|
||||
seeds += space.seed(subseed)
|
||||
elif isinstance(seed, int):
|
||||
seeds = super().seed(seed)
|
||||
subseeds = self.np_random.integers(
|
||||
np.iinfo(np.int32).max, size=len(self.spaces)
|
||||
)
|
||||
for subspace, subseed in zip(self.spaces, subseeds):
|
||||
seeds += subspace.seed(int(subseed))
|
||||
elif seed is None:
|
||||
seeds = super().seed(None)
|
||||
for space in self.spaces:
|
||||
seeds += space.seed(None)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
|
||||
)
|
||||
|
||||
return seeds
|
||||
|
||||
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
|
||||
"""Generates a single random sample inside this space.
|
||||
|
||||
This method draws independent samples from the subspaces.
|
||||
|
||||
Args:
|
||||
mask: An optional tuple of optional masks for each of the subspace's samples,
|
||||
expects the same number of masks as spaces
|
||||
|
||||
Returns:
|
||||
Tuple of the subspace's samples
|
||||
"""
|
||||
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
|
||||
subspace = self.spaces[subspace_idx]
|
||||
if mask is not None:
|
||||
assert isinstance(
|
||||
mask, tuple
|
||||
), f"Expected type of mask is tuple, actual type: {type(mask)}"
|
||||
assert len(mask) == len(
|
||||
self.spaces
|
||||
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"
|
||||
|
||||
mask = mask[subspace_idx]
|
||||
|
||||
return subspace_idx, subspace.sample(mask=mask)
|
||||
|
||||
def contains(self, x: tuple[int, Any]) -> bool:
|
||||
"""Return boolean specifying if x is a valid member of this space."""
|
||||
(idx, value) = x
|
||||
|
||||
return isinstance(x, tuple) and self.spaces[idx].contains(value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Gives a string representation of this space."""
|
||||
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
||||
|
||||
def to_jsonable(
|
||||
self, sample_n: typing.Sequence[tuple[int, Any]]
|
||||
) -> list[list[Any]]:
|
||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||
return [
|
||||
[int(i), self.spaces[i].to_jsonable([subsample])[0]]
|
||||
for (i, subsample) in sample_n
|
||||
]
|
||||
|
||||
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
|
||||
"""Convert a JSONable data type to a batch of samples from this space."""
|
||||
return [
|
||||
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
|
||||
for space_idx, jsonable_sample in sample_n
|
||||
]
|
||||
|
||||
def __getitem__(self, index: int) -> Space[Any]:
|
||||
"""Get the subspace at specific `index`."""
|
||||
return self.spaces[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the number of subspaces that are involved in the cartesian product."""
|
||||
return len(self.spaces)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return isinstance(other, OneOf) and self.spaces == other.spaces
|
@@ -23,6 +23,7 @@ from gymnasium.spaces import (
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
OneOf,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
@@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int:
|
||||
return space.max_length
|
||||
|
||||
|
||||
@flatdim.register(OneOf)
|
||||
def _flatdim_oneof(space: OneOf) -> int:
|
||||
return 1 + max(flatdim(s) for s in space.spaces)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
FlatType = Union[
|
||||
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
|
||||
@@ -256,6 +262,22 @@ def _flatten_sequence(
|
||||
return tuple(flatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@flatten.register(OneOf)
|
||||
def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
|
||||
idx, sample = x
|
||||
sub_space = space.spaces[idx]
|
||||
flat_sample = flatten(sub_space, sample)
|
||||
|
||||
max_flatdim = flatdim(space) - 1 # Don't include the index
|
||||
if flat_sample.size < max_flatdim:
|
||||
padding = np.full(
|
||||
max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype
|
||||
)
|
||||
flat_sample = np.concatenate([flat_sample, padding])
|
||||
|
||||
return np.concatenate([[idx], flat_sample])
|
||||
|
||||
|
||||
@singledispatch
|
||||
def unflatten(space: Space[T], x: FlatType) -> T:
|
||||
"""Unflatten a data point from a space.
|
||||
@@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
|
||||
return tuple(unflatten(space.feature_space, item) for item in x)
|
||||
|
||||
|
||||
@unflatten.register(OneOf)
|
||||
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
|
||||
idx = int(x[0])
|
||||
sub_space = space.spaces[idx]
|
||||
|
||||
original_size = flatdim(sub_space)
|
||||
trimmed_sample = x[1 : 1 + original_size]
|
||||
|
||||
return idx, unflatten(sub_space, trimmed_sample)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
"""Flatten a space into a space that is as flat as possible.
|
||||
@@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box:
|
||||
@flatten_space.register(Sequence)
|
||||
def _flatten_space_sequence(space: Sequence) -> Sequence:
|
||||
return Sequence(flatten_space(space.feature_space), stack=space.stack)
|
||||
|
||||
|
||||
@flatten_space.register(OneOf)
|
||||
def _flatten_space_oneof(space: OneOf) -> Box:
|
||||
num_subspaces = len(space.spaces)
|
||||
max_flatdim = max(flatdim(s) for s in space.spaces) + 1
|
||||
|
||||
lows = np.array([np.min(flatten_space(s).low) for s in space.spaces])
|
||||
highs = np.array([np.max(flatten_space(s).high) for s in space.spaces])
|
||||
|
||||
overall_low = np.min(lows)
|
||||
overall_high = np.max(highs)
|
||||
|
||||
low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)])
|
||||
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])
|
||||
|
||||
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
|
||||
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
|
||||
|
@@ -17,6 +17,7 @@ from gymnasium.spaces import (
|
||||
Graph,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
OneOf,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
@@ -93,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
|
||||
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
|
||||
|
||||
|
||||
@create_shared_memory.register(OneOf)
|
||||
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
|
||||
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)
|
||||
|
||||
|
||||
@create_shared_memory.register(Graph)
|
||||
@create_shared_memory.register(Sequence)
|
||||
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
|
||||
@@ -170,7 +176,9 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Text)
|
||||
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
|
||||
def _read_text_from_shared_memory(
|
||||
space: Text, shared_memory, n: int = 1
|
||||
) -> tuple[str, ...]:
|
||||
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
|
||||
(n, space.max_length)
|
||||
)
|
||||
@@ -187,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(OneOf)
|
||||
def _read_one_of_from_shared_memory(
|
||||
space: OneOf, shared_memory, n: int = 1
|
||||
) -> tuple[Any, ...]:
|
||||
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
|
||||
subspace_samples = tuple(
|
||||
read_from_shared_memory(subspace, memory, n=n)
|
||||
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
|
||||
)
|
||||
return tuple(
|
||||
(index, sample[index])
|
||||
for index, sample in zip(sample_indexes, subspace_samples)
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def write_to_shared_memory(
|
||||
space: Space,
|
||||
@@ -258,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
|
||||
destination[index * size : (index + 1) * size],
|
||||
flatten(space, values),
|
||||
)
|
||||
|
||||
|
||||
@write_to_shared_memory.register(OneOf)
|
||||
def _write_oneof_to_shared_memory(
|
||||
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
|
||||
):
|
||||
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
|
||||
np.copyto(destination[index : index + 1], values[0])
|
||||
|
||||
for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
|
||||
write_to_shared_memory(subspace, index, value, memory)
|
||||
|
@@ -23,6 +23,7 @@ from gymnasium.spaces import (
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
OneOf,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
@@ -124,8 +125,9 @@ def _batch_space_dict(space: Dict, n: int = 1):
|
||||
@batch_space.register(Graph)
|
||||
@batch_space.register(Text)
|
||||
@batch_space.register(Sequence)
|
||||
@batch_space.register(OneOf)
|
||||
@batch_space.register(Space)
|
||||
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1):
|
||||
def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
|
||||
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
|
||||
# Which is an issue if you are sampling actions of both the original space and the batched space
|
||||
batched_space = Tuple(
|
||||
@@ -297,6 +299,7 @@ def _concatenate_dict(
|
||||
@concatenate.register(Text)
|
||||
@concatenate.register(Sequence)
|
||||
@concatenate.register(Space)
|
||||
@concatenate.register(OneOf)
|
||||
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
|
||||
return tuple(items)
|
||||
|
||||
@@ -336,7 +339,7 @@ def create_empty_array(
|
||||
)
|
||||
|
||||
|
||||
# It is possible for the some of the Box low to be greater than 0, then array is not in space
|
||||
# It is possible for some of the Box low to be greater than 0, then array is not in space
|
||||
@create_empty_array.register(Box)
|
||||
# If the Discrete start > 0 or start + length < 0 then array is not in space
|
||||
@create_empty_array.register(Discrete)
|
||||
@@ -402,6 +405,11 @@ def _create_empty_array_sequence(
|
||||
return tuple(tuple() for _ in range(n))
|
||||
|
||||
|
||||
@create_empty_array.register(OneOf)
|
||||
def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros):
|
||||
return tuple(tuple() for _ in range(n))
|
||||
|
||||
|
||||
@create_empty_array.register(Space)
|
||||
def _create_empty_array_custom(space, n=1, fn=np.zeros):
|
||||
return None
|
||||
|
@@ -14,6 +14,7 @@ from gymnasium.spaces import (
|
||||
GraphInstance,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
OneOf,
|
||||
Sequence,
|
||||
Text,
|
||||
Tuple,
|
||||
@@ -145,3 +146,8 @@ def _create_graph_zero_array(space: Graph):
|
||||
edges = np.expand_dims(create_zero_array(space.edge_space), axis=0)
|
||||
edge_links = np.zeros((1, 2), dtype=np.int64)
|
||||
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)
|
||||
|
||||
|
||||
@create_zero_array.register(OneOf)
|
||||
def _create_one_of_zero_array(space: OneOf):
|
||||
return 0, create_zero_array(space.spaces[0])
|
||||
|
72
tests/spaces/test_oneof.py
Normal file
72
tests/spaces/test_oneof.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
|
||||
|
||||
def test_oneof_inheritance():
|
||||
"""Tests that OneOf space properly inherits and implements required methods."""
|
||||
spaces = [Discrete(5), Box(-1, 1, shape=(3,)), MultiBinary(2)]
|
||||
oneof_space = OneOf(spaces)
|
||||
|
||||
assert len(oneof_space) == len(spaces)
|
||||
# Test indexing
|
||||
for i in range(len(oneof_space)):
|
||||
assert oneof_space[i] == spaces[i]
|
||||
|
||||
# Test iterable
|
||||
for space in oneof_space:
|
||||
assert space in spaces
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spaces, seed, expected_len",
|
||||
[
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], None, 3),
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3),
|
||||
([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3),
|
||||
],
|
||||
)
|
||||
def test_oneof_seeds(spaces, seed, expected_len):
|
||||
oneof_space = OneOf(spaces)
|
||||
seeds = oneof_space.seed(seed)
|
||||
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
|
||||
assert len(seeds) == expected_len
|
||||
|
||||
sample1 = oneof_space.sample()
|
||||
|
||||
seeds2 = oneof_space.seed(seed)
|
||||
sample2 = oneof_space.sample()
|
||||
|
||||
data_equivalence(seeds, seeds2)
|
||||
data_equivalence(sample1, sample2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spaces_fn",
|
||||
[
|
||||
lambda: OneOf(["abc"]),
|
||||
lambda: OneOf([Box(0, 1), "abc"]),
|
||||
lambda: OneOf("abc"),
|
||||
],
|
||||
)
|
||||
def test_bad_oneof_calls(spaces_fn):
|
||||
with pytest.raises(AssertionError):
|
||||
spaces_fn()
|
||||
|
||||
|
||||
def test_oneof_contains():
|
||||
space = OneOf([Box(0, 1), Box(-1, 0, (2,))])
|
||||
|
||||
assert (0, np.array([0.5], dtype=np.float32)) in space
|
||||
assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space
|
||||
|
||||
|
||||
def test_bad_oneof_seed():
|
||||
space = OneOf([Box(0, 1), Box(0, 1)])
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
|
||||
):
|
||||
space.seed(0.0)
|
@@ -534,6 +534,7 @@ SPACE_KWARGS = [
|
||||
{"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
|
||||
{"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
|
||||
{"space": Discrete(4)}, # Sequence
|
||||
{"spaces": (Discrete(3), Discrete(5))}, # OneOf
|
||||
]
|
||||
assert len(SPACE_CLS) == len(SPACE_KWARGS)
|
||||
|
||||
|
@@ -54,6 +54,11 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
# OneOf
|
||||
4,
|
||||
5,
|
||||
]
|
||||
|
||||
|
||||
@@ -106,7 +111,8 @@ def test_flatten_space(space):
|
||||
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
|
||||
def test_flatten(space):
|
||||
"""Test that a flattened sample have the `flatdim` shape."""
|
||||
flattened_sample = utils.flatten(space, space.sample())
|
||||
sample = space.sample()
|
||||
flattened_sample = utils.flatten(space, sample)
|
||||
|
||||
if space.is_np_flattenable:
|
||||
assert isinstance(flattened_sample, np.ndarray)
|
||||
|
@@ -9,6 +9,7 @@ from gymnasium.spaces import (
|
||||
Graph,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
OneOf,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
@@ -108,6 +109,9 @@ TESTING_COMPOSITE_SPACES = [
|
||||
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
|
||||
Sequence(Box(low=0.0, high=1.0), stack=True),
|
||||
Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True),
|
||||
# OneOf spaces
|
||||
OneOf([Discrete(3), Box(low=0.0, high=1.0)]),
|
||||
OneOf([MultiBinary(2), MultiDiscrete([2, 2])]),
|
||||
]
|
||||
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]
|
||||
|
||||
|
Reference in New Issue
Block a user