Fix reading shared memory for Tuple and Dict spaces (#941)

This commit is contained in:
Mark Towers
2024-02-27 16:27:34 +00:00
committed by GitHub
parent b3f0361f91
commit 36598b939b
2 changed files with 9 additions and 4 deletions

View File

@@ -148,20 +148,25 @@ def _read_base_from_shared_memory(
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
return tuple(
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
return tuple(zip(*subspace_samples))
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
return OrderedDict(
subspace_samples = OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
for (key, subspace) in space.spaces.items()
]
)
return tuple(
OrderedDict({key: subspace_samples[key][i] for key in space.keys()})
for i in range(n)
)
@read_from_shared_memory.register(Text)

View File

@@ -22,7 +22,7 @@ from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
def test_shared_memory_create_read_write(space, num, ctx):
"""Test the shared memory functions, create, read and write for all of the testing spaces."""
"""Test the shared memory functions, create, read and write for all testing spaces."""
if ctx not in mp.get_all_start_methods():
pytest.skip(
f"Multiprocessing start method {ctx} not available on this platform."
@@ -41,7 +41,7 @@ def test_shared_memory_create_read_write(space, num, ctx):
read_samples = read_from_shared_memory(space, shared_memory, n=num)
for read_sample, sample in zip(read_samples, samples):
data_equivalence(read_sample, sample)
assert data_equivalence(read_sample, sample)
def test_custom_space():