Add a OneOf space for exclusive unions (#812)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2024-03-11 13:30:50 +01:00
committed by GitHub
parent fd4ae52045
commit 2b2e853536
12 changed files with 353 additions and 5 deletions

View File

@@ -66,7 +66,8 @@ Often environment spaces require joining fundamental spaces together for vectori
* :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces * :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces * :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions * :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values. * :class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values
* :class:`OneOf` - Supports optional action spaces such that an action can be one of N possible subspaces
``` ```
## Utility functions ## Utility functions

View File

@@ -21,4 +21,9 @@
.. automethod:: gymnasium.spaces.Graph.sample .. automethod:: gymnasium.spaces.Graph.sample
.. automethod:: gymnasium.spaces.Graph.seed .. automethod:: gymnasium.spaces.Graph.seed
.. autoclass:: gymnasium.spaces.OneOf
.. automethod:: gymnasium.spaces.OneOf.sample
.. automethod:: gymnasium.spaces.OneOf.seed
``` ```

View File

@@ -16,6 +16,7 @@ from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.graph import Graph, GraphInstance from gymnasium.spaces.graph import Graph, GraphInstance
from gymnasium.spaces.multi_binary import MultiBinary from gymnasium.spaces.multi_binary import MultiBinary
from gymnasium.spaces.multi_discrete import MultiDiscrete from gymnasium.spaces.multi_discrete import MultiDiscrete
from gymnasium.spaces.oneof import OneOf
from gymnasium.spaces.sequence import Sequence from gymnasium.spaces.sequence import Sequence
from gymnasium.spaces.space import Space from gymnasium.spaces.space import Space
from gymnasium.spaces.text import Text from gymnasium.spaces.text import Text
@@ -38,6 +39,7 @@ __all__ = [
"Tuple", "Tuple",
"Sequence", "Sequence",
"Dict", "Dict",
"OneOf",
# util functions (there are more utility functions in vector/utils/spaces.py) # util functions (there are more utility functions in vector/utils/spaces.py)
"flatdim", "flatdim",
"flatten_space", "flatten_space",

158
gymnasium/spaces/oneof.py Normal file
View File

@@ -0,0 +1,158 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations
import collections.abc
import typing
from typing import Any, Iterable
import numpy as np
from gymnasium.spaces.space import Space
class OneOf(Space[Any]):
"""An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.
Elements of this space are elements of one of the constituent spaces.
Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
(1, array([-0.3991573 , 0.21649833], dtype=float32))
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(0, 0)
>>> observation_space[0]
Discrete(2)
>>> observation_space[1]
Box(-1.0, 1.0, (2,), float32)
>>> len(observation_space)
2
"""
def __init__(
self,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`OneOf` space.
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
Args:
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
self.spaces = tuple(spaces)
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
for space in self.spaces:
assert isinstance(
space, Space
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
super().__init__(None, None, seed)
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
if isinstance(seed, collections.abc.Sequence):
assert (
len(seed) == len(self.spaces) + 1
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
seeds = super().seed(seed[0])
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int):
seeds = super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
seeds = super().seed(None)
for space in self.spaces:
seeds += space.seed(None)
else:
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)
return seeds
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
Args:
mask: An optional tuple of optional masks for each of the subspace's samples,
expects the same number of masks as spaces
Returns:
Tuple of the subspace's samples
"""
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
subspace = self.spaces[subspace_idx]
if mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"
mask = mask[subspace_idx]
return subspace_idx, subspace.sample(mask=mask)
def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
(idx, value) = x
return isinstance(x, tuple) and self.spaces[idx].contains(value)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(
self, sample_n: typing.Sequence[tuple[int, Any]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [
[int(i), self.spaces[i].to_jsonable([subsample])[0]]
for (i, subsample) in sample_n
]
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
for space_idx, jsonable_sample in sample_n
]
def __getitem__(self, index: int) -> Space[Any]:
"""Get the subspace at specific `index`."""
return self.spaces[index]
def __len__(self) -> int:
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, OneOf) and self.spaces == other.spaces

View File

@@ -23,6 +23,7 @@ from gymnasium.spaces import (
GraphInstance, GraphInstance,
MultiBinary, MultiBinary,
MultiDiscrete, MultiDiscrete,
OneOf,
Sequence, Sequence,
Space, Space,
Text, Text,
@@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int:
return space.max_length return space.max_length
@flatdim.register(OneOf)
def _flatdim_oneof(space: OneOf) -> int:
return 1 + max(flatdim(s) for s in space.spaces)
T = TypeVar("T") T = TypeVar("T")
FlatType = Union[ FlatType = Union[
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
@@ -256,6 +262,22 @@ def _flatten_sequence(
return tuple(flatten(space.feature_space, item) for item in x) return tuple(flatten(space.feature_space, item) for item in x)
@flatten.register(OneOf)
def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
idx, sample = x
sub_space = space.spaces[idx]
flat_sample = flatten(sub_space, sample)
max_flatdim = flatdim(space) - 1 # Don't include the index
if flat_sample.size < max_flatdim:
padding = np.full(
max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype
)
flat_sample = np.concatenate([flat_sample, padding])
return np.concatenate([[idx], flat_sample])
@singledispatch @singledispatch
def unflatten(space: Space[T], x: FlatType) -> T: def unflatten(space: Space[T], x: FlatType) -> T:
"""Unflatten a data point from a space. """Unflatten a data point from a space.
@@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
return tuple(unflatten(space.feature_space, item) for item in x) return tuple(unflatten(space.feature_space, item) for item in x)
@unflatten.register(OneOf)
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
idx = int(x[0])
sub_space = space.spaces[idx]
original_size = flatdim(sub_space)
trimmed_sample = x[1 : 1 + original_size]
return idx, unflatten(sub_space, trimmed_sample)
@singledispatch @singledispatch
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
"""Flatten a space into a space that is as flat as possible. """Flatten a space into a space that is as flat as possible.
@@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence) @flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence: def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space), stack=space.stack) return Sequence(flatten_space(space.feature_space), stack=space.stack)
@flatten_space.register(OneOf)
def _flatten_space_oneof(space: OneOf) -> Box:
num_subspaces = len(space.spaces)
max_flatdim = max(flatdim(s) for s in space.spaces) + 1
lows = np.array([np.min(flatten_space(s).low) for s in space.spaces])
highs = np.array([np.max(flatten_space(s).high) for s in space.spaces])
overall_low = np.min(lows)
overall_high = np.max(highs)
low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)])
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)

View File

@@ -17,6 +17,7 @@ from gymnasium.spaces import (
Graph, Graph,
MultiBinary, MultiBinary,
MultiDiscrete, MultiDiscrete,
OneOf,
Sequence, Sequence,
Space, Space,
Text, Text,
@@ -93,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length) return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)
@create_shared_memory.register(Graph) @create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence) @create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp): def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
@@ -170,7 +176,9 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
@read_from_shared_memory.register(Text) @read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]: def _read_text_from_shared_memory(
space: Text, shared_memory, n: int = 1
) -> tuple[str, ...]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape( data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length) (n, space.max_length)
) )
@@ -187,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup
) )
@read_from_shared_memory.register(OneOf)
def _read_one_of_from_shared_memory(
space: OneOf, shared_memory, n: int = 1
) -> tuple[Any, ...]:
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
)
return tuple(
(index, sample[index])
for index, sample in zip(sample_indexes, subspace_samples)
)
@singledispatch @singledispatch
def write_to_shared_memory( def write_to_shared_memory(
space: Space, space: Space,
@@ -258,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
destination[index * size : (index + 1) * size], destination[index * size : (index + 1) * size],
flatten(space, values), flatten(space, values),
) )
@write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
):
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
np.copyto(destination[index : index + 1], values[0])
for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
write_to_shared_memory(subspace, index, value, memory)

View File

@@ -23,6 +23,7 @@ from gymnasium.spaces import (
GraphInstance, GraphInstance,
MultiBinary, MultiBinary,
MultiDiscrete, MultiDiscrete,
OneOf,
Sequence, Sequence,
Space, Space,
Text, Text,
@@ -124,8 +125,9 @@ def _batch_space_dict(space: Dict, n: int = 1):
@batch_space.register(Graph) @batch_space.register(Graph)
@batch_space.register(Text) @batch_space.register(Text)
@batch_space.register(Sequence) @batch_space.register(Sequence)
@batch_space.register(OneOf)
@batch_space.register(Space) @batch_space.register(Space)
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1): def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random # Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space # Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple( batched_space = Tuple(
@@ -297,6 +299,7 @@ def _concatenate_dict(
@concatenate.register(Text) @concatenate.register(Text)
@concatenate.register(Sequence) @concatenate.register(Sequence)
@concatenate.register(Space) @concatenate.register(Space)
@concatenate.register(OneOf)
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]: def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
return tuple(items) return tuple(items)
@@ -336,7 +339,7 @@ def create_empty_array(
) )
# It is possible for the some of the Box low to be greater than 0, then array is not in space # It is possible for some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box) @create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space # If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete) @create_empty_array.register(Discrete)
@@ -402,6 +405,11 @@ def _create_empty_array_sequence(
return tuple(tuple() for _ in range(n)) return tuple(tuple() for _ in range(n))
@create_empty_array.register(OneOf)
def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros):
return tuple(tuple() for _ in range(n))
@create_empty_array.register(Space) @create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros): def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None return None

View File

@@ -14,6 +14,7 @@ from gymnasium.spaces import (
GraphInstance, GraphInstance,
MultiBinary, MultiBinary,
MultiDiscrete, MultiDiscrete,
OneOf,
Sequence, Sequence,
Text, Text,
Tuple, Tuple,
@@ -145,3 +146,8 @@ def _create_graph_zero_array(space: Graph):
edges = np.expand_dims(create_zero_array(space.edge_space), axis=0) edges = np.expand_dims(create_zero_array(space.edge_space), axis=0)
edge_links = np.zeros((1, 2), dtype=np.int64) edge_links = np.zeros((1, 2), dtype=np.int64)
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links) return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)
@create_zero_array.register(OneOf)
def _create_one_of_zero_array(space: OneOf):
return 0, create_zero_array(space.spaces[0])

View File

@@ -0,0 +1,72 @@
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
from gymnasium.utils.env_checker import data_equivalence
def test_oneof_inheritance():
"""Tests that OneOf space properly inherits and implements required methods."""
spaces = [Discrete(5), Box(-1, 1, shape=(3,)), MultiBinary(2)]
oneof_space = OneOf(spaces)
assert len(oneof_space) == len(spaces)
# Test indexing
for i in range(len(oneof_space)):
assert oneof_space[i] == spaces[i]
# Test iterable
for space in oneof_space:
assert space in spaces
@pytest.mark.parametrize(
"spaces, seed, expected_len",
[
([Discrete(5), Box(-1, 1, shape=(3,))], None, 3),
([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3),
([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3),
],
)
def test_oneof_seeds(spaces, seed, expected_len):
oneof_space = OneOf(spaces)
seeds = oneof_space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
sample1 = oneof_space.sample()
seeds2 = oneof_space.seed(seed)
sample2 = oneof_space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
@pytest.mark.parametrize(
"spaces_fn",
[
lambda: OneOf(["abc"]),
lambda: OneOf([Box(0, 1), "abc"]),
lambda: OneOf("abc"),
],
)
def test_bad_oneof_calls(spaces_fn):
with pytest.raises(AssertionError):
spaces_fn()
def test_oneof_contains():
space = OneOf([Box(0, 1), Box(-1, 0, (2,))])
assert (0, np.array([0.5], dtype=np.float32)) in space
assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space
def test_bad_oneof_seed():
space = OneOf([Box(0, 1), Box(0, 1)])
with pytest.raises(
TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
):
space.seed(0.0)

View File

@@ -534,6 +534,7 @@ SPACE_KWARGS = [
{"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict {"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
{"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph {"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
{"space": Discrete(4)}, # Sequence {"space": Discrete(4)}, # Sequence
{"spaces": (Discrete(3), Discrete(5))}, # OneOf
] ]
assert len(SPACE_CLS) == len(SPACE_KWARGS) assert len(SPACE_CLS) == len(SPACE_KWARGS)

View File

@@ -54,6 +54,11 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
None, None,
None, None,
None, None,
None,
None,
# OneOf
4,
5,
] ]
@@ -106,7 +111,8 @@ def test_flatten_space(space):
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flatten(space): def test_flatten(space):
"""Test that a flattened sample have the `flatdim` shape.""" """Test that a flattened sample have the `flatdim` shape."""
flattened_sample = utils.flatten(space, space.sample()) sample = space.sample()
flattened_sample = utils.flatten(space, sample)
if space.is_np_flattenable: if space.is_np_flattenable:
assert isinstance(flattened_sample, np.ndarray) assert isinstance(flattened_sample, np.ndarray)

View File

@@ -9,6 +9,7 @@ from gymnasium.spaces import (
Graph, Graph,
MultiBinary, MultiBinary,
MultiDiscrete, MultiDiscrete,
OneOf,
Sequence, Sequence,
Space, Space,
Text, Text,
@@ -108,6 +109,9 @@ TESTING_COMPOSITE_SPACES = [
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
Sequence(Box(low=0.0, high=1.0), stack=True), Sequence(Box(low=0.0, high=1.0), stack=True),
Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True), Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True),
# OneOf spaces
OneOf([Discrete(3), Box(low=0.0, high=1.0)]),
OneOf([MultiBinary(2), MultiDiscrete([2, 2])]),
] ]
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES] TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]