Add a OneOf space for exclusive unions (#812)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2024-03-11 13:30:50 +01:00
committed by GitHub
parent fd4ae52045
commit 2b2e853536
12 changed files with 353 additions and 5 deletions

View File

@@ -66,7 +66,8 @@ Often environment spaces require joining fundamental spaces together for vectori
* :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values.
* :class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values
* :class:`OneOf` - Supports optional action spaces such that an action can be one of N possible subspaces
```
## Utility functions

View File

@@ -21,4 +21,9 @@
.. automethod:: gymnasium.spaces.Graph.sample
.. automethod:: gymnasium.spaces.Graph.seed
.. autoclass:: gymnasium.spaces.OneOf
.. automethod:: gymnasium.spaces.OneOf.sample
.. automethod:: gymnasium.spaces.OneOf.seed
```

View File

@@ -16,6 +16,7 @@ from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.graph import Graph, GraphInstance
from gymnasium.spaces.multi_binary import MultiBinary
from gymnasium.spaces.multi_discrete import MultiDiscrete
from gymnasium.spaces.oneof import OneOf
from gymnasium.spaces.sequence import Sequence
from gymnasium.spaces.space import Space
from gymnasium.spaces.text import Text
@@ -38,6 +39,7 @@ __all__ = [
"Tuple",
"Sequence",
"Dict",
"OneOf",
# util functions (there are more utility functions in vector/utils/spaces.py)
"flatdim",
"flatten_space",

158
gymnasium/spaces/oneof.py Normal file
View File

@@ -0,0 +1,158 @@
"""Implementation of a space that represents the cartesian product of other spaces."""
from __future__ import annotations
import collections.abc
import typing
from typing import Any, Iterable
import numpy as np
from gymnasium.spaces.space import Space
class OneOf(Space[Any]):
"""An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.
Elements of this space are elements of one of the constituent spaces.
Example:
>>> from gymnasium.spaces import OneOf, Box, Discrete
>>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space.sample() # the first element is the space index (Box in this case) and the second element is the sample from Box
(1, array([-0.3991573 , 0.21649833], dtype=float32))
>>> observation_space.sample() # this time the Discrete space was sampled as index=0
(0, 0)
>>> observation_space[0]
Discrete(2)
>>> observation_space[1]
Box(-1.0, 1.0, (2,), float32)
>>> len(observation_space)
2
"""
def __init__(
self,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`OneOf` space.
The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
Args:
spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
"""
self.spaces = tuple(spaces)
assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
for space in self.spaces:
assert isinstance(
space, Space
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
super().__init__(None, None, seed)
@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)
def seed(self, seed: int | typing.Sequence[int] | None = None) -> list[int]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
"""
if isinstance(seed, collections.abc.Sequence):
assert (
len(seed) == len(self.spaces) + 1
), f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
seeds = super().seed(seed[0])
for subseed, space in zip(seed, self.spaces):
seeds += space.seed(subseed)
elif isinstance(seed, int):
seeds = super().seed(seed)
subseeds = self.np_random.integers(
np.iinfo(np.int32).max, size=len(self.spaces)
)
for subspace, subseed in zip(self.spaces, subseeds):
seeds += subspace.seed(int(subseed))
elif seed is None:
seeds = super().seed(None)
for space in self.spaces:
seeds += space.seed(None)
else:
raise TypeError(
f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
)
return seeds
def sample(self, mask: tuple[Any | None, ...] | None = None) -> tuple[int, Any]:
"""Generates a single random sample inside this space.
This method draws independent samples from the subspaces.
Args:
mask: An optional tuple of optional masks for each of the subspace's samples,
expects the same number of masks as spaces
Returns:
Tuple of the subspace's samples
"""
subspace_idx = int(self.np_random.integers(0, len(self.spaces)))
subspace = self.spaces[subspace_idx]
if mask is not None:
assert isinstance(
mask, tuple
), f"Expected type of mask is tuple, actual type: {type(mask)}"
assert len(mask) == len(
self.spaces
), f"Expected length of mask is {len(self.spaces)}, actual length: {len(mask)}"
mask = mask[subspace_idx]
return subspace_idx, subspace.sample(mask=mask)
def contains(self, x: tuple[int, Any]) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
(idx, value) = x
return isinstance(x, tuple) and self.spaces[idx].contains(value)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"
def to_jsonable(
self, sample_n: typing.Sequence[tuple[int, Any]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [
[int(i), self.spaces[i].to_jsonable([subsample])[0]]
for (i, subsample) in sample_n
]
def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [
(space_idx, self.spaces[space_idx].from_jsonable([jsonable_sample])[0])
for space_idx, jsonable_sample in sample_n
]
def __getitem__(self, index: int) -> Space[Any]:
"""Get the subspace at specific `index`."""
return self.spaces[index]
def __len__(self) -> int:
"""Get the number of subspaces that are involved in the cartesian product."""
return len(self.spaces)
def __eq__(self, other: Any) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return isinstance(other, OneOf) and self.spaces == other.spaces

View File

@@ -23,6 +23,7 @@ from gymnasium.spaces import (
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
@@ -104,6 +105,11 @@ def _flatdim_text(space: Text) -> int:
return space.max_length
@flatdim.register(OneOf)
def _flatdim_oneof(space: OneOf) -> int:
return 1 + max(flatdim(s) for s in space.spaces)
T = TypeVar("T")
FlatType = Union[
NDArray[Any], typing.Dict[str, Any], typing.Tuple[Any, ...], GraphInstance
@@ -256,6 +262,22 @@ def _flatten_sequence(
return tuple(flatten(space.feature_space, item) for item in x)
@flatten.register(OneOf)
def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
idx, sample = x
sub_space = space.spaces[idx]
flat_sample = flatten(sub_space, sample)
max_flatdim = flatdim(space) - 1 # Don't include the index
if flat_sample.size < max_flatdim:
padding = np.full(
max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype
)
flat_sample = np.concatenate([flat_sample, padding])
return np.concatenate([[idx], flat_sample])
@singledispatch
def unflatten(space: Space[T], x: FlatType) -> T:
"""Unflatten a data point from a space.
@@ -399,6 +421,17 @@ def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...]
return tuple(unflatten(space.feature_space, item) for item in x)
@unflatten.register(OneOf)
def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
idx = int(x[0])
sub_space = space.spaces[idx]
original_size = flatdim(sub_space)
trimmed_sample = x[1 : 1 + original_size]
return idx, unflatten(sub_space, trimmed_sample)
@singledispatch
def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
"""Flatten a space into a space that is as flat as possible.
@@ -525,3 +558,21 @@ def _flatten_space_text(space: Text) -> Box:
@flatten_space.register(Sequence)
def _flatten_space_sequence(space: Sequence) -> Sequence:
return Sequence(flatten_space(space.feature_space), stack=space.stack)
@flatten_space.register(OneOf)
def _flatten_space_oneof(space: OneOf) -> Box:
num_subspaces = len(space.spaces)
max_flatdim = max(flatdim(s) for s in space.spaces) + 1
lows = np.array([np.min(flatten_space(s).low) for s in space.spaces])
highs = np.array([np.max(flatten_space(s).high) for s in space.spaces])
overall_low = np.min(lows)
overall_high = np.max(highs)
low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)])
high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)

View File

@@ -17,6 +17,7 @@ from gymnasium.spaces import (
Graph,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
@@ -93,6 +94,11 @@ def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(OneOf)
def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
return (ctx.Array(np.int32, n),) + _create_tuple_shared_memory(space)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
@@ -170,7 +176,9 @@ def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
def _read_text_from_shared_memory(
space: Text, shared_memory, n: int = 1
) -> tuple[str, ...]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
@@ -187,6 +195,21 @@ def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tup
)
@read_from_shared_memory.register(OneOf)
def _read_one_of_from_shared_memory(
space: OneOf, shared_memory, n: int = 1
) -> tuple[Any, ...]:
sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=space.dtype)
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory[1:], space.spaces)
)
return tuple(
(index, sample[index])
for index, sample in zip(sample_indexes, subspace_samples)
)
@singledispatch
def write_to_shared_memory(
space: Space,
@@ -258,3 +281,14 @@ def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_me
destination[index * size : (index + 1) * size],
flatten(space, values),
)
@write_to_shared_memory.register(OneOf)
def _write_oneof_to_shared_memory(
space: OneOf, index: int, values: tuple[Any, ...], shared_memory
):
destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int32)
np.copyto(destination[index : index + 1], values[0])
for value, memory, subspace in zip(values[1], shared_memory[1:], space.spaces):
write_to_shared_memory(subspace, index, value, memory)

