mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
Fix reading shared memory for Tuple
and Dict
spaces (#941)
This commit is contained in:
@@ -148,20 +148,25 @@ def _read_base_from_shared_memory(
|
||||
|
||||
@read_from_shared_memory.register(Tuple)
|
||||
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
|
||||
return tuple(
|
||||
subspace_samples = tuple(
|
||||
read_from_shared_memory(subspace, memory, n=n)
|
||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||
)
|
||||
return tuple(zip(*subspace_samples))
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Dict)
|
||||
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
|
||||
return OrderedDict(
|
||||
subspace_samples = OrderedDict(
|
||||
[
|
||||
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
return tuple(
|
||||
OrderedDict({key: subspace_samples[key][i] for key in space.keys()})
|
||||
for i in range(n)
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Text)
|
||||
|
@@ -22,7 +22,7 @@ from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
||||
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
|
||||
)
|
||||
def test_shared_memory_create_read_write(space, num, ctx):
|
||||
"""Test the shared memory functions, create, read and write for all of the testing spaces."""
|
||||
"""Test the shared memory functions, create, read and write for all testing spaces."""
|
||||
if ctx not in mp.get_all_start_methods():
|
||||
pytest.skip(
|
||||
f"Multiprocessing start method {ctx} not available on this platform."
|
||||
@@ -41,7 +41,7 @@ def test_shared_memory_create_read_write(space, num, ctx):
|
||||
|
||||
read_samples = read_from_shared_memory(space, shared_memory, n=num)
|
||||
for read_sample, sample in zip(read_samples, samples):
|
||||
data_equivalence(read_sample, sample)
|
||||
assert data_equivalence(read_sample, sample)
|
||||
|
||||
|
||||
def test_custom_space():
|
||||
|
Reference in New Issue
Block a user