Fix reading shared memory for Tuple and Dict spaces (#941)

This commit is contained in:
Mark Towers
2024-02-27 16:27:34 +00:00
committed by GitHub
parent b3f0361f91
commit 36598b939b
2 changed files with 9 additions and 4 deletions

View File

@@ -148,20 +148,25 @@ def _read_base_from_shared_memory(
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
return tuple(
subspace_samples = tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
return tuple(zip(*subspace_samples))
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
return OrderedDict(
subspace_samples = OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
for (key, subspace) in space.spaces.items()
]
)
return tuple(
OrderedDict({key: subspace_samples[key][i] for key in space.keys()})
for i in range(n)
)
@read_from_shared_memory.register(Text)