View File

@@ -23,6 +23,7 @@ from gymnasium.spaces import (
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
@@ -124,8 +125,9 @@ def _batch_space_dict(space: Dict, n: int = 1):
@batch_space.register(Graph)
@batch_space.register(Text)
@batch_space.register(Sequence)
@batch_space.register(OneOf)
@batch_space.register(Space)
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1):
def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
@@ -297,6 +299,7 @@ def _concatenate_dict(
@concatenate.register(Text)
@concatenate.register(Sequence)
@concatenate.register(Space)
@concatenate.register(OneOf)
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
return tuple(items)
@@ -336,7 +339,7 @@ def create_empty_array(
)
# It is possible for the some of the Box low to be greater than 0, then array is not in space
# It is possible for some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete)
@@ -402,6 +405,11 @@ def _create_empty_array_sequence(
return tuple(tuple() for _ in range(n))
@create_empty_array.register(OneOf)
def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros):
return tuple(tuple() for _ in range(n))
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None

View File

@@ -14,6 +14,7 @@ from gymnasium.spaces import (
GraphInstance,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Text,
Tuple,
@@ -145,3 +146,8 @@ def _create_graph_zero_array(space: Graph):
edges = np.expand_dims(create_zero_array(space.edge_space), axis=0)
edge_links = np.zeros((1, 2), dtype=np.int64)
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)
@create_zero_array.register(OneOf)
def _create_one_of_zero_array(space: OneOf):
return 0, create_zero_array(space.spaces[0])

View File

@@ -0,0 +1,72 @@
import numpy as np
import pytest
from gymnasium.spaces import Box, Discrete, MultiBinary, OneOf
from gymnasium.utils.env_checker import data_equivalence
def test_oneof_inheritance():
"""Tests that OneOf space properly inherits and implements required methods."""
spaces = [Discrete(5), Box(-1, 1, shape=(3,)), MultiBinary(2)]
oneof_space = OneOf(spaces)
assert len(oneof_space) == len(spaces)
# Test indexing
for i in range(len(oneof_space)):
assert oneof_space[i] == spaces[i]
# Test iterable
for space in oneof_space:
assert space in spaces
@pytest.mark.parametrize(
"spaces, seed, expected_len",
[
([Discrete(5), Box(-1, 1, shape=(3,))], None, 3),
([Discrete(5), Box(-1, 1, shape=(3,))], 123, 3),
([Discrete(5), Box(-1, 1, shape=(3,))], [123, 456, 789], 3),
],
)
def test_oneof_seeds(spaces, seed, expected_len):
oneof_space = OneOf(spaces)
seeds = oneof_space.seed(seed)
assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
assert len(seeds) == expected_len
sample1 = oneof_space.sample()
seeds2 = oneof_space.seed(seed)
sample2 = oneof_space.sample()
data_equivalence(seeds, seeds2)
data_equivalence(sample1, sample2)
@pytest.mark.parametrize(
"spaces_fn",
[
lambda: OneOf(["abc"]),
lambda: OneOf([Box(0, 1), "abc"]),
lambda: OneOf("abc"),
],
)
def test_bad_oneof_calls(spaces_fn):
with pytest.raises(AssertionError):
spaces_fn()
def test_oneof_contains():
space = OneOf([Box(0, 1), Box(-1, 0, (2,))])
assert (0, np.array([0.5], dtype=np.float32)) in space
assert (1, np.array([-0.5, -0.5], dtype=np.float32)) in space
def test_bad_oneof_seed():
space = OneOf([Box(0, 1), Box(0, 1)])
with pytest.raises(
TypeError,
match="Expected seed type: list, tuple, int or None, actual type: <class 'float'>",
):
space.seed(0.0)

View File

@@ -534,6 +534,7 @@ SPACE_KWARGS = [
{"spaces": {"a": Discrete(3), "b": Discrete(2)}}, # Dict
{"node_space": Discrete(4), "edge_space": Discrete(3)}, # Graph
{"space": Discrete(4)}, # Sequence
{"spaces": (Discrete(3), Discrete(5))}, # OneOf
]
assert len(SPACE_CLS) == len(SPACE_KWARGS)

View File

@@ -54,6 +54,11 @@ TESTING_SPACES_EXPECTED_FLATDIMS = [
None,
None,
None,
None,
None,
# OneOf
4,
5,
]
@@ -106,7 +111,8 @@ def test_flatten_space(space):
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
def test_flatten(space):
"""Test that a flattened sample have the `flatdim` shape."""
flattened_sample = utils.flatten(space, space.sample())
sample = space.sample()
flattened_sample = utils.flatten(space, sample)
if space.is_np_flattenable:
assert isinstance(flattened_sample, np.ndarray)

View File

@@ -9,6 +9,7 @@ from gymnasium.spaces import (
Graph,
MultiBinary,
MultiDiscrete,
OneOf,
Sequence,
Space,
Text,
@@ -108,6 +109,9 @@ TESTING_COMPOSITE_SPACES = [
Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
Sequence(Box(low=0.0, high=1.0), stack=True),
Sequence(Dict({"a": Box(0, 1, (3,)), "b": Discrete(5)}), stack=True),
# OneOf spaces
OneOf([Discrete(3), Box(low=0.0, high=1.0)]),
OneOf([MultiBinary(2), MultiDiscrete([2, 2])]),
]
TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]