Update the experimental vector shared memory util functions (#339)

This commit is contained in:
Mark Towers
2023-02-20 16:02:12 +00:00
committed by GitHub
parent 31277a8f5b
commit b3685f51a2
4 changed files with 173 additions and 186 deletions

View File

@@ -5,6 +5,7 @@ import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
from typing import Any
import numpy as np
@@ -13,10 +14,14 @@ from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
flatten,
)
@@ -24,7 +29,9 @@ __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_m
@singledispatch
def create_shared_memory(space: Space, n: int = 1, ctx=mp) -> dict | tuple | mp.Array:
def create_shared_memory(
space: Space[Any], n: int = 1, ctx=mp
) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
@@ -40,17 +47,24 @@ def create_shared_memory(space: Space, n: int = 1, ctx=mp) -> dict | tuple | mp.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot create a shared memory for space with type `{type(space)}`. "
"`create_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
def _create_base_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
):
assert space.dtype is not None
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
@@ -58,14 +72,14 @@ def _create_base_shared_memory(space, n: int = 1, ctx=mp):
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space, n=1, ctx=mp):
def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
@@ -74,10 +88,23 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
)
@create_shared_memory.register(Text)
def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
raise TypeError(
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
)
@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
) -> dict | tuple | np.ndarray:
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array.
..notes::
@@ -97,24 +124,30 @@ def read_from_shared_memory(
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot read from a shared memory for space with type `{type(space)}`. "
"`read_from_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
def _read_base_from_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
@@ -122,7 +155,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
return OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
@@ -131,12 +164,30 @@ def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
)
@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
return tuple(
"".join(
[
space.character_list[val]
for val in values
if val < len(space.character_set)
]
)
for values in data
)
@singledispatch
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: dict | tuple | mp.Array,
shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
):
"""Write the observation of a single environment into shared memory.
@@ -150,17 +201,26 @@ def write_to_shared_memory(
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot write to a shared memory for space with type `{type(space)}`. "
"`write_to_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(space, index, value, shared_memory):
def _write_base_to_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
index: int,
value,
shared_memory,
):
size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
@@ -170,12 +230,26 @@ def _write_base_to_shared_memory(space, index, value, shared_memory):
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(space, index, values, shared_memory):
def _write_tuple_to_shared_memory(
space: Tuple, index: int, values: tuple[Any, ...], shared_memory
):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(space, index, values, shared_memory):
def _write_dict_to_shared_memory(
space: Dict, index: int, values: dict[str, Any], shared_memory
):
for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
@write_to_shared_memory.register(Text)
def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
size = space.max_length
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
np.copyto(
destination[index * size : (index + 1) * size],
flatten(space, values),
)

View File

@@ -39,7 +39,7 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
for space in self.spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
super().__init__(None, None, seed) # type: ignore
@property

View File

@@ -1,187 +1,93 @@
"""Tests `gymnasium.experimental.vector.utils.shared_memory functions."""
import multiprocessing as mp
from collections import OrderedDict
from multiprocessing import Array, Process
from multiprocessing.sharedctypes import SynchronizedArray
import re
import numpy as np
import pytest
from gymnasium import Space
from gymnasium.error import CustomSpaceError
from gymnasium.experimental.vector.utils import (
create_shared_memory,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.spaces import Dict, Tuple
from gymnasium.vector.utils import BaseGymSpaces
from tests.experimental.vector.testing_utils import custom_spaces, spaces
from gymnasium.utils.env_checker import data_equivalence
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
expected_types = [
Array("d", 1),
Array("f", 1),
Array("f", 3),
Array("f", 4),
Array("B", 1),
Array("B", 32 * 32 * 3),
Array("i", 1),
Array("i", 1),
(Array("i", 1), Array("i", 1)),
(Array("i", 1), Array("f", 2)),
Array("B", 3),
Array("B", 19),
OrderedDict([("position", Array("i", 1)), ("velocity", Array("f", 1))]),
OrderedDict(
[
("position", OrderedDict([("x", Array("i", 1)), ("y", Array("i", 1))])),
("velocity", (Array("i", 1), Array("B", 1))),
]
),
]
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space,expected_type",
list(zip(spaces, expected_types)),
ids=[space.__class__.__name__ for space in spaces],
)
@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)
@pytest.mark.parametrize("num", [1, 8])
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
def test_create_shared_memory(space, expected_type, n, ctx):
"""Tests the `create_shared_memory` function with a number of spaces."""
def test_shared_memory_create_read_write(space, num, ctx):
"""Test the shared memory functions, create, read and write for all of the testing spaces."""
ctx = mp if ctx is None else mp.get_context(ctx)
samples = [space.sample() for _ in range(num)]
def assert_nested_type(lhs, rhs, n):
assert type(lhs) == type(rhs)
if isinstance(lhs, (list, tuple)):
assert len(lhs) == len(rhs)
for lhs_, rhs_ in zip(lhs, rhs):
assert_nested_type(lhs_, rhs_, n)
try:
shared_memory = create_shared_memory(space, n=num, ctx=ctx)
except TypeError:
return
elif isinstance(lhs, (dict, OrderedDict)):
assert set(lhs.keys()) ^ set(rhs.keys()) == set()
for key in lhs.keys():
assert_nested_type(lhs[key], rhs[key], n)
for i, sample in enumerate(samples):
write_to_shared_memory(space, i, sample, shared_memory)
elif isinstance(lhs, SynchronizedArray):
# Assert the length of the array
assert len(lhs[:]) == n * len(rhs[:])
# Assert the data type
assert isinstance(lhs[0], type(rhs[0]))
else:
raise TypeError(f"Got unknown type `{type(lhs)}`.")
ctx = mp if (ctx is None) else mp.get_context(ctx)
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
assert_nested_type(shared_memory, expected_type, n=n)
read_samples = read_from_shared_memory(space, shared_memory, n=num)
for read_sample, sample in zip(read_samples, samples):
data_equivalence(read_sample, sample)
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
@pytest.mark.parametrize("space", custom_spaces)
def test_create_shared_memory_custom_space(n, ctx, space):
"""Tests the `create_shared_memory` function with a custom space."""
ctx = mp if (ctx is None) else mp.get_context(ctx)
with pytest.raises(CustomSpaceError):
create_shared_memory(space, n=n, ctx=ctx)
def test_custom_space():
"""Test using custom spaces for shared memory functions."""
with pytest.raises(
CustomSpaceError,
match=re.escape(
"Space of type `<class 'gymnasium.spaces.space.Space'>` doesn't have an registered `create_shared_memory` function. Register `<class 'gymnasium.spaces.space.Space'>` for `create_shared_memory` to support it."
),
):
create_shared_memory(Space())
with pytest.raises(
CustomSpaceError,
match=re.escape(
"Space of type `<class 'gymnasium.spaces.space.Space'>` doesn't have an registered `read_from_shared_memory` function. Register `<class 'gymnasium.spaces.space.Space'>` for `read_from_shared_memory` to support it."
),
):
read_from_shared_memory(Space(), None, 1)
with pytest.raises(
CustomSpaceError,
match=re.escape(
"Space of type `<class 'gymnasium.spaces.space.Space'>` doesn't have an registered `write_to_shared_memory` function. Register `<class 'gymnasium.spaces.space.Space'>` for `write_to_shared_memory` to support it."
),
):
write_to_shared_memory(Space(), 1, None, None)
def _write_shared_memory(space, i, shared_memory, sample):
write_to_shared_memory(space, i, sample, shared_memory)
def test_non_space():
"""Test the use of non-space types on the shared memory functions."""
with pytest.raises(
TypeError,
match=re.escape(
"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: <class 'str'>, space"
),
):
create_shared_memory("space")
with pytest.raises(
TypeError,
match=re.escape(
"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: <class 'str'>, space"
),
):
read_from_shared_memory("space", None, 1)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_write_to_shared_memory(space):
"""Tests `write_to_shared_memory` function with a list of spaces."""
def assert_nested_equal(lhs, rhs):
assert isinstance(rhs, list)
if isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs])
elif isinstance(lhs, (dict, OrderedDict)):
for key in lhs.keys():
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs])
elif isinstance(lhs, SynchronizedArray):
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
else:
raise TypeError(f"Got unknown type `{type(lhs)}`.")
shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [
Process(
target=_write_shared_memory, args=(space, i, shared_memory_n8, samples[i])
)
for i in range(8)
]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(shared_memory_n8, samples)
def _process_write(space, i, shared_memory, sample):
write_to_shared_memory(space, i, sample, shared_memory)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_read_from_shared_memory(space):
"""Tests `read_from_shared_memory` function with list of spaces."""
def assert_nested_equal(lhs, rhs, space, n):
assert isinstance(rhs, list)
if isinstance(space, Tuple):
assert isinstance(lhs, tuple)
for i in range(len(lhs)):
assert_nested_equal(
lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n
)
elif isinstance(space, Dict):
assert isinstance(lhs, OrderedDict)
for key in lhs.keys():
assert_nested_equal(
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
)
elif isinstance(space, BaseGymSpaces):
assert isinstance(lhs, np.ndarray)
assert lhs.shape == ((n,) + space.shape)
assert lhs.dtype == space.dtype
assert np.all(lhs == np.stack(rhs, axis=0))
else:
raise TypeError(f"Got unknown type `{type(space)}`")
shared_memory_n8 = create_shared_memory(space, n=8)
memory_view_n8 = read_from_shared_memory(space, shared_memory_n8, n=8)
samples = [space.sample() for _ in range(8)]
processes = [
Process(target=_process_write, args=(space, i, shared_memory_n8, samples[i]))
for i in range(8)
]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(memory_view_n8, samples, space, n=8)
with pytest.raises(
TypeError,
match=re.escape(
"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: <class 'str'>, space"
),
):
write_to_shared_memory("space", 1, None, None)

View File

@@ -111,3 +111,10 @@ TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES
TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES + TESTING_COMPOSITE_SPACES
TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS + TESTING_COMPOSITE_SPACES_IDS
CUSTOM_SPACES = [
Space(),
Tuple([Space(), Space(), Space()]),
Dict(a=Space(), b=Space()),
]
CUSTOM_SPACES_IDS = [f"{space}" for space in CUSTOM_SPACES]