Update the experimental vector shared memory util functions (#339)

This commit is contained in:
Mark Towers
2023-02-20 16:02:12 +00:00
committed by GitHub
parent 31277a8f5b
commit b3685f51a2
4 changed files with 173 additions and 186 deletions

View File

@@ -5,6 +5,7 @@ import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
from typing import Any
import numpy as np
@@ -13,10 +14,14 @@ from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
flatten,
)
@@ -24,7 +29,9 @@ __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_m
@singledispatch
def create_shared_memory(space: Space, n: int = 1, ctx=mp) -> dict | tuple | mp.Array:
def create_shared_memory(
space: Space[Any], n: int = 1, ctx=mp
) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
@@ -40,17 +47,24 @@ def create_shared_memory(space: Space, n: int = 1, ctx=mp) -> dict | tuple | mp.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot create a shared memory for space with type `{type(space)}`. "
"`create_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
def _create_base_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
):
assert space.dtype is not None
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
@@ -58,14 +72,14 @@ def _create_base_shared_memory(space, n: int = 1, ctx=mp):
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space, n=1, ctx=mp):
def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
@@ -74,10 +88,23 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
)
@create_shared_memory.register(Text)
def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
raise TypeError(
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
)
@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
) -> dict | tuple | np.ndarray:
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array.
..notes::
@@ -97,24 +124,30 @@ def read_from_shared_memory(
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot read from a shared memory for space with type `{type(space)}`. "
"`read_from_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
def _read_base_from_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
@@ -122,7 +155,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
return OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
@@ -131,12 +164,30 @@ def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
)
@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
return tuple(
"".join(
[
space.character_list[val]
for val in values
if val < len(space.character_set)
]
)
for values in data
)
@singledispatch
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: dict | tuple | mp.Array,
shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
):
"""Write the observation of a single environment into shared memory.
@@ -150,17 +201,26 @@ def write_to_shared_memory(
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot write to a shared memory for space with type `{type(space)}`. "
"`write_to_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(space, index, value, shared_memory):
def _write_base_to_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
index: int,
value,
shared_memory,
):
size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
@@ -170,12 +230,26 @@ def _write_base_to_shared_memory(space, index, value, shared_memory):
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(space, index, values, shared_memory):
def _write_tuple_to_shared_memory(
space: Tuple, index: int, values: tuple[Any, ...], shared_memory
):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(space, index, values, shared_memory):
def _write_dict_to_shared_memory(
space: Dict, index: int, values: dict[str, Any], shared_memory
):
for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
@write_to_shared_memory.register(Text)
def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
size = space.max_length
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
np.copyto(
destination[index * size : (index + 1) * size],
flatten(space, values),
